1
+ #!/usr/bin/env python
2
+ import plac
3
+ import random
4
+ from pathlib import Path
5
+ from collections import Counter
6
+ import spacy
7
+ import torch
8
+ from spacy .util import minibatch
9
+ import tqdm
10
+ import wasabi
11
+ from AtticusUtils .Loaders import load_atticus_data , create_training_set
12
+ from spacy_transformers .util import cyclic_triangular_rate
13
+
14
+ # Based on sample Spacy code here: https://github.com/explosion/spacy-transformers/blob/v0.6.x/examples/train_textcat.py
15
+
16
+ @plac .annotations (
17
+ model = ("Model name" , "positional" , None , str ),
18
+ output_dir = ("Optional output directory (you'd be stupid not to save this, takes forever to run)" , "option" , "o" , Path ),
19
+ use_test = ("Whether to use the actual test set" , "flag" , "E" ),
20
+ batch_size = ("Number of docs per batch" , "option" , "bs" , int ),
21
+ learn_rate = ("Learning rate" , "option" , "lr" , float ),
22
+ max_wpb = ("Max words per sub-batch" , "option" , "wpb" , int ),
23
+ n_texts = ("Number of texts to train from (0 uses al of them)" , "option" , "t" , int ),
24
+ n_iter = ("Number of training epochs (0 to autodetect)" , "option" , "n" , int ),
25
+ pos_label = ("Positive label for evaluation" , "option" , "pl" , str ),
26
+ )
27
+ def main (
28
+ model = 'en_trf_bertbaseuncased_lg' ,
29
+ output_dir = '/models/BertClassifier' ,
30
+ n_iter = 0 ,
31
+ n_texts = 0 ,
32
+ batch_size = 8 ,
33
+ learn_rate = 2e-5 ,
34
+ max_wpb = 1000 ,
35
+ use_test = False ,
36
+ pos_label = None ,
37
+ ):
38
+ spacy .util .fix_random_seed (0 )
39
+ is_using_gpu = spacy .prefer_gpu ()
40
+ if is_using_gpu :
41
+ torch .set_default_tensor_type ("torch.cuda.FloatTensor" )
42
+ if output_dir is not None :
43
+ output_dir = Path (output_dir )
44
+ if not output_dir .exists ():
45
+ output_dir .mkdir ()
46
+
47
+ nlp = spacy .load (model )
48
+ print (nlp .pipe_names )
49
+ print (f"Loaded model '{ model } '" )
50
+ textcat = nlp .create_pipe (
51
+ "trf_textcat" ,
52
+ config = {"architecture" : "softmax_last_hidden" , "words_per_batch" : max_wpb },
53
+ )
54
+
55
+ # load the Atticus dataset
56
+ print ("Loading Atticus Project training data..." )
57
+ train_data , data_headers = load_atticus_data ()
58
+ (train_texts , train_cats ), (eval_texts , eval_cats ) = create_training_set (train_data = train_data , limit = n_texts )
59
+ train_cats = [i ['cats' ] for i in train_cats ]
60
+ eval_cats = [i ['cats' ] for i in eval_cats ]
61
+
62
+ # add label to text classifier
63
+ print ("Add labels to text classifier" )
64
+ for label in data_headers :
65
+ print (label )
66
+ textcat .add_label (label )
67
+
68
+ print ("Labels:" , textcat .labels )
69
+ print ("Positive label for evaluation:" , pos_label )
70
+ nlp .add_pipe (textcat , last = True )
71
+ print (f"Using { len (train_texts )} training docs, { len (eval_texts )} evaluation" )
72
+ split_training_by_sentence = False
73
+ if split_training_by_sentence :
74
+ # If we're using a model that averages over sentence predictions (we are),
75
+ # there are some advantages to just labelling each sentence as an example.
76
+ # It means we can mix the sentences into different batches, so we can make
77
+ # more frequent updates. It also changes the loss somewhat, in a way that's
78
+ # not obviously better -- but it does seem to work well.
79
+ train_texts , train_cats = make_sentence_examples (nlp , train_texts , train_cats )
80
+ print (f"Extracted { len (train_texts )} training sents" )
81
+ # total_words = sum(len(text.split()) for text in train_texts)
82
+ train_data = list (zip (train_texts , [{"cats" : cats } for cats in train_cats ]))
83
+ # Initialize the TextCategorizer, and create an optimizer.
84
+ optimizer = nlp .resume_training ()
85
+ optimizer .alpha = 0.001
86
+ optimizer .trf_weight_decay = 0.005
87
+ optimizer .L2 = 0.0
88
+ learn_rates = cyclic_triangular_rate (
89
+ learn_rate / 3 , learn_rate * 3 , 2 * len (train_data ) // batch_size
90
+ )
91
+ print ("Training the model..." )
92
+ print ("{:^5}\t {:^5}\t {:^5}\t {:^5}" .format ("LOSS" , "P" , "R" , "F" ))
93
+
94
+ pbar = tqdm .tqdm (total = 100 , leave = False )
95
+ results = []
96
+ epoch = 0
97
+ step = 0
98
+ eval_every = 100
99
+ patience = 3
100
+
101
+ while True :
102
+ # Train and evaluate
103
+ losses = Counter ()
104
+ random .shuffle (train_data )
105
+ batches = minibatch (train_data , size = batch_size )
106
+ for batch in batches :
107
+ optimizer .trf_lr = next (learn_rates )
108
+ texts , annotations = zip (* batch )
109
+ nlp .update (texts , annotations , sgd = optimizer , drop = 0.1 , losses = losses )
110
+ pbar .update (1 )
111
+ if step and (step % eval_every ) == 0 :
112
+ pbar .close ()
113
+ with nlp .use_params (optimizer .averages ):
114
+ scores = evaluate (nlp , eval_texts , eval_cats , pos_label )
115
+ results .append ((scores ["textcat_f" ], step , epoch ))
116
+ print (
117
+ "{0:.3f}\t {1:.3f}\t {2:.3f}\t {3:.3f}" .format (
118
+ losses ["trf_textcat" ],
119
+ scores ["textcat_p" ],
120
+ scores ["textcat_r" ],
121
+ scores ["textcat_f" ],
122
+ )
123
+ )
124
+ pbar = tqdm .tqdm (total = eval_every , leave = False )
125
+ step += 1
126
+ epoch += 1
127
+
128
+ # Stop if n_iter is 0 and we blow past user hard-coded n_iter limit
129
+ if 0 < n_iter <= epoch :
130
+ break
131
+
132
+ # Stop if no improvement in HP.patience checkpoints
133
+ if results :
134
+ best_score , best_step , best_epoch = max (results )
135
+ if ((step - best_step ) // eval_every ) >= patience :
136
+ break
137
+
138
+ msg = wasabi .Printer ()
139
+ table_widths = [2 , 4 , 6 ]
140
+ msg .info (f"Best scoring checkpoints" )
141
+ msg .row (["Epoch" , "Step" , "Score" ], widths = table_widths )
142
+ msg .row (["-" * width for width in table_widths ])
143
+ for score , step , epoch in sorted (results , reverse = True )[:10 ]:
144
+ msg .row ([epoch , step , "%.2f" % (score * 100 )], widths = table_widths )
145
+
146
+ # Test the trained model
147
+ test_text = eval_texts [0 ]
148
+ doc = nlp (test_text )
149
+ print (test_text , doc .cats )
150
+
151
+ if output_dir is not None :
152
+ nlp .to_disk (output_dir )
153
+ print ("Saved model to" , output_dir )
154
+ # test the saved model
155
+ print ("Loading from" , output_dir )
156
+ nlp2 = spacy .load (output_dir )
157
+ doc2 = nlp2 (test_text )
158
+ print (test_text , doc2 .cats )
159
+
160
+
161
+ def make_sentence_examples (nlp , texts , labels ):
162
+ """Treat each sentence of the document as an instance, using the doc labels."""
163
+ sents = []
164
+ sent_cats = []
165
+ for text , cats in zip (texts , labels ):
166
+ doc = nlp .make_doc (text )
167
+ doc = nlp .get_pipe ("sentencizer" )(doc )
168
+ for sent in doc .sents :
169
+ sents .append (sent .text )
170
+ sent_cats .append (cats )
171
+ return sents , sent_cats
172
+
173
+ def evaluate (nlp , texts , cats , pos_label ):
174
+ tp = 0.0 # True positives
175
+ fp = 0.0 # False positives
176
+ fn = 0.0 # False negatives
177
+ tn = 0.0 # True negatives
178
+ total_words = sum (len (text .split ()) for text in texts )
179
+ with tqdm .tqdm (total = total_words , leave = False ) as pbar :
180
+ for i , doc in enumerate (nlp .pipe (texts , batch_size = 8 )):
181
+ gold = cats [i ]
182
+ for label , score in doc .cats .items ():
183
+ if label not in gold :
184
+ continue
185
+ if score >= 0.5 and gold [label ] >= 0.5 :
186
+ tp += 1.0
187
+ elif score >= 0.5 and gold [label ] < 0.5 :
188
+ fp += 1.0
189
+ elif score < 0.5 and gold [label ] < 0.5 :
190
+ tn += 1
191
+ elif score < 0.5 and gold [label ] >= 0.5 :
192
+ fn += 1
193
+ pbar .update (len (doc .text .split ()))
194
+ precision = tp / (tp + fp + 1e-8 )
195
+ recall = tp / (tp + fn + 1e-8 )
196
+ if (precision + recall ) == 0 :
197
+ f_score = 0.0
198
+ else :
199
+ f_score = 2 * (precision * recall ) / (precision + recall )
200
+ return {"textcat_p" : precision , "textcat_r" : recall , "textcat_f" : f_score }
201
+
202
+
203
+ if __name__ == "__main__" :
204
+ plac .call (main )
0 commit comments