diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 05d9b90..8792ae5 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -37,21 +37,21 @@ def test_load_snli(): train_data, classes = mz.datasets.snli.load_data('train', 'classification', return_classes=True) - num_samples = 550146 + num_samples = 549361 assert len(train_data) == num_samples x, y = train_data.unpack() assert len(x['text_left']) == num_samples assert len(x['text_right']) == num_samples assert y.shape == (num_samples, 1) - assert classes == ['entailment', 'contradiction', 'neutral', '-'] + assert classes == ['entailment', 'contradiction', 'neutral'] dev_data, classes = mz.datasets.snli.load_data('dev', 'classification', return_classes=True) - assert len(dev_data) == 10000 - assert classes == ['entailment', 'contradiction', 'neutral', '-'] + assert len(dev_data) == 9842 + assert classes == ['entailment', 'contradiction', 'neutral'] test_data, classes = mz.datasets.snli.load_data('test', 'classification', return_classes=True) - assert len(test_data) == 10000 - assert classes == ['entailment', 'contradiction', 'neutral', '-'] + assert len(test_data) == 9824 + assert classes == ['entailment', 'contradiction', 'neutral'] train_data = mz.datasets.snli.load_data('train', 'ranking') x, y = train_data.unpack()