Skip to content
This repository was archived by the owner on Jun 28, 2023. It is now read-only.

Commit 706b9c2

Browse files
committed
Recoded the BertModelBuilder and used it to build a decent classifier (f score of about .75). Also zipped up that generated model and included it in the pre-trained folder as a tar archive.
1 parent 177cf41 commit 706b9c2

11 files changed

+8197
-1
lines changed

AtticusUtils/Loaders.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pandas as pd
2+
import random
3+
4+
5+
def create_training_set(train_data=[{}], limit=0, split=0.8):
6+
"""Load data from the Atticus dataset, splitting off a held-out set."""
7+
random.shuffle(train_data)
8+
train_data = train_data[-limit:]
9+
10+
texts, labels = zip(*train_data)
11+
split = int(len(train_data) * split)
12+
13+
# Return data in format that matches example here:
14+
# https://github.com/explosion/spaCy/blob/master/examples/training/train_textcat.py
15+
return (texts[:split], labels[:split]), (texts[split:], labels[split:])
16+
17+
18+
def load_atticus_data(filepath='./data/master_clauses.csv'):
19+
"""
20+
Load data from the atticus csv (omitting the answer cols as we want to train classifiers
21+
not question answering).
22+
23+
Data is returned in the Spacy training format:
24+
TRAIN_DATA = [
25+
("text1", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}})
26+
]
27+
28+
A list of headers is also returned so you can add these labels. FYI, the Filename and Doc name
29+
columns are dropped as well.
30+
31+
"""
32+
33+
# Load csv
34+
atticus_clauses_df = pd.read_csv(filepath)
35+
36+
# Do a little post-processing
37+
data_headers = [h for h in list(atticus_clauses_df.columns) if not "Answer" in h]
38+
data_headers.pop(0) # Drop filename col (index 0 for col 1)
39+
data_headers.pop(0) # Drop doc name (orig col 2 (index 1) but now first col (index 0))
40+
41+
training_values = {i: 0 for i in data_headers}
42+
atticus_clauses_data_df = atticus_clauses_df.loc[:, data_headers]
43+
44+
train_data = []
45+
46+
# Iterate over csv to build training data dict
47+
for header in atticus_clauses_data_df.columns:
48+
49+
for row in atticus_clauses_data_df[[header]].iterrows():
50+
51+
value = row[1][header]
52+
53+
if not pd.isnull(value):
54+
train_data.append((value, {'cats': {**training_values, header: 1}}))
55+
56+
return train_data, data_headers

AtticusUtils/__init__.py

Whitespace-only changes.

