From 72da117830a7dcf2de4c252eaf5353b1e65103c3 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 19 Jun 2023 09:16:44 +0200 Subject: [PATCH 1/3] recreate `to_dict` and add relations --- flair/data.py | 41 +++++++++++++++-------- flair/models/relation_classifier_model.py | 2 +- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/flair/data.py b/flair/data.py index ecbcb55ba..d73195e4f 100644 --- a/flair/data.py +++ b/flair/data.py @@ -604,6 +604,14 @@ def __len__(self) -> int: def embedding(self): return self.get_embedding() + def to_dict(self, tag_type: Optional[str] = None): + return { + "text": self.text, + "start_pos": self.start_position, + "end_pos": self.end_position, + "labels": [label.to_dict() for label in self.get_labels(tag_type)], + } + class Relation(_PartOfSentence): def __new__(self, first: Span, second: Span): @@ -664,6 +672,15 @@ def end_position(self) -> int: def embedding(self): pass + def to_dict(self, tag_type: Optional[str] = None): + return { + "from_text": self.first.text, + "to_text": self.second.text, + "from_idx": self.first.tokens[0].idx - 1, + "to_idx": self.second.tokens[0].idx - 1, + "labels": [label.to_dict() for label in self.get_labels(tag_type)], + } + class Sentence(DataPoint): """A Sentence is a list of tokens and is used to represent a sentence or text fragment.""" @@ -760,17 +777,17 @@ def __init__( def unlabeled_identifier(self): return f'Sentence[{len(self)}]: "{self.text}"' - def get_relations(self, type: str) -> List[Relation]: + def get_relations(self, label_type: Optional[str] = None) -> List[Relation]: relations: List[Relation] = [] - for label in self.get_labels(type): + for label in self.get_labels(label_type): if isinstance(label.data_point, Relation): relations.append(label.data_point) return relations - def get_spans(self, type: str) -> List[Span]: + def get_spans(self, label_type: Optional[str] = None) -> List[Span]: spans: List[Span] = [] for potential_span in self._known_spans.values(): - if isinstance(potential_span, Span) and potential_span.has_label(type): + if isinstance(potential_span, Span) and (label_type is None or potential_span.has_label(label_type)): spans.append(potential_span) return sorted(spans) @@ -937,16 +954,12 @@ def to_original_text(self) -> str: ).strip() def to_dict(self, tag_type: Optional[str] = None): - labels = [] - - if tag_type: - labels = [label.to_dict() for label in self.get_labels(tag_type)] - return {"text": self.to_original_text(), tag_type: labels} - - if self.labels: - labels = [label.to_dict() for label in self.labels] - - return {"text": self.to_original_text(), "all labels": labels} + return { + "text": self.to_original_text(), + "labels": [label.to_dict() for label in self.get_labels(tag_type) if label.data_point is self], + "entities": [span.to_dict() for span in self.get_spans(tag_type)], + "relations": [relation.to_dict() for relation in self.get_relations(tag_type)], + } def get_span(self, start: int, stop: int): span_slice = slice(start, stop) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 82fafa71f..43b7dc203 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -347,7 +347,7 @@ def _valid_entities(self, sentence: Sentence) -> Iterator[_Entity]: :return: Valid entities as `_Entity` """ for label_type, valid_labels in self.entity_label_types.items(): - for entity_span in sentence.get_spans(type=label_type): + for entity_span in sentence.get_spans(label_type=label_type): entity_label: Label = entity_span.get_label(label_type=label_type) # Only use entities labelled with the specified labels for each label type From 0e6c0b614afbb44501e8fa5837e02354804cb302 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 17 Jul 2023 13:29:23 +0200 Subject: [PATCH 2/3] add tokens --- flair/data.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/flair/data.py b/flair/data.py index d73195e4f..8779b48b6 100644 --- a/flair/data.py +++ b/flair/data.py @@ -543,6 +543,14 @@ def set_label(self, typename: str, value: str, score: float = 1.0): else: DataPoint.set_label(self, typename=typename, value=value, score=score) + def to_dict(self, tag_type: Optional[str] = None): + return { + "text": self.text, + "start_pos": self.start_position, + "end_pos": self.end_position, + "labels": [label.to_dict() for label in self.get_labels(tag_type)], + } + class Span(_PartOfSentence): """This class represents one textual span consisting of Tokens.""" @@ -957,8 +965,9 @@ def to_dict(self, tag_type: Optional[str] = None): return { "text": self.to_original_text(), "labels": [label.to_dict() for label in self.get_labels(tag_type) if label.data_point is self], - "entities": [span.to_dict() for span in self.get_spans(tag_type)], - "relations": [relation.to_dict() for relation in self.get_relations(tag_type)], + "entities": [span.to_dict(tag_type) for span in self.get_spans(tag_type)], + "relations": [relation.to_dict(tag_type) for relation in self.get_relations(tag_type)], + "tokens": [token.to_dict(tag_type) for token in self.tokens] } def get_span(self, start: int, stop: int): From 14d5a073ec68471c7b12bf213754d3284770f608 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 17 Jul 2023 16:34:37 +0200 Subject: [PATCH 3/3] black formatting and ruff fixes --- flair/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/data.py b/flair/data.py index 8779b48b6..182b12836 100644 --- a/flair/data.py +++ b/flair/data.py @@ -967,7 +967,7 @@ def to_dict(self, tag_type: Optional[str] = None): "labels": [label.to_dict() for label in self.get_labels(tag_type) if label.data_point is self], "entities": [span.to_dict(tag_type) for span in self.get_spans(tag_type)], "relations": [relation.to_dict(tag_type) for relation in self.get_relations(tag_type)], - "tokens": [token.to_dict(tag_type) for token in self.tokens] + "tokens": [token.to_dict(tag_type) for token in self.tokens], } def get_span(self, start: int, stop: int):