Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relation extraction llama #522

Open
wants to merge 170 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
170 commits
Select commit Hold shift + click to select a range
b20e7c8
Added files.
vladd-bit Aug 24, 2021
eec6c59
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Aug 31, 2021
56220aa
More additions to rel extraction.
vladd-bit Sep 1, 2021
7ad88f5
Rel base.
vladd-bit Sep 3, 2021
233ce36
Update.
vladd-bit Sep 6, 2021
85a7015
Updates.
vladd-bit Sep 10, 2021
5003548
Dependency parsing.
vladd-bit Oct 1, 2021
541b47d
Updates.
vladd-bit Oct 13, 2021
c042b0d
Added pre-training steps.
vladd-bit Oct 15, 2021
87d0c0c
Added training & model utils.
vladd-bit Oct 18, 2021
4f42696
Cleanup & fixes.
vladd-bit Oct 19, 2021
018d811
Update.
vladd-bit Oct 21, 2021
f3d3f44
Evaluation updates for pretraining.
vladd-bit Oct 27, 2021
e5f354e
Removed duplicate relation storage.
vladd-bit Nov 9, 2021
c69de67
Merged master.
vladd-bit Nov 9, 2021
031d256
Moved RE model file location.
vladd-bit Nov 12, 2021
2259a6b
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Nov 16, 2021
1c469e9
Structure revisions.
vladd-bit Nov 22, 2021
423b4e1
Added custom config for RE.
vladd-bit Dec 13, 2021
8ae9abb
Implemented custom dataset loader for RE.
vladd-bit Dec 13, 2021
186416c
More changes.
vladd-bit Dec 13, 2021
451e33f
Small fix.
vladd-bit Dec 13, 2021
8b36413
Latest additions to RelCAT (pipe + predictions)
vladd-bit Jan 19, 2022
2fb8fc9
Setup.py fix.
vladd-bit Jan 19, 2022
930dd11
RE utils update.
vladd-bit Jan 19, 2022
24b2841
rel model update.
vladd-bit Jan 19, 2022
193ecb1
rel dataset + tokenizer improvements.
vladd-bit Jan 19, 2022
03111a7
RelCAT updates.
vladd-bit Jan 19, 2022
7ab60f4
RelCAT saving/loading improvements.
vladd-bit Jan 21, 2022
40875f3
RelCAT saving/loading improvements.
vladd-bit Jan 21, 2022
810d1dc
RelCAT model fixes.
vladd-bit Jan 21, 2022
11dcb32
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Jan 21, 2022
72187f6
Attempted gpu learning fix. Dataset label generation fixes.
vladd-bit Jan 24, 2022
5f67a4c
Minor train dataset gen fix.
vladd-bit Jan 24, 2022
cfc0e91
Minor train dataset gen fix No.2.
vladd-bit Jan 24, 2022
9f4b220
Config updates.
vladd-bit Jan 25, 2022
19afa81
Gpu support fixes. Added label stats.
vladd-bit Jan 25, 2022
8eb1665
Evaluation stat fixes.
vladd-bit Jan 26, 2022
6e86fa2
Cleaned stat output mode during training.
vladd-bit Jan 26, 2022
5cee8cf
Build fix.
vladd-bit Jan 26, 2022
223ac9a
removed unused dependencies and fixed code formatting
vladd-bit Jan 26, 2022
ea7d68c
Mypy compliance.
vladd-bit Jan 26, 2022
1ea9738
Fixed linting.
vladd-bit Jan 27, 2022
9f6609e
More Gpu mode train fixes.
vladd-bit Jan 28, 2022
1782c0b
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Jan 28, 2022
fb86869
Fixed model saving/loading issues when using other baes models.
vladd-bit Jan 31, 2022
df21543
More fixes to stat evaluation. Added proper CAT integration of RelCAT.
vladd-bit Feb 3, 2022
92a5e08
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Feb 3, 2022
87d1a9c
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Mar 11, 2022
ced1627
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Mar 14, 2022
7b69710
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Mar 28, 2022
37fd212
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Apr 4, 2022
f0eda2b
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Apr 8, 2022
10269b9
Setup.py typo fix.
vladd-bit Apr 8, 2022
b8a45b2
Merge branch 'relation_extraction' of https://github.com/CogStack/Med…
vladd-bit Apr 8, 2022
20203ac
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit May 10, 2022
f057139
RelCAT loading fix.
vladd-bit May 10, 2022
197a27a
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Jul 21, 2022
86fd509
RelCAT Config changes.
vladd-bit Aug 1, 2022
79dc069
Type fix. Minor additions to RelCAT model.
vladd-bit Aug 1, 2022
323c895
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Aug 1, 2022
f1c56bf
Type fixes.
vladd-bit Aug 1, 2022
a78ff86
Type corrections.
vladd-bit Aug 2, 2022
f09ceb2
RelCAT update.
vladd-bit Mar 21, 2023
32574f2
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Mar 21, 2023
c081c3e
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit May 22, 2023
e2e48b5
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Dec 11, 2023
4ce5ba3
Type fixes.
vladd-bit Dec 12, 2023
21c09ff
Merge branch 'relation_extraction' of https://github.com/CogStack/Med…
vladd-bit Dec 13, 2023
8123689
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Dec 13, 2023
57ab0c5
Fixed type issue.
vladd-bit Dec 13, 2023
9da5aa6
RelCATConfig: added seed param.
vladd-bit Dec 13, 2023
009e832
Adaptations to the new codebase + type fixes..
vladd-bit Dec 15, 2023
1a7d130
Doc/type fixes.
vladd-bit Dec 19, 2023
53dba6a
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Dec 20, 2023
92613ed
Fixed input size issue for model.
vladd-bit Jan 8, 2024
a49a44a
Fixed issue(s) with model size and config.
vladd-bit Jan 16, 2024
6456e6e
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Jan 16, 2024
5aac9ab
RelCAT: updated configs to new style.
vladd-bit Jan 19, 2024
9c50b30
RelCAT: removed old refs to logging.
vladd-bit Jan 19, 2024
b071607
Merge branches 'relation_extraction' and 'master' of https://github.c…
vladd-bit Jan 29, 2024
89d9128
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Feb 7, 2024
e6e99cb
Fixed GPU training + added extra stat print for train set.
vladd-bit Feb 7, 2024
307d194
Type fixes.
vladd-bit Feb 7, 2024
fb7efe3
Updated dev requirements.
vladd-bit Feb 7, 2024
c235daf
Linting.
vladd-bit Feb 7, 2024
fcdf2e3
Merge branches 'relation_extraction' and 'master' of https://github.c…
vladd-bit Feb 9, 2024
aad0a73
Fixed pin_memory issue when training on CPU.
vladd-bit Feb 9, 2024
8a9026b
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Mar 8, 2024
f94e349
Updated RelCAT dataset get + default config.
vladd-bit Mar 21, 2024
0770356
Updated RelDS generator + default config
vladd-bit Mar 25, 2024
bdf20f5
Linting.
vladd-bit Mar 25, 2024
f7b5aaf
Updated RelDatset + config.
vladd-bit Apr 3, 2024
3e827cf
Merge branch 'relation_extraction' of https://github.com/CogStack/Med…
vladd-bit Apr 3, 2024
aaf6533
Pushing updates to model
shubham-s-agarwal Apr 8, 2024
18f9bb8
Fixing formatting
shubham-s-agarwal Apr 8, 2024
503513c
Update rel_dataset.py
shubham-s-agarwal Apr 8, 2024
040821b
Update rel_dataset.py
shubham-s-agarwal Apr 8, 2024
ed7c8d5
Update rel_dataset.py
shubham-s-agarwal Apr 8, 2024
8d0bfe4
RelCAT: added test resource files.
vladd-bit Apr 9, 2024
3f3a780
RelCAT: Fixed model load/checkpointing.
vladd-bit Apr 10, 2024
3f56824
RelCAT: updated to pipe spacy doc call.
vladd-bit Apr 12, 2024
b7a4987
RelCAT: added tests.
vladd-bit Apr 12, 2024
77d27b0
Merge branch 'relation_extraction' of https://github.com/CogStack/Med…
vladd-bit Apr 12, 2024
a9258a2
Fixed lint/type issues & added rel tag to test DS.
vladd-bit Apr 15, 2024
0ed70fb
Fixed ann id to token issue.
vladd-bit Apr 15, 2024
8db2e76
RelCAT: updated test dataset + tests.
vladd-bit Apr 18, 2024
6eea6b7
RelCAT: updates to requested changes + dataset improvements.
vladd-bit Apr 18, 2024
6972310
RelCAT: updated docs/logs according to commends.
vladd-bit Apr 18, 2024
d03316c
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Apr 18, 2024
8cb12a4
RelCAT: type fix.
vladd-bit Apr 18, 2024
d10318a
RelCAT: mct export dataset updates.
vladd-bit Apr 19, 2024
12acaeb
RelCAT: test updates + requested changes p2.
vladd-bit Apr 19, 2024
4c14a3a
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Apr 19, 2024
382cefc
RelCAT: log for MCT export train.
vladd-bit Apr 19, 2024
35b0913
Updated docs + split train_test & dataset for benchmarks.
vladd-bit Apr 26, 2024
d48bc41
type fixes.
vladd-bit Apr 26, 2024
3068516
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Apr 26, 2024
61f319c
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit May 10, 2024
8c26b1a
RelCAT: Initial Llama integration.
vladd-bit Jun 23, 2024
553f94e
RelCAT: updates to Llama impl.
vladd-bit Jun 24, 2024
4fc0eae
RelCAT: model typo fix.
vladd-bit Jun 25, 2024
8b3052d
RelCAT: label_id /sample no. mixup fix.
vladd-bit Jun 28, 2024
ede9289
Updated cleaned up Relataset, added new ways to create relations via …
vladd-bit Jul 16, 2024
8119973
Added option to predict any text /w annotations via RelCAT. MCT expor…
vladd-bit Jul 26, 2024
b593c76
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Jul 26, 2024
7a56082
RelCAT: added sample limiter / class, more logging info.
vladd-bit Jul 29, 2024
387f70a
RelCAT: test/train ds shuffle update.
vladd-bit Jul 30, 2024
65a789e
RelCAT: added option to keep original text when using reldataset class.
vladd-bit Aug 1, 2024
bce266b
Pushing change for stratified batching
shubham-s-agarwal Aug 1, 2024
d9efe18
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Aug 4, 2024
7942a80
Merge branch 'relation_extraction_llama' of https://github.com/CogSta…
vladd-bit Aug 4, 2024
a03c21d
RelCAT: fixed doc processing issue + class weights.
vladd-bit Aug 4, 2024
d3c6a42
Merge branch 'relation_extraction_llama' of https://github.com/CogSta…
vladd-bit Aug 4, 2024
8ad07d4
Merge branch 'relation_extraction_llama' of https://github.com/CogSta…
vladd-bit Aug 4, 2024
f4bc853
RelCAT: class weights addtions to cfg + param.
vladd-bit Aug 7, 2024
fdd31a8
RelCAT: added config params for Adam optimizer.
vladd-bit Aug 15, 2024
1246995
RelCAT updated default config.
vladd-bit Aug 16, 2024
66e8069
RelCAT: config update + optimizer change.
vladd-bit Aug 19, 2024
9238eeb
RelCAT: fixed model freeze flags.
vladd-bit Sep 3, 2024
94097f2
RelCAT: model optimizer save/load fix.
vladd-bit Sep 5, 2024
ef32757
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Sep 7, 2024
576839a
RelCAT: added export ent tag check.
vladd-bit Sep 12, 2024
0269ee4
Fixed issues when saving/loading model for class weights + inference …
vladd-bit Oct 18, 2024
1a4727b
RelCAT: bug fix for ents that are @ EoS.
vladd-bit Nov 18, 2024
c93afd8
Rel Dataset updates.
vladd-bit Jan 6, 2025
8475011
Rel Dataset updates.
vladd-bit Jan 6, 2025
c33cf95
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Jan 30, 2025
63662c2
Pushing change for ModernBERT
shubham-s-agarwal Jan 31, 2025
ef2f646
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Jan 31, 2025
e40dca2
Merge branch 'relation_extraction_llama' of https://github.com/CogSta…
vladd-bit Jan 31, 2025
bebe5bd
Bumped transformers version.
vladd-bit Feb 10, 2025
6c1a9af
Updated rel dataset generation from fake Spacy Docs.
vladd-bit Feb 10, 2025
d71c624
ModernBert updates.
vladd-bit Feb 16, 2025
fe8737e
Updated RelCAT model-load/save.
vladd-bit Feb 22, 2025
9268570
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Feb 22, 2025
3dad478
Minor relCAT updates, code format.
vladd-bit Mar 10, 2025
ff0a9ca
Merge branch 'master' of https://github.com/CogStack/MedCAT into rela…
vladd-bit Mar 10, 2025
ed5924a
Type check updates.
vladd-bit Mar 10, 2025
c8d4aaa
Fixed inference issue.
vladd-bit Mar 10, 2025
184feea
RelCAT: testing updates.
vladd-bit Mar 10, 2025
23dbf35
Type fixes.
vladd-bit Mar 10, 2025
8e787e6
Type fixes.
vladd-bit Mar 10, 2025
a0b4f0a
Type fixes.
vladd-bit Mar 10, 2025
fd9e2d6
Type fixes IV.
vladd-bit Mar 10, 2025
a7c0336
Type fixes python 3.9.
vladd-bit Mar 10, 2025
d58227b
RelCAT: flake8 fixes.
vladd-bit Mar 12, 2025
bec80eb
RelCAT: flake8 fixes.
vladd-bit Mar 12, 2025
c41e7e7
RelCAT: Updates (fixed model loading after save).
vladd-bit Mar 16, 2025
4101851
Fixed test.
vladd-bit Mar 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Evaluation updates for pretraining.
vladd-bit committed Oct 27, 2021
commit f3d3f44b6d9a3878dc27334fe97caac158f851ad
95 changes: 57 additions & 38 deletions medcat/relation_extraction.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
import logging
import pandas
import torch

