Skip to content

Commit

Permalink
Merge pull request #3247 from flairNLP/add-africa-pos-dataset
Browse files Browse the repository at this point in the history
Add support for MasakhaPOS Dataset
  • Loading branch information
alanakbik committed Aug 11, 2023
2 parents fd1ea0f + 2ddae63 commit 10a63dd
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 0 deletions.
2 changes: 2 additions & 0 deletions flair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
KEYPHRASE_INSPEC,
KEYPHRASE_SEMEVAL2010,
KEYPHRASE_SEMEVAL2017,
MASAKHA_POS,
NER_ARABIC_ANER,
NER_ARABIC_AQMAR,
NER_BASQUE,
Expand Down Expand Up @@ -447,6 +448,7 @@
"KEYPHRASE_INSPEC",
"KEYPHRASE_SEMEVAL2010",
"KEYPHRASE_SEMEVAL2017",
"MASAKHA_POS",
"NER_ARABIC_ANER",
"NER_ARABIC_AQMAR",
"NER_BASQUE",
Expand Down
101 changes: 101 additions & 0 deletions flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4795,3 +4795,104 @@ def __init__(
sample_missing_splits=False,
name="nermud",
)


class MASAKHA_POS(MultiCorpus):
def __init__(
self,
languages: Union[str, List[str]] = "bam",
version: str = "v1",
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = True,
**corpusargs,
) -> None:
"""Initialize the MasakhaPOS corpus available on https://github.com/masakhane-io/masakhane-pos.
It consists of 20 African languages. Pass a language code or a list of language codes to initialize the corpus
with the languages you require. If you pass "all", all languages will be initialized.
:version: Specifies version of the dataset. Currently, only "v1" is supported.
:param base_path: Default is None, meaning that corpus gets auto-downloaded and loaded. You can override this
to point to a different folder but typically this should not be necessary.
:param in_memory: If True, keeps dataset in memory giving speedups in training.
"""
base_path = flair.cache_root / "datasets" if not base_path else Path(base_path)

# if only one language is given
if isinstance(languages, str):
languages = [languages]

# column format
columns = {0: "text", 1: "pos"}

# this dataset name
dataset_name = self.__class__.__name__.lower()

supported_versions = ["v1"]

if version not in supported_versions:
log.error(f"The specified version '{version}' is not in the list of supported version!")
log.error(f"Supported versions are '{supported_versions}'!")
raise Exception

data_folder = base_path / dataset_name / version

supported_languages = [
"bam",
"bbj",
"ewe",
"fon",
"hau",
"ibo",
"kin",
"lug",
"mos",
"pcm",
"nya",
"sna",
"swa",
"twi",
"wol",
"xho",
"yor",
"zul",
]

data_paths = {
"v1": "https://raw.githubusercontent.com/masakhane-io/masakhane-pos/main/data",
}

# use all languages if explicitly set to "all"
if languages == ["all"]:
languages = supported_languages

corpora: List[Corpus] = []
for language in languages:
if language not in supported_languages:
log.error(f"Language '{language}' is not in list of supported languages!")
log.error(f"Supported are '{supported_languages}'!")
log.error("Instantiate this Corpus for instance like so 'corpus = MASAKHA_POS(languages='bam')'")
raise Exception

language_folder = data_folder / language

# download data if necessary
data_path = f"{data_paths[version]}/{language}"
cached_path(f"{data_path}/dev.txt", language_folder)
cached_path(f"{data_path}/test.txt", language_folder)
cached_path(f"{data_path}/train.txt", language_folder)

# initialize comlumncorpus and add it to list
log.info(f"Reading data for language {language}@{version}")
corp = ColumnCorpus(
data_folder=language_folder,
column_format=columns,
encoding="utf-8",
in_memory=in_memory,
name=language,
**corpusargs,
)
corpora.append(corp)
super().__init__(
corpora,
name="africa-pos-" + "-".join(languages),
)
64 changes: 64 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,70 @@ def test_german_ler_corpus(tasks_base_path):
assert len(corpus.test) == 6673, "Mismatch in number of sentences for test split"


def test_masakha_pos_corpus(tasks_base_path):
# This test covers the complete MasakhaPOS dataset.
supported_versions = ["v1"]

supported_languages = {
"v1": [
"bam",
"bbj",
"ewe",
"fon",
"hau",
"ibo",
"kin",
"lug",
"mos",
"pcm",
"nya",
"sna",
"swa",
"twi",
"wol",
"xho",
"yor",
"zul",
],
}

africa_pos_stats = {
"v1": {
"bam": {"train": 775, "dev": 154, "test": 619},
"bbj": {"train": 750, "dev": 149, "test": 599},
"ewe": {"train": 728, "dev": 145, "test": 582},
"fon": {"train": 810, "dev": 161, "test": 646},
"hau": {"train": 753, "dev": 150, "test": 601},
"ibo": {"train": 803, "dev": 160, "test": 642},
"kin": {"train": 757, "dev": 151, "test": 604},
"lug": {"train": 733, "dev": 146, "test": 586},
"mos": {"train": 757, "dev": 151, "test": 604},
"pcm": {"train": 752, "dev": 150, "test": 600},
"nya": {"train": 728, "dev": 145, "test": 582},
"sna": {"train": 747, "dev": 149, "test": 596},
"swa": {"train": 693, "dev": 138, "test": 553},
"twi": {"train": 785, "dev": 157, "test": 628},
"wol": {"train": 782, "dev": 156, "test": 625},
"xho": {"train": 752, "dev": 150, "test": 601},
"yor": {"train": 893, "dev": 178, "test": 713},
"zul": {"train": 753, "dev": 150, "test": 601},
},
}

def check_number_sentences(reference: int, actual: int, split_name: str, language: str, version: str):
assert actual == reference, f"Mismatch in number of sentences for {language}@{version}/{split_name}"

for version in supported_versions:
for language in supported_languages[version]:
corpus = flair.datasets.MASAKHA_POS(languages=language, version=version)

gold_stats = africa_pos_stats[version][language]

check_number_sentences(len(corpus.train), gold_stats["train"], "train", language, version)
check_number_sentences(len(corpus.dev), gold_stats["dev"], "dev", language, version)
check_number_sentences(len(corpus.test), gold_stats["test"], "test", language, version)


def test_multi_file_jsonl_corpus_should_use_label_type(tasks_base_path):
corpus = MultiFileJsonlCorpus(
train_files=[tasks_base_path / "jsonl/train.jsonl"],
Expand Down

0 comments on commit 10a63dd

Please sign in to comment.