BertModelBuilder.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
#!/usr/bin/env python
2+
import plac
3+
import random
4+
from pathlib import Path
5+
from collections import Counter
6+
import spacy
7+
import torch
8+
from spacy.util import minibatch
9+
import tqdm
10+
import wasabi
11+
from AtticusUtils.Loaders import load_atticus_data, create_training_set
12+
from spacy_transformers.util import cyclic_triangular_rate
13+
14+
# Based on sample Spacy code here: https://github.com/explosion/spacy-transformers/blob/v0.6.x/examples/train_textcat.py
15+
16+
@plac.annotations(
17+
model=("Model name", "positional", None, str),
18+
output_dir=("Optional output directory (you'd be stupid not to save this, takes forever to run)", "option", "o", Path),
19+
use_test=("Whether to use the actual test set", "flag", "E"),
20+
batch_size=("Number of docs per batch", "option", "bs", int),
21+
learn_rate=("Learning rate", "option", "lr", float),
22+
max_wpb=("Max words per sub-batch", "option", "wpb", int),
23+
n_texts=("Number of texts to train from (0 uses al of them)", "option", "t", int),
24+
n_iter=("Number of training epochs (0 to autodetect)", "option", "n", int),
25+
pos_label=("Positive label for evaluation", "option", "pl", str),
26+
)
27+
def main(
28+
model='en_trf_bertbaseuncased_lg',
29+
output_dir='/models/BertClassifier',
30+
n_iter=0,
31+
n_texts=0,
32+
batch_size=8,
33+
learn_rate=2e-5,
34+
max_wpb=1000,
35+
use_test=False,
36+
pos_label=None,
37+
):
38+
spacy.util.fix_random_seed(0)
39+
is_using_gpu = spacy.prefer_gpu()
40+
if is_using_gpu:
41+
torch.set_default_tensor_type("torch.cuda.FloatTensor")
42+
if output_dir is not None:
43+
output_dir = Path(output_dir)
44+
if not output_dir.exists():
45+
output_dir.mkdir()
46+
47+
nlp = spacy.load(model)
48+
print(nlp.pipe_names)
49+
print(f"Loaded model '{model}'")
50+
textcat = nlp.create_pipe(
51+
"trf_textcat",
52+
config={"architecture": "softmax_last_hidden", "words_per_batch": max_wpb},
53+
)
54+
55+
# load the Atticus dataset
56+
print("Loading Atticus Project training data...")
57+
train_data, data_headers = load_atticus_data()
58+
(train_texts, train_cats), (eval_texts, eval_cats) = create_training_set(train_data=train_data, limit=n_texts)
59+
train_cats = [i['cats'] for i in train_cats]
60+
eval_cats = [i['cats'] for i in eval_cats]
61+
62+
# add label to text classifier
63+
print("Add labels to text classifier")
64+
for label in data_headers:
65+
print(label)
66+
textcat.add_label(label)
67+
68+
print("Labels:", textcat.labels)
69+
print("Positive label for evaluation:", pos_label)
70+
nlp.add_pipe(textcat, last=True)
71+
print(f"Using {len(train_texts)} training docs, {len(eval_texts)} evaluation")
72+
split_training_by_sentence = False
73+
if split_training_by_sentence:
74+
# If we're using a model that averages over sentence predictions (we are),
75+
# there are some advantages to just labelling each sentence as an example.
76+
# It means we can mix the sentences into different batches, so we can make
77+
# more frequent updates. It also changes the loss somewhat, in a way that's
78+
# not obviously better -- but it does seem to work well.
79+
train_texts, train_cats = make_sentence_examples(nlp, train_texts, train_cats)
80+
print(f"Extracted {len(train_texts)} training sents")
81+
# total_words = sum(len(text.split()) for text in train_texts)
82+
train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats]))
83+
# Initialize the TextCategorizer, and create an optimizer.
84+
optimizer = nlp.resume_training()
85+
optimizer.alpha = 0.001
86+
optimizer.trf_weight_decay = 0.005
87+
optimizer.L2 = 0.0
88+
learn_rates = cyclic_triangular_rate(
89+
learn_rate / 3, learn_rate * 3, 2 * len(train_data) // batch_size
90+
)
91+
print("Training the model...")
92+
print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F"))
93+
94+
pbar = tqdm.tqdm(total=100, leave=False)
95+
results = []
96+
epoch = 0
97+
step = 0
98+
eval_every = 100
99+
patience = 3
100+
101+
while True:
102+
# Train and evaluate
103+
losses = Counter()
104+
random.shuffle(train_data)
105+
batches = minibatch(train_data, size=batch_size)
106+
for batch in batches:
107+
optimizer.trf_lr = next(learn_rates)
108+
texts, annotations = zip(*batch)
109+
nlp.update(texts, annotations, sgd=optimizer, drop=0.1, losses=losses)
110+
pbar.update(1)
111+
if step and (step % eval_every) == 0:
112+
pbar.close()
113+
with nlp.use_params(optimizer.averages):
114+
scores = evaluate(nlp, eval_texts, eval_cats, pos_label)
115+
results.append((scores["textcat_f"], step, epoch))
116+
print(
117+
"{0:.3f}\t{1:.3f}\t{2:.3f}\t{3:.3f}".format(
118+
losses["trf_textcat"],
119+
scores["textcat_p"],
120+
scores["textcat_r"],
121+
scores["textcat_f"],
122+
)
123+
)
124+
pbar = tqdm.tqdm(total=eval_every, leave=False)
125+
step += 1
126+
epoch += 1
127+
128+
# Stop if n_iter is 0 and we blow past user hard-coded n_iter limit
129+
if 0 < n_iter <= epoch:
130+
break
131+
132+
# Stop if no improvement in HP.patience checkpoints
133+
if results:
134+
best_score, best_step, best_epoch = max(results)
135+
if ((step - best_step) // eval_every) >= patience:
136+
break
137+
138+
msg = wasabi.Printer()
139+
table_widths = [2, 4, 6]
140+
msg.info(f"Best scoring checkpoints")
141+
msg.row(["Epoch", "Step", "Score"], widths=table_widths)
142+
msg.row(["-" * width for width in table_widths])
143+
for score, step, epoch in sorted(results, reverse=True)[:10]:
144+
msg.row([epoch, step, "%.2f" % (score * 100)], widths=table_widths)
145+
146+
# Test the trained model
147+
test_text = eval_texts[0]
148+
doc = nlp(test_text)
149+
print(test_text, doc.cats)
150+
151+
if output_dir is not None:
152+
nlp.to_disk(output_dir)
153+
print("Saved model to", output_dir)
154+
# test the saved model
155+
print("Loading from", output_dir)
156+
nlp2 = spacy.load(output_dir)
157+
doc2 = nlp2(test_text)
158+
print(test_text, doc2.cats)
159+
160+
161+
def make_sentence_examples(nlp, texts, labels):
162+
"""Treat each sentence of the document as an instance, using the doc labels."""
163+
sents = []
164+
sent_cats = []
165+
for text, cats in zip(texts, labels):
166+
doc = nlp.make_doc(text)
167+
doc = nlp.get_pipe("sentencizer")(doc)
168+
for sent in doc.sents:
169+
sents.append(sent.text)
170+
sent_cats.append(cats)
171+
return sents, sent_cats
172+
173+
def evaluate(nlp, texts, cats, pos_label):
174+
tp = 0.0 # True positives
175+
fp = 0.0 # False positives
176+
fn = 0.0 # False negatives
177+
tn = 0.0 # True negatives
178+
total_words = sum(len(text.split()) for text in texts)
179+
with tqdm.tqdm(total=total_words, leave=False) as pbar:
180+
for i, doc in enumerate(nlp.pipe(texts, batch_size=8)):
181+
gold = cats[i]
182+
for label, score in doc.cats.items():
183+
if label not in gold:
184+
continue
185+
if score >= 0.5 and gold[label] >= 0.5:
186+
tp += 1.0
187+
elif score >= 0.5 and gold[label] < 0.5:
188+
fp += 1.0
189+
elif score < 0.5 and gold[label] < 0.5:
190+
tn += 1
191+
elif score < 0.5 and gold[label] >= 0.5:
192+
fn += 1
193+
pbar.update(len(doc.text.split()))
194+
precision = tp / (tp + fp + 1e-8)
195+
recall = tp / (tp + fn + 1e-8)
196+
if (precision + recall) == 0:
197+
f_score = 0.0
198+
else:
199+
f_score = 2 * (precision * recall) / (precision + recall)
200+
return {"textcat_p": precision, "textcat_r": recall, "textcat_f": f_score}
201+
202+
203+
if __name__ == "__main__":
204+
plac.call(main)

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2020 JSIV
3+
Copyright (c) 2020 John Scrudato IVth
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

README.rst

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
This is still a work in progress and is not meant for public use yet.
2+
Leaving the repository public in case anyone stumbles across it and it
3+
saves some time in preparing your own Atticus classifiers. I am planning
4+
to write a blog post / instructions once the performance is slightly better.
5+
6+
I am using LexNLP for its great sentence tokenization functionality. You
7+
could probably get decent performance out of NLTK. Spacy is ok. Currently,
8+
I have been training the classifiers, then using LexNLP to clean and split
9+
sentences / sections / whatever and then running those chunks through Spacy
10+
and checking the category labels.
11+
12+
To get the full (and excellent) Atticus dataset, go here:
13+
https://www.atticusprojectai.org/
14+
15+
Using a BERT-based model, the beta release of the Atticus training set yields
16+
an acceptable (but still not really production-ready) F-score of .735::
17+
18+
LOSS P R F
19+
1.093 0.739 0.472 0.576
20+
1.960 0.763 0.566 0.649
21+
0.290 0.756 0.661 0.706
22+
0.985 0.764 0.683 0.721
23+
1.616 0.770 0.681 0.723
24+
0.517 0.743 0.673 0.706
25+
1.044 0.754 0.697 0.724
26+
0.127 0.762 0.728 0.745
27+
0.542 0.748 0.722 0.735
28+
0.946 0.756 0.722 0.739
29+
0.219 0.751 0.720 0.735
30+
0.551 0.751 0.720 0.735
31+
32+
Training the BERT-based model takes a lot more computing power, and a CUDA-compatible
33+
graphics card is absolutely recommended. Using a Nvidia 1050 Ti, the above training
34+
took about three hours.
35+
36+
Training older, Word2Vec-based models yields less promising results but is much,
37+
much faster to train. Spacy's en_web_core_lg model yields an f-score of around .6.
38+
Using Word2Vec embeddings trained on legal data (such as Law2Vec) yields a slightly
39+
better f-score of around .635. Neither is as good as the BERT-based approach, however.

0 commit comments

Comments
 (0)