Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MasakhaPOS Dataset #3247

Merged
merged 10 commits into from
Aug 11, 2023
2 changes: 2 additions & 0 deletions flair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
KEYPHRASE_INSPEC,
KEYPHRASE_SEMEVAL2010,
KEYPHRASE_SEMEVAL2017,
MASAKHA_POS,
NER_ARABIC_ANER,
NER_ARABIC_AQMAR,
NER_BASQUE,
Expand Down Expand Up @@ -447,6 +448,7 @@
"KEYPHRASE_INSPEC",
"KEYPHRASE_SEMEVAL2010",
"KEYPHRASE_SEMEVAL2017",
"MASAKHA_POS",
"NER_ARABIC_ANER",
"NER_ARABIC_AQMAR",
"NER_BASQUE",
Expand Down
101 changes: 101 additions & 0 deletions flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4795,3 +4795,104 @@ def __init__(
sample_missing_splits=False,
name="nermud",
)


class MASAKHA_POS(MultiCorpus):
def __init__(
self,
languages: Union[str, List[str]] = "bam",
version: str = "v1",
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = True,
**corpusargs,
) -> None:
"""Initialize the MasakhaPOS corpus available on https://github.com/masakhane-io/masakhane-pos.

It consists of 20 African languages. Pass a language code or a list of language codes to initialize the corpus
with the languages you require. If you pass "all", all languages will be initialized.
:version: Specifies version of the dataset. Currently, only "v1" is supported.
:param base_path: Default is None, meaning that corpus gets auto-downloaded and loaded. You can override this
to point to a different folder but typically this should not be necessary.
:param in_memory: If True, keeps dataset in memory giving speedups in training.
"""
base_path = flair.cache_root / "datasets" if not base_path else Path(base_path)

# if only one language is given
if isinstance(languages, str):
languages = [languages]

# column format
columns = {0: "text", 1: "pos"}

# this dataset name
dataset_name = self.__class__.__name__.lower()

supported_versions = ["v1"]

if version not in supported_versions:
log.error(f"The specified version '{version}' is not in the list of supported version!")
log.error(f"Supported versions are '{supported_versions}'!")
raise Exception

data_folder = base_path / dataset_name / version

supported_languages = [
"bam",
"bbj",
"ewe",
"fon",
"hau",
"ibo",
"kin",
"lug",
"mos",
"pcm",
"nya",
"sna",
"swa",
"twi",
"wol",
"xho",
"yor",
"zul",
]

data_paths = {
"v1": "https://raw.githubusercontent.com/masakhane-io/masakhane-pos/main/data",
}

# use all languages if explicitly set to "all"
if languages == ["all"]:
languages = supported_languages

corpora: List[Corpus] = []
for language in languages:
if language not in supported_languages:
log.error(f"Language '{language}' is not in list of supported languages!")
log.error(f"Supported are '{supported_languages}'!")
log.error("Instantiate this Corpus for instance like so 'corpus = MASAKHA_POS(languages='bam')'")
raise Exception

language_folder = data_folder / language

# download data if necessary
data_path = f"{data_paths[version]}/{language}"
cached_path(f"{data_path}/dev.txt", language_folder)
cached_path(f"{data_path}/test.txt", language_folder)
cached_path(f"{data_path}/train.txt", language_folder)

# initialize comlumncorpus and add it to list
log.info(f"Reading data for language {language}@{version}")
corp = ColumnCorpus(
data_folder=language_folder,
column_format=columns,
encoding="utf-8",
in_memory=in_memory,
name=language,
**corpusargs,
)
corpora.append(corp)
super().__init__(
corpora,
name="africa-pos-" + "-".join(languages),
)
64 changes: 64 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,70 @@ def test_german_ler_corpus(tasks_base_path):
assert len(corpus.test) == 6673, "Mismatch in number of sentences for test split"


def test_masakha_pos_corpus(tasks_base_path):
# This test covers the complete MasakhaPOS dataset.
supported_versions = ["v1"]

supported_languages = {
"v1": [
"bam",
"bbj",
"ewe",
"fon",
"hau",
"ibo",
"kin",
"lug",
"mos",
"pcm",
"nya",
"sna",
"swa",
"twi",
"wol",
"xho",
"yor",
"zul",
],
}

africa_pos_stats = {
"v1": {
"bam": {"train": 775, "dev": 154, "test": 619},
"bbj": {"train": 750, "dev": 149, "test": 599},
"ewe": {"train": 728, "dev": 145, "test": 582},
"fon": {"train": 810, "dev": 161, "test": 646},
"hau": {"train": 753, "dev": 150, "test": 601},
"ibo": {"train": 803, "dev": 160, "test": 642},
"kin": {"train": 757, "dev": 151, "test": 604},
"lug": {"train": 733, "dev": 146, "test": 586},
"mos": {"train": 757, "dev": 151, "test": 604},
"pcm": {"train": 752, "dev": 150, "test": 600},
"nya": {"train": 728, "dev": 145, "test": 582},
"sna": {"train": 747, "dev": 149, "test": 596},
"swa": {"train": 693, "dev": 138, "test": 553},
"twi": {"train": 785, "dev": 157, "test": 628},
"wol": {"train": 782, "dev": 156, "test": 625},
"xho": {"train": 752, "dev": 150, "test": 601},
"yor": {"train": 893, "dev": 178, "test": 713},
"zul": {"train": 753, "dev": 150, "test": 601},
},
}

def check_number_sentences(reference: int, actual: int, split_name: str, language: str, version: str):
assert actual == reference, f"Mismatch in number of sentences for {language}@{version}/{split_name}"

for version in supported_versions:
for language in supported_languages[version]:
corpus = flair.datasets.MASAKHA_POS(languages=language, version=version)

gold_stats = africa_pos_stats[version][language]

check_number_sentences(len(corpus.train), gold_stats["train"], "train", language, version)
check_number_sentences(len(corpus.dev), gold_stats["dev"], "dev", language, version)
check_number_sentences(len(corpus.test), gold_stats["test"], "test", language, version)


def test_multi_file_jsonl_corpus_should_use_label_type(tasks_base_path):
corpus = MultiFileJsonlCorpus(
train_files=[tasks_base_path / "jsonl/train.jsonl"],
Expand Down