Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add biomedical entity normalization #3180

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
641a3c0
Initial version (already adapted to recent Flair API changes)
Mar 14, 2023
9779abf
Revise mention text pre-processing: define general interface and adap…
Mar 14, 2023
8da7d75
Refactor entity linking model structure
Mar 15, 2023
e34c831
Update documentation
Mar 22, 2023
f54925c
Introduce separate methods for pre-processing (1) entity mentions fro…
Mar 23, 2023
90a0acb
Merge branch 'master' into bio-entity-normalization
alanakbik Apr 21, 2023
f1f51fd
Fix formatting
alanakbik Apr 21, 2023
f2f21d3
feat(test): biomedical entity linking
Apr 26, 2023
82c1b8b
fix(requirements): add faiss
Apr 26, 2023
2e3cda3
fix(test): hold on w/ automatic tests for now
Apr 26, 2023
adb231e
fix(bionel): start major refactoring
Apr 26, 2023
c80f1be
fix(bionel): major refactor
Apr 27, 2023
d10d297
fix(bionel): assign entity type
May 2, 2023
25ba2dd
fix(biencoder): set sparse encoder and weight
May 2, 2023
4525d3b
fix(bionel): address comments
May 11, 2023
3a5913d
fix(candidate_generator): container for search result
May 12, 2023
734d895
fix(predict): default annotation layer iff not provided by use
May 19, 2023
d79f871
fix(label): scores can be >= or <=
May 19, 2023
118fb95
fix(candidate): parametrize database name
May 19, 2023
1fcfddf
feat(candidate_generator): cache sparse encoder
May 22, 2023
9322c1b
fix(candidate_generator): minor improvements
May 23, 2023
071f51e
feat(linking_candidate): pretty print
May 24, 2023
a23f360
fix(candidate_generator): check sparse encoder for sparse search
May 24, 2023
ce29290
chore: crystal clear dictionary name
Jun 1, 2023
0d65336
feat(candidate_generator): add sparse index
Jun 1, 2023
02812f0
fix(candidate_generator): KISS: sparse search w/ scipy sparse matrices
Jun 2, 2023
ca6eee8
Minor update to comments and documentation
Jul 12, 2023
6c8f219
Fix tests and type annotations
Jul 12, 2023
2fa43cc
Merge branch 'master' into bio-entity-normalization
Jul 12, 2023
d90d92d
Merge
Jul 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 101 additions & 5 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from collections import Counter, defaultdict, namedtuple
from operator import itemgetter
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Union, cast
from typing import Dict, Iterable, List, Optional, Tuple, Union, cast

import torch
from deprecated import deprecated
from deprecated import deprecated # type: ignore
from torch.utils.data import Dataset, IterableDataset
from torch.utils.data.dataset import ConcatDataset, Subset

Expand Down Expand Up @@ -326,11 +326,13 @@ def get_metadata(self, key: str) -> typing.Any:
def has_metadata(self, key: str) -> bool:
return key in self._metadata

def add_label(self, typename: str, value: str, score: float = 1.0):
def add_label(self, typename: str, value_or_label: Union[str, Label], score: float = 1.0):
label = value_or_label if isinstance(value_or_label, Label) else Label(self, value_or_label, score)

if typename not in self.annotation_layers:
self.annotation_layers[typename] = [Label(self, value, score)]
self.annotation_layers[typename] = [label]
else:
self.annotation_layers[typename].append(Label(self, value, score))
self.annotation_layers[typename].append(label)

return self

Expand Down Expand Up @@ -421,6 +423,100 @@ def __len__(self) -> int:
raise NotImplementedError


class EntityLinkingCandidate:
"""Represent a single candidate returned by a CandidateGenerator"""

def __init__(
self,
concept_id: str,
concept_name: str,
database_name: str,
score: float = 1.0,
additional_ids: Optional[Union[List[str], str]] = None,
):
"""
:param concept_id: Identifier of the entity / concept from the knowledge base / ontology
:param concept_name: (Canonical) name of the entity / concept from the knowledge base / ontology
:param score: Matching score of the entity / concept according to the entity mention
:param additional_ids: List of additional identifiers for the concept / entity in the KB / ontology
:param database_name: Name of the knowlege base / ontology
"""
self.concept_id = concept_id
self.concept_name = concept_name
self.database_name = database_name
self.score = score
self.additional_ids = additional_ids

def __str__(self) -> str:
string = f"EntityLinkingCandidate: {self.database_name}:{self.concept_id} - {self.concept_name} - {self.score}"
if self.additional_ids is not None:
string += f" - {self.additional_ids}"
return string

def __repr__(self) -> str:
return str(self)


class EntityLinkingLabel(Label):
"""
Label class models entity linking annotations. Each entity linking label has a data point it refers
to as well as the identifier and name of the concept / entity from a knowledge base or ontology.
Optionally, additional concepts identifier and the database name can be provided.
"""

def __init__(self, data_point: DataPoint, candidates: List[EntityLinkingCandidate]):
"""
Initializes the label instance.
:param data_point: Data point / span the label refers to
:param candidates: **sorted** list of candidates from candidate generator
"""

def is_sorted(lst, key=lambda x: x, comparison=lambda x, y: x >= y):
for i, el in enumerate(lst[1:]):
if comparison(key(el), key(lst[i])):
return False
return True

# candidates must be sorted, regardless if higher is better or not
assert is_sorted(candidates, key=lambda x: x.score) or is_sorted(
candidates, key=lambda x: x.score, comparison=lambda x, y: x <= y
), "List of candidates must be sorted!"

super().__init__(data_point, candidates[0].concept_id, candidates[0].score)
self.candidates = candidates
self.concept_name = self.candidates[0].concept_name
self.database_name = self.candidates[0].database_name

def __str__(self):
return (
f"{self.data_point.unlabeled_identifier}{flair._arrow} "
f"{self.concept_name} - {self.database_name}:{self._value} ({round(self._score, 4)})"
)

def __repr__(self):
return (
f"{self.data_point.unlabeled_identifier}{flair._arrow} "
f"{self.concept_name} - {self.database_name}:{self._value} ({round(self._score, 4)})"
)

def __len__(self):
return len(self.data_point)

def __eq__(self, other):
return (
self.value == other.value
and self.data_point == other.data_point
and self.concept_name == other.concept_name
and self.identifier == other.identifier
and self.database_name == other.database_name
and self.score == other.score
)

@property
def identifier(self):
return f"{self.value}"


DT = typing.TypeVar("DT", bound=DataPoint)
DT2 = typing.TypeVar("DT2", bound=DataPoint)

Expand Down
8 changes: 8 additions & 0 deletions flair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
CLL,
CRAFT,
CRAFT_V4,
CTD_CHEMICALS_DICTIONARY,
CTD_DISEASES_DICTIONARY,
DECA,
FSU,
GELLUS,
Expand Down Expand Up @@ -90,6 +92,8 @@
LOCTEXT,
MIRNA,
NCBI_DISEASE,
NCBI_GENE_HUMAN_DICTIONARY,
NCBI_TAXONOMY_DICTIONARY,
OSIRIS,
PDR,
S800,
Expand Down Expand Up @@ -386,6 +390,10 @@
"LINNEAUS",
"LOCTEXT",
"MIRNA",
"NCBI_GENE_HUMAN_DICTIONARY",
"NCBI_TAXONOMY_DICTIONARY",
"CTD_DISEASES_DICTIONARY",
"CTD_CHEMICALS_DICTIONARY",
"NCBI_DISEASE",
"ONTONOTES",
"OSIRIS",
Expand Down
Loading