diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 3a6ab650..5d4f4725 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -396,6 +396,10 @@ def from_dataset(self, *datasets, field_name:Union[str,List[str]], no_create_ent def construct_vocab(ins, no_create_entry=False): for fn in field_name: field = ins[fn] + # 如果 field 为空或者 None, 那么直接跳过即可。 + if field is None or (hasattr(field, "__len__") and len(field) == 0): + logger.warning(f"instance: {ins} has null field. Skip now!") + continue if isinstance(field, str) or not _is_iterable(field): self.add_word(field, no_create_entry=no_create_entry) else: diff --git a/tests/core/test_vocabulary.py b/tests/core/test_vocabulary.py new file mode 100644 index 00000000..7787840a --- /dev/null +++ b/tests/core/test_vocabulary.py @@ -0,0 +1,21 @@ +import pytest +from collections import Counter + +from fastNLP.core.dataset import DataSet +from fastNLP.core.vocabulary import Vocabulary +from fastNLP import logger + + +class TestVocabulary: + + def test_from_dataset(self): + ds = DataSet({"x": [[1, 2], [3, 4]], "y": ["apple", ""]}) + vocab = Vocabulary() + vocab.from_dataset(ds, field_name="y") + assert vocab.word_count == Counter({'apple': 1}) + + def test_from_dataset1(self): + ds = DataSet({"x": [[1, 2], [3, 4], [5]], "y": [1, None, 2]}) + vocab = Vocabulary() + vocab.from_dataset(ds, field_name="y") + assert vocab.word_count == Counter({1: 1, 2: 1})