-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathutils_e2e_clean.py
97 lines (79 loc) · 2.85 KB
/
utils_e2e_clean.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
Utilities.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import random
import texar as tx
from tensorflow.contrib.seq2seq import tile_batch
from data2text.data_utils import get_train_ents, extract_entities, extract_numbers
# load all entities
#all_ents, players, teams, cities = get_train_ents(path=os.path.join("data2text", "rotowire"), connect_multiwords=True)
e2e_ents = set()
with open('e2e_data/x_value.vocab.txt', 'r') as f:
all_vocb = f.readlines()
for vocab in all_vocb:
e2e_ents.add(vocab.strip('\n'))
get_scope_name_of_train_op = 'train_{}'.format
get_scope_name_of_summary_op = 'summary_{}'.format
ref_strs = ['', '_ref']
sent_fields = ['y_aux', 'y_ref']
x_fields = ['value', 'type', 'associated']
x_strs = ['x', 'x_ref']
y_strs = ['y_aux', 'y_ref']
y_tgt_strs = ['y_ref']
class DataItem(collections.namedtuple('DataItem', x_fields)):
def __str__(self):
return '|'.join(map(str, self))
def pack_sd(paired_texts):
return [DataItem(*_) for _ in zip(*paired_texts)]
def batchize(func):
def batchized_func(*inputs):
return [func(*paired_inputs) for paired_inputs in zip(*inputs)]
return batchized_func
def strip_special_tokens_of_list(text):
return tx.utils.strip_special_tokens(text, is_token_list=True)
batch_strip_special_tokens_of_list = batchize(strip_special_tokens_of_list)
def replace_data_in_sent(sent, token="<UNK>"):
data_type = 'e2e'
if(data_type == 'e2e'):
datas = extract_entities(sent, e2e_ents)
datas.sort(key=lambda data: data.start, reverse=True)
for data in datas:
sent[data.start] = token
return sent
def corpus_bleu(list_of_references, hypotheses, **kwargs):
list_of_references = [
list(map(replace_data_in_sent, refs))
for refs in list_of_references]
hypotheses = list(map(replace_data_in_sent, hypotheses))
return tx.evals.corpus_bleu_moses(
list_of_references, hypotheses,
lowercase=True, return_all=False,
**kwargs)
def read_sents_from_file(file_name):
with open(file_name, 'r') as f:
return list(map(str.split, f))
def read_x(data_prefix, ref_flag, stage):
ref_str = ref_strs[ref_flag]
return list(map(
lambda paired_sents: list(map(
lambda tup: DataItem(*tup),
zip(*paired_sents))),
zip(*map(
lambda field: read_sents_from_file(
'{}{}{}.{}.txt'.format(data_prefix, field, ref_str, stage)),
sd_fields))))
def read_y(data_prefix, ref_flag, stage):
ref_str = ref_strs[ref_flag]
field = sent_fields[0]
return read_sents_from_file(
'{}{}{}.{}.txt'.format(data_prefix, field, ref_str, stage))
def divide_or_const(a, b, c=0.):
try:
return a / b
except ZeroDivisionError:
return c