Skip to content

Commit e55d712

Browse files
authored
Merge pull request #4 from midas-research/ad_dev
KP extraction
2 parents 1c7df5a + 7eda301 commit e55d712

20 files changed

+6490
-146
lines changed

dlkp/datasets/pre_process.py

Whitespace-only changes.

dlkp/kp_metrics/__init__.py

Whitespace-only changes.

dlkp/kp_metrics/metrics.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
2+
from seqeval.scheme import IOB2, IOB1
3+
import numpy as np
4+
5+
6+
def compute_metrics(p):
7+
return_entity_level_metrics = False
8+
ignore_value = -100
9+
predictions, labels = p
10+
label_to_id = {"B": 0, "I": 1, "O": 2}
11+
id_to_label = ["B", "I", "O"]
12+
# if model_args.use_CRF is False:
13+
predictions = np.argmax(predictions, axis=2)
14+
# print(predictions.shape, labels.shape)
15+
16+
# Remove ignored index (special tokens)
17+
true_predictions = [
18+
[id_to_label[p] for (p, l) in zip(prediction, label) if l != ignore_value]
19+
for prediction, label in zip(predictions, labels)
20+
]
21+
true_labels = [
22+
[id_to_label[l] for (p, l) in zip(prediction, label) if l != ignore_value]
23+
for prediction, label in zip(predictions, labels)
24+
]
25+
26+
# results = metric.compute(predictions=true_predictions, references=true_labels)
27+
results = {}
28+
# print("cal precisi")
29+
# mode="strict"
30+
results["overall_precision"] = precision_score(
31+
true_labels, true_predictions, scheme=IOB2
32+
)
33+
results["overall_recall"] = recall_score(true_labels, true_predictions, scheme=IOB2)
34+
# print("cal f1")
35+
results["overall_f1"] = f1_score(true_labels, true_predictions, scheme=IOB2)
36+
results["overall_accuracy"] = accuracy_score(true_labels, true_predictions)
37+
if return_entity_level_metrics:
38+
# Unpack nested dictionaries
39+
final_results = {}
40+
# print("cal entity level mat")
41+
for key, value in results.items():
42+
if isinstance(value, dict):
43+
for n, v in value.items():
44+
final_results[f"{key}_{n}"] = v
45+
else:
46+
final_results[key] = value
47+
return final_results
48+
else:
49+
return {
50+
"precision": results["overall_precision"],
51+
"recall": results["overall_recall"],
52+
"f1": results["overall_f1"],
53+
"accuracy": results["overall_accuracy"],
54+
}

dlkp/models/ke/crf/__init__.py

Whitespace-only changes.

dlkp/models/ke/crf/crf.py