import torch.nn
import pickle
import dill
@@ -13,7 +14,7 @@
import torch.optim
import torch
import torch.nn as nn
from datetime import datetime
from datetime import date, datetime
from itertools import combinations
from torch.nn.modules.module import T
from torch.nn.utils.rnn import pad_sequence
@@ -172,7 +173,7 @@ def __init__(self, docs, config: PretrainedConfig = None, spacy_model : str = No
embeddings = numpy.load(os.path.join("./", "embeddings.npy"), allow_pickle=False)
self.embeddings = torch.tensor(embeddings, dtype=torch.float32)

self.spacy_nlp = spacy.load("en_core_sci_lg") if spacy_model is None else spacy.load(spacy_model)
self.spacy_nlp = spacy.load("en_core_sci_md") if spacy_model is None else spacy.load(spacy_model)

self.rel_data = RelData(docs)

@@ -230,13 +231,6 @@ def train(self, num_epoch=10, gradient_acc_steps=10, multistep_lr_gamma=0.8):
e1_e2_start=e1_e2_start, pooled_output=None)

token_mask_matches = (tokens == mask_id).to(dtype=torch.bool, device=self.device)

#lm_logits = lm_logits[token_mask_matches]
#print(token_mask_matches)


#lm_logits[0] = lm_logits[0][token_mask_matches]


