-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathget_xx.py
110 lines (85 loc) · 3.08 KB
/
get_xx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=invalid-name, no-member, too-many-locals
import importlib
import os
import numpy as np
import tensorflow as tf
import texar as tx
import pickle
from utils import *
from munkres import *
inf = int(1e9)
def calc_cost(a, b):
if a.attribute != b.attribute:
return inf
if a.entry.isdigit():
if b.entry.isdigit():
return abs(int(a.entry) - int(b.entry))
else:
return inf
else:
if b.entry.isdigit():
return inf
else:
return 0 if a.entry == b.entry else 1
def get_match(text00, text01, text02, text10, text11, text12):
text00, text01, text02, text10, text11, text12 = map(
strip_special_tokens_of_list,
(text00, text01, text02, text10, text11, text12))
texts = [DataItem(text00, text01, text02),
DataItem(text10, text11, text12)]
xs = list(map(pack_sd, texts))
cost = [[calc_cost(x_i, x_j) for x_j in xs[1]] for x_i in xs[0]]
matches = []
for idx, cost_i in enumerate(cost):
if min(cost_i) == inf:
match = (idx, -1)
else:
match = (idx, cost_i.index(min(cost_i)))
matches.append(match)
return matches
# return Munkres().compute(cost)
batch_get_match = batchize(get_match)
def main():
# data batch
datasets = {mode: tx.data.MultiAlignedData(hparams)
for mode, hparams in config_data.datas.items()}
data_iterator = tx.data.FeedableDataIterator(datasets)
data_batch = data_iterator.get_next()
def _get_match(sess, mode):
print('in _get_match')
data_iterator.restart_dataset(sess, mode)
feed_dict = {
tx.global_mode(): tf.estimator.ModeKeys.EVAL,
data_iterator.handle: data_iterator.get_handle(sess, mode),
}
with open('match.{}.pkl'.format(mode), 'wb') as out_file:
while True:
try:
batch = sess.run(data_batch, feed_dict)
texts = [[batch['{}{}_text'.format(field, ref_str)]
for field in sd_fields]
for ref_str in ref_strs]
matches = batch_get_match(*(texts[0] + texts[1]))
for match in matches:
pickle.dump(match, out_file)
except tf.errors.OutOfRangeError:
break
print('end _get_match')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
_get_match(sess, 'test')
_get_match(sess, 'valid')
_get_match(sess, 'train')
if __name__ == '__main__':
flags = tf.flags
flags.DEFINE_string("config_data", "config_data_nba_stable",
"The data config.")
flags.DEFINE_boolean("verbose", False, "verbose.")
FLAGS = flags.FLAGS
config_data = importlib.import_module(FLAGS.config_data)
main()