-
Notifications
You must be signed in to change notification settings - Fork 19
/
util.py
59 lines (48 loc) · 1.82 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""util.py: polishing trigger set to use.
1. same trigger in different entity type -> delete
2. multiple same triggers -> merge it using mean pooling (temporary)
Written in 2020 by Dong-Ho Lee.
"""
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
import torch
from collections import Counter
from pytorch_pretrained_bert import BertTokenizer
from pytorch_pretrained_bert import BertModel
import os.path
import pickle
from tqdm import tqdm
def remove_duplicates(features, labels, triggers, dataset):
feature_dict = dict()
for feature, label, trigger in zip(features, labels, triggers):
if trigger not in feature_dict:
feature_dict[trigger] = []
feature_dict[trigger].append((feature, label))
else:
feature_dict[trigger].append((feature, label))
for key, value in feature_dict.items():
embedding = [f[0] for f in feature_dict[key]]
embedding = torch.mean(torch.stack(embedding), dim=0)
for data in dataset:
if key == data.trigger_key:
data.trigger_vec = embedding
duplicate_key = []
for key, value in feature_dict.items():
labels = [f[1] for f in feature_dict[key]]
labels = set(labels)
if len(labels) > 1:
duplicate_key.append(key)
for key in duplicate_key:
del feature_dict[key]
for key, value in feature_dict.items():
embedding = [f[0] for f in feature_dict[key]]
label = feature_dict[key][0][1]
embedding = torch.mean(torch.stack(embedding), dim=0)
feature_dict[key] = [embedding, label]
trigger_key = []
final_trigger = []
for key, value in feature_dict.items():
final_trigger.append(feature_dict[key][0])
trigger_key.append(key)
return torch.stack(final_trigger), trigger_key