Skip to content

Commit

Permalink
Merge pull request #3276 from flairNLP/GH3275-sample_missing_splits-i…
Browse files Browse the repository at this point in the history
…n-corpus-subclasses

Gh3275: sample_missing_splits in SST-2
  • Loading branch information
alanakbik authored Aug 10, 2023
2 parents a57020a + 5c21bef commit 9ea0894
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
16 changes: 14 additions & 2 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,15 +1238,27 @@ def __init__(

# sample test data from train if none is provided
if test is None and sample_missing_splits and train and sample_missing_splits != "only_dev":
test_portion = 0.1
train_length = _len_dataset(train)
test_size: int = round(train_length / 10)
test_size: int = round(train_length * test_portion)
test, train = randomly_split_into_two_datasets(train, test_size)
log.warning(
"No test split found. Using %.0f%% (i.e. %d samples) of the train split as test data",
test_portion,
test_size,
)

# sample dev data from train if none is provided
if dev is None and sample_missing_splits and train and sample_missing_splits != "only_test":
dev_portion = 0.1
train_length = _len_dataset(train)
dev_size: int = round(train_length / 10)
dev_size: int = round(train_length * dev_portion)
dev, train = randomly_split_into_two_datasets(train, dev_size)
log.warning(
"No dev split found. Using %.0f%% (i.e. %d samples) of the train split as dev data",
dev_portion,
dev_size,
)

# set train dev and test data
self._train: Optional[Dataset[T_co]] = train
Expand Down
5 changes: 4 additions & 1 deletion flair/datasets/document_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def __init__(
skip_header: bool = False,
encoding: str = "utf-8",
no_class_label=None,
sample_missing_splits: Union[bool, str] = True,
**fmtparams,
) -> None:
"""Instantiates a Corpus for text classification from CSV column formatted data.
Expand Down Expand Up @@ -396,7 +397,7 @@ def __init__(
else None
)

super().__init__(train, dev, test, name=name)
super().__init__(train, dev, test, name=name, sample_missing_splits=sample_missing_splits)


class CSVClassificationDataset(FlairDataset):
Expand Down Expand Up @@ -1488,6 +1489,7 @@ def __init__(
tokenizer: Tokenizer = SegtokTokenizer(),
in_memory: bool = False,
encoding: str = "utf-8",
sample_missing_splits: bool = True,
**datasetargs,
) -> None:
base_path = flair.cache_root / "datasets" if not base_path else Path(base_path)
Expand Down Expand Up @@ -1525,6 +1527,7 @@ def __init__(
column_name_map={0: "text", 1: "label"},
train_file=train_file,
dev_file=data_folder / "dev.tsv",
sample_missing_splits=sample_missing_splits,
**kwargs,
)

Expand Down

0 comments on commit 9ea0894

Please sign in to comment.