cd data
sh 01_prepare.sh
cd ..
The script 01_prepare.sh
downloads Schuster et al. (2021)'s preprocessing of the FEVER and MultiNLI datasets. After the files are downloaded and unzipped, you should see something like
+ wc -l fever/train.jsonl fever/dev.jsonl fever/test.jsonl
178059 fever/train.jsonl
11620 fever/dev.jsonl
11710 fever/test.jsonl
201389 total
+ python ../../group_stats.py fever/train.jsonl
S & 99303 (61.5) & 1267 (7.7) & 100570 (56.5)
R & 27575 (17.1) & 14275 (86.3) & 41850 (23.5)
N & 34633 (21.4) & 1006 (6.1) & 35639 (20.0)
for FEVER. For MultiNLI, you will see something like
+ wc -l mnli/train.jsonl mnli/dev.jsonl mnli/test.jsonl
392702 mnli/train.jsonl
9832 mnli/dev.jsonl
9832 mnli/test.jsonl
412366 total
+ python ../../group_stats.py mnli/train.jsonl
S & 118554 (36.7) & 12345 (17.7) & 130899 (33.3)
R & 88180 (27.3) & 42723 (61.2) & 130903 (33.3)
N & 116185 (36.0) & 14715 (21.1) & 130900 (33.3)
Below we go through the procedure for the FEVER experiments. The process for the MNLI experiments is the same, except that we only evaluate on MultiNLI's test set, since its development set and test set are identical in this preprocessing version.
To train the model:
cd fever+sgd
sh 01_train.sh
Note that the script contains SLURM directives that specify our GPU resource requirements for the sbtach
command, but it can also be run with the sh
command.
After training is finished, run
sh 02_predict.sh
This step
(a) makes predictions on the training set and saves the penultimate-layer embeddings of the training set for outlier removal later, and
(b) makes predictions on the development and test set (if the test set is available).
You can view the prediction results in bert-base-uncased-128-out/eval.{train,dev,test}.txt
, and the group accuracies in bert-base-uncased-128-out/eval.groups.{train,dev,test}.txt
.
The results should look something like this (showing the dev set results):
tail -n 11 bert-base-uncased-128-out/eval.dev.txt
S R N
S 3788 118 58
R 705 2574 1044
N 530 494 2309
S R N
Prec: 75.4 80.8 67.7
Rec: 95.6 59.5 69.3
F1: 84.3 68.6 68.5
Acc: 74.6
tail bert-base-uncased-128-out/eval.groups.dev.txt
Total 11620, correct 8671, wrong 2949
Avg acc: 74.6 (8671/11620)
Worst group acc: 14.0
(S, no neg): 96.0 (3777/3934)
(S, neg): 36.7 (11/30)
(R, no neg): 43.7 (1339/3067)
(R, neg): 98.3 (1235/1256)
(N, no neg): 70.9 (2296/3240)
(N, neg): 14.0 (13/93)
Once all the predictions are finished, run
sh 03_calc_mahal.sh
(This script does not require a GPU.)
This calculates the Mahalanobis distances of the penultimate layer embeddings and saves the distances calculated in bert-base-uncased-128-out/train.mahal.npy
.
Run
sh 04_augment.sh
(This script does not require a GPU.)
This script upweights the training data in two different ways, and the upweighted training data are saved in corresponding subfolders (with the same name as their experiment folders) in ../data
:
-
JTT-m: Upweights incorrectly-predicted training set examples with outliers removed (by the Mahalanobis distance method) from the error set
The folder is
fever_sgd_df5_up3
, meaning it usessgd
in the ERM training, filters the error set from the ERM training with degree of freedomdf
5, and upweights (up
) the filtered error set for3
times. The second training usesadamw
as its optimizer. -
JTT: Upweights incorrectly-predicted training set examples
Its folder is
fever_sgd_thres1.0_up3
.thres1.0
sets the threshold to1.0
for filtering out incorrect examples by their predicted probabilities. The default threshold is 1.0, so no examples will be filtered out for JTT. The examples are upweighted3
times.
We use fever_sgd_df5_up3+adamw
(the JTT-m experiment) as an example. The procedure is the same for fever_sgd_thres1.0_up3+adamw
(the JTT experiment).
cd ..
cd fever_sgd_df5_up3+adamw
sh 01_train.sh
Once the training is finished, run
sh 02_predict.sh
The prediction results will be saved in bert-base-uncased-128-out
as eval.{dev,test}.txt
and eval.groups.{dev,test}.txt
. You should see something like this (using the dev set as example):
==> bert-base-uncased-128-out/eval.dev.txt <==
S 3763 109 92
R 250 3748 325
N 257 359 2717
S R N
Prec: 88.1 88.9 86.7
Rec: 94.9 86.7 81.5
F1: 91.4 87.8 84.0
Acc: 88.0
==> bert-base-uncased-128-out/eval.groups.dev.txt <==
Total 11620, correct 10228, wrong 1392
Avg acc: 88.0 (10228/11620)
Worst group acc: 50.0
(S, no neg): 95.3 (3748/3934)
(S, neg): 50.0 (15/30)
(R, no neg): 83.3 (2554/3067)
(R, neg): 95.1 (1194/1256)
(N, no neg): 82.2 (2664/3240)
(N, neg): 57.0 (53/93)
The ERM models are trained with a different optimizer than the first training in Step 1. Instead of using SGD, AdamW is used. The training and prediction scripts are in fever+adamw
and mnli+adamw
.
To train, run
sh 01_train.sh
Run
sh 02_predict.sh
to obtain ERM model results for the test set (and dev set, if available).
We also release model checkpoints (as well as example outputs and preprocessed data) at: https://doi.org/10.5281/zenodo.7260028.