+291
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
# add models having crf classification layer with option of bilstm layers
2+
3+
from .crf_utils import *
4+
from typing import List, Tuple, Dict, Union
5+
6+
import torch
7+
8+
VITERBI_DECODING = Tuple[List[int], float]
9+
10+
11+
class ConditionalRandomField(torch.nn.Module):
12+
"""
13+
This module uses the "forward-backward" algorithm to compute
14+
the log-likelihood of its inputs assuming a conditional random field model.
15+
See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf
16+
# Parameters
17+
num_tags : `int`, required
18+
The number of tags.
19+
constraints : `List[Tuple[int, int]]`, optional (default = `None`)
20+
An optional list of allowed transitions (from_tag_id, to_tag_id).
21+
These are applied to `viterbi_tags()` but do not affect `forward()`.
22+
These should be derived from `allowed_transitions` so that the
23+
start and end transitions are handled correctly for your tag type.
24+
include_start_end_transitions : `bool`, optional (default = `True`)
25+
Whether to include the start and end transition parameters.
26+
"""
27+
28+
def __init__(
29+
self,
30+
num_tags: int,
31+
label_encoding,
32+
idx2tag,
33+
include_start_end_transitions: bool = True,
34+
) -> None:
35+
super().__init__()
36+
self.num_tags = num_tags
37+
constraints = allowed_transitions(label_encoding, idx2tag)
38+
# transitions[i, j] is the logit for transitioning from state i to state j.
39+
self.transitions = torch.nn.Parameter(torch.Tensor(num_tags, num_tags))
40+
41+
# _constraint_mask indicates valid transitions (based on supplied constraints).
42+
# Include special start of sequence (num_tags + 1) and end of sequence tags (num_tags + 2)
43+
if constraints is None:
44+
# All transitions are valid.
45+
constraint_mask = torch.Tensor(num_tags + 2, num_tags + 2).fill_(1.0)
46+
else:
47+
constraint_mask = torch.Tensor(num_tags + 2, num_tags + 2).fill_(0.0)
48+
for i, j in constraints:
49+
constraint_mask[i, j] = 1.0
50+
51+
self._constraint_mask = torch.nn.Parameter(constraint_mask, requires_grad=False)
52+
53+
# Also need logits for transitioning from "start" state and to "end" state.
54+
self.include_start_end_transitions = include_start_end_transitions
55+
if include_start_end_transitions:
56+
self.start_transitions = torch.nn.Parameter(torch.Tensor(num_tags))
57+
self.end_transitions = torch.nn.Parameter(torch.Tensor(num_tags))
58+
59+
self.reset_parameters()
60+
61+
def reset_parameters(self):
62+
torch.nn.init.xavier_normal_(self.transitions)
63+
if self.include_start_end_transitions:
64+
torch.nn.init.normal_(self.start_transitions)
65+
torch.nn.init.normal_(self.end_transitions)
66+
67+
def _input_likelihood(
68+
self, logits: torch.Tensor, mask: torch.BoolTensor
69+
) -> torch.Tensor:
70+
"""
71+
Computes the (batch_size,) denominator term for the log-likelihood, which is the
72+
sum of the likelihoods across all possible state sequences.
73+
"""
74+
batch_size, sequence_length, num_tags = logits.size()
75+
76+
# Transpose batch size and sequence dimensions
77+
mask = mask.transpose(0, 1).contiguous()
78+
logits = logits.transpose(0, 1).contiguous()
79+
80+
# Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the
81+
# transitions to the initial states and the logits for the first timestep.
82+
if self.include_start_end_transitions:
83+
alpha = self.start_transitions.view(1, num_tags) + logits[0]
84+
else:
85+
alpha = logits[0]
86+
87+
# For each i we compute logits for the transitions from timestep i-1 to timestep i.
88+
# We do so in a (batch_size, num_tags, num_tags) tensor where the axes are
89+
# (instance, current_tag, next_tag)
90+
for i in range(1, sequence_length):
91+
# The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis.
92+
emit_scores = logits[i].view(batch_size, 1, num_tags)
93+
# Transition scores are (current_tag, next_tag) so we broadcast along the instance axis.
94+
transition_scores = self.transitions.view(1, num_tags, num_tags)
95+
# Alpha is for the current_tag, so we broadcast along the next_tag axis.
96+
broadcast_alpha = alpha.view(batch_size, num_tags, 1)
97+
98+
# Add all the scores together and logexp over the current_tag axis.
99+
inner = broadcast_alpha + emit_scores + transition_scores
100+
101+
# In valid positions (mask == True) we want to take the logsumexp over the current_tag dimension
102+
# of `inner`. Otherwise (mask == False) we want to retain the previous alpha.
103+
alpha = logsumexp(inner, 1) * mask[i].view(batch_size, 1) + alpha * (
104+
~mask[i]
105+
).view(batch_size, 1)
106+
107+
# Every sequence needs to end with a transition to the stop_tag.
108+
if self.include_start_end_transitions:
109+
stops = alpha + self.end_transitions.view(1, num_tags)
110+
else:
111+
stops = alpha
112+
113+
# Finally we log_sum_exp along the num_tags dim, result is (batch_size,)
114+
return logsumexp(stops)
115+
116+
def _joint_likelihood(
117+
self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor
118+
) -> torch.Tensor:
119+
"""
120+
Computes the numerator term for the log-likelihood, which is just score(inputs, tags)
121+
"""
122+
batch_size, sequence_length, _ = logits.data.shape
123+
124+
# Transpose batch size and sequence dimensions:
125+
logits = logits.transpose(0, 1).contiguous()
126+
mask = mask.transpose(0, 1).contiguous()
127+
tags = tags.transpose(0, 1).contiguous()
128+
129+
# Start with the transition scores from start_tag to the first tag in each input
130+
if self.include_start_end_transitions:
131+
score = self.start_transitions.index_select(0, tags[0])
132+
else:
133+
score = 0.0
134+
135+
# Add up the scores for the observed transitions and all the inputs but the last
136+
# print(mask.shape, tags.shape, logits.shape, sequence_length)
137+
for i in range(sequence_length - 1):
138+
# Each is shape (batch_size,)
139+
current_tag, next_tag = tags[i], tags[i + 1]
140+
# print(current_tag, next_tag)
141+
# print("tags printiiinggggg")
142+
# print(current_tag, next_tag)
143+
# The scores for transitioning from current_tag to next_tag
144+
transition_score = self.transitions[current_tag.view(-1), next_tag.view(-1)]
145+
146+
# The score for using current_tag
147+
emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1)
148+
# emit_score= 0
149+
# Include transition score if next element is unmasked,
150+
# input_score if this element is unmasked.
151+
score = score + transition_score * mask[i + 1] + emit_score * mask[i]
152+
153+
# Transition from last state to "stop" state. To start with, we need to find the last tag
154+
# for each instance.
155+
last_tag_index = mask.sum(0).long() - 1
156+
last_tags = tags.gather(0, last_tag_index.view(1, batch_size)).squeeze(0)
157+
158+
# Compute score of transitioning to `stop_tag` from each "last tag".
159+
if self.include_start_end_transitions:
160+
last_transition_score = self.end_transitions.index_select(0, last_tags)
161+
else:
162+
last_transition_score = 0.0
163+
164+
# Add the last input if it's not masked.
165+
last_inputs = logits[-1] # (batch_size, num_tags)
166+
last_input_score = last_inputs.gather(
167+
1, last_tags.view(-1, 1)
168+
) # (batch_size, 1)
169+
last_input_score = last_input_score.squeeze() # (batch_size,)
170+
171+
score = score + last_transition_score + last_input_score * mask[-1]
172+
173+
return score
174+
175+
def forward(
176+
self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None
177+
) -> torch.Tensor:
178+
"""
179+
Computes the log likelihood.
180+
"""
181+
# mask[tags==-100]=0
182+
if mask is None:
183+
mask = torch.ones(*tags.size(), dtype=torch.bool)
184+
else:
185+
# The code below fails in weird ways if this isn't a bool tensor, so we make sure.
186+
mask = mask.to(torch.bool)
187+
# print("forward",inputs.shape, tags.shape, mask.shape)
188+
189+
log_denominator = self._input_likelihood(inputs, mask)
190+
# temp_tags= tags
191+
# tags[tags==-100]=2
192+
# print(tags[0])
193+
log_numerator = self._joint_likelihood(inputs, tags, mask)
194+
# tags[mask==0]=-100
195+
return torch.sum(log_numerator - log_denominator)
196+
197+
def viterbi_tags(
198+
self, logits: torch.Tensor, mask: torch.BoolTensor = None, top_k: int = None
199+
) -> Union[List[VITERBI_DECODING], List[List[VITERBI_DECODING]]]:
200+
"""
201+
Uses viterbi algorithm to find most likely tags for the given inputs.
202+
If constraints are applied, disallows all other transitions.
203+
Returns a list of results, of the same size as the batch (one result per batch member)
204+
Each result is a List of length top_k, containing the top K viterbi decodings
205+
Each decoding is a tuple (tag_sequence, viterbi_score)
206+
For backwards compatibility, if top_k is None, then instead returns a flat list of
207+
tag sequences (the top tag sequence for each batch item).
208+
"""
209+
if mask is None:
210+
mask = torch.ones(*logits.shape[:2], dtype=torch.bool, device=logits.device)
211+
212+
if top_k is None:
213+
top_k = 1
214+
flatten_output = True
215+
else:
216+
flatten_output = False
217+
218+
_, max_seq_length, num_tags = logits.size()
219+
220+
# Get the tensors out of the variables
221+
logits, mask = logits.data, mask.data
222+
223+
# Augment transitions matrix with start and end transitions
224+
start_tag = num_tags
225+
end_tag = num_tags + 1
226+
transitions = torch.Tensor(num_tags + 2, num_tags + 2).fill_(-10000.0)
227+
228+
# Apply transition constraints
229+
constrained_transitions = self.transitions * self._constraint_mask[
230+
:num_tags, :num_tags
231+
] + -10000.0 * (1 - self._constraint_mask[:num_tags, :num_tags])
232+
transitions[:num_tags, :num_tags] = constrained_transitions.data
233+
234+
if self.include_start_end_transitions:
235+
transitions[
236+
start_tag, :num_tags
237+
] = self.start_transitions.detach() * self._constraint_mask[
238+
start_tag, :num_tags
239+
].data + -10000.0 * (
240+
1 - self._constraint_mask[start_tag, :num_tags].detach()
241+
)
242+
transitions[
243+
:num_tags, end_tag
244+
] = self.end_transitions.detach() * self._constraint_mask[
245+
:num_tags, end_tag
246+
].data + -10000.0 * (
247+
1 - self._constraint_mask[:num_tags, end_tag].detach()
248+
)
249+
else:
250+
transitions[start_tag, :num_tags] = -10000.0 * (
251+
1 - self._constraint_mask[start_tag, :num_tags].detach()
252+
)
253+
transitions[:num_tags, end_tag] = -10000.0 * (
254+
1 - self._constraint_mask[:num_tags, end_tag].detach()
255+
)
256+
257+
best_paths = []
258+
# Pad the max sequence length by 2 to account for start_tag + end_tag.
259+
tag_sequence = torch.Tensor(max_seq_length + 2, num_tags + 2)
260+
261+
for prediction, prediction_mask in zip(logits, mask):
262+
mask_indices = prediction_mask.nonzero(as_tuple=False).squeeze()
263+
masked_prediction = torch.index_select(prediction, 0, mask_indices)
264+
sequence_length = masked_prediction.shape[0]
265+
266+
# Start with everything totally unlikely
267+
tag_sequence.fill_(-10000.0)
268+
# At timestep 0 we must have the START_TAG
269+
tag_sequence[0, start_tag] = 0.0
270+
# At steps 1, ..., sequence_length we just use the incoming prediction
271+
tag_sequence[1 : (sequence_length + 1), :num_tags] = masked_prediction
272+
# And at the last timestep we must have the END_TAG
273+
tag_sequence[sequence_length + 1, end_tag] = 0.0
274+
275+
# We pass the tags and the transitions to `viterbi_decode`.
276+
viterbi_paths, viterbi_scores = viterbi_decode(
277+
tag_sequence=tag_sequence[: (sequence_length + 2)],
278+
transition_matrix=transitions,
279+
top_k=top_k,
280+
)
281+
top_k_paths = []
282+
for viterbi_path, viterbi_score in zip(viterbi_paths, viterbi_scores):
283+
# Get rid of START and END sentinels and append.
284+
viterbi_path = viterbi_path[1:-1]
285+
top_k_paths.append((viterbi_path, viterbi_score.item()))
286+
best_paths.append(top_k_paths)
287+
288+
if flatten_output:
289+
return [top_k_paths[0] for top_k_paths in best_paths]
290+
291+
return best_paths

0 commit comments

Comments
 (0)