if (i % update_size) == (update_size - 1):
verbose = True
@@ -245,36 +239,42 @@ def train(self, num_epoch=10, gradient_acc_steps=10, multistep_lr_gamma=0.8):

#blank_labels = torch.zeros((blanks_logits.size()))

print("BLANK LABELS: ", blank_labels)

loss = criterion(lm_logits, blanks_logits, masked_for_pred, blank_labels, verbose=verbose)
input_matched_lm_logits = lm_logits[0][token_mask_matches]


loss = criterion(input_matched_lm_logits, blanks_logits, masked_for_pred, blank_labels, verbose=verbose)
loss = loss/gradient_acc_steps


loss.backward()

if (i % gradient_acc_steps) == 0:
optimizer.step()
optimizer.zero_grad()

total_loss += loss.item()
total_acc += Two_Headed_Loss.evaluate_(lm_logits, blanks_logits, masked_for_pred, blank_labels, \
self.tokenizer.hf_tokenizers, print_=False)[0]

total_acc += Two_Headed_Loss.evaluate_results(input_matched_lm_logits, blanks_logits,
masked_for_pred, blank_labels, \
self.tokenizer.hf_tokenizers, print_=verbose)

if (i % update_size) == (update_size - 1):
losses_per_batch.append(gradient_acc_steps*total_loss/update_size)
losses_per_batch.append(gradient_acc_steps * total_loss/ update_size)
lm_accuracy_per_batch.append(total_acc/update_size)
print('[Epoch: %d, %5d/ %d points] total loss, lm accuracy per batch: %.3f, %.3f' %
(epoch + 1, (i + 1), train_len, losses_per_batch[-1], lm_accuracy_per_batch[-1]))
total_loss = 0.0; total_acc = 0.0
logging.info("Last batch samples (pos, neg): %d, %d" % ((blank_labels.squeeze() == 1).sum().item(),\
(blank_labels.squeeze() == 0).sum().item()))


