-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtest-lstm_v2.py
executable file
·135 lines (97 loc) · 4.92 KB
/
test-lstm_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import numpy as np
import tensorflow as tf
from collections import defaultdict
import argparse
import pickle
from datetime import datetime
from itertools import islice
import tensor_utils as utils
parser = argparse.ArgumentParser(description='Trains meaning embeddings based on precomputed LSTM model')
parser.add_argument('-m', dest='model_path', required=True, help='path to model trained LSTM model')
# model_path = '/var/scratch/mcpostma/wsd-dynamic-sense-vector/output/lstm-wsd-small'
parser.add_argument('-v', dest='vocab_path', required=True, help='path to LSTM vocabulary')
#vocab_path = '/var/scratch/mcpostma/wsd-dynamic-sense-vector/output/gigaword.1m-sents-lstm-wsd.index.pkl'
parser.add_argument('-i', dest='input_path', required=True, help='input path with sense annotated sentences')
parser.add_argument('-o',dest='output_path', required=True, help='path where sense embeddings will be stored')
parser.add_argument('-b', dest='batch_size', required=True, help='batch size')
parser.add_argument('-t', dest='max_lines', required=True, help='maximum number of lines you want to train on')
parser.add_argument('-s', dest='setting', required=True, help='sensekey | synset | hdn')
args = parser.parse_args()
print('loaded arguments for training meaning embeddings')
def ctx_embd_input(sentence):
"""
given a annotated sentence, return
each the sentence with only one annotation
:param str sentence: a sentence with annotations
(lemma---annotation)
:rtype: generator
:return: generator of input for the lstm (synset_id, sentence)
"""
sent_split = sentence.split()
annotation_indices = []
tokens = []
for index, token in enumerate(sent_split):
token, *annotation = token.split('---')
tokens.append(token)
if annotation:
annotation_indices.append((index, annotation[0]))
return tokens, annotation_indices
vocab = np.load(args.vocab_path)
print('loaded vocab')
synset2context_embds = defaultdict(list)
synset2instances = dict()
meaning_freqs = defaultdict(int)
batch_size = int(args.batch_size)
counter = 0
with tf.Session() as sess: # your session object
saver = tf.train.import_meta_graph(args.model_path + '.meta', clear_devices=True)
saver.restore(sess, args.model_path)
x, predicted_context_embs, lens = utils.load_tensors(sess)
#x = sess.graph.get_tensor_by_name('Model_1/x:0')
#predicted_context_embs = sess.graph.get_tensor_by_name('Model_1/predicted_context_embs:0')
#lens = sess.graph.get_tensor_by_name('Model_1/lens:0')
with open(args.input_path) as infile:
for n_lines in iter(lambda: tuple(islice(infile, batch_size)), ()):
counter += len(n_lines)
if counter >= int(args.max_lines):
break
print(counter, datetime.now())
identifiers = [] # list of sy_ids
annotated_sentences = []
sentence_lens = [] # list of ints
for line in n_lines:
sentence = line.strip()
tokens, annotation_indices = ctx_embd_input(sentence)
for index, synset_id in annotation_indices:
if args.setting == 'hdn':
base_synset, synset_id = synset_id.split('_')
sentence_as_ids = [vocab.get(w) or vocab['<unkn>'] for w in tokens]
target_id = vocab['<target>']
sentence_as_ids[index] = target_id
meaning_freqs[synset_id] += 1
# update batch information
identifiers.append(synset_id)
annotated_sentences.append(sentence_as_ids)
sentence_lens.append(len(sentence_as_ids))
# compute embeddings for batch
max_length = max([len(_list) for _list in annotated_sentences])
for _list in annotated_sentences:
length_diff = max_length - len(_list)
[_list.append(vocab['<unkn>']) for _ in range(length_diff)]
target_embeddings = sess.run(predicted_context_embs, {x: annotated_sentences,
lens: sentence_lens})
for synset_id, target_embedding in zip(identifiers, target_embeddings):
synset2context_embds[synset_id].append(target_embedding)
synset2avg_embedding = dict()
for synset, embeddings in synset2context_embds.items():
average = sum(embeddings) / len(embeddings)
std = np.std(embeddings)
if len(embeddings) == 1:
assert all(average == embeddings[0])
synset2avg_embedding[synset] = average, std
with open(args.output_path, 'wb') as outfile:
pickle.dump(synset2avg_embedding, outfile)
with open(args.output_path + '.instances', 'wb') as outfile:
pickle.dump(synset2context_embds, outfile)
with open(args.output_path + '.freq', 'wb') as outfile:
pickle.dump(meaning_freqs, outfile)