Skip to content

Commit

Permalink
Merge branch 'master' into clearml_logger
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik authored Aug 11, 2023
2 parents e4521a9 + 10a63dd commit f3410d7
Show file tree
Hide file tree
Showing 30 changed files with 418 additions and 1,095 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Flair ships with state-of-the-art models for a range of NLP tasks. For instance,
| Spanish | Conll-03 (4-class) | **90.54** | *90.3 [(Yu et al., 2020)](https://www.aclweb.org/anthology/2020.acl-main.577.pdf)* | [Flair Spanish 4-class NER demo](https://huggingface.co/flair/ner-spanish-large) |

Many Flair sequence tagging models (named entity recognition, part-of-speech tagging etc.) are also hosted
on the [__🤗 HuggingFace model hub__](https://huggingface.co/models?library=flair&sort=downloads)! You can browse models, check detailed information on how they were trained, and even try each model out online!
on the [__🤗 Hugging Face model hub__](https://huggingface.co/models?library=flair&sort=downloads)! You can browse models, check detailed information on how they were trained, and even try each model out online!


## Quick Start
Expand Down
80 changes: 59 additions & 21 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import re
import typing
from abc import ABC, abstractmethod
from collections import Counter, defaultdict, namedtuple
from collections import Counter, defaultdict
from operator import itemgetter
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Union, cast
from typing import Dict, Iterable, List, NamedTuple, Optional, Union, cast

import torch
from deprecated import deprecated
Expand Down Expand Up @@ -39,7 +39,11 @@ def _len_dataset(dataset: Optional[Dataset]) -> int:
return len(loader)


BoundingBox = namedtuple("BoundingBox", ["left", "top", "right", "bottom"])
class BoundingBox(NamedTuple):
left: str
top: int
right: int
bottom: int


class Dictionary:
Expand Down Expand Up @@ -543,6 +547,14 @@ def set_label(self, typename: str, value: str, score: float = 1.0):
else:
DataPoint.set_label(self, typename=typename, value=value, score=score)

def to_dict(self, tag_type: Optional[str] = None):
return {
"text": self.text,
"start_pos": self.start_position,
"end_pos": self.end_position,
"labels": [label.to_dict() for label in self.get_labels(tag_type)],
}


class Span(_PartOfSentence):
"""This class represents one textual span consisting of Tokens."""
Expand Down Expand Up @@ -604,6 +616,14 @@ def __len__(self) -> int:
def embedding(self):
return self.get_embedding()

def to_dict(self, tag_type: Optional[str] = None):
return {
"text": self.text,
"start_pos": self.start_position,
"end_pos": self.end_position,
"labels": [label.to_dict() for label in self.get_labels(tag_type)],
}


class Relation(_PartOfSentence):
def __new__(self, first: Span, second: Span):
Expand Down Expand Up @@ -664,6 +684,15 @@ def end_position(self) -> int:
def embedding(self):
pass

def to_dict(self, tag_type: Optional[str] = None):
return {
"from_text": self.first.text,
"to_text": self.second.text,
"from_idx": self.first.tokens[0].idx - 1,
"to_idx": self.second.tokens[0].idx - 1,
"labels": [label.to_dict() for label in self.get_labels(tag_type)],
}


class Sentence(DataPoint):
"""A Sentence is a list of tokens and is used to represent a sentence or text fragment."""
Expand Down Expand Up @@ -702,7 +731,7 @@ def __init__(
if isinstance(use_tokenizer, Tokenizer):
tokenizer = use_tokenizer

elif type(use_tokenizer) == bool:
elif isinstance(use_tokenizer, bool):
tokenizer = SegtokTokenizer() if use_tokenizer else SpaceTokenizer()

else:
Expand Down Expand Up @@ -760,17 +789,17 @@ def __init__(
def unlabeled_identifier(self):
return f'Sentence[{len(self)}]: "{self.text}"'

def get_relations(self, type: str) -> List[Relation]:
def get_relations(self, label_type: Optional[str] = None) -> List[Relation]:
relations: List[Relation] = []
for label in self.get_labels(type):
for label in self.get_labels(label_type):
if isinstance(label.data_point, Relation):
relations.append(label.data_point)
return relations

def get_spans(self, type: str) -> List[Span]:
def get_spans(self, label_type: Optional[str] = None) -> List[Span]:
spans: List[Span] = []
for potential_span in self._known_spans.values():
if isinstance(potential_span, Span) and potential_span.has_label(type):
if isinstance(potential_span, Span) and (label_type is None or potential_span.has_label(label_type)):
spans.append(potential_span)
return sorted(spans)

Expand All @@ -784,7 +813,7 @@ def _add_token(self, token: Union[Token, str]):
if isinstance(token, Token):
assert token.sentence is None

if type(token) is str:
if isinstance(token, str):
token = Token(token)
token = cast(Token, token)

Expand Down Expand Up @@ -937,16 +966,13 @@ def to_original_text(self) -> str:
).strip()

def to_dict(self, tag_type: Optional[str] = None):
labels = []

if tag_type:
labels = [label.to_dict() for label in self.get_labels(tag_type)]
return {"text": self.to_original_text(), tag_type: labels}

if self.labels:
labels = [label.to_dict() for label in self.labels]

return {"text": self.to_original_text(), "all labels": labels}
return {
"text": self.to_original_text(),
"labels": [label.to_dict() for label in self.get_labels(tag_type) if label.data_point is self],
"entities": [span.to_dict(tag_type) for span in self.get_spans(tag_type)],
"relations": [relation.to_dict(tag_type) for relation in self.get_relations(tag_type)],
"tokens": [token.to_dict(tag_type) for token in self.tokens],
}

def get_span(self, start: int, stop: int):
span_slice = slice(start, stop)
Expand Down Expand Up @@ -1212,15 +1238,27 @@ def __init__(

# sample test data from train if none is provided
if test is None and sample_missing_splits and train and sample_missing_splits != "only_dev":
test_portion = 0.1
train_length = _len_dataset(train)
test_size: int = round(train_length / 10)
test_size: int = round(train_length * test_portion)
test, train = randomly_split_into_two_datasets(train, test_size)
log.warning(
"No test split found. Using %.0f%% (i.e. %d samples) of the train split as test data",
test_portion,
test_size,
)

# sample dev data from train if none is provided
if dev is None and sample_missing_splits and train and sample_missing_splits != "only_test":
dev_portion = 0.1
train_length = _len_dataset(train)
dev_size: int = round(train_length / 10)
dev_size: int = round(train_length * dev_portion)
dev, train = randomly_split_into_two_datasets(train, dev_size)
log.warning(
"No dev split found. Using %.0f%% (i.e. %d samples) of the train split as dev data",
dev_portion,
dev_size,
)

# set train dev and test data
self._train: Optional[Dataset[T_co]] = train
Expand Down
2 changes: 2 additions & 0 deletions flair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
KEYPHRASE_INSPEC,
KEYPHRASE_SEMEVAL2010,
KEYPHRASE_SEMEVAL2017,
MASAKHA_POS,
NER_ARABIC_ANER,
NER_ARABIC_AQMAR,
NER_BASQUE,
Expand Down Expand Up @@ -447,6 +448,7 @@
"KEYPHRASE_INSPEC",
"KEYPHRASE_SEMEVAL2010",
"KEYPHRASE_SEMEVAL2017",
"MASAKHA_POS",
"NER_ARABIC_ANER",
"NER_ARABIC_AQMAR",
"NER_BASQUE",
Expand Down
2 changes: 1 addition & 1 deletion flair/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __getitem__(self, index: int = 0) -> Sentence:


def find_train_dev_test_files(data_folder, dev_file, test_file, train_file, autofind_splits=True):
if type(data_folder) == str:
if isinstance(data_folder, str):
data_folder: Path = Path(data_folder)

if train_file is not None:
Expand Down
5 changes: 4 additions & 1 deletion flair/datasets/document_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def __init__(
skip_header: bool = False,
encoding: str = "utf-8",
no_class_label=None,
sample_missing_splits: Union[bool, str] = True,
**fmtparams,
) -> None:
"""Instantiates a Corpus for text classification from CSV column formatted data.
Expand Down Expand Up @@ -396,7 +397,7 @@ def __init__(
else None
)

super().__init__(train, dev, test, name=name)
super().__init__(train, dev, test, name=name, sample_missing_splits=sample_missing_splits)


class CSVClassificationDataset(FlairDataset):
Expand Down Expand Up @@ -1488,6 +1489,7 @@ def __init__(
tokenizer: Tokenizer = SegtokTokenizer(),
in_memory: bool = False,
encoding: str = "utf-8",
sample_missing_splits: bool = True,
**datasetargs,
) -> None:
base_path = flair.cache_root / "datasets" if not base_path else Path(base_path)
Expand Down Expand Up @@ -1525,6 +1527,7 @@ def __init__(
column_name_map={0: "text", 1: "label"},
train_file=train_file,
dev_file=data_folder / "dev.tsv",
sample_missing_splits=sample_missing_splits,
**kwargs,
)

Expand Down
Loading

0 comments on commit f3410d7

Please sign in to comment.