print("")

scheduler.step()
losses_per_epoch.append(sum(losses_per_batch)/len(losses_per_batch))
accuracy_per_epoch.append(sum(lm_accuracy_per_batch)/len(lm_accuracy_per_batch))
end_time = datetime.now().time()

print("Epoch finished, took %.2f seconds." % (end_time - start_time))
print("Epoch finished, took " + str(datetime.combine(date.today(), end_time) - datetime.combine(date.today(), start_time) ) + " seconds")
print("Losses at Epoch %d: %.7f" % (epoch + 1, losses_per_epoch[-1]))
print("Accuracy at Epoch %d: %.7f" % (epoch + 1, accuracy_per_epoch[-1]))

@@ -354,28 +354,51 @@ def tokenize(self, relations_dataset: Series):

ent1_ent2_start = ([i for i, e in enumerate(tokens) if e == "[ENT1]"] [0] , [i for i, e in enumerate(tokens) if e == "[ENT2]"] [0])

tagged_tokens = self.tokenizer.hf_tokenizers.convert_tokens_to_ids(tokens)
token_ids = self.tokenizer.hf_tokenizers.convert_tokens_to_ids(tokens)
masked_for_pred = self.tokenizer.hf_tokenizers.convert_tokens_to_ids(masked_for_pred)

