diff --git a/flair/data.py b/flair/data.py index 04e10a8dd..0665f5a64 100644 --- a/flair/data.py +++ b/flair/data.py @@ -547,25 +547,9 @@ def set_label(self, typename: str, value: str, score: float = 1.0): class Span(_PartOfSentence): """This class represents one textual span consisting of Tokens.""" - def __new__(self, tokens: List[Token]): - # check if the span already exists. If so, return it - unlabeled_identifier = self._make_unlabeled_identifier(tokens) - if unlabeled_identifier in tokens[0].sentence._known_spans: - span = tokens[0].sentence._known_spans[unlabeled_identifier] - return span - - # else make a new span - else: - span = super().__new__(self) - span.initialized = False - tokens[0].sentence._known_spans[unlabeled_identifier] = span - return span - def __init__(self, tokens: List[Token]) -> None: - if not self.initialized: - super().__init__(tokens[0].sentence) - self.tokens = tokens - self.initialized: bool = True + super().__init__(tokens[0].sentence) + self.tokens = tokens @property def start_position(self) -> int: @@ -606,26 +590,10 @@ def embedding(self): class Relation(_PartOfSentence): - def __new__(self, first: Span, second: Span): - # check if the relation already exists. If so, return it - unlabeled_identifier = self._make_unlabeled_identifier(first, second) - if unlabeled_identifier in first.sentence._known_spans: - span = first.sentence._known_spans[unlabeled_identifier] - return span - - # else make a new relation - else: - span = super().__new__(self) - span.initialized = False - first.sentence._known_spans[unlabeled_identifier] = span - return span - def __init__(self, first: Span, second: Span) -> None: - if not self.initialized: - super().__init__(sentence=first.sentence) - self.first: Span = first - self.second: Span = second - self.initialized: bool = True + super().__init__(sentence=first.sentence) + self.first: Span = first + self.second: Span = second def __repr__(self) -> str: return str(self) @@ -692,7 +660,7 @@ def __init__( self.tokens: List[Token] = [] # private field for all known spans - self._known_spans: Dict[str, _PartOfSentence] = {} + self._known_parts: Dict[str, _PartOfSentence] = {} self.language_code: Optional[str] = language_code @@ -769,7 +737,7 @@ def get_relations(self, type: str) -> List[Relation]: def get_spans(self, type: str) -> List[Span]: spans: List[Span] = [] - for potential_span in self._known_spans.values(): + for potential_span in self._known_parts.values(): if isinstance(potential_span, Span) and potential_span.has_label(type): spans.append(potential_span) return sorted(spans) @@ -949,8 +917,7 @@ def to_dict(self, tag_type: Optional[str] = None): return {"text": self.to_original_text(), "all labels": labels} def get_span(self, start: int, stop: int): - span_slice = slice(start, stop) - return self[span_slice] + return self[start:stop] @typing.overload def __getitem__(self, idx: int) -> Token: @@ -960,9 +927,27 @@ def __getitem__(self, idx: int) -> Token: def __getitem__(self, s: slice) -> Span: ... + @typing.overload + def __getitem__(self, s: typing.Tuple[Span, Span]) -> Relation: + ... + def __getitem__(self, subscript): - if isinstance(subscript, slice): - return Span(self.tokens[subscript]) + if isinstance(subscript, tuple): + first, second = subscript + identifier = "" + if isinstance(first, Span) and isinstance(second, Span): + identifier = Relation._make_unlabeled_identifier(first, second) + if identifier not in self._known_parts: + self._known_parts[identifier] = Relation(first, second) + + return self._known_parts[identifier] + elif isinstance(subscript, slice): + identifier = Span._make_unlabeled_identifier(self.tokens[subscript]) + + if identifier not in self._known_parts: + self._known_parts[identifier] = Span(self.tokens[subscript]) + + return self._known_parts[identifier] else: return self.tokens[subscript] @@ -1108,11 +1093,11 @@ def remove_labels(self, typename: str): token.remove_labels(typename) # labels also need to be deleted at all known spans - for span in self._known_spans.values(): + for span in self._known_parts.values(): span.remove_labels(typename) # remove spans without labels - self._known_spans = {k: v for k, v in self._known_spans.items() if len(v.labels) > 0} + self._known_parts = {k: v for k, v in self._known_parts.items() if len(v.labels) > 0} # delete labels at object itself super().remove_labels(typename) diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 5525ae45a..e3058cec6 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -26,7 +26,6 @@ Corpus, FlairDataset, MultiCorpus, - Relation, Sentence, Token, get_spans_from_bio, @@ -682,9 +681,7 @@ def _convert_lines_to_sentence( tail_end = int(indices[3]) label = indices[4] # head and tail span indices are 1-indexed and end index is inclusive - relation = Relation( - first=sentence[head_start - 1 : head_end], second=sentence[tail_start - 1 : tail_end] - ) + relation = sentence[sentence[head_start - 1 : head_end], sentence[tail_start - 1 : tail_end]] remapped = self._remap_label(label) if remapped != "O": relation.add_label(typename="relation", value=remapped) diff --git a/flair/models/regexp_tagger.py b/flair/models/regexp_tagger.py index a6b7f6c80..4f4932539 100644 --- a/flair/models/regexp_tagger.py +++ b/flair/models/regexp_tagger.py @@ -38,7 +38,7 @@ def get_token_span(self, span: Tuple[int, int]) -> Span: """ span_start: int = self.__tokens_start_pos.index(span[0]) span_end: int = self.__tokens_end_pos.index(span[1]) - return Span(self.tokens[span_start : span_end + 1]) + return self.sentence[span_start : span_end + 1] class RegexpTagger: diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 82fafa71f..b188ee9ab 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -377,11 +377,9 @@ def _entity_pair_permutations( """ valid_entities: List[_Entity] = list(self._valid_entities(sentence)) - # Use a dictionary to find gold relation annotations for a given entity pair - relation_to_gold_label: Dict[str, str] = { - relation.unlabeled_identifier: relation.get_label(self.label_type, zero_tag_value=self.zero_tag_value).value - for relation in sentence.get_relations(self.label_type) - } + # ensure that all existing relations without label have the label set to zero_tag_value. + for relation in sentence.get_relations(self.label_type): + relation.set_label(self.label_type, relation.get_label(self.label_type, self.zero_tag_value).value) # Yield head and tail entity pairs from the cross product of all entities for head, tail in itertools.product(valid_entities, repeat=2): @@ -398,9 +396,8 @@ def _entity_pair_permutations( continue # Obtain gold label, if existing - original_relation: Relation = Relation(first=head.span, second=tail.span) - gold_label: Optional[str] = relation_to_gold_label.get(original_relation.unlabeled_identifier) - + gold_relation = sentence[head.span, tail.span] + gold_label: Optional[str] = gold_relation.get_label(self.label_type, zero_tag_value=None).value yield head, tail, gold_label def _encode_sentence( @@ -481,7 +478,7 @@ def _encode_sentence_for_inference( tail=tail, gold_label=gold_label if gold_label is not None else self.zero_tag_value, ) - original_relation: Relation = Relation(first=head.span, second=tail.span) + original_relation: Relation = sentence[head.span, tail.span] yield masked_sentence, original_relation def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedSentence]: diff --git a/flair/models/relation_extractor_model.py b/flair/models/relation_extractor_model.py index 83e063b72..745b0a0fd 100644 --- a/flair/models/relation_extractor_model.py +++ b/flair/models/relation_extractor_model.py @@ -79,7 +79,7 @@ def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Relation]: ): continue - relation = Relation(span_1, span_2) + relation = sentence[span_1, span_2] if self.training and self.train_on_gold_pairs_only and relation.get_label(self.label_type).value == "O": continue entity_pairs.append(relation) diff --git a/flair/models/tars_model.py b/flair/models/tars_model.py index de431c468..51f9533e3 100644 --- a/flair/models/tars_model.py +++ b/flair/models/tars_model.py @@ -404,10 +404,9 @@ def _get_tars_formatted_sentence(self, label, sentence): for entity_label in sentence.get_labels(self.label_type): if entity_label.value == label: - new_span = Span( - [tars_sentence.get_token(token.idx + label_length) for token in entity_label.data_point] - ) - new_span.add_label(self.static_label_type, value="entity") + start_pos = entity_label.data_point[0].idx + label_length - 1 + end_pos = entity_label.data_point[-1].idx + label_length + tars_sentence[start_pos:end_pos].add_label(self.static_label_type, value="entity") tars_sentence.copy_context_from_sentence(sentence) return tars_sentence @@ -588,9 +587,10 @@ def predict( # only add if all tokens have no label if tag_this: # make and add a corresponding predicted span - predicted_span = Span( - [sentence.get_token(token.idx - label_length) for token in label.data_point] - ) + start_pos = label.data_point.data_point[0].idx - label_length - 1 + end_pos = label.data_point.data_point[-1].idx - label_length + + predicted_span = sentence[start_pos:end_pos] predicted_span.add_label(label_name, value=label.value, score=label.score) # set indices so that no token can be tagged twice diff --git a/tests/test_labels.py b/tests/test_labels.py index 210a21588..0357725b7 100644 --- a/tests/test_labels.py +++ b/tests/test_labels.py @@ -189,9 +189,9 @@ def test_relation_tags(): sentence = Sentence("Humboldt Universität zu Berlin is located in Berlin .") # create two relation label - Relation(sentence[0:4], sentence[7:8]).add_label("rel", "located in") - Relation(sentence[0:2], sentence[3:4]).add_label("rel", "university of") - Relation(sentence[0:2], sentence[3:4]).add_label("syntactic", "apposition") + sentence[sentence[0:4], sentence[7:8]].add_label("rel", "located in") + sentence[sentence[0:2], sentence[3:4]].add_label("rel", "university of") + sentence[sentence[0:2], sentence[3:4]].add_label("syntactic", "apposition") # there should be two relation labels labels: List[Label] = sentence.get_labels("rel") diff --git a/tests/test_sentence.py b/tests/test_sentence.py index 3e3142264..ad5e85470 100644 --- a/tests/test_sentence.py +++ b/tests/test_sentence.py @@ -1,3 +1,6 @@ +import copy +import pickle + from flair.data import Sentence @@ -73,3 +76,37 @@ def test_start_end_position_pretokenized() -> None: (10, 18), (19, 20), ] + + +def test_spans_support_deepcopy() -> None: + sentence = Sentence(["I", "live", "in", "Vienna", "."]) + sentence[3:4].add_label("ner", "LOC") + + _ = copy.deepcopy(sentence) + + +def test_spans_support_pickle() -> None: + sentence = Sentence(["I", "live", "in", "Vienna", "."]) + sentence[3:4].add_label("ner", "LOC") + + pickle_data = pickle.dumps(sentence) + _ = pickle.loads(pickle_data) + + +def test_relations_support_deepcopy() -> None: + sentence = Sentence(["Vienna", "is", "the", "capital", "of", "Austria"]) + sentence[0:1].add_label("ner", "LOC") + sentence[5:6].add_label("ner", "LOC") + sentence[sentence[0:1], sentence[5:6]].add_label("rel", "capital") + + _ = copy.deepcopy(sentence) + + +def test_relations_support_pickle() -> None: + sentence = Sentence(["Vienna", "is", "the", "capital", "of", "Austria"]) + sentence[0:1].add_label("ner", "LOC") + sentence[5:6].add_label("ner", "LOC") + sentence[sentence[0:1], sentence[5:6]].add_label("rel", "capital") + + pickle_data = pickle.dumps(sentence) + _ = pickle.loads(pickle_data)