diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index e549810d1..8100e4821 100644 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -173,6 +173,7 @@ KEYPHRASE_INSPEC, KEYPHRASE_SEMEVAL2010, KEYPHRASE_SEMEVAL2017, + MASAKHA_POS, NER_ARABIC_ANER, NER_ARABIC_AQMAR, NER_BASQUE, @@ -447,6 +448,7 @@ "KEYPHRASE_INSPEC", "KEYPHRASE_SEMEVAL2010", "KEYPHRASE_SEMEVAL2017", + "MASAKHA_POS", "NER_ARABIC_ANER", "NER_ARABIC_AQMAR", "NER_BASQUE", diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index d214873b5..0a5bf1b58 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -4795,3 +4795,104 @@ def __init__( sample_missing_splits=False, name="nermud", ) + + +class MASAKHA_POS(MultiCorpus): + def __init__( + self, + languages: Union[str, List[str]] = "bam", + version: str = "v1", + base_path: Optional[Union[str, Path]] = None, + in_memory: bool = True, + **corpusargs, + ) -> None: + """Initialize the MasakhaPOS corpus available on https://github.com/masakhane-io/masakhane-pos. + + It consists of 20 African languages. Pass a language code or a list of language codes to initialize the corpus + with the languages you require. If you pass "all", all languages will be initialized. + :version: Specifies version of the dataset. Currently, only "v1" is supported. + :param base_path: Default is None, meaning that corpus gets auto-downloaded and loaded. You can override this + to point to a different folder but typically this should not be necessary. + :param in_memory: If True, keeps dataset in memory giving speedups in training. + """ + base_path = flair.cache_root / "datasets" if not base_path else Path(base_path) + + # if only one language is given + if isinstance(languages, str): + languages = [languages] + + # column format + columns = {0: "text", 1: "pos"} + + # this dataset name + dataset_name = self.__class__.__name__.lower() + + supported_versions = ["v1"] + + if version not in supported_versions: + log.error(f"The specified version '{version}' is not in the list of supported version!") + log.error(f"Supported versions are '{supported_versions}'!") + raise Exception + + data_folder = base_path / dataset_name / version + + supported_languages = [ + "bam", + "bbj", + "ewe", + "fon", + "hau", + "ibo", + "kin", + "lug", + "mos", + "pcm", + "nya", + "sna", + "swa", + "twi", + "wol", + "xho", + "yor", + "zul", + ] + + data_paths = { + "v1": "https://raw.githubusercontent.com/masakhane-io/masakhane-pos/main/data", + } + + # use all languages if explicitly set to "all" + if languages == ["all"]: + languages = supported_languages + + corpora: List[Corpus] = [] + for language in languages: + if language not in supported_languages: + log.error(f"Language '{language}' is not in list of supported languages!") + log.error(f"Supported are '{supported_languages}'!") + log.error("Instantiate this Corpus for instance like so 'corpus = MASAKHA_POS(languages='bam')'") + raise Exception + + language_folder = data_folder / language + + # download data if necessary + data_path = f"{data_paths[version]}/{language}" + cached_path(f"{data_path}/dev.txt", language_folder) + cached_path(f"{data_path}/test.txt", language_folder) + cached_path(f"{data_path}/train.txt", language_folder) + + # initialize comlumncorpus and add it to list + log.info(f"Reading data for language {language}@{version}") + corp = ColumnCorpus( + data_folder=language_folder, + column_format=columns, + encoding="utf-8", + in_memory=in_memory, + name=language, + **corpusargs, + ) + corpora.append(corp) + super().__init__( + corpora, + name="africa-pos-" + "-".join(languages), + ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index d642be251..56d524d04 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -808,6 +808,70 @@ def test_german_ler_corpus(tasks_base_path): assert len(corpus.test) == 6673, "Mismatch in number of sentences for test split" +def test_masakha_pos_corpus(tasks_base_path): + # This test covers the complete MasakhaPOS dataset. + supported_versions = ["v1"] + + supported_languages = { + "v1": [ + "bam", + "bbj", + "ewe", + "fon", + "hau", + "ibo", + "kin", + "lug", + "mos", + "pcm", + "nya", + "sna", + "swa", + "twi", + "wol", + "xho", + "yor", + "zul", + ], + } + + africa_pos_stats = { + "v1": { + "bam": {"train": 775, "dev": 154, "test": 619}, + "bbj": {"train": 750, "dev": 149, "test": 599}, + "ewe": {"train": 728, "dev": 145, "test": 582}, + "fon": {"train": 810, "dev": 161, "test": 646}, + "hau": {"train": 753, "dev": 150, "test": 601}, + "ibo": {"train": 803, "dev": 160, "test": 642}, + "kin": {"train": 757, "dev": 151, "test": 604}, + "lug": {"train": 733, "dev": 146, "test": 586}, + "mos": {"train": 757, "dev": 151, "test": 604}, + "pcm": {"train": 752, "dev": 150, "test": 600}, + "nya": {"train": 728, "dev": 145, "test": 582}, + "sna": {"train": 747, "dev": 149, "test": 596}, + "swa": {"train": 693, "dev": 138, "test": 553}, + "twi": {"train": 785, "dev": 157, "test": 628}, + "wol": {"train": 782, "dev": 156, "test": 625}, + "xho": {"train": 752, "dev": 150, "test": 601}, + "yor": {"train": 893, "dev": 178, "test": 713}, + "zul": {"train": 753, "dev": 150, "test": 601}, + }, + } + + def check_number_sentences(reference: int, actual: int, split_name: str, language: str, version: str): + assert actual == reference, f"Mismatch in number of sentences for {language}@{version}/{split_name}" + + for version in supported_versions: + for language in supported_languages[version]: + corpus = flair.datasets.MASAKHA_POS(languages=language, version=version) + + gold_stats = africa_pos_stats[version][language] + + check_number_sentences(len(corpus.train), gold_stats["train"], "train", language, version) + check_number_sentences(len(corpus.dev), gold_stats["dev"], "dev", language, version) + check_number_sentences(len(corpus.test), gold_stats["test"], "test", language, version) + + def test_multi_file_jsonl_corpus_should_use_label_type(tasks_base_path): corpus = MultiFileJsonlCorpus( train_files=[tasks_base_path / "jsonl/train.jsonl"],