return tagged_tokens, masked_for_pred, ent1_ent2_start
print(tokens)

return token_ids, masked_for_pred, ent1_ent2_start

def __len__(self):
return len(self.rel_data.relations_dataframe)

def __getitem__(self, index):
relation, ent1_text, ent2_text = self.rel_data.relations_dataframe.iloc[index] # positive sample
relation, ent1_text, ent2_text = self.rel_data.relations_dataframe.iloc[index]

print("relation : ", str(relation), " | ent1_text: " + ent1_text, "| ent2_text" + ent2_text)
print("\n")

pool = self.rel_data.relations_dataframe[((self.rel_data.relations_dataframe["ent1"] == ent1_text) & (self.rel_data.relations_dataframe["ent2"] == ent2_text))].index

pool = pool.append(self.rel_data.relations_dataframe[((self.rel_data.relations_dataframe["ent1"] == ent2_text) & (self.rel_data.relations_dataframe["ent2"] == ent1_text))].index)
pool.append(self.rel_data.relations_dataframe[((self.rel_data.relations_dataframe["ent1"] == ent2_text) & (self.rel_data.relations_dataframe["ent2"] == ent1_text))].index)

pos_idxs = numpy.random.choice(pool, size=min(int(self.batch_size//2), len(pool)), replace=False)

# if numpy.random.uniform() > 0.5:
# pool = self.rel_data.relations_dataframe[((self.rel_data.relations_dataframe["ENT1"] != ent1_text) | (self.rel_data.relations_dataframe["ENT2"] != ent2_text))].index
neg_idxs = []

if numpy.random.uniform() > 0.5:
pool = self.rel_data.relations_dataframe[((self.rel_data.relations_dataframe["ent1"] != ent1_text) | \
(self.rel_data.relations_dataframe["ent2"] != ent2_text))].index
else:
if numpy.random.uniform() > 0.5: # share e1 but not e2
pool = self.rel_data.relations_dataframe[(( self.rel_data.relations_dataframe['ent1'] == ent1_text) & \
(self.rel_data.relations_dataframe['ent2'] != ent2_text))].index
else: # share e2 but not e1
pool = self.rel_data.relations_dataframe[(( self.rel_data.relations_dataframe['ent1'] != ent1_text) & \
(self.rel_data.relations_dataframe['ent2'] == ent2_text))].index

if len(pool) == 0:
pool = self.rel_data.relations_dataframe[((self.rel_data.relations_dataframe["ent1"] != ent1_text) | \
(self.rel_data.relations_dataframe["ent2"] != ent2_text))].index

neg_idxs = numpy.random.choice(pool, size=min(int(self.batch_size//2), len(pool)), replace=False)

Q = 1/len(pool)

print(" Pos Idx: " + str(pos_idxs))
print(" Neg Idx: " + str(neg_idxs))

batch = []
## process positive sample
pos_df = self.rel_data.relations_dataframe.loc[pos_idxs]
@@ -412,7 +435,7 @@ class Two_Headed_Loss(nn.Module):
def __init__(self, lm_ignore_idx, use_logits=False, normalize=False):
super(Two_Headed_Loss, self).__init__()
self.lm_ignore_idx = lm_ignore_idx
self.LM_criterion = nn.CrossEntropyLoss(ignore_index=self.lm_ignore_idx)
self.LM_criterion = nn.CrossEntropyLoss(ignore_index=lm_ignore_idx)
self.use_logits = use_logits
self.normalize = normalize

@@ -455,8 +478,8 @@ def forward(self, lm_logits, blank_logits, lm_labels, blank_labels, verbose=Fals
pos_labels = [1.0 for _ in range(pos_logits.shape[0])]
else:
pos_logits, pos_labels = torch.FloatTensor([]), []
if blank_logits.is_cuda:
pos_logits = pos_logits.cuda()
# if blank_logits.is_cuda:
# pos_logits = pos_logits.cuda()

# negatives
neg_logits = []
@@ -473,12 +496,8 @@ def forward(self, lm_logits, blank_logits, lm_labels, blank_labels, verbose=Fals

blank_labels_ = torch.FloatTensor(pos_labels + neg_labels)

lm_labels = torch.stack((torch.zeros(lm_logits[0].shape[2]),
torch.cat((lm_labels, torch.zeros(lm_logits[0].shape[2] - lm_labels.shape[0]))
))
)
lm_loss = self.LM_criterion(lm_logits, target=lm_labels.long())

lm_loss = self.LM_criterion(lm_logits[0].type(torch.LongTensor), lm_labels)

blank_loss = self.BCE_criterion(torch.cat([pos_logits, neg_logits], dim=0), \
blank_labels_)
@@ -516,7 +535,7 @@ def load_state(net, optimizer, scheduler, model_name="BERT", load_best=False):
return start_epoch, best_pred, amp_checkpoint

@classmethod
def load_results(model_name="BERT"):
def load_results(cls, model_name="BERT"):
""" Loads saved results if exists """
losses_path = "./data/test_losses_per_epoch_%s.pkl" % model_name
accuracy_path = "./data/test_accuracy_per_epoch_%s.pkl" % model_name
@@ -529,12 +548,12 @@ def load_results(model_name="BERT"):
return losses_per_epoch, accuracy_per_epoch

@classmethod
def evaluate_(lm_logits, blanks_logits, masked_for_pred, blank_labels, tokenizer, print_=True):
def evaluate_results(cls, lm_logits, blanks_logits, masked_for_pred, blank_labels, tokenizer, print_=False):
'''
evaluate must be called after loss.backward()
'''
# lm_logits
lm_logits_pred_ids = torch.softmax(lm_logits, dim=-1).max(1)[1]

lm_logits_pred_ids = torch.softmax(input=lm_logits, dim=-1).max(1)[1]
lm_accuracy = ((lm_logits_pred_ids == masked_for_pred).sum().float()/len(masked_for_pred)).item()

if print_:
@@ -544,8 +563,8 @@ def evaluate_(lm_logits, blanks_logits, masked_for_pred, blank_labels, tokenizer
print("\nMasked labels tokens: \n")
print(tokenizer.decode(masked_for_pred.cpu().numpy() if masked_for_pred.is_cuda else \
masked_for_pred.numpy()))
blanks_mse = 0
return lm_accuracy, blanks_mse

return lm_accuracy

class Pad_Sequence():
"""
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@
'elasticsearch~=7.10',
'dill~=0.3.3',
'datasets~=1.6.0',
'jsonpickle~=2.0.0',
'jsonpickle~=2.0.0'
],
classifiers=[
"Programming Language :: Python :: 3",