-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaddlink-query_links.py
More file actions
118 lines (101 loc) · 4.47 KB
/
addlink-query_links.py
File metadata and controls
118 lines (101 loc) · 4.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from sqlitedict import SqliteDict
import argparse
import xgboost as xgb
import sys,os
import json
from scripts.utils import normalise_title
from scripts.utils import getPageDict,process_page
import multiprocessing
## logging via json
#https://github.com/bobbui/json-logging-python
import json_logging, logging, sys
LOG_LEVEL = logging.DEBUG
# log is initialized without a web framework name
json_logging.init_non_web(enable_json=True)
logger = logging.getLogger("logger")
logger.setLevel(LOG_LEVEL)
logger.addHandler(logging.StreamHandler(sys.stdout))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--page","-p",
default=None,
type = str,
required=True,
help="page-title to get recommendations for")
parser.add_argument("--lang","-l",
default=None,
type = str,
required=True,
help="language (wiki) for which to get recommendations (e.g. enwiki or en)")
parser.add_argument("--threshold","-t",
default=0.5,
type = float,
help="threshold value for links to be recommended")
parser.add_argument("--source","-s",
default="/home/mgerlach/REPOS/mwaddlink/",
type = str,
help="Location of the trained model")
args = parser.parse_args()
lang = args.lang.replace('wiki','')
page_title = normalise_title(args.page)
threshold = args.threshold
PATH_mwaddlink = args.source
logger.info('Getting link recommendations for article %s in %swiki with link-threshold %s'%(page_title, lang,threshold))
## open the trained model
logger.info('Loading the trained model')
try:
anchors = SqliteDict(os.path.join(PATH_mwaddlink,"data/{0}/{0}.anchors.sqlite".format(lang)) )
pageids = SqliteDict(os.path.join(PATH_mwaddlink,"data/{0}/{0}.pageids.sqlite".format(lang)))
redirects = SqliteDict(os.path.join(PATH_mwaddlink,"data/{0}/{0}.redirects.sqlite".format(lang)) )
word2vec = SqliteDict(os.path.join(PATH_mwaddlink,"data/{0}/{0}.w2v.filtered.sqlite".format(lang)) )
nav2vec = SqliteDict(os.path.join(PATH_mwaddlink,"data/{0}/{0}.nav.filtered.sqlite".format(lang)) )
## load trained model
n_cpus_max = min([int(multiprocessing.cpu_count()/4),8])
model = xgb.XGBClassifier(n_jobs =n_cpus_max ) # init model
model.load_model(os.path.join(PATH_mwaddlink,"data/{0}/{0}.linkmodel_v2.bin".format(lang))) # load data
except:
# logging
logger.error('Could not open trained model in %swiki. try another language.'%lang)
## querying the API to get the wikitext for the page
logger.info('Getting the wikitext of the article')
try:
page_dict = getPageDict(page_title,lang)
wikitext = page_dict['wikitext']
pageid = page_dict['pageid']
revid = page_dict['revid']
except:
wikitext = ""
logger.error("""Not able to retrieve article '%s' in %swiki. try another article."""%(page_title,lang))
## querying the API to get the wikitext for the page
logger.info('Processing wikitext to get link recommendations')
try:
added_links = process_page(wikitext, page_title, anchors, pageids, redirects, word2vec,nav2vec, model, threshold = threshold, return_wikitext = False)
except:
logger.error("""Not able to process article '%s' in %swiki. try another article."""%(page_title,lang))
## closing model
try:
anchors.close()
pageids.close()
redirects.close()
word2vec.close()
nav2vec.close()
except:
logger.warning('Could not close model in %swiki.'%lang)
## querying the API to get the wikitext for the page
logger.info('Number of links from recommendation model: %s'%len(added_links))
if len(added_links) == 0:
logger.info('Model did not yield any links to recommend. Try a lower link-threshold (e.g. -t 0.2)')
dict_return = {
'page_title':page_title,
'lang':lang,
'pageid':pageid,
'revid':revid,
'no_added_links':len(added_links),
'added_links':added_links,
}
json_out = json.dumps(dict_return, indent=4)
logger.info('Recommended links: %s',dict_return)
print('--- Recommended links ---')
print(json_out)
if __name__ == "__main__":
main()