diff --git a/flair/data.py b/flair/data.py index 839f75006..5016e8a87 100644 --- a/flair/data.py +++ b/flair/data.py @@ -6,10 +6,10 @@ from collections import Counter, defaultdict, namedtuple from operator import itemgetter from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union, cast +from typing import Dict, Iterable, List, Optional, Tuple, Union, cast import torch -from deprecated import deprecated +from deprecated import deprecated # type: ignore from torch.utils.data import Dataset, IterableDataset from torch.utils.data.dataset import ConcatDataset, Subset @@ -326,11 +326,13 @@ def get_metadata(self, key: str) -> typing.Any: def has_metadata(self, key: str) -> bool: return key in self._metadata - def add_label(self, typename: str, value: str, score: float = 1.0): + def add_label(self, typename: str, value_or_label: Union[str, Label], score: float = 1.0): + label = value_or_label if isinstance(value_or_label, Label) else Label(self, value_or_label, score) + if typename not in self.annotation_layers: - self.annotation_layers[typename] = [Label(self, value, score)] + self.annotation_layers[typename] = [label] else: - self.annotation_layers[typename].append(Label(self, value, score)) + self.annotation_layers[typename].append(label) return self @@ -421,6 +423,100 @@ def __len__(self) -> int: raise NotImplementedError +class EntityLinkingCandidate: + """Represent a single candidate returned by a CandidateGenerator""" + + def __init__( + self, + concept_id: str, + concept_name: str, + database_name: str, + score: float = 1.0, + additional_ids: Optional[Union[List[str], str]] = None, + ): + """ + :param concept_id: Identifier of the entity / concept from the knowledge base / ontology + :param concept_name: (Canonical) name of the entity / concept from the knowledge base / ontology + :param score: Matching score of the entity / concept according to the entity mention + :param additional_ids: List of additional identifiers for the concept / entity in the KB / ontology + :param database_name: Name of the knowlege base / ontology + """ + self.concept_id = concept_id + self.concept_name = concept_name + self.database_name = database_name + self.score = score + self.additional_ids = additional_ids + + def __str__(self) -> str: + string = f"EntityLinkingCandidate: {self.database_name}:{self.concept_id} - {self.concept_name} - {self.score}" + if self.additional_ids is not None: + string += f" - {self.additional_ids}" + return string + + def __repr__(self) -> str: + return str(self) + + +class EntityLinkingLabel(Label): + """ + Label class models entity linking annotations. Each entity linking label has a data point it refers + to as well as the identifier and name of the concept / entity from a knowledge base or ontology. + Optionally, additional concepts identifier and the database name can be provided. + """ + + def __init__(self, data_point: DataPoint, candidates: List[EntityLinkingCandidate]): + """ + Initializes the label instance. + :param data_point: Data point / span the label refers to + :param candidates: **sorted** list of candidates from candidate generator + """ + + def is_sorted(lst, key=lambda x: x, comparison=lambda x, y: x >= y): + for i, el in enumerate(lst[1:]): + if comparison(key(el), key(lst[i])): + return False + return True + + # candidates must be sorted, regardless if higher is better or not + assert is_sorted(candidates, key=lambda x: x.score) or is_sorted( + candidates, key=lambda x: x.score, comparison=lambda x, y: x <= y + ), "List of candidates must be sorted!" + + super().__init__(data_point, candidates[0].concept_id, candidates[0].score) + self.candidates = candidates + self.concept_name = self.candidates[0].concept_name + self.database_name = self.candidates[0].database_name + + def __str__(self): + return ( + f"{self.data_point.unlabeled_identifier}{flair._arrow} " + f"{self.concept_name} - {self.database_name}:{self._value} ({round(self._score, 4)})" + ) + + def __repr__(self): + return ( + f"{self.data_point.unlabeled_identifier}{flair._arrow} " + f"{self.concept_name} - {self.database_name}:{self._value} ({round(self._score, 4)})" + ) + + def __len__(self): + return len(self.data_point) + + def __eq__(self, other): + return ( + self.value == other.value + and self.data_point == other.data_point + and self.concept_name == other.concept_name + and self.identifier == other.identifier + and self.database_name == other.database_name + and self.score == other.score + ) + + @property + def identifier(self): + return f"{self.value}" + + DT = typing.TypeVar("DT", bound=DataPoint) DT2 = typing.TypeVar("DT2", bound=DataPoint) diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index e549810d1..c90f36f86 100644 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -37,6 +37,8 @@ CLL, CRAFT, CRAFT_V4, + CTD_CHEMICALS_DICTIONARY, + CTD_DISEASES_DICTIONARY, DECA, FSU, GELLUS, @@ -90,6 +92,8 @@ LOCTEXT, MIRNA, NCBI_DISEASE, + NCBI_GENE_HUMAN_DICTIONARY, + NCBI_TAXONOMY_DICTIONARY, OSIRIS, PDR, S800, @@ -386,6 +390,10 @@ "LINNEAUS", "LOCTEXT", "MIRNA", + "NCBI_GENE_HUMAN_DICTIONARY", + "NCBI_TAXONOMY_DICTIONARY", + "CTD_DISEASES_DICTIONARY", + "CTD_CHEMICALS_DICTIONARY", "NCBI_DISEASE", "ONTONOTES", "OSIRIS", diff --git a/flair/datasets/biomedical.py b/flair/datasets/biomedical.py index 0a5141942..e3eda0ce6 100644 --- a/flair/datasets/biomedical.py +++ b/flair/datasets/biomedical.py @@ -10,15 +10,9 @@ from copy import copy from operator import attrgetter from pathlib import Path -from tarfile import ( - CompressionError, - ExtractError, - HeaderError, - ReadError, - StreamError, - TarError, -) -from typing import Dict, Iterable, List, NamedTuple, Optional, Tuple, Union +from tarfile import CompressionError, ExtractError, HeaderError, ReadError, StreamError, TarError +from typing import Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union +from warnings import warn from zipfile import BadZipFile, LargeZipFile import ftfy @@ -31,6 +25,7 @@ from flair.datasets.sequence_labeling import ColumnCorpus, ColumnDataset from flair.file_utils import Tqdm, cached_path, unpack_file from flair.splitter import ( + NewlineSentenceSplitter, NoSentenceSplitter, SciSpacySentenceSplitter, SentenceSplitter, @@ -178,10 +173,7 @@ def filter_nested_entities(dataset: InternalBioNerDataset) -> None: num_entities_after = sum([len(x) for x in dataset.entities_per_document.values()]) if num_entities_before != num_entities_after: removed = num_entities_before - num_entities_after - logger.warning( - f"WARNING: Corpus modified by filtering nested entities. " - f"Removed {removed} entities. Keep {num_entities_after} entities." - ) + warn(f"Corpus modified by filtering nested entities. Removed {removed} entities.") def bioc_to_internal(bioc_file: Path): @@ -395,7 +387,7 @@ def write_to_conll(self, dataset: InternalBioNerDataset, output_file: Path): tag = "O" in_entity = False - whitespace_after = "+" if flair_token.whitespace_after > 0 else "-" + whitespace_after = "+" if flair_token.whitespace_after else "-" if len(token) > 0: f.write(" ".join([token, tag, whitespace_after]) + "\n") sentence_had_tokens = True @@ -460,9 +452,9 @@ def __init__( self.sentence_splitter = sentence_splitter if sentence_splitter else SciSpacySentenceSplitter() else: if sentence_splitter: - logger.warning( - f"WARNING: The corpus {self.__class__.__name__} has a pre-defined sentence splitting, " - f"thus just the tokenizer of the given sentence splitter is used" + warn( + f"The corpus {self.__class__.__name__} has a pre-defined sentence splitting, " + f"thus just the tokenizer of the given sentence splitter ist used" ) self.sentence_splitter.tokenizer = sentence_splitter.tokenizer @@ -494,6 +486,7 @@ def __init__( dev_file=dev_file.name, test_file=test_file.name, column_format=columns, + tag_to_bioes="ner", in_memory=in_memory, ) @@ -510,6 +503,404 @@ def get_subset(self, dataset: InternalBioNerDataset, split: str, split_dir: Path ) +class AbstractBiomedicalEntityLinkingDictionary(ABC): + """Base class for downloading and reading of dictionaries for named entity linking. + + A dictionary represents all entities of a knowledge base and their associated ids. + """ + + def __init__( + self, + base_path: Union[str, Path] = None, + ): + """:param base_path: Path to the corpus on your machine""" + if base_path is None: + base_path = flair.cache_root / "datasets" + else: + base_path = Path(base_path) + + # this dataset name + dataset_name = self.__class__.__name__.lower() + data_folder = base_path / dataset_name + self.dataset_file = data_folder / f"{dataset_name}_parsed.txt" + + # check if there is a parsed_dict file in cache + if not self.dataset_file.exists(): + logger.info("Preprocess and cache dictionary `%s` file: %s", dataset_name, self.dataset_file) + data_file = self.download_dictionary(data_folder) + + with open(self.dataset_file, "w", encoding="utf-8") as f: + for cui, name in self.parse_dictionary(data_file): + f.write(f"{cui}||{name}\n") + + @property + @abstractmethod + def database_name(self) -> str: + """Name of the database represented by the dictionary""" + + @abstractmethod + def download_dictionary(self, data_dir: Path) -> Path: + """Download dictionary""" + + @abstractmethod + def parse_dictionary(self, original_file: Path): + """Parse data into HunFlair format""" + + def stream(self) -> Iterator[Tuple[str, str]]: + """Stream preprocessed dictionary""" + + with open(self.dataset_file) as fp: + for line in fp: + line = line.strip() + if line == "": + continue + assert "||" in line, "Preprocessed EntityLinkingDictionary must have lines in the format: `cui||name`" + cui, name = line.split("||") + name = name.lower() + yield (name, cui) + + +class ParsedBiomedicalEntityLinkingDictionary(AbstractBiomedicalEntityLinkingDictionary): + """ + Base dictionary with data already in preprocessed format, i.e. every line in the file must + be formatted as follows: + + concept_id||concept_name + + If multiple concept ids are associated to a given name they have to be separated by a `|`, e.g. + + 7157||TP53|tumor protein p53 + """ + + def __init__(self, path: Path, database_name: str): + self.dataset_file = path + self._database_name = database_name + + @property + def database_name(self): + return self._database_name + + def download_dictionary(self): + pass + + def parse_dictionary(self): + pass + + +class CTD_DISEASES_DICTIONARY(AbstractBiomedicalEntityLinkingDictionary): + """ + Dictionary for named entity linking on diseases using the Comparative + Toxicogenomics Database (CTD). + + Fur further information can be found at https://ctdbase.org/ + """ + + def __init__( + self, + base_path: Union[str, Path] = None, + ): + """ + :param base_path: Path to the corpus on your machine""" + super(CTD_DISEASES_DICTIONARY, self).__init__(base_path=base_path) + + @property + def database_name(self): + return "CTD-DISEASES" + + def download_dictionary(self, data_dir: Path) -> Path: + data_url = "https://ctdbase.org/reports/CTD_diseases.tsv.gz" + data_path = cached_path(data_url, data_dir) + unpack_file(data_path, unpack_to=data_dir / "CTD_diseases.tsv") + + return data_dir / "CTD_diseases.tsv" + + def parse_dictionary(self, original_file: Path) -> Iterator[Tuple[str, str]]: + CTD_DISEASES_COLUMNS = [ + "symbol", + "identifier", + "alternative_identifiers", + "definition", + "parent_identifiers", + "tree_numbers", + "parent_tree_numbers", + "synonyms", + "slim_mappings", + ] + + with open(original_file, mode="r", encoding="utf-8") as f: + # parse every line + with open(original_file, mode="r", encoding="utf-8") as f: + for line in f: + if line.startswith("#"): + continue + + entries = [] + # parse line + values = line.strip().split("\t") + row = dict(zip(CTD_DISEASES_COLUMNS, values)) + + original_identifiers = [row["identifier"]] + + original_identifiers += [i for i in row.get("alternative_identifiers", "").split("|") if i != ""] + identifier = "|".join(original_identifiers) + + if original_identifiers == "MESH:C": + return None + + if row.get("symbol") is not None: + entries.append((identifier, row["symbol"])) + + synonyms = [s for s in row.get("synonyms", "").split("|") if s != ""] + + for synonym in synonyms: + entries.append((identifier, synonym)) + + for e in entries: + yield e + + +class CTD_CHEMICALS_DICTIONARY(AbstractBiomedicalEntityLinkingDictionary): + """ + Dictionary for named entity linking on chemicals using the Comparative + Toxicogenomics Database (CTD). + + Fur further information can be found at https://ctdbase.org/ + """ + + def __init__( + self, + base_path: Union[str, Path] = None, + ): + """ + :param base_path: Path to the corpus on your machine""" + super(CTD_CHEMICALS_DICTIONARY, self).__init__(base_path=base_path) + + @property + def database_name(self): + return "CTD-CHEMICALS" + + def download_dictionary(self, data_dir: Path) -> Path: + data_url = "https://ctdbase.org/reports/CTD_chemicals.tsv.gz" + data_path = cached_path(data_url, data_dir) + unpack_file(data_path, unpack_to=data_dir / "CTD_chemicals.tsv") + + return data_dir / "CTD_chemicals.tsv" + + def parse_dictionary(self, original_file: Path) -> Iterator[Tuple[str, str]]: + CTD_CHEMICALS_COLUMNS = [ + "symbol", + "identifier", + "casrn", + "definition", + "parent_identifiers", + "tree_numbers", + "parent_tree_numbers", + "synonyms", + ] + + with open(original_file, mode="r", encoding="utf-8") as f: + for line in f: + if line.startswith("#"): + continue + + entries = [] + # parse line + values = line.strip().split("\t") + row: dict = dict(zip(CTD_CHEMICALS_COLUMNS, values)) + + identifier = row["identifier"] + + if ( + row.get("symbol") is not None and row.get("symbol") != "MESH:D013749" + ): ## This MeSH ID was used by MeSH when this chemical was part of the MeSH controlled vocabulary. + entries.append((identifier, row["symbol"])) + + synonyms = [s for s in row.get("synonyms", "").split("|") if s != ""] + + for synonym in synonyms: + if synonym == row.get("symbol"): + continue + + entries.append((identifier, synonym)) + + for e in entries: + yield e + + +class NCBI_GENE_HUMAN_DICTIONARY(AbstractBiomedicalEntityLinkingDictionary): + """ + Dictionary for named entity linking on diseases using the NCBI Gene ontology. + + Note that this dictionary only represents human genes - gene from different species + aren't included! + + Fur further information can be found at https://www.ncbi.nlm.nih.gov/gene/ + """ + + def __init__( + self, + base_path: Union[str, Path] = None, + ): + """ + :param base_path: Path to the corpus on your machine""" + super(NCBI_GENE_HUMAN_DICTIONARY, self).__init__(base_path=base_path) + + def _is_invalid_name(self, name: str) -> bool: + """ + Determine if a name should be skipped + """ + EMPTY_ENTRY_TEXT = [ + "when different from all specified ones in Gene.", + "Record to support submission of GeneRIFs for a gene not in Gene", + ] + + newentry = name == "NEWENTRY" + empty = name == "" + text_comment = any(e in name for e in EMPTY_ENTRY_TEXT) + + return any([newentry, empty, text_comment]) + + @property + def database_name(self): + return "NCBI-GENE-HUMAN" + + def download_dictionary(self, data_dir: Path) -> Path: + data_url = "https://ftp.ncbi.nih.gov/gene/DATA/GENE_INFO/Mammalia/Homo_sapiens.gene_info.gz" + data_path = cached_path(data_url, data_dir) + unpack_file(data_path, unpack_to=data_dir / "Homo_sapiens.gene_info") + + return data_dir / "Homo_sapiens.gene_info" + + def parse_dictionary(self, original_file: Path) -> Iterator[Tuple[str, str]]: + NCBI_GENE_SYNONYMS_FIELDS = tuple( + [ + "Symbol_from_nomenclature_authority", + "Full_name_from_nomenclature_authority", + "description", + "Synonyms", + "Other_designations", + ] + ) + + with open(original_file, mode="r", encoding="utf-8") as f: + line = f.readline() + header = line.strip().split("\t") + for line in f: + if line.startswith("#"): + continue + + entries = [] + # parse line + values = line.strip().split("\t") + row: dict = dict(zip(header, values)) + + # parse row + identifier = row["GeneID"] + symbol = row["Symbol"] + + if not self._is_invalid_name(symbol): + entries.append((str(identifier), symbol)) + + # get synonyms + synonyms = set() + + for field in NCBI_GENE_SYNONYMS_FIELDS: + names = row.get(field, "-") + if names in ["-", symbol]: + continue + + names_list = [name.replace("'", "") for name in names.split("|")] + names_list = [n for n in names_list if not self._is_invalid_name(n)] + + synonyms.update(names_list) + + for name in synonyms: + entries.append((str(identifier), name)) + + for e in entries: + yield e + + +class NCBI_TAXONOMY_DICTIONARY(AbstractBiomedicalEntityLinkingDictionary): + """ + Dictionary for named entity linking on organisms / species using the NCBI taxonomy ontology. + + Further information about the ontology can be found at https://www.ncbi.nlm.nih.gov/taxonomy + """ + + def __init__( + self, + base_path: Union[str, Path] = None, + ): + """ + :param base_path: Path to the corpus on your machine""" + super(NCBI_TAXONOMY_DICTIONARY, self).__init__(base_path=base_path) + + @property + def database_name(self): + return "NCBI-TAXONOMY" + + def download_dictionary(self, data_dir: Path) -> Path: + data_url = "https://ftp.ncbi.nih.gov/pub/taxonomy/new_taxdump/new_taxdump.tar.gz" + data_path = cached_path(data_url, data_dir) + unpack_file(data_path, data_dir) + + return data_dir / "names.dmp" + + def parse_dictionary(self, original_file: Path) -> Iterator[Tuple[str, str]]: + NCBI_TAXONOMY_SYNSET = [ + "genbank common name", + "common name", + "scientific name", + "equivalent name", + "synonym", + "acronym", + "blast name", + "genbank", + "genbank synonym", + "genbank acronym", + "includes", + "type material", + ] + + with open(original_file, mode="r", encoding="utf-8") as f: + curr_identifier = None + names = [] + + for line in f: + # parse line + parsed_line = {} + elements = [e.strip() for e in line.strip().split("|")] + parsed_line["identifier"] = elements[0] + parsed_line["name"] = elements[1] if elements[2] == "" else elements[2] + parsed_line["field"] = elements[3] + + if parsed_line["name"] in ["all", "root"]: + continue + + if parsed_line["field"] in ["authority", "in-part", "type material"]: + continue + + if parsed_line["field"] not in NCBI_TAXONOMY_SYNSET: + raise ValueError(f"Field {parsed_line['field']} unknown!") + + if curr_identifier is None: + curr_identifier = parsed_line["identifier"] + + if curr_identifier == parsed_line["identifier"]: + synonym = parsed_line["name"] + names.append(synonym) + + elif curr_identifier != parsed_line["identifier"]: + for name in names: + yield (curr_identifier, name) + + curr_identifier = parsed_line["identifier"] + names = [] + synonym = parsed_line["name"] + names.append(synonym) + + class BIO_INFER(ColumnCorpus): """Original BioInfer corpus. @@ -539,28 +930,25 @@ def __init__( data_folder = base_path / dataset_name train_file = data_folder / "train.conll" - test_file = data_folder / "test.conll" - if not (train_file.exists() and test_file.exists()): + if not (train_file.exists()): corpus_folder = self.download_dataset(data_folder) - sentence_splitter = NoSentenceSplitter(tokenizer=SpaceTokenizer()) + corpus_data = self.parse_dataset(corpus_folder) - train_data = self.parse_dataset(corpus_folder / "BioInfer-train.xml") - test_data = self.parse_dataset(corpus_folder / "BioInfer-test.xml") + sentence_splitter = NoSentenceSplitter(tokenizer=SpaceTokenizer()) conll_writer = CoNLLWriter(sentence_splitter=sentence_splitter) - conll_writer.write_to_conll(train_data, train_file) - conll_writer.write_to_conll(test_data, test_file) + conll_writer.write_to_conll(corpus_data, train_file) super().__init__(data_folder, columns, in_memory=in_memory) @classmethod def download_dataset(cls, data_dir: Path) -> Path: - data_url = "https://github.com/metalrt/ppi-dataset/archive/refs/heads/master.zip" + data_url = "http://mars.cs.utu.fi/BioInfer/files/BioInfer_corpus_1.1.1.zip" data_path = cached_path(data_url, data_dir) unpack_file(data_path, data_dir) - return data_dir / "ppi-dataset-master/csv_output" + return data_dir / "BioInfer_corpus_1.1.1.xml" @classmethod def parse_dataset(cls, original_file: Path): @@ -571,19 +959,66 @@ def parse_dataset(cls, original_file: Path): sentence_elems = tree.xpath("//sentence") for s_id, sentence in enumerate(sentence_elems): sentence_id = str(s_id) - documents[sentence_id] = sentence.attrib["text"] + token_id_to_span = {} + sentence_text = "" entities_per_document[sentence_id] = [] - for entity in sentence.xpath(".//entity"): - char_offsets = re.split("-|,", entity.attrib["charOffset"]) - start_token = int(char_offsets[0]) - end_token = int(char_offsets[-1]) - entities_per_document[sentence_id].append( - Entity( - char_span=(start_token, end_token), - entity_type=entity.attrib["type"], + for token in sentence.xpath(".//token"): + token_text = "".join(token.xpath(".//subtoken/@text")) + token_id = ".".join(token.attrib["id"].split(".")[1:]) + + if not sentence_text: + token_id_to_span[token_id] = (0, len(token_text)) + sentence_text = token_text + else: + token_id_to_span[token_id] = ( + len(sentence_text) + 1, + len(token_text) + len(sentence_text) + 1, ) - ) + sentence_text += " " + token_text + documents[sentence_id] = sentence_text + + entities = [ + e for e in sentence.xpath(".//entity") if not e.attrib["type"].isupper() + ] # all caps entity type apparently marks event trigger + + for entity in entities: + token_nums = [] + entity_character_starts = [] + entity_character_ends = [] + + for subtoken in entity.xpath(".//nestedsubtoken"): + token_id_parts = subtoken.attrib["id"].split(".") + token_id = ".".join(token_id_parts[1:3]) + + token_nums.append(int(token_id_parts[2])) + entity_character_starts.append(token_id_to_span[token_id][0]) + entity_character_ends.append(token_id_to_span[token_id][1]) + + if token_nums and entity_character_starts and entity_character_ends: + entity_tokens = list(zip(token_nums, entity_character_starts, entity_character_ends)) + + start_token = entity_tokens[0] + last_entity_token = entity_tokens[0] + for entity_token in entity_tokens[1:]: + if not (entity_token[0] - 1) == last_entity_token[0]: + entities_per_document[sentence_id].append( + Entity( + char_span=(start_token[1], last_entity_token[2]), + entity_type=entity.attrib["type"], + ) + ) + start_token = entity_token + + last_entity_token = entity_token + + if start_token: + entities_per_document[sentence_id].append( + Entity( + char_span=(start_token[1], last_entity_token[2]), + entity_type=entity.attrib["type"], + ) + ) return InternalBioNerDataset(documents=documents, entities_per_document=entities_per_document) @@ -599,22 +1034,17 @@ def split_url() -> str: return "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/bioinfer" def to_internal(self, data_dir: Path) -> InternalBioNerDataset: - corpus_folder = BIO_INFER.download_dataset(data_dir) - train_data = BIO_INFER.parse_dataset(corpus_folder / "BioInfer-train.xml") - test_data = BIO_INFER.parse_dataset(corpus_folder / "BioInfer-test.xml") + original_file = BIO_INFER.download_dataset(data_dir) + corpus = BIO_INFER.parse_dataset(original_file) entity_type_mapping = { "Individual_protein": GENE_TAG, "Gene/protein/RNA": GENE_TAG, "Gene": GENE_TAG, "DNA_family_or_group": GENE_TAG, - "Protein_family_or_group": GENE_TAG, } - train_data = filter_and_map_entities(train_data, entity_type_mapping) - test_data = filter_and_map_entities(test_data, entity_type_mapping) - - return merge_datasets([train_data, test_data]) + return filter_and_map_entities(corpus, entity_type_mapping) @deprecated(version="0.13", reason="Please use data set implementation from BigBio instead (see BIGBIO_NER_CORPUS)") @@ -653,9 +1083,9 @@ def __init__(self, base_path: Optional[Union[str, Path]] = None, in_memory: bool train_data_path = cached_path(train_data_url, download_dir) unpack_file(train_data_path, download_dir) - test_data_url = "http://www.nactem.ac.uk/GENIA/current/Shared-tasks/JNLPBA/Evaluation/Genia4ERtest.tar.gz" - test_data_path = cached_path(test_data_url, download_dir) - unpack_file(test_data_path, download_dir) + train_data_url = "http://www.nactem.ac.uk/GENIA/current/Shared-tasks/JNLPBA/Evaluation/Genia4ERtest.tar.gz" + train_data_path = cached_path(train_data_url, download_dir) + unpack_file(train_data_path, download_dir) train_file = download_dir / "Genia4ERtask2.iob2" shutil.copy(train_file, data_folder / "train.conll") @@ -666,6 +1096,7 @@ def __init__(self, base_path: Optional[Union[str, Path]] = None, in_memory: bool super().__init__( data_folder, columns, + tag_to_bioes="ner", in_memory=in_memory, comment_symbol="#", ) @@ -1187,7 +1618,7 @@ class KaewphanCorpusHelper: @staticmethod def download_cll_dataset(data_folder: Path): - data_url = "https://github.com/hu-ner/hunflair-corpora/raw/main/cll/CLL_corpus.tar.gz" + data_url = "http://bionlp-www.utu.fi/cell-lines/CLL_corpus.tar.gz" data_path = cached_path(data_url, data_folder) unpack_file(data_path, data_folder) @@ -1232,7 +1663,7 @@ def prepare_and_save_dataset(nersuite_folder: Path, output_file: Path): @staticmethod def download_gellus_dataset(data_folder: Path): - data_url = "https://github.com/hu-ner/hunflair-corpora/raw/main/gellus/Gellus_corpus.tar.gz" + data_url = "http://bionlp-www.utu.fi/cell-lines/Gellus_corpus.tar.gz" data_path = cached_path(data_url, data_folder) unpack_file(data_path, data_folder) @@ -1606,6 +2037,7 @@ def __init__( super().__init__(data_folder, columns, in_memory=in_memory) + @staticmethod def download_dataset(data_dir: Path): data_url = "https://biocreative.bioinformatics.udel.edu/media/store/files/2014/chemdner_corpus.tar.gz" @@ -1616,7 +2048,8 @@ def download_dataset(data_dir: Path): class HUNER_CHEMICAL_CHEMDNER(HunerDataset): """HUNER version of the CHEMDNER corpus containing chemical annotations.""" - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args, download_folder=None, **kwargs) -> None: + self.download_folder = download_folder or CHEMDNER.default_dir / "original" super().__init__(*args, **kwargs) @staticmethod @@ -1624,11 +2057,11 @@ def split_url() -> str: return "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/chemdner" def to_internal(self, data_dir: Path) -> InternalBioNerDataset: - os.makedirs(str(data_dir), exist_ok=True) - CHEMDNER.download_dataset(data_dir) - train_data = bioc_to_internal(data_dir / "chemdner_corpus" / "training.bioc.xml") - dev_data = bioc_to_internal(data_dir / "chemdner_corpus" / "development.bioc.xml") - test_data = bioc_to_internal(data_dir / "chemdner_corpus" / "evaluation.bioc.xml") + os.makedirs(str(self.download_folder), exist_ok=True) + CHEMDNER.download_dataset(self.download_folder) + train_data = bioc_to_internal(self.download_folder / "chemdner_corpus" / "training.bioc.xml") + dev_data = bioc_to_internal(self.download_folder / "chemdner_corpus" / "development.bioc.xml") + test_data = bioc_to_internal(self.download_folder / "chemdner_corpus" / "evaluation.bioc.xml") all_data = merge_datasets([train_data, dev_data, test_data]) all_data = filter_and_map_entities( all_data, @@ -1660,11 +2093,14 @@ def __init__( self, base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, + tokenizer: Tokenizer = None, ) -> None: """Initialize the IEPA corpus. :param base_path: Path to the corpus on your machine :param in_memory: If True, keeps dataset in memory giving speedups in training. + :param tokenizer: Custom implementation of :class:`Tokenizer` which + segments sentences into tokens (default :class:`SciSpacyTokenizer`) """ base_path = flair.cache_root / "datasets" if base_path is None else Path(base_path) @@ -1676,62 +2112,31 @@ def __init__( data_folder = base_path / dataset_name - train_file = data_folder / "train.conll" - test_file = data_folder / "test.conll" + if tokenizer is None: + tokenizer = SciSpacyTokenizer() - if not (train_file.exists() and test_file.exists()): - corpus_folder = self.download_dataset(data_folder) - sentence_splitter = NoSentenceSplitter(tokenizer=SpaceTokenizer()) + sentence_splitter = NewlineSentenceSplitter(tokenizer=tokenizer) - train_data = self.parse_dataset(corpus_folder / "IEPA-train.xml") - test_data = self.parse_dataset(corpus_folder / "IEPA-test.xml") + train_file = data_folder / f"{sentence_splitter.name}_train.conll" + + if not (train_file.exists()): + download_dir = data_folder / "original" + os.makedirs(download_dir, exist_ok=True) + self.download_dataset(download_dir) + + all_data = bioc_to_internal(download_dir / "iepa_bioc.xml") conll_writer = CoNLLWriter(sentence_splitter=sentence_splitter) - conll_writer.write_to_conll(train_data, train_file) - conll_writer.write_to_conll(test_data, test_file) + conll_writer.write_to_conll(all_data, train_file) super().__init__(data_folder, columns, in_memory=in_memory) @staticmethod def download_dataset(data_dir: Path): - data_url = "https://github.com/metalrt/ppi-dataset/archive/refs/heads/master.zip" + data_url = "http://corpora.informatik.hu-berlin.de/corpora/brat2bioc/iepa_bioc.xml.zip" data_path = cached_path(data_url, data_dir) unpack_file(data_path, data_dir) - return data_dir / "ppi-dataset-master/csv_output" - - @classmethod - def parse_dataset(cls, original_file: Path): - documents: Dict[str, str] = {} - entities_per_document: Dict[str, List[Entity]] = {} - - tree = etree.parse(str(original_file)) - document_elems = tree.xpath("//document") - for document in document_elems: - document_id = "_".join(document.attrib["id"].split(".")) - document_text = "" - entities_per_document[document_id] = [] - sentence_elems = document.xpath(".//sentence") - for sentence in sentence_elems: - sentence_text = sentence.attrib["text"] - if document_text == "": - document_text = sentence_text - else: - document_text += " " + sentence_text - for entity in sentence.xpath(".//entity"): - char_offsets = re.split("-|,", entity.attrib["charOffset"]) - start_token = int(char_offsets[0]) - end_token = int(char_offsets[-1]) - entities_per_document[document_id].append( - Entity( - char_span=(start_token, end_token), - entity_type="Protein", - ) - ) - documents[document_id] = document_text - - return InternalBioNerDataset(documents=documents, entities_per_document=entities_per_document) - class HUNER_GENE_IEPA(HunerDataset): """HUNER version of the IEPA corpus containing gene annotations.""" @@ -1743,17 +2148,17 @@ def __init__(self, *args, **kwargs) -> None: def split_url() -> str: return "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/iepa" - def to_internal(self, data_dir: Path) -> InternalBioNerDataset: - corpus_folder = IEPA.download_dataset(data_dir) - train_data = IEPA.parse_dataset(corpus_folder / "IEPA-train.xml") - test_data = IEPA.parse_dataset(corpus_folder / "IEPA-test.xml") + def get_corpus_sentence_splitter(self) -> SentenceSplitter: + return NewlineSentenceSplitter(tokenizer=SciSpacyTokenizer()) - entity_type_mapping = {"Protein": GENE_TAG} + def to_internal(self, data_dir: Path) -> InternalBioNerDataset: + os.makedirs(str(data_dir), exist_ok=True) + IEPA.download_dataset(data_dir) - train_data = filter_and_map_entities(train_data, entity_type_mapping) - test_data = filter_and_map_entities(test_data, entity_type_mapping) + all_data = bioc_to_internal(data_dir / "iepa_bioc.xml") + all_data = filter_and_map_entities(all_data, {"Protein": GENE_TAG}) - return merge_datasets([train_data, test_data]) + return all_data @deprecated(version="0.13", reason="Please use data set implementation from BigBio instead (see BIGBIO_NER_CORPUS)") @@ -1805,7 +2210,7 @@ def __init__( @staticmethod def download_and_parse_dataset(data_dir: Path): - data_url = "https://sourceforge.net/projects/linnaeus/files/Corpora/manual-corpus-species-1.0.tar.gz" + data_url = "https://iweb.dl.sourceforge.net/project/linnaeus/Corpora/manual-corpus-species-1.0.tar.gz" data_path = cached_path(data_url, data_dir) unpack_file(data_path, data_dir) @@ -2014,7 +2419,7 @@ def __init__( @staticmethod def download_dataset(data_dir: Path): - data_url = "https://github.com/hu-ner/hunflair-corpora/raw/main/variome/hvp_bioc.xml.zip" + data_url = "http://corpora.informatik.hu-berlin.de/corpora/brat2bioc/hvp_bioc.xml.zip" data_path = cached_path(data_url, data_dir) unpack_file(data_path, data_dir) @@ -2517,11 +2922,11 @@ def __init__( @classmethod def download_dataset(cls, data_dir: Path) -> Path: - url = "https://github.com/hu-ner/hunflair-corpora/raw/main/osiris/OSIRIScorpusv02.tar" + url = "http://ibi.imim.es/OSIRIScorpusv02.tar" data_path = cached_path(url, data_dir) unpack_file(data_path, data_dir) - return data_dir + return data_dir / "OSIRIScorpusv02" @classmethod def parse_dataset(cls, corpus_folder: Path, fix_annotation=True): @@ -2577,7 +2982,7 @@ def split_url() -> str: def to_internal(self, data_dir: Path) -> InternalBioNerDataset: original_file = OSIRIS.download_dataset(data_dir) - corpus = OSIRIS.parse_dataset(original_file / "OSIRIScorpusv02") + corpus = OSIRIS.parse_dataset(original_file) entity_type_mapping = {"ge": GENE_TAG} return filter_and_map_entities(corpus, entity_type_mapping) @@ -3817,6 +4222,7 @@ def __init__( super().__init__(data_folder, columns, in_memory=in_memory) + @staticmethod @abstractmethod def download_corpus(data_folder: Path) -> Tuple[Path, Path, Path]: @@ -3900,19 +4306,35 @@ class BIONLP2013_CG(BioNLPCorpus): @staticmethod def download_corpus(download_folder: Path) -> Tuple[Path, Path, Path]: - url = "https://github.com/openbiocorpora/bionlp-st-2013-cg/archive/refs/heads/master.zip" + train_url = "http://2013.bionlp-st.org/tasks/BioNLP-ST_2013_CG_training_data.tar.gz" + dev_url = "http://2013.bionlp-st.org/tasks/BioNLP-ST_2013_CG_development_data.tar.gz" + test_url = "http://2013.bionlp-st.org/tasks/BioNLP-ST_2013_CG_test_data.tar.gz" + + download_folder = download_folder / "original" - cached_path(url, download_folder) + cached_path(train_url, download_folder) + cached_path(dev_url, download_folder) + cached_path(test_url, download_folder) unpack_file( - download_folder / "master.zip", + download_folder / "BioNLP-ST_2013_CG_training_data.tar.gz", + download_folder, + keep=False, + ) + unpack_file( + download_folder / "BioNLP-ST_2013_CG_development_data.tar.gz", + download_folder, + keep=False, + ) + unpack_file( + download_folder / "BioNLP-ST_2013_CG_test_data.tar.gz", download_folder, keep=False, ) - train_folder = download_folder / "bionlp-st-2013-cg-master/original-data/train" - dev_folder = download_folder / "bionlp-st-2013-cg-master/original-data/devel" - test_folder = download_folder / "bionlp-st-2013-cg-master/original-data/test" + train_folder = download_folder / "BioNLP-ST_2013_CG_training_data" + dev_folder = download_folder / "BioNLP-ST_2013_CG_development_data" + test_folder = download_folder / "BioNLP-ST_2013_CG_test_data" return train_folder, dev_folder, test_folder @@ -3973,6 +4395,7 @@ def __init__( super().__init__(data_folder, columns, in_memory=in_memory) + @staticmethod @abstractmethod def download_corpus(data_folder: Path): @@ -4132,6 +4555,7 @@ def __init__(self, base_path: Optional[Union[str, Path]] = None, in_memory: bool BioBertHelper.download_corpora(common_path) BioBertHelper.convert_and_write(common_path / "BC4CHEMD", data_folder, tag_type=CHEMICAL_TAG) + super().__init__(data_folder, columns, in_memory=in_memory) @@ -4162,6 +4586,7 @@ def __init__(self, base_path: Optional[Union[str, Path]] = None, in_memory: bool if not (common_path / "BC2GM").exists(): BioBertHelper.download_corpora(common_path) BioBertHelper.convert_and_write(common_path / "BC2GM", data_folder, tag_type=GENE_TAG) + super().__init__(data_folder, columns, in_memory=in_memory) @@ -4192,9 +4617,11 @@ def __init__(self, base_path: Optional[Union[str, Path]] = None, in_memory: bool if not (common_path / "JNLPBA").exists(): BioBertHelper.download_corpora(common_path) BioBertHelper.convert_and_write(common_path / "JNLPBA", data_folder, tag_type=GENE_TAG) + super().__init__(data_folder, columns, in_memory=in_memory) + class BIOBERT_CHEMICAL_BC5CDR(ColumnCorpus): """BC5CDR corpus with chemical annotations as used in the evaluation of BioBERT. @@ -4222,9 +4649,11 @@ def __init__(self, base_path: Optional[Union[str, Path]] = None, in_memory: bool if not (common_path / "BC5CDR-chem").exists(): BioBertHelper.download_corpora(common_path) BioBertHelper.convert_and_write(common_path / "BC5CDR-chem", data_folder, tag_type=CHEMICAL_TAG) + super().__init__(data_folder, columns, in_memory=in_memory) + class BIOBERT_DISEASE_BC5CDR(ColumnCorpus): """BC5CDR corpus with disease annotations as used in the evaluation of BioBERT. @@ -4252,9 +4681,11 @@ def __init__(self, base_path: Optional[Union[str, Path]] = None, in_memory: bool if not (common_path / "BC5CDR-disease").exists(): BioBertHelper.download_corpora(common_path) BioBertHelper.convert_and_write(common_path / "BC5CDR-disease", data_folder, tag_type=DISEASE_TAG) + super().__init__(data_folder, columns, in_memory=in_memory) + class BIOBERT_DISEASE_NCBI(ColumnCorpus): """NCBI disease corpus as used in the evaluation of BioBERT. @@ -4282,9 +4713,11 @@ def __init__(self, base_path: Optional[Union[str, Path]] = None, in_memory: bool if not (common_path / "NCBI-disease").exists(): BioBertHelper.download_corpora(common_path) BioBertHelper.convert_and_write(common_path / "NCBI-disease", data_folder, tag_type=DISEASE_TAG) + super().__init__(data_folder, columns, in_memory=in_memory) + class BIOBERT_SPECIES_LINNAEUS(ColumnCorpus): """Linneaeus corpus with species annotations as used in the evaluation of BioBERT. @@ -4312,9 +4745,11 @@ def __init__(self, base_path: Optional[Union[str, Path]] = None, in_memory: bool if not (common_path / "linnaeus").exists(): BioBertHelper.download_corpora(common_path) BioBertHelper.convert_and_write(common_path / "linnaeus", data_folder, tag_type=SPECIES_TAG) + super().__init__(data_folder, columns, in_memory=in_memory) + class BIOBERT_SPECIES_S800(ColumnCorpus): """S800 corpus with species annotations as used in the evaluation of BioBERT. @@ -4342,6 +4777,7 @@ def __init__(self, base_path: Optional[Union[str, Path]] = None, in_memory: bool if not (common_path / "s800").exists(): BioBertHelper.download_corpora(common_path) BioBertHelper.convert_and_write(common_path / "s800", data_folder, tag_type=SPECIES_TAG) + super().__init__(data_folder, columns, in_memory=in_memory) @@ -4841,6 +5277,7 @@ def entity_type_predicate(member): corpus = constructor_func(sentence_splitter=sentence_splitter) self.huner_corpora.append(corpus) + except (CompressionError, ExtractError, HeaderError, ReadError, StreamError, TarError): logger.exception( f"Error while processing Tar file from corpus {name}:\n{sys.exc_info()[1]}\n\n", exc_info=False diff --git a/flair/embeddings/document.py b/flair/embeddings/document.py index 2110e13bb..1f566554e 100644 --- a/flair/embeddings/document.py +++ b/flair/embeddings/document.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path from typing import Any, Dict, List, Optional, Union, cast import torch @@ -29,7 +30,7 @@ class TransformerDocumentEmbeddings(DocumentEmbeddings, TransformerEmbeddings): def __init__( self, - model: str = "bert-base-uncased", # set parameters with different default values + model: Union[str, Path] = "bert-base-uncased", # set parameters with different default values layers: str = "-1", layer_mean: bool = False, is_token_embedding: bool = False, diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py new file mode 100644 index 000000000..78434c784 --- /dev/null +++ b/flair/models/biomedical_entity_linking.py @@ -0,0 +1,1264 @@ +import logging +import os +import re +import stat +import string +import subprocess +import tempfile +import time +import warnings +from abc import ABC, abstractmethod +from collections import defaultdict +from enum import Enum, auto +from pathlib import Path +from typing import Dict, Iterator, List, Optional, Tuple, Type, Union, cast + +import joblib +import numpy as np +import scipy +import torch +from huggingface_hub import hf_hub_download +from scipy.sparse import csr_matrix +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity +from tqdm import tqdm + +import flair +from flair.data import EntityLinkingCandidate, EntityLinkingLabel, Label, Sentence, Span +from flair.datasets import ( + CTD_CHEMICALS_DICTIONARY, + CTD_DISEASES_DICTIONARY, + NCBI_GENE_HUMAN_DICTIONARY, + NCBI_TAXONOMY_DICTIONARY, +) +from flair.datasets.biomedical import ( + AbstractBiomedicalEntityLinkingDictionary, + ParsedBiomedicalEntityLinkingDictionary, +) +from flair.embeddings import TransformerDocumentEmbeddings +from flair.file_utils import cached_path + +FAISS_VERSION = "1.7.4" + +try: + import faiss +except ImportError as error: + raise ImportError( + f"You need to install faiss to run the biomedical entity linking: `pip install faiss-cpu=={FAISS_VERSION}`" + ) from error + + +logger = logging.getLogger("flair") + + +PRETRAINED_DENSE_MODELS = [ + "cambridgeltl/SapBERT-from-PubMedBERT-fulltext", +] + +# Dense + sparse retrieval +PRETRAINED_HYBRID_MODELS = { + "dmis-lab/biosyn-sapbert-bc5cdr-disease": "disease", + "dmis-lab/biosyn-sapbert-ncbi-disease": "disease", + "dmis-lab/biosyn-sapbert-bc5cdr-chemical": "chemical", + "dmis-lab/biosyn-biobert-bc5cdr-disease": "disease", + "dmis-lab/biosyn-biobert-ncbi-disease": "disease", + "dmis-lab/biosyn-biobert-bc5cdr-chemical": "chemical", + "dmis-lab/biosyn-biobert-bc2gn": "gene", + "dmis-lab/biosyn-sapbert-bc2gn": "gene", +} + +PRETRAINED_MODELS = list(PRETRAINED_HYBRID_MODELS) + PRETRAINED_DENSE_MODELS + +# just in case we add: fuzzy search, Levenstein, ... +STRING_MATCHING_MODELS = ["exact-string-match"] + +MODELS = PRETRAINED_MODELS + STRING_MATCHING_MODELS + +ENTITY_TYPES = ["disease", "chemical", "gene", "species"] + +ENTITY_TYPE_TO_LABELS = { + "disease": "diseases", + "gene": "genes", + "species": "species", + "chemical": "chemical", +} + +ENTITY_TYPE_TO_HYBRID_MODEL = { + "disease": "dmis-lab/biosyn-sapbert-bc5cdr-disease", + "chemical": "dmis-lab/biosyn-sapbert-bc5cdr-chemical", + "gene": "dmis-lab/biosyn-sapbert-bc2gn", +} + +# for now we always fall back to SapBERT, +# but we should train our own models at some point +ENTITY_TYPE_TO_DENSE_MODEL = { + entity_type: "cambridgeltl/SapBERT-from-PubMedBERT-fulltext" for entity_type in ENTITY_TYPES +} + + +ENTITY_TYPE_TO_DICTIONARY = { + "gene": "ncbi-gene", + "species": "ncbi-taxonomy", + "disease": "ctd-diseases", + "chemical": "ctd-chemicals", +} + +ENTITY_TYPE_TO_ANNOTATION_LAYER = { + "disease": "diseases", + "gene": "genes", + "chemical": "chemicals", + "species": "species", +} + +BIOMEDICAL_DICTIONARIES: Dict[str, Type] = { + "ctd-diseases": CTD_DISEASES_DICTIONARY, + "ctd-chemicals": CTD_CHEMICALS_DICTIONARY, + "ncbi-gene": NCBI_GENE_HUMAN_DICTIONARY, + "ncbi-taxonomy": NCBI_TAXONOMY_DICTIONARY, +} + +MODEL_NAME_TO_DICTIONARY = { + "dmis-lab/biosyn-sapbert-bc5cdr-disease": "ctd-disease", + "dmis-lab/biosyn-sapbert-ncbi-disease": "ctd-disease", + "dmis-lab/biosyn-sapbert-bc5cdr-chemical": "ctd-chemical", + "dmis-lab/biosyn-biobert-bc5cdr-disease": "ctd-chemical", + "dmis-lab/biosyn-biobert-ncbi-disease": "ctd-disease", + "dmis-lab/biosyn-biobert-bc5cdr-chemical": "ctd-chemical", + "dmis-lab/biosyn-biobert-bc2gn": "ncbi-gene", + "dmis-lab/biosyn-sapbert-bc2gn": "ncbi-gene", +} + + +DEFAULT_SPARSE_WEIGHT = 0.5 + + +class SimilarityMetric(Enum): + """Similarity metrics""" + + INNER_PRODUCT = faiss.METRIC_INNER_PRODUCT + # L2 = faiss.METRIC_L2 + COSINE = auto() + + +def timeit(func): + """ + This function shows the execution time of the function object passed + """ + + def wrap_func(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + elapsed = round(time.time() - start, 4) + class_name, func_name = func.__qualname__.split(".") + logger.info("%s: %s took ~%s", class_name, func_name, elapsed) + return result + + return wrap_func + + +class AbstractEntityPreprocessor(ABC): + """ + A pre-processor used to transform / clean both entity mentions and entity names + This class provides the basic interface for such transformations + and must provide a `name` attribute to uniquely identify the type of preprocessing applied. + """ + + @property + @abstractmethod + def name(self) -> str: + """ + This is needed to correctly cache different multiple version of the dictionary + """ + + @abstractmethod + def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: + """ + Processes the given entity mention and applies the transformation procedure to it. + + :param entity_mention: entity mention under investigation + :param sentence: sentence in which the entity mentioned occurred + :result: Cleaned / transformed string representation of the given entity mention + """ + + @abstractmethod + def process_entity_name(self, entity_name: str) -> str: + """ + Processes the given entity name (originating from a knowledge base / ontology) and + applies the transformation procedure to it. + + :param entity_name: entity mention given as DataPoint + :result: Cleaned / transformed string representation of the given entity mention + """ + + @abstractmethod + def initialize(self, sentences: List[Sentence]): + """ + Initializes the pre-processor for a batch of sentences, which is may be necessary for + more sophisticated transformations. + + :param sentences: List of sentences that will be processed. + """ + + +class EntityPreprocessor(AbstractEntityPreprocessor): + """ + Entity preprocessor adapted from: + Sung et al. 2020, Biomedical Entity Representations with Synonym Marginalization + https://github.com/dmis-lab/BioSyn/blob/master/src/biosyn/preprocesser.py#L5 + + The preprocessor provides basic string transformation options including lower-casing, + removal of punctuations symbols, etc. + """ + + def __init__(self, lowercase: bool = True, remove_punctuation: bool = True): + """ + Initializes the mention preprocessor. + + :param lowercase: Indicates whether to perform lowercasing or not (True by default) + :param remove_punctuation: Indicates whether to perform removal punctuations symbols (True by default) + """ + self.lowercase = lowercase + self.remove_punctuation = remove_punctuation + self.rmv_puncts_regex = re.compile(r"[\s{}]+".format(re.escape(string.punctuation))) + + @property + def name(self): + return "biosyn" + + def initialize(self, sentences): + pass + + def process_entity_name(self, entity_name: str) -> str: + if self.lowercase: + entity_name = entity_name.lower() + + if self.remove_punctuation: + name_parts = self.rmv_puncts_regex.split(entity_name) + entity_name = " ".join(name_parts).strip() + + return entity_name.strip() + + def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: + return self.process_entity_name(entity_mention.data_point.text) + + +class Ab3PEntityPreprocessor(AbstractEntityPreprocessor): + """ + Entity preprocessor which uses Ab3P, an (biomedical) abbreviation definition detector: + Abbreviation definition identification based on automatic precision estimates. + Sohn S, Comeau DC, Kim W, Wilbur WJ. BMC Bioinformatics. 2008 Sep 25;9:402. + PubMed ID: 18817555 + https://github.com/ncbi-nlp/Ab3P + """ + + def __init__( + self, ab3p_path: Path, word_data_dir: Path, preprocessor: Optional[AbstractEntityPreprocessor] = None + ) -> None: + """ + Creates the mention pre-processor + + :param ab3p_path: Path to the folder containing the Ab3P implementation + :param word_data_dir: Path to the word data directory + :param preprocessor: Basic entity preprocessor + """ + self.ab3p_path = ab3p_path + self.word_data_dir = word_data_dir + self.preprocessor = preprocessor + self.abbreviation_dict: Dict[str, Dict[str, str]] = {} + + @property + def name(self): + return f"ab3p_{self.preprocessor.name}" + + def initialize(self, sentences: List[Sentence]) -> None: + self.abbreviation_dict = self._build_abbreviation_dict(sentences) + + def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: + sentence_text = sentence.to_tokenized_string().strip() + tokens = [token.text for token in cast(Span, entity_mention.data_point).tokens] + + parsed_tokens = [] + for token in tokens: + if self.preprocessor is not None: + token = self.preprocessor.process_entity_name(token) + + if sentence_text in self.abbreviation_dict: + if token.lower() in self.abbreviation_dict[sentence_text]: + parsed_tokens.append(self.abbreviation_dict[sentence_text][token.lower()]) + continue + + if len(token) != 0: + parsed_tokens.append(token) + + return " ".join(parsed_tokens) + + def process_entity_name(self, entity_name: str) -> str: + # Ab3P works on sentence-level and not on a single entity mention / name + # - so we just apply the wrapped text pre-processing here (if configured) + if self.preprocessor is not None: + return self.preprocessor.process_entity_name(entity_name) + + return entity_name + + @classmethod + def load(cls, ab3p_path: Path = None, preprocessor: Optional[AbstractEntityPreprocessor] = None): + data_dir = flair.cache_root / "ab3p" + if not data_dir.exists(): + data_dir.mkdir(parents=True) + + word_data_dir = data_dir / "word_data" + if not word_data_dir.exists(): + word_data_dir.mkdir() + + if ab3p_path is None: + ab3p_path = cls.download_ab3p(data_dir, word_data_dir) + + return cls(ab3p_path, word_data_dir, preprocessor) + + @classmethod + def download_ab3p(cls, data_dir: Path, word_data_dir: Path) -> Path: + """Downloads the Ab3P tool and all necessary data files.""" + + # Download word data for Ab3P if not already downloaded + ab3p_url = "https://raw.githubusercontent.com/dmis-lab/BioSyn/master/Ab3P/WordData/" + + ab3p_files = [ + "Ab3P_prec.dat", + "Lf1chSf", + "SingTermFreq.dat", + "cshset_wrdset3.ad", + "cshset_wrdset3.ct", + "cshset_wrdset3.ha", + "cshset_wrdset3.nm", + "cshset_wrdset3.str", + "hshset_Lf1chSf.ad", + "hshset_Lf1chSf.ha", + "hshset_Lf1chSf.nm", + "hshset_Lf1chSf.str", + "hshset_stop.ad", + "hshset_stop.ha", + "hshset_stop.nm", + "hshset_stop.str", + "stop", + ] + for file in ab3p_files: + cached_path(ab3p_url + file, word_data_dir) + + # Download Ab3P executable + ab3p_path = cached_path("https://github.com/dmis-lab/BioSyn/raw/master/Ab3P/identify_abbr", data_dir) + + # Make Ab3P executable + ab3p_path.chmod(ab3p_path.stat().st_mode | stat.S_IXUSR) + return ab3p_path + + def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict[str, Dict[str, str]]: + """ + Processes the given sentences with the Ab3P tool. The function returns a (nested) dictionary + containing the abbreviations found for each sentence, e.g.: + + { + "Respiratory syncytial viruses ( RSV ) are a subgroup of the paramyxoviruses.": + {"RSV": "Respiratory syncytial viruses"}, + "Rous sarcoma virus ( RSV ) is a retrovirus.": + {"RSV": "Rous sarcoma virus"} + } + + :param sentences: list of sentences + :result abbreviation_dict: abbreviations and their resolution detected in each input sentence + """ + abbreviation_dict: Dict = defaultdict(dict) + + # Create a temp file which holds the sentences we want to process with Ab3P + with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as temp_file: + for sentence in sentences: + temp_file.write(sentence.to_tokenized_string() + "\n") + temp_file.flush() + + # Temporarily create path file in the current working directory for Ab3P + with open(os.path.join(os.getcwd(), "path_Ab3P"), "w") as path_file: + path_file.write(str(self.word_data_dir) + "/\n") + + # Run Ab3P with the temp file containing the dataset + # https://pylint.pycqa.org/en/latest/user_guide/messages/warning/subprocess-run-check.html + try: + result = subprocess.run( + [self.ab3p_path, temp_file.name], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + ) + except subprocess.CalledProcessError: + logger.error( + """The abbreviation resolver Ab3P could not be run on your system. To ensure maximum accuracy, please + install Ab3P yourself. See https://github.com/ncbi-nlp/Ab3P""" + ) + else: + line = result.stdout.decode("utf-8") + if "Path file for type cshset does not exist!" in line: + logger.error( + "Error when using Ab3P for abbreviation resolution. A file named path_Ab3p needs to exist in your current directory containing the path to the WordData directory for Ab3P to work!" + ) + elif "Cannot open" in line: + logger.error( + "Error when using Ab3P for abbreviation resolution. Could not open the WordData directory for Ab3P!" + ) + elif "failed to open" in line: + logger.error( + "Error when using Ab3P for abbreviation resolution. Could not open the WordData directory for Ab3P!" + ) + + lines = line.split("\n") + cur_sentence = None + for line in lines: + if len(line.split("|")) == 3: + if cur_sentence is None: + continue + + sf, lf, _ = line.split("|") + sf = sf.strip().lower() + lf = lf.strip().lower() + abbreviation_dict[cur_sentence][sf] = lf + + elif len(line.strip()) > 0: + cur_sentence = line + else: + cur_sentence = None + + finally: + # remove the path file + os.remove(os.path.join(os.getcwd(), "path_Ab3P")) + + return abbreviation_dict + + +class BiomedicalEntityLinkingDictionary: + """ + Class to load named entity dictionaries: either pre-defined or from a path on disk. + For the latter, every line in the file must be formatted as follows: + + concept_id||concept_name + + If multiple concept ids are associated to a given name they must be separated by a `|`, e.g. + + 7157||TP53|tumor protein p53 + """ + + def __init__( + self, reader: Union[AbstractBiomedicalEntityLinkingDictionary, ParsedBiomedicalEntityLinkingDictionary] + ): + self.reader = reader + + @classmethod + def load( + cls, dictionary_name_or_path: Union[Path, str], database_name: Optional[str] = None + ) -> "BiomedicalEntityLinkingDictionary": + """Load dictionary: either pre-definded or from path""" + + if isinstance(dictionary_name_or_path, str): + dictionary_name_or_path = cast(str, dictionary_name_or_path) + + if ( + dictionary_name_or_path not in ENTITY_TYPE_TO_DICTIONARY + and dictionary_name_or_path not in BIOMEDICAL_DICTIONARIES + ): + raise ValueError( + f"Unkwnon dictionary `{dictionary_name_or_path}`!" + f" Available dictionaries are: {tuple(BIOMEDICAL_DICTIONARIES)}" + " If you want to pass a local path please use the `Path` class, " + "i.e. `model_name_or_path=Path(my_path)`" + ) + + dictionary_name_or_path = ENTITY_TYPE_TO_DICTIONARY.get(dictionary_name_or_path, dictionary_name_or_path) + + reader = BIOMEDICAL_DICTIONARIES[str(dictionary_name_or_path)]() + + else: + # use custom dictionary file + assert ( + database_name is not None + ), "When providing a path to a custom dictionary you must specify the `database_name`!" + reader = ParsedBiomedicalEntityLinkingDictionary(path=dictionary_name_or_path, database_name=database_name) + + return cls(reader=reader) + + @property + def database_name(self) -> str: + """Database name of the dictionary""" + + return self.reader.database_name + + def stream(self) -> Iterator[Tuple[str, str]]: + """ + Stream entries from preprocessed dictionary + """ + + for entry in self.reader.stream(): + yield entry + + +class AbstractCandidateGenerator(ABC): + """ + Base class for a candidate generator, i.e. given a mention of an entity, find matching + entries from the dictionary. + """ + + @abstractmethod + def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: + """ + Returns the top-k entity / concept identifiers for each entity mention. + + :param entity_mentions: Entity mentions + :param top_k: Number of best-matching entities from the knowledge base to return + :result: List containing a list of entity linking candidates per entity mention from the input + """ + + def build_candidate(self, candidate: Tuple[str, str, float], database_name: str) -> EntityLinkingCandidate: + """Get nice container with all info about entity linking candidate""" + + concept_name = candidate[0] + concept_id = candidate[1] + score = candidate[2] + + if "|" in concept_id: + labels = concept_id.split("|") + concept_id = labels[0] + additional_labels = labels[1:] + else: + additional_labels = None + + return EntityLinkingCandidate( + concept_id=concept_id, + concept_name=concept_name, + score=score, + additional_ids=additional_labels, + database_name=database_name, + ) + + +class ExactMatchCandidateGenerator(AbstractCandidateGenerator): + """ + Candidate generator using exact string matching as search criterion + """ + + def __init__(self, dictionary: BiomedicalEntityLinkingDictionary): + # Build index which maps concept / entity names to concept / entity ids + self.dictionary = dictionary + self.name_to_id_index = dict(list(dictionary.stream())) + + @classmethod + def load(cls, dictionary_name_or_path: Union[str, Path]) -> "ExactMatchCandidateGenerator": + """Compatibility function""" + return cls(BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path)) + + def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: + candidates: List[List[EntityLinkingCandidate]] = [] + for mention in entity_mentions: + dict_entry = self.name_to_id_index.get(mention) + if not dict_entry: + candidates.append([]) + continue + + candidates.append([ + self.build_candidate( + candidate=(mention, dict_entry, 1.0), + database_name=self.dictionary.database_name + ) + ]) + + return candidates + + +class BigramTfIDFVectorizer: + """ + Wrapper for sklearn TfIDFVectorizer w/ fixed ngram range at the character level + Implementation adapted from: + + Sung et al.: Biomedical Entity Representations with Synonym Marginalization, 2020 + https://github.com/dmis-lab/BioSyn/tree/master/src/biosyn/sparse_encoder.py#L8 + """ + + def __init__(self): + self.encoder = TfidfVectorizer(analyzer="char", ngram_range=(1, 2)) + + def fit(self, names: List[str]): + """Learn vocabulary""" + self.encoder.fit(names) + return self + + def transform(self, names: Union[List[str], np.ndarray]) -> csr_matrix: + """Convert strings to sparse vectors""" + embeddings = self.encoder.transform(names) + return embeddings + + def __call__(self, mentions: Union[List[str], np.ndarray]) -> np.ndarray: + """Short for `transform`""" + return self.transform(mentions) + + def save(self, path: Path): + """Save vectorizer to disk""" + joblib.dump(self.encoder, str(path)) + + @classmethod + def load(cls, path: Union[Path, str]) -> "BigramTfIDFVectorizer": + """Instantiate from path""" + newVectorizer = cls() + + # with open(path, "rb") as fin: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + newVectorizer.encoder = joblib.load(str(path)) + # logger.info("Sparse encoder loaded from %s", path) + + return newVectorizer + + +class BiEncoderCandidateGenerator(AbstractCandidateGenerator): + """ + Candidate generator using both dense (transformer-based) and (optionally) sparse vector representations, + to search candidates in a knowledge base / dictionary. + """ + + def __init__( + self, + model_name_or_path: Union[str, Path], + dictionary_name_or_path: Union[str, Path], + similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, + preprocessor: AbstractEntityPreprocessor = Ab3PEntityPreprocessor.load(preprocessor=EntityPreprocessor()), + max_length: int = 25, + batch_size: int = 1024, + hybrid_search: bool = False, + sparse_weight: Optional[float] = None, + force_hybrid_search: bool = False, + dictionary: Optional[BiomedicalEntityLinkingDictionary] = None, + ): + """ + Initializes the BiEncoderEntityRetrieverModel. + + :param model_name_or_path: Name of or path to the transformer model to be used. + :param dictionary_name_or_path: Name of or path to the transformer model to be used. + :param similarity_metric: which metric to use to compute similarity + :param preprocessor: Preprocessing strategy for entity mentions and names + :param max_length: Maximum number of input tokens to transformer model + :param batch_size: Number of entity mentions/names to embed in one forward pass + :param hybrid_search: Indicates whether to use sparse embeddings or not + :param sparse_weight: Weight to balance sparse and dense similarity scores (default sparse weight) + :param force_hybrid_search: if pre-trained model is not hybrid (dense+sparse) fit a sparse encoder + :param dictionary: optionally pass a custom dictionary + """ + self.model_name_or_path = model_name_or_path + self.dictionary_name_or_path = dictionary_name_or_path + self.preprocessor = preprocessor + self.similarity_metric = similarity_metric + self.max_length = max_length + self.batch_size = batch_size + self.hybrid_search = hybrid_search + self.sparse_weight = sparse_weight + self.force_hybrid_search = force_hybrid_search + if self.force_hybrid_search: + self.hybrid_search = True + + # allow to pass custom dictionary + if dictionary is not None: + self.dictionary = dictionary + else: + self.dictionary = BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path) + + self.dictionary_data: List[Tuple[str, str]] = [ + (self.preprocessor.process_entity_name(name), cui) for name, cui in self.dictionary.stream() + ] + + # Load encoders + self.dense_encoder = TransformerDocumentEmbeddings(model=model_name_or_path, is_token_embedding=False) + self.sparse_encoder: Optional[BigramTfIDFVectorizer] = None + + if self.hybrid_search: + self.sparse_encoder, self.sparse_weight = self._get_sparse_encoder_and_weight( + model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path + ) + + self.indices = self._load_indices( + model_name_or_path=model_name_or_path, + dictionary_name_or_path=dictionary_name_or_path, + ) + + @property + def higher_is_better(self): + """ + Determine if similarity is proportional to score. + E.g. for L2 lower is better, while INNER_PRODUCT higher is better + """ + + return self.similarity_metric in [SimilarityMetric.COSINE, SimilarityMetric.INNER_PRODUCT] + + def _get_cache_name(self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path]) -> str: + """Fixed name for caching""" + + # Check for embedded dictionary in cache + dictionary_name = os.path.splitext(os.path.basename(dictionary_name_or_path))[0] + file_name = f"{str(model_name_or_path).split('/')[-1]}_{dictionary_name}" + pp_name = self.preprocessor.name if self.preprocessor is not None else "null" + + return f"{file_name}-{pp_name}" + + @timeit + def fit_sparse_encoder(self) -> BigramTfIDFVectorizer: + """Fit sparse encoder to current dictionary""" + + logger.info( + "BiEncoderCandidateGenerator: hybrid model has no pretrained sparse encoder. Fit to dictionary `%s`", + self.dictionary_name_or_path, + ) + sparse_encoder = BigramTfIDFVectorizer().fit([name for name, cui in self.dictionary_data]) + # sparse_encoder.save(Path(sparse_encoder_path)) + # torch.save(torch.FloatTensor(self.sparse_weight), sparse_weight_path) + + return sparse_encoder + + def _handle_sparse_encoder( + self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path] + ) -> BigramTfIDFVectorizer: + """If necessary fit and cache sparse encoder""" + + if isinstance(model_name_or_path, str): + cache_name = self._get_cache_name( + model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path + ) + path = flair.cache_root / "models" / f"{cache_name}-sparse-encoder.pk" + else: + path = model_name_or_path / "sparse_encoder.pk" + + if path.exists(): + sparse_encoder = BigramTfIDFVectorizer.load(path) + else: + sparse_encoder = self.fit_sparse_encoder() + # logger.info("Save fitted sparse encoder to %s", path) + sparse_encoder.save(path) + + return sparse_encoder + + def _get_sparse_encoder_and_weight( + self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path] + ) -> Tuple[BigramTfIDFVectorizer, float]: + sparse_encoder_path = os.path.join(model_name_or_path, "sparse_encoder.pk") + sparse_weight_path = os.path.join(model_name_or_path, "sparse_weight.pt") + + if isinstance(model_name_or_path, str) and model_name_or_path in PRETRAINED_HYBRID_MODELS: + model_name_or_path = cast(str, model_name_or_path) + + if not os.path.exists(sparse_encoder_path): + sparse_encoder_path = hf_hub_download( + repo_id=model_name_or_path, + filename="sparse_encoder.pk", + cache_dir=flair.cache_root / "models" / model_name_or_path, + ) + + sparse_encoder = BigramTfIDFVectorizer.load(path=sparse_encoder_path) + + if not os.path.exists(sparse_weight_path): + sparse_weight_path = hf_hub_download( + repo_id=model_name_or_path, + filename="sparse_weight.pt", + cache_dir=flair.cache_root / "models" / model_name_or_path, + ) + sparse_weight = torch.load(sparse_weight_path, map_location="cpu").item() + else: + sparse_weight = self.sparse_weight if self.sparse_weight is not None else DEFAULT_SPARSE_WEIGHT + sparse_encoder = self._handle_sparse_encoder( + model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path + ) + + return sparse_encoder, sparse_weight + + def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: + """ + Create sparse embeddings from array of entity mentions/names. + + :param inputs: Numpy array of entity / concept names + :returns Numpy array containing the sparse embeddings of the names + """ + if self.sparse_encoder is None: + raise AssertionError("Error while using the model") + + return self.sparse_encoder(inputs) + + def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: bool = False) -> np.ndarray: + """ + Create dense embeddings from array of entity mentions/names. + + :param names: Numpy array of entity / concept names + :param batch_size: Batch size used while embedding the name + :param show_progress: bool to toggle progress bar + :return: Numpy array containing the dense embeddings of the names + """ + self.dense_encoder.eval() # prevent dropout + + dense_embeds = [] + + with torch.no_grad(): + if show_progress: + iterations = tqdm( + range(0, len(inputs), batch_size), + desc=f"Embedding `{self.dictionary.database_name}`", + ) + else: + iterations = range(0, len(inputs), batch_size) + + for start in iterations: + # Create batch + end = min(start + batch_size, len(inputs)) + batch = [Sentence(name) for name in inputs[start:end]] + + # embed batch + self.dense_encoder.embed(batch) + + dense_embeds += [name.embedding.cpu().detach().numpy() for name in batch] + + if flair.device.type == "cuda": + torch.cuda.empty_cache() + + return np.array(dense_embeds) + + # separate method to allow more sophisticated logic in the future, e.g.: ANN with HNSW, PQ... + def get_dense_index(self, names: np.ndarray, path: Path) -> faiss.Index: + """Load or create dense index and save it to disk""" + + if path.exists(): + index = faiss.read_index(str(path)) + + else: + embeddings = self.embed_dense(inputs=np.array(names), batch_size=self.batch_size, show_progress=True) + + index = faiss.IndexFlatIP(embeddings.shape[1]) + index.add(embeddings) + + if self.similarity_metric == SimilarityMetric.COSINE: + faiss.normalize_L2(embeddings) + + faiss.write_index(index, str(path)) + + return index + + def get_sparse_index(self, names: np.ndarray, path: Path) -> csr_matrix: + """Load or create sparse index and save it to disk""" + + if path.exists(): + index = scipy.sparse.load_npz(str(path)) + else: + index = self.embed_sparse(inputs=names) + + scipy.sparse.save_npz(str(path), index) + # index.save_index # HNSWLIB + # index.save # ANNOY + + return index + + def _load_indices(self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path]) -> Dict: + """Load cached indices if available, otherwise compute embeddings, build index and cache""" + + cache_name = self._get_cache_name( + model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path + ) + + cache_folder = flair.cache_root / "datasets" / cache_name + cache_folder.mkdir(parents=True, exist_ok=True) + + indices = {} + + logger.info( + "BiEncoderCandidateGenerator: initialize %s %s", + self.dictionary.database_name, + "indices" if self.hybrid_search else "index", + ) + + for index_type in ["sparse", "dense"]: + if index_type == "sparse" and not self.hybrid_search: + continue + + extension = "bin" if index_type == "dense" else "npz" + file_name = f"index-{index_type}.{extension}" + + index_cache_file = cache_folder / file_name + + names = np.array([n for n, _ in self.dictionary_data]) + + if index_type == "dense": + indices[index_type] = self.get_dense_index( + names=names, path=index_cache_file + ) + + else: + indices[index_type] = self.get_sparse_index( + names=names, path=index_cache_file + ) + + return indices + + @timeit + def search_sparse(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np.ndarray, np.ndarray]: + """ + Find candidates with sparse representations + + :param entity_mentions: list of entity mentions (~ queries) + :param top_k: number of candidates to retrieve per mention + """ + assert ( + self.sparse_encoder is not None + ), "BiEncoderCandidateGenerator has no `sparse_encoder`! Pass `force_hybrid_search=True` at initialization" + + mention_embeddings = self.sparse_encoder(entity_mentions) + + if self.similarity_metric == SimilarityMetric.COSINE: + score_matrix = cosine_similarity(mention_embeddings, self.indices["sparse"], dense_output=False) + elif self.similarity_metric == SimilarityMetric.INNER_PRODUCT: + score_matrix = mention_embeddings.dot(self.indices["sparse"].T) + + score_matrix = score_matrix.toarray() + + num_mentions = score_matrix.shape[0] + + unsorted_indices = np.argpartition(score_matrix, -top_k)[:, -top_k:] + unsorted_scores = score_matrix[np.arange(num_mentions)[:, None], unsorted_indices] + + sorted_score_matrix_indices = np.argsort(-unsorted_scores) + + idxs = unsorted_indices[np.arange(num_mentions)[:, None], sorted_score_matrix_indices] + dists = unsorted_scores[np.arange(num_mentions)[:, None], sorted_score_matrix_indices] + + return idxs, dists + + @timeit + def search_dense(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np.ndarray, np.ndarray]: + """ + Find candidates with dense representations (FAISS) + + :param entity_mentions: list of entity mentions (~ queries) + :param top_k: number of candidates to retrieve + """ + + # Compute dense embedding for the given entity mention + mention_dense_embeds = self.embed_dense(inputs=np.array(entity_mentions), batch_size=self.batch_size) + + if self.similarity_metric == SimilarityMetric.COSINE: + faiss.normalize_L2(mention_dense_embeds) + + # Get candidates from dense embeddings + dists, ids = self.indices["dense"].search(mention_dense_embeds, top_k) + + return ids, dists + + def combine_dense_and_sparse_results( + self, + dense_ids: np.ndarray, + dense_scores: np.ndarray, + sparse_ids: np.ndarray, + sparse_scores: np.ndarray, + top_k: int = 1, + ): + """ + Expand dense results with sparse ones (that are not already in the dense) and re-weight the + score as: dense_score + sparse_weight * sparse_scores + """ + + hybrid_ids = [] + hybrid_scores = [] + for i in range(dense_ids.shape[0]): + mention_ids = dense_ids[i] + mention_scores = dense_scores[i] + + mention_spare_ids = sparse_ids[i] + mention_sparse_scores = sparse_scores[i] + + for sparse_id, sparse_score in zip(mention_spare_ids, mention_sparse_scores): + if sparse_id not in mention_ids: + mention_ids = np.append(mention_ids, sparse_id) + mention_scores = np.append(mention_scores, self.sparse_weight * sparse_score) + else: + index = np.where(mention_ids == sparse_id)[0][0] + mention_scores[index] += self.sparse_weight * sparse_score + + rerank_indices = np.argsort(-mention_scores if self.higher_is_better else mention_scores) + mention_ids = mention_ids[rerank_indices][:top_k] + mention_scores = mention_scores[rerank_indices][:top_k] + hybrid_ids.append(mention_ids.tolist()) + hybrid_scores.append(mention_scores.tolist()) + + return hybrid_scores, hybrid_ids + + def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: + """ + Returns the top-k entity / concept identifiers for each entity mention. + + :param entity_mentions: Entity mentions + :param top_k: Number of best-matching entities from the knowledge base to return + :result: List containing a list of entity linking candidates per entity mention from the input + """ + + ids, scores = self.search_dense(entity_mentions=entity_mentions, top_k=top_k) + + if self.hybrid_search and self.sparse_encoder is not None: + sparse_ids, sparse_scores = self.search_sparse(entity_mentions=entity_mentions, top_k=top_k) + + scores, ids = self.combine_dense_and_sparse_results( + dense_ids=ids, + dense_scores=scores, + sparse_scores=sparse_scores, + sparse_ids=sparse_ids, + top_k=top_k, + ) + + return [ + [ + self.build_candidate( + candidate=self.dictionary_data[i] + (score,), database_name=self.dictionary.database_name + ) + for i, score in zip(mention_ids, mention_scores) + ] + for mention_ids, mention_scores in zip(ids, scores) + ] + + +class BiomedicalEntityLinker: + """Entity linking model for the biomedical domain""" + + def __init__( + self, + candidate_generator: AbstractCandidateGenerator, + preprocessor: AbstractEntityPreprocessor, + entity_type: str, + ): + self.preprocessor = preprocessor + self.candidate_generator = candidate_generator + self.entity_type = entity_type + self.annotation_layers = [ENTITY_TYPE_TO_ANNOTATION_LAYER.get(self.entity_type, "ner")] + + def extract_mentions( + self, + sentences: List[Sentence], + annotation_layers: Optional[List[str]] = None, + ) -> Tuple[List[int], List[Span], List[str], List[str]]: + """Unpack all mentions in sentences for batch search.""" + + source = [] + data_points = [] + mentions = [] + mention_annotation_layers = [] + + # use default annotation layers only if are not provided + annotation_layers = annotation_layers if annotation_layers is not None else self.annotation_layers + + for i, sentence in enumerate(sentences): + for annotation_layer in annotation_layers: + for entity in sentence.get_labels(annotation_layer): + source.append(i) + data_points.append(entity.data_point) + mentions.append( + self.preprocessor.process_mention(entity, sentence) + if self.preprocessor is not None + else entity.data_point.text, + ) + mention_annotation_layers.append(annotation_layer) + + # assert len(mentions) > 0, f"There are no entity mentions of type `{self.entity_type}`" + + return source, data_points, mentions, mention_annotation_layers + + def predict( + self, + sentences: Union[List[Sentence], Sentence], + annotation_layers: Optional[List[str]] = None, + top_k: int = 1, + ) -> None: + """ + Predicts the best matching top-k entity / concept identifiers of all named entities annotated + with tag input_entity_annotation_layer. + + :param sentences: One or more sentences to run the prediction on + :param annotation_layers: List of annotation layers to extract entity mentions + :param top_k: Number of best-matching entity / concept identifiers + """ + # make sure sentences is a list of sentences + if not isinstance(sentences, list): + sentences = [sentences] + + if self.preprocessor is not None: + self.preprocessor.initialize(sentences) + + source, data_points, mentions, mentions_annotation_layers = self.extract_mentions( + sentences=sentences, annotation_layers=annotation_layers + ) + + # no mentions: nothing to do here + if len(mentions) > 0: + # Retrieve top-k concept / entity candidates + candidates = self.candidate_generator.search(entity_mentions=mentions, top_k=top_k) + + # Add a label annotation for each candidate + for i, data_point, mention_candidates, mentions_annotation_layer in zip( + source, data_points, candidates, mentions_annotation_layers + ): + sentences[i].add_label( + typename=mentions_annotation_layer, + value_or_label=EntityLinkingLabel(data_point=data_point, candidates=mention_candidates), + ) + + @classmethod + def load( + cls, + model_name_or_path: Union[str, Path], + dictionary_name_or_path: Optional[Union[str, Path]] = None, + hybrid_search: bool = True, + max_length: int = 25, + batch_size: int = 1024, + similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, + preprocessor: AbstractEntityPreprocessor = Ab3PEntityPreprocessor.load(preprocessor=EntityPreprocessor()), + force_hybrid_search: bool = False, + sparse_weight: float = DEFAULT_SPARSE_WEIGHT, + entity_type: Optional[str] = None, + dictionary: Optional[BiomedicalEntityLinkingDictionary] = None, + ) -> "BiomedicalEntityLinker": + """ + Loads a model for biomedical named entity normalization. + See __init__ method for detailed docstring on arguments + """ + if not isinstance(model_name_or_path, str): + raise AssertionError(f"String matching model name has to be an " + f"string (and not {type(model_name_or_path)}") + model_name_or_path = cast(str, model_name_or_path) + + if dictionary_name_or_path is None or isinstance(dictionary_name_or_path, str): + dictionary_name_or_path = cls.__get_dictionary_path( + model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path + ) + + if isinstance(model_name_or_path, str): + model_name_or_path, entity_type = cls.__get_model_path_and_entity_type( + model_name_or_path=model_name_or_path, + entity_type=entity_type, + hybrid_search=hybrid_search, + force_hybrid_search=force_hybrid_search, + ) + else: + assert entity_type is not None, "When using a custom model you must specify `entity_type`" + assert entity_type in ENTITY_TYPES, f"Invalid entity type `{entity_type}! Must be one of: {ENTITY_TYPES}" + + if model_name_or_path == "exact-string-match": + candidate_generator: AbstractCandidateGenerator = ExactMatchCandidateGenerator.load(dictionary_name_or_path) + else: + candidate_generator = BiEncoderCandidateGenerator( + model_name_or_path=model_name_or_path, + dictionary_name_or_path=dictionary_name_or_path, + hybrid_search=hybrid_search, + similarity_metric=similarity_metric, + max_length=max_length, + batch_size=batch_size, + sparse_weight=sparse_weight, + preprocessor=preprocessor, + force_hybrid_search=force_hybrid_search, + dictionary=dictionary, + ) + + logger.info( + "BiomedicalEntityLinker predicts: Dictionary `%s` (entity type: %s)", + dictionary_name_or_path, + entity_type + ) + + return cls(candidate_generator=candidate_generator, preprocessor=preprocessor, entity_type=entity_type) + + @staticmethod + def __get_model_path_and_entity_type( + model_name_or_path: Union[str, Path], + entity_type: Optional[str] = None, + hybrid_search: bool = False, + force_hybrid_search: bool = False, + ) -> Tuple[Union[str, Path], str]: + """ + Try to figure out what model the user wants + """ + + if model_name_or_path not in MODELS and model_name_or_path not in ENTITY_TYPES: + raise ValueError( + f"Unknown model `{model_name_or_path}`!" + f" Available entity types are: {ENTITY_TYPES}" + " If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)`" + ) + + if model_name_or_path == "cambridgeltl/SapBERT-from-PubMedBERT-fulltext": + assert entity_type is not None, f"For model {model_name_or_path} you must specify `entity_type`" + + if hybrid_search: + # load model by entity_type + if isinstance(model_name_or_path, str) and model_name_or_path in ENTITY_TYPES: + model_name_or_path = cast(str, model_name_or_path) + + # check if we have a hybrid pre-trained model + if model_name_or_path in ENTITY_TYPE_TO_HYBRID_MODEL: + entity_type = model_name_or_path + model_name_or_path = ENTITY_TYPE_TO_HYBRID_MODEL[model_name_or_path] + else: + # check if user really wants to use hybrid search anyway + if not force_hybrid_search: + logger.warning( + "BiEncoderCandidateGenerator: model for entity type `%s` was not trained for" + " hybrid search: no sparse search will be performed." + " If you want to use sparse search please pass `force_hybrid_search=True`:" + " we will fit a sparse encoder for you. The default value of `sparse_weight` is `%s`.", + model_name_or_path, + DEFAULT_SPARSE_WEIGHT, + ) + model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] + else: + if model_name_or_path not in PRETRAINED_HYBRID_MODELS and not force_hybrid_search: + logger.warning( + "BiEncoderCandidateGenerator: model `%s` was not trained for hybrid search: no sparse" + " search will be performed." + " If you want to use sparse search please pass `force_hybrid_search=True`:" + " we will fit a sparse encoder for you. The default value of `sparse_weight` is `%s`.", + model_name_or_path, + DEFAULT_SPARSE_WEIGHT, + ) + + model_name_or_path = cast(str, model_name_or_path) + entity_type = PRETRAINED_HYBRID_MODELS[model_name_or_path] + + else: + if isinstance(model_name_or_path, str) and model_name_or_path in ENTITY_TYPES: + model_name_or_path = cast(str, model_name_or_path) + model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] + + assert ( + entity_type is not None + ), f"Impossible to determine entity type for model `{model_name_or_path}`: please specify via `entity_type`" + + return model_name_or_path, entity_type + + @staticmethod + def __get_dictionary_path( + model_name_or_path: str, + dictionary_name_or_path: Optional[Union[str, Path]] = None, + ) -> Union[str, Path]: + """ + Try to figure out what dictionary (depending on the model) the user wants + """ + + if model_name_or_path in STRING_MATCHING_MODELS and dictionary_name_or_path is None: + raise ValueError( + "When using a string-matching candidate generator you must specify `dictionary_name_or_path`!" + ) + + if dictionary_name_or_path is not None and isinstance(dictionary_name_or_path, str): + dictionary_name_or_path = cast(str, dictionary_name_or_path) + + if dictionary_name_or_path in ENTITY_TYPES: + dictionary_name_or_path = ENTITY_TYPE_TO_DICTIONARY[dictionary_name_or_path] + else: + if model_name_or_path in MODEL_NAME_TO_DICTIONARY: + dictionary_name_or_path = MODEL_NAME_TO_DICTIONARY[model_name_or_path] + elif model_name_or_path in ENTITY_TYPE_TO_DICTIONARY: + dictionary_name_or_path = ENTITY_TYPE_TO_DICTIONARY[model_name_or_path] + else: + raise ValueError( + f"When using a custom model you need to specify a dictionary. Available options are: {ENTITY_TYPES}. Or provide a path to a dictionary file." + ) + + return dictionary_name_or_path diff --git a/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md new file mode 100644 index 000000000..b7c814570 --- /dev/null +++ b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md @@ -0,0 +1,59 @@ +# HunFlair Tutorial 3: Entity Linking + +After adding named entity recognition tags to your sentence, you can run named entity linking on these annotations. +```python +from flair.models.biomedical_entity_linking import BiomedicalEntityLinker +from flair.nn import Classifier +from flair.tokenization import SciSpacyTokenizer +from flair.data import Sentence + +sentence = Sentence( + "The mutation in the ABCD1 gene causes X-linked adrenoleukodystrophy, " + "a neurodegenerative disease, which is exacerbated by exposure to high " + "levels of mercury in dolphin populations.", + use_tokenizer=SciSpacyTokenizer() +) + +ner_tagger = Classifier.load("hunflair") +ner_tagger.predict(sentence) + +nen_tagger = BiomedicalEntityLinker.load("disease") +nen_tagger.predict(sentence) + +nen_tagger = BiomedicalEntityLinker.load("gene") +nen_tagger.predict(sentence) + +nen_tagger = BiomedicalEntityLinker.load("chemical") +nen_tagger.predict(sentence) + +nen_tagger = BiomedicalEntityLinker.load("species", entity_type="species") +nen_tagger.predict(sentence) + +for tag in sentence.get_labels(): + print(tag) +``` +This should print: +~~~ +Span[4:5]: "ABCD1" → Gene (0.9575) +Span[4:5]: "ABCD1" → abcd1 - NCBI-GENE-HUMAN:215 (14.5503) +Span[7:11]: "X-linked adrenoleukodystrophy" → Disease (0.9867) +Span[7:11]: "X-linked adrenoleukodystrophy" → x linked adrenoleukodystrophy - CTD-DISEASES:MESH:D000326 (13.9717) +Span[13:15]: "neurodegenerative disease" → Disease (0.8865) +Span[13:15]: "neurodegenerative disease" → neurodegenerative disease - CTD-DISEASES:MESH:D019636 (14.2779) +Span[25:26]: "mercury" → Chemical (0.9456) +Span[25:26]: "mercury" → mercury - CTD-CHEMICALS:MESH:D008628 (14.9185) +Span[27:28]: "dolphin" → Species (0.8082) +Span[27:28]: "dolphin" → marine dolphins - NCBI-TAXONOMY:9726 (14.473) +~~~ +The output contains both the NER disease annotations and their entity / concept identifiers according to +a knowledge base or ontology. We have pre-configured combinations of models and dictionaries for +"disease", "chemical" and "gene". + +You can also provide your own model and dictionary: +```python +from flair.models.biomedical_entity_linking import BiomedicalEntityLinker + +nen_tagger = BiomedicalEntityLinker.load("name_or_path_to_your_model", dictionary_names_or_path="name_or_path_to_your_dictionary") +nen_tagger = BiomedicalEntityLinker.load("path_to_custom_disease_model", dictionary_names_or_path="disease") +```` +You can use any combination of provided models, provided dictionaries and your own. diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py new file mode 100644 index 000000000..a285af5c7 --- /dev/null +++ b/tests/test_biomedical_entity_linking.py @@ -0,0 +1,66 @@ +# from flair.data import Sentence +# from flair.models.biomedical_entity_linking import ( +# BiomedicalEntityLinker, +# BiomedicalEntityLinkingDictionary, +# ) +# from flair.nn import Classifier + + +# def test_bel_dictionary(): +# """ +# Check data in dictionary is what we expect. +# Hard to define a good test as dictionaries are DYNAMIC, +# i.e. they can change over time +# """ + +# dictionary = BiomedicalEntityLinkingDictionary.load("disease") +# _, identifier = next(dictionary.stream()) +# assert identifier.startswith(("MESH:", "OMIM:", "DO:DOID")) + +# dictionary = BiomedicalEntityLinkingDictionary.load("ctd-disease") +# _, identifier = next(dictionary.stream()) +# assert identifier.startswith("MESH:") + +# dictionary = BiomedicalEntityLinkingDictionary.load("ctd-chemical") +# _, identifier = next(dictionary.stream()) +# assert identifier.startswith("MESH:") + +# dictionary = BiomedicalEntityLinkingDictionary.load("chemical") +# _, identifier = next(dictionary.stream()) +# assert identifier.startswith("MESH:") + +# dictionary = BiomedicalEntityLinkingDictionary.load("ncbi-taxonomy") +# _, identifier = next(dictionary.stream()) +# assert identifier.isdigit() + +# dictionary = BiomedicalEntityLinkingDictionary.load("species") +# _, identifier = next(dictionary.stream()) +# assert identifier.isdigit() + +# dictionary = BiomedicalEntityLinkingDictionary.load("ncbi-gene") +# _, identifier = next(dictionary.stream()) +# assert identifier.isdigit() + +# dictionary = BiomedicalEntityLinkingDictionary.load("gene") +# _, identifier = next(dictionary.stream()) +# assert identifier.isdigit() + + +# def test_biomedical_entity_linking(): + +# sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome") + +# tagger = Classifier.load("hunflair") +# tagger.predict(sentence) + +# disease_linker = BiomedicalEntityLinker.load("disease", hybrid_search=True) +# disease_linker.predict(sentence) + +# gene_linker = BiomedicalEntityLinker.load("gene", hybrid_search=False) + +# breakpoint() + + +# if __name__ == "__main__": +# # test_bel_dictionary() +# test_biomedical_entity_linking() diff --git a/tests/test_datasets_biomedical.py b/tests/test_datasets_biomedical.py index c515e068c..fbff952ae 100644 --- a/tests/test_datasets_biomedical.py +++ b/tests/test_datasets_biomedical.py @@ -182,7 +182,7 @@ def assert_conll_writer_output( assert contents == expected_output -def test_filter_nested_entities(caplog): +def test_filter_nested_entities(recwarn): entities_per_document = { "d0": [Entity((0, 1), "t0"), Entity((2, 3), "t1")], "d1": [Entity((0, 6), "t0"), Entity((2, 3), "t1"), Entity((4, 5), "t2")], @@ -204,9 +204,11 @@ def test_filter_nested_entities(caplog): } dataset = InternalBioNerDataset(documents={}, entities_per_document=entities_per_document) - caplog.set_level(logging.WARNING) filter_nested_entities(dataset) - assert "WARNING: Corpus modified by filtering nested entities." in caplog.text + + assert len(recwarn.list) == 1 + assert isinstance(recwarn.list[0].message, UserWarning) + assert "Corpus modified by filtering nested entities." in recwarn.list[0].message.args[0] for key, entities in dataset.entities_per_document.items(): assert key in target