From c08a6d8de35bd7cb1279f8b8181a66818524739a Mon Sep 17 00:00:00 2001 From: Adam Sutton Date: Mon, 4 Aug 2025 14:30:52 +0100 Subject: [PATCH 01/12] initial commit for embedding linker --- .../components/linking/embedding_linker.py | 286 ++++++++++++++++++ 1 file changed, 286 insertions(+) create mode 100644 medcat-v2/medcat/components/linking/embedding_linker.py diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py new file mode 100644 index 000000000..d35367286 --- /dev/null +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -0,0 +1,286 @@ +from medcat.cdb import CDB +from medcat.config.config import Config, ComponentConfig +from medcat.components.types import CoreComponentType, AbstractCoreComponent +from medcat.tokenizing.tokens import MutableEntity, MutableDocument +from medcat.tokenizing.tokenizers import BaseTokenizer +from typing import Optional, Iterator, Any +from medcat.vocab import Vocab +from torch import Tensor +from transformers import AutoTokenizer, AutoModel +from medcat.utils.postprocessing import create_main_ann +from tqdm import tqdm +from medcat.tokenizing.spacy_impl.tokens import Entity +import torch.nn.functional as F +import torch +import logging +import numpy as np +import math +import copy +logger = logging.getLogger(__name__) + +class Linker(AbstractCoreComponent): + name = "embedding_linker" + DEFAULT_MODEL = "sentence-transformers/all-MiniLM-L6-v2" + + # NOTE: NEED TO IMPLEMENT + # the arguments provide to the init method in order + @classmethod + def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, + model_load_path: Optional[str]) -> list[Any]: + extra = cdb.config.components.linking.additional or {} + emb_name = extra.get("embedding_model_name", cls.DEFAULT_MODEL) + max_len = extra.get("max_length", 64) + return [ + cdb, + cdb.config, + emb_name, + max_len + ] + + # NOTE: NEED TO IMPLEMENT + # the keyword arguments to the init method + @classmethod + def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, + model_load_path: Optional[str]) -> dict[str, Any]: + extra = cdb.config.components.linking.additional or {} + emb_name = extra.get("embedding_model_name", cls.DEFAULT_MODEL) + max_len = extra.get("max_length", 64) + return { + "cdb": cdb, + "config": cdb.config, + "embedding_model_name": emb_name, + "max_length": max_len, + } + + def __init__(self, + cdb: CDB, + config: Config, + embedding_model_name: str = DEFAULT_MODEL, + max_length = 64,) -> None: + """Initializes the embedding linker with a CDB and configuration. + Args: + cdb (CDB): The concept database to use. + config (Config): The base config. + embedding_model_name (Optional[str]): The name of the embedding model to use. Default is "sentence-transformers/all-MiniLM-L6-v2" + max_length (int): The maximum length of the input sequences for the embedding model. Default is 64. + """ + self.cdb = cdb + self.config = config + self.max_length = max_length + self.embedding_model_name = embedding_model_name + extra = self.cdb.config.components.linking.additional or {} + extra.setdefault("embedding_model_name", self.DEFAULT_MODEL) + extra.setdefault("max_length", max_length) + + def embed_names(self, embedding_model_name: str, batch_size: int = 4096) -> None: + """Obtain embeddings for all names in the CDB using the specified + embedding model and store them in the name2info.context_vectors + Args: + embedding_model_name (str): The name of the embedding model to use. + batch_size (int): The size of the batches to use when embedding names. Default 4096 + """ + if embedding_model_name == self.embedding_model_name: + logger.debug("Using the same embedding model for training.") + else: + self.embedding_model_name = embedding_model_name + self._load_transformers(embedding_model_name) + names = list(self.cdb.name2info.keys()) + # embed each name in batches. Because there can be 3+ million names + total_batches = math.ceil(len(names) / batch_size) + for names in tqdm(self._batch_data(names, batch_size), total=total_batches + 1, desc="Embedding names"): + with torch.no_grad(): + # removing ~ from names, as it is used to indicate a space in the CDB + names_to_embed = [name.replace("~", " ") for name in names] + batch_dict = self.tokenizer(names_to_embed, max_length=self.max_length, padding=True, truncation=True, return_tensors='pt').to(self.device) + outputs = self.model(**batch_dict) + embeddings = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) + embeddings = F.normalize(embeddings, p=2, dim=1) + for name, embedding in zip(names, embeddings): + name_info = self.cdb.name2info[name] + name_info["context_vectors"] = embedding.cpu() + self.cdb.name2info[name] = name_info + logger.debug("Embedding names done, total: %d", len(names)) + + + def get_type(self) -> CoreComponentType: + return CoreComponentType.linking + + def _batch_data(self, data, batch_size=4096) -> Iterator[list]: + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + + def _load_transformers(self, + embedding_model_name: str = DEFAULT_MODEL) -> None: + """Load the transformers model and tokenizer. + No need to load a transformer model until it's required. + Args: + embedding_model_name (str): The name of the embedding model to load. Default is "sentence-transformers/all-MiniLM-L6-v2" + """ + if not hasattr(self, "model") or not hasattr(self, "tokenizer") or embedding_model_name != self.embedding_model_name: + self.embedding_model_name = embedding_model_name + self.tokenizer = AutoTokenizer.from_pretrained(embedding_model_name) + self.model = AutoModel.from_pretrained(embedding_model_name) + self.model.eval() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model.to(self.device) + logger.debug(f"Loaded embedding model: {embedding_model_name} on device: {self.device}") + + def _score_names(self, names_to_score: list[str]): + """Predicts the appropriate names for detected names comparing embedding similarity + Args: + name (str): The detected name to score. + Returns: + list[tuple[str, float, Tensor]]: A list of tuples containing the predicted name, similarity, and embedding.""" + if not hasattr(self, "context_matrix"): + raise ValueError("Embeddings have not been initialised. Please run `cat._pipeline._components[-1].embed_names` first.") + + with torch.no_grad(): + batch_dict = self.tokenizer(names_to_score, max_length=self.max_length, padding=True, truncation=True, return_tensors='pt').to(self.device) + outputs = self.model(**batch_dict) + detected_name_embeddings = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) + detected_name_embeddings = F.normalize(detected_name_embeddings, p=2, dim=1) + scores = detected_name_embeddings @ self.context_matrix.T + + argmax_indices = torch.argmax(scores, dim=1) + predicted_names = [self.names[i] for i in argmax_indices] + max_similarities = scores.gather(1, argmax_indices.unsqueeze(1)).squeeze(1) + return [ + (pred_name, score.item(), emb.cpu()) + for pred_name, score, emb in zip(predicted_names, max_similarities, detected_name_embeddings) + ] + + def _disambiguate_entity_by_vector_similarity(self, + potential_name_cui_pairs: list[tuple[str, str]], + detected_name_embedding: Tensor) -> str: + """Disambiguate entities based on vector similarity. + If there are multiple potential cuis, try to find the one with the highest similarity to the detected name. + Args: + name (str): The detected name. + potential_name_cui_pairs (list[tuple[str, str]]): List of tuples containing CUI and preferred name pairs. + name_embedding (Tensor): The embedding of the detected name. + Returns: + str: The CUI with the highest similarity to the detected name. + """ + names = [name for _, name in potential_name_cui_pairs] + # we have to embed CUI preferred names because they might not exist + batch_dict = self.tokenizer(names, max_length=self.max_length, padding=True, truncation=True, return_tensors='pt').to(self.device) + outputs = self.model(**batch_dict) + name_vectors = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) + name_vectors = F.normalize(name_vectors, p=2, dim=1).cpu() + scores = detected_name_embedding @ name_vectors.T + best_idx = torch.argmax(scores).item() + return potential_name_cui_pairs[best_idx][0] + + + + def _disambiguate_entity(self, + entity: MutableEntity, + cuis: list[str], + name_embedding: Tensor) -> str: + """Disambiguation where multiple cuis are linked to the same name. + Try to choose the best one based on cui2preffered names. + If there is multiple potential cuis still then try scoring the highest again + Args: + entity (MutableEntity): The entity to disambiguate. + cui (str): The CUI to disambiguate to. + embedding (Tensor): The embedding of the detected name. + Returns: + str: The disambiguated CUI. + """ + # if theres only one CUI, just return it + if len(cuis) == 1: + return cuis[0] + # collect all preferred name / cui pairs first + potential_name_cui_pairs = [(cui, self.cdb.cui2info[cui]['preferred_name'].replace("~", " ")) for cui in cuis] + # if there are multiple, try to find all that matches the detected name + name = entity.detected_name or entity.base.text + name = name.replace("~", " ") + matching_cuis = [cui for cui, preferred_name in potential_name_cui_pairs if preferred_name.lower() == name.lower()] + # if there are multiple matching, just return the first one + # if there are mulitple preferred names then I'm not sure how to choose + if len(matching_cuis) == 1: + return matching_cuis[0][0] + else: + # no perfect names match, so disambiguate by vector similarities + return self._disambiguate_entity_by_vector_similarity(potential_name_cui_pairs, name_embedding) + + def _process_entity_inference( + self, + entities: MutableEntity, + ) -> Iterator[MutableEntity]: + """Infer all entities at once (or in batches), to avoid multiple gpu calls when it isn't nessescary""" + # I don't think we have to concern ourselves with link candidates from the NER step. + # Check does it have a detected name, if not just use the base text + names_to_score = [entity.detected_name or entity.base.text for entity in entities] + names_to_score = [name.replace("~", " ") for name in names_to_score] + results = self._score_names(names_to_score) + + for entity, (predicted_name, similarity, embedding) in zip(entities, results): + # is there a better way to get cui2name mapping? + # this isn't a one to one mapping, so we just take the first one + predicted_cuis = list(self.cdb.name2info[predicted_name]["per_cui_status"].keys()) + # filter out unwanted cuis + cnf_l = self.config.components.linking + predicted_cuis = [cui for cui in predicted_cuis if cnf_l.filters.check_filters(cui)] + # if there are no cuis, just skip the entity + if not predicted_cuis: + continue + predicted_cui = self._disambiguate_entity(entity, predicted_cuis, embedding) + entity.cui = predicted_cui + entity.context_similarity = similarity + yield entity + + def _inference(self, doc: MutableDocument) -> Iterator[MutableEntity]: + # doing this here so it isn't done on each entity + self.names = list(self.cdb.name2info.keys()) + self.context_matrix = torch.stack([self.cdb.name2info[name]["context_vectors"] for name in self.cdb.name2info]).to(self.device) + for entities in self._batch_data(doc.ner_ents): + logger.debug("Linker started with entities: %s", len(entities)) + yield from self._process_entity_inference(entities) + + def _check_similarity(self, cui: str, context_similarity: float) -> bool: + th_type = self.config.components.linking.similarity_threshold_type + threshold = self.config.components.linking.similarity_threshold + if th_type == 'static': + return context_similarity >= threshold + if th_type == 'dynamic': + conf = self.cdb.cui2info[cui]['average_confidence'] + return context_similarity >= conf * threshold + return False + + def _last_token_pool(self, last_hidden_states: Tensor, + attention_mask: Tensor) -> Tensor: + left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) + if left_padding: + return last_hidden_states[:, -1] + else: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden_states.shape[0] + return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] + + def __call__(self, doc: MutableDocument) -> MutableDocument: + # Reset main entities, will be recreated later + doc.linked_ents.clear() + self._load_transformers(self.embedding_model_name) + + cnf_l = self.config.components.linking + + if cnf_l.train: + logger.warning("Attemping to train an embedding linker. This is not required.") + linked_entities = self._inference(doc) + # evaluating generator here because the `all_ents` list gets + # cleared afterwards otherwise + le = list(linked_entities) + + doc.ner_ents.clear() + doc.ner_ents.extend(le) + create_main_ann(doc) + + return doc + + @classmethod + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str] + ) -> 'Linker': + return cls(cdb, cdb.config) \ No newline at end of file From 41ba8c0ea0bdb926d87d6ebbf8a75b6a898e263a Mon Sep 17 00:00:00 2001 From: Adam Sutton Date: Mon, 18 Aug 2025 00:27:03 +0100 Subject: [PATCH 02/12] update to embedding logic and additional configutations --- .../components/linking/embedding_linker.py | 523 ++++++++++++------ medcat-v2/medcat/config/config.py | 29 + 2 files changed, 381 insertions(+), 171 deletions(-) diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py index d35367286..a5a6559dd 100644 --- a/medcat-v2/medcat/components/linking/embedding_linker.py +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -1,62 +1,27 @@ from medcat.cdb import CDB -from medcat.config.config import Config, ComponentConfig +from medcat.config.config import Config, ComponentConfig, EmbeddingLinking from medcat.components.types import CoreComponentType, AbstractCoreComponent -from medcat.tokenizing.tokens import MutableEntity, MutableDocument +from medcat.tokenizing.tokens import MutableEntity, MutableDocument, MutableToken from medcat.tokenizing.tokenizers import BaseTokenizer -from typing import Optional, Iterator, Any +from typing import Optional, Iterator, cast, Iterable from medcat.vocab import Vocab from torch import Tensor from transformers import AutoTokenizer, AutoModel from medcat.utils.postprocessing import create_main_ann from tqdm import tqdm -from medcat.tokenizing.spacy_impl.tokens import Entity +from collections import defaultdict import torch.nn.functional as F import torch import logging -import numpy as np import math -import copy logger = logging.getLogger(__name__) class Linker(AbstractCoreComponent): name = "embedding_linker" - DEFAULT_MODEL = "sentence-transformers/all-MiniLM-L6-v2" - - # NOTE: NEED TO IMPLEMENT - # the arguments provide to the init method in order - @classmethod - def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> list[Any]: - extra = cdb.config.components.linking.additional or {} - emb_name = extra.get("embedding_model_name", cls.DEFAULT_MODEL) - max_len = extra.get("max_length", 64) - return [ - cdb, - cdb.config, - emb_name, - max_len - ] - - # NOTE: NEED TO IMPLEMENT - # the keyword arguments to the init method - @classmethod - def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, - model_load_path: Optional[str]) -> dict[str, Any]: - extra = cdb.config.components.linking.additional or {} - emb_name = extra.get("embedding_model_name", cls.DEFAULT_MODEL) - max_len = extra.get("max_length", 64) - return { - "cdb": cdb, - "config": cdb.config, - "embedding_model_name": emb_name, - "max_length": max_len, - } def __init__(self, cdb: CDB, - config: Config, - embedding_model_name: str = DEFAULT_MODEL, - max_length = 64,) -> None: + config: Config) -> None: """Initializes the embedding linker with a CDB and configuration. Args: cdb (CDB): The concept database to use. @@ -66,13 +31,68 @@ def __init__(self, """ self.cdb = cdb self.config = config - self.max_length = max_length - self.embedding_model_name = embedding_model_name - extra = self.cdb.config.components.linking.additional or {} - extra.setdefault("embedding_model_name", self.DEFAULT_MODEL) - extra.setdefault("max_length", max_length) + if not isinstance(config.components.linking, EmbeddingLinking): + raise TypeError("Linking config must be an EmbeddingLinking instance") + self.cnf_l: EmbeddingLinking = config.components.linking + self.max_length = self.cnf_l.max_token_length + self.embedding_model_name = self.cnf_l.embedding_model_name + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - def embed_names(self, embedding_model_name: str, batch_size: int = 4096) -> None: + # these only need to be populated when called for embedding or inference + self._name_keys = None + self._cui_keys = None + self._names_context_matrix = None + self._cui_context_matrix = None + + # used for filters, and if the name contains a valid cui see: _set_filters + self._last_include_set = None + self._last_exclude_set = None + self._allowed_mask = None + self._name_has_allowed_cui = None + + self._cui_to_idx = {cui: idx for idx, cui in enumerate(self.cui_keys)} + self._name_to_idx = {name: idx for idx, name in enumerate(self.name_keys)} + self._name_to_cui_idxs = [ + [ self._cui_to_idx[cui] + for cui in self.cdb.name2info[name].get("per_cui_status", {}).keys() + if cui in self._cui_to_idx ] + for name in self._name_keys + ] + + def embed_cui_names(self, + embedding_model_name: str, + ) -> None: + """Obtain embeddings for all prefered_names in the CDB using the specified + embedding model and store them in the name2info.context_vectors + Args: + embedding_model_name (str): The name of the embedding model to use. + batch_size (int): The size of the batches to use when embedding names. Default 4096 + """ + if embedding_model_name == self.embedding_model_name and "cui_embeddings" in self.cdb.addl_info: + logger.warning("Using the same model for embedding names.") + else: + self.embedding_model_name = embedding_model_name + self._load_transformers(embedding_model_name) + # use the preferred name, if not take the longest name + # cui_names = [self.cdb.cui2info[cui]["preferred_name"] for cui in self._name_keys] + cui_names = [max(self.cdb.cui2info[cui]["names"], key=len) for cui in self._cui_keys] + # embed each name in batches. Because there can be 3+ million names + total_batches = math.ceil(len(cui_names) / self.cnf_l.embedding_batch_size) + all_embeddings = [] + for names in tqdm(self._batch_data(cui_names, self.cnf_l.embedding_batch_size), total=total_batches, desc="Embedding cuis' preferred names"): + with torch.no_grad(): + # removing ~ from names, as it is used to indicate a space in the CDB + names_to_embed = [name.replace("~", " ") for name in names] + embeddings= self._embed(names_to_embed, self.device) + all_embeddings.append(embeddings.cpu()) + # cat all batches into one tensor + all_embeddings = torch.cat(all_embeddings, dim=0) + self.cdb.addl_info["cui_embeddings"] = all_embeddings + self.cdb.addl_info["cui_to_idx"] = {cui: idx for idx, cui in enumerate(self._cui_keys)} + logger.debug("Embedding cui names done, total: %d", len(names)) + + def embed_names(self, + embedding_model_name: str) -> None: """Obtain embeddings for all names in the CDB using the specified embedding model and store them in the name2info.context_vectors Args: @@ -80,167 +100,289 @@ def embed_names(self, embedding_model_name: str, batch_size: int = 4096) -> None batch_size (int): The size of the batches to use when embedding names. Default 4096 """ if embedding_model_name == self.embedding_model_name: - logger.debug("Using the same embedding model for training.") + logger.debug("Using the same model for embedding names.") else: self.embedding_model_name = embedding_model_name self._load_transformers(embedding_model_name) names = list(self.cdb.name2info.keys()) # embed each name in batches. Because there can be 3+ million names - total_batches = math.ceil(len(names) / batch_size) - for names in tqdm(self._batch_data(names, batch_size), total=total_batches + 1, desc="Embedding names"): + total_batches = math.ceil(len(names) / self.cnf_l.embedding_batch_size) + all_embeddings = [] + for names in tqdm(self._batch_data(names, self.cnf_l.embedding_batch_size), total=total_batches, desc="Embedding names"): with torch.no_grad(): # removing ~ from names, as it is used to indicate a space in the CDB names_to_embed = [name.replace("~", " ") for name in names] - batch_dict = self.tokenizer(names_to_embed, max_length=self.max_length, padding=True, truncation=True, return_tensors='pt').to(self.device) - outputs = self.model(**batch_dict) - embeddings = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) - embeddings = F.normalize(embeddings, p=2, dim=1) - for name, embedding in zip(names, embeddings): - name_info = self.cdb.name2info[name] - name_info["context_vectors"] = embedding.cpu() - self.cdb.name2info[name] = name_info + embeddings = self._embed(names_to_embed, self.device) + all_embeddings.append(embeddings.cpu()) + all_embeddings = torch.cat(all_embeddings, dim=0) + self.cdb.addl_info["name_embeddings"] = all_embeddings + self.cdb.addl_info["name_to_idx"] = {name: idx for idx, name in enumerate(self.name_keys)} logger.debug("Embedding names done, total: %d", len(names)) def get_type(self) -> CoreComponentType: return CoreComponentType.linking - def _batch_data(self, data, batch_size=4096) -> Iterator[list]: + def _batch_data(self, data, batch_size=512) -> Iterator[list]: for i in range(0, len(data), batch_size): yield data[i:i + batch_size] def _load_transformers(self, - embedding_model_name: str = DEFAULT_MODEL) -> None: + embedding_model_name) -> None: """Load the transformers model and tokenizer. No need to load a transformer model until it's required. Args: embedding_model_name (str): The name of the embedding model to load. Default is "sentence-transformers/all-MiniLM-L6-v2" """ - if not hasattr(self, "model") or not hasattr(self, "tokenizer") or embedding_model_name != self.embedding_model_name: - self.embedding_model_name = embedding_model_name + if not hasattr(self, "model") or not hasattr(self, "tokenizer") or embedding_model_name != self.cnf_l.embedding_model_name: + self.cnf_l.embedding_model_name = embedding_model_name self.tokenizer = AutoTokenizer.from_pretrained(embedding_model_name) self.model = AutoModel.from_pretrained(embedding_model_name) self.model.eval() - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + gpu_device = self.cnf_l.gpu_device + self.device = torch.device(gpu_device or ("cuda" if torch.cuda.is_available() else "cpu")) self.model.to(self.device) logger.debug(f"Loaded embedding model: {embedding_model_name} on device: {self.device}") - - def _score_names(self, names_to_score: list[str]): - """Predicts the appropriate names for detected names comparing embedding similarity - Args: - name (str): The detected name to score. - Returns: - list[tuple[str, float, Tensor]]: A list of tuples containing the predicted name, similarity, and embedding.""" - if not hasattr(self, "context_matrix"): - raise ValueError("Embeddings have not been initialised. Please run `cat._pipeline._components[-1].embed_names` first.") - - with torch.no_grad(): - batch_dict = self.tokenizer(names_to_score, max_length=self.max_length, padding=True, truncation=True, return_tensors='pt').to(self.device) - outputs = self.model(**batch_dict) - detected_name_embeddings = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) - detected_name_embeddings = F.normalize(detected_name_embeddings, p=2, dim=1) - scores = detected_name_embeddings @ self.context_matrix.T - - argmax_indices = torch.argmax(scores, dim=1) - predicted_names = [self.names[i] for i in argmax_indices] - max_similarities = scores.gather(1, argmax_indices.unsqueeze(1)).squeeze(1) - return [ - (pred_name, score.item(), emb.cpu()) - for pred_name, score, emb in zip(predicted_names, max_similarities, detected_name_embeddings) - ] - def _disambiguate_entity_by_vector_similarity(self, - potential_name_cui_pairs: list[tuple[str, str]], - detected_name_embedding: Tensor) -> str: - """Disambiguate entities based on vector similarity. - If there are multiple potential cuis, try to find the one with the highest similarity to the detected name. + def _embed(self, + to_embed: list[str], + device) -> Tensor: + """Embeds a list of strings + """ + batch_dict = self.tokenizer(to_embed, max_length=self.max_length, padding=True, truncation=True, return_tensors='pt').to(device) + outputs = self.model(**batch_dict) + outputs = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) + outputs = F.normalize(outputs, p=2, dim=1) + return outputs.half() + + def _get_context_tokens(self, + entity: MutableEntity, + doc: MutableDocument, + size: int + ) -> tuple[list[MutableToken], + list[MutableToken], + list[MutableToken]]: + """Get context tokens for an entity + Args: - name (str): The detected name. - potential_name_cui_pairs (list[tuple[str, str]]): List of tuples containing CUI and preferred name pairs. - name_embedding (Tensor): The embedding of the detected name. + entity (BaseEntity): The entity to look for. + doc (BaseDocument): The document look in. + size (int): The size of the entity. + Returns: - str: The CUI with the highest similarity to the detected name. + tuple[list[BaseToken], list[BaseToken], list[BaseToken]]: + The tokens on the left, centre, and right. """ - names = [name for _, name in potential_name_cui_pairs] - # we have to embed CUI preferred names because they might not exist - batch_dict = self.tokenizer(names, max_length=self.max_length, padding=True, truncation=True, return_tensors='pt').to(self.device) - outputs = self.model(**batch_dict) - name_vectors = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) - name_vectors = F.normalize(name_vectors, p=2, dim=1).cpu() - scores = detected_name_embedding @ name_vectors.T - best_idx = torch.argmax(scores).item() - return potential_name_cui_pairs[best_idx][0] + start_ind = entity.base.start_index + end_ind = entity.base.end_index + _left_tokens = doc[max(0, start_ind - size):start_ind] + tokens_left = [tkn for tkn in _left_tokens] + tokens_center: list[MutableToken] = list( + cast(Iterable[MutableToken], entity)) + _right_tokens = doc[end_ind + 1:end_ind + 1 + size] + tokens_right = [tkn for tkn in _right_tokens] + return tokens_left, tokens_center, tokens_right + + def _get_context_vectors(self, + doc: MutableDocument, + entities: list[MutableEntity], + size: int) -> Tensor: + """Get context vectors for all detected concepts based on their raw text or detected names. + + Args: + doc (BaseDocument): The document look in. + size (int): The size of the entity. + Returns: + tuple[list[BaseToken], list[BaseToken], list[BaseToken]]: + The tokens on the left, centre, and right.""" + texts = [] + for entity in entities: + tokens_left, tokens_center, tokens_right = self._get_context_tokens(entity, doc, size) + tokens = tokens_left + tokens_center + tokens_right + text = " ".join(token.base.text for token in tokens) + texts.append(text) + return self._embed(texts, self.device) - def _disambiguate_entity(self, - entity: MutableEntity, - cuis: list[str], - name_embedding: Tensor) -> str: - """Disambiguation where multiple cuis are linked to the same name. - Try to choose the best one based on cui2preffered names. - If there is multiple potential cuis still then try scoring the highest again + def _set_filters(self) -> None: + include_set = self.cnf_l.filters.cuis + exclude_set = self.cnf_l.filters.cuis_exclude + + # Check if sets changed (avoid recomputation if same) + if (include_set == self._last_include_set and + exclude_set == self._last_exclude_set): + return + + n = len(self._name_keys) + allowed_mask = torch.empty(n, dtype=torch.bool, device=self.device) + + if include_set: + # if in include set, ignore exclude set. + allowed_mask[:] = False + include_cui_idxs = {self._cui_to_idx[cui] for cui in include_set if cui in self._cui_to_idx} + include_idxs = [ + name_idx + for name_idx, name_cui_idxs in enumerate(self._name_to_cui_idxs) + if any(cui in include_cui_idxs for cui in name_cui_idxs) + ] + allowed_mask[torch.tensor(include_idxs, dtype=torch.long, device=self.device)] = True + else: + # only look at exclude if there's no include set + allowed_mask[:] = True + if exclude_set: + exclude_cui_idxs = {self._cui_to_idx[cui] for cui in exclude_set if cui in self._cui_to_idx} + exclude_idxs = [i for i, name_cui_idxs in enumerate(self._name_to_cui_idxs) if any(ci in exclude_cui_idxs for ci in name_cui_idxs)] + allowed_mask[torch.tensor(exclude_idxs, dtype=torch.long, device=self.device)] = False + + # checking if a name has at least 1 cui related to it. Might as well do this cheeck here. + _has_cuis_all = torch.tensor( + [bool(self.cdb.name2info[name]["per_cui_status"]) for name in self.name_keys], + device=self.device + ) + self._valid_names = (_has_cuis_all & allowed_mask) + self._last_include_set = include_set + self._last_exclude_set = exclude_set + + def _disambiguate_by_cui(self, + cui_candidates: list[str], + scores: Tensor): + """Disambiguate a detected concept by a list of potential cuis Args: - entity (MutableEntity): The entity to disambiguate. - cui (str): The CUI to disambiguate to. - embedding (Tensor): The embedding of the detected name. + cuis (list[str]): Potential cuis + cui_to_idx (dict[str, int]): Mapping of cui to relevant idx position + scores (Tensor): Scores for the detected cui2info concepts similarity + cui_keys (list[str]): idx_to_cui inverse Returns: - str: The disambiguated CUI. + tuple[str, int]: + The CUI and its similarity """ - # if theres only one CUI, just return it - if len(cuis) == 1: - return cuis[0] - # collect all preferred name / cui pairs first - potential_name_cui_pairs = [(cui, self.cdb.cui2info[cui]['preferred_name'].replace("~", " ")) for cui in cuis] - # if there are multiple, try to find all that matches the detected name - name = entity.detected_name or entity.base.text - name = name.replace("~", " ") - matching_cuis = [cui for cui, preferred_name in potential_name_cui_pairs if preferred_name.lower() == name.lower()] - # if there are multiple matching, just return the first one - # if there are mulitple preferred names then I'm not sure how to choose - if len(matching_cuis) == 1: - return matching_cuis[0][0] - else: - # no perfect names match, so disambiguate by vector similarities - return self._disambiguate_entity_by_vector_similarity(potential_name_cui_pairs, name_embedding) - - def _process_entity_inference( - self, - entities: MutableEntity, - ) -> Iterator[MutableEntity]: - """Infer all entities at once (or in batches), to avoid multiple gpu calls when it isn't nessescary""" - # I don't think we have to concern ourselves with link candidates from the NER step. - # Check does it have a detected name, if not just use the base text - names_to_score = [entity.detected_name or entity.base.text for entity in entities] - names_to_score = [name.replace("~", " ") for name in names_to_score] - results = self._score_names(names_to_score) - - for entity, (predicted_name, similarity, embedding) in zip(entities, results): - # is there a better way to get cui2name mapping? - # this isn't a one to one mapping, so we just take the first one - predicted_cuis = list(self.cdb.name2info[predicted_name]["per_cui_status"].keys()) - # filter out unwanted cuis - cnf_l = self.config.components.linking - predicted_cuis = [cui for cui in predicted_cuis if cnf_l.filters.check_filters(cui)] - # if there are no cuis, just skip the entity - if not predicted_cuis: - continue - predicted_cui = self._disambiguate_entity(entity, predicted_cuis, embedding) + cui_idxs = [self._cui_to_idx[cui] for cui in cui_candidates] + candidate_scores = scores[cui_idxs] + candidate_idx = torch.argmax(candidate_scores).item() + best_idx = cui_idxs[candidate_idx] + + predicted_cui = self._cui_keys[best_idx] + similarity = candidate_scores[candidate_idx].item() + return predicted_cui, similarity + + def _inference_by_names( + self, + doc: MutableDocument, + entities: list[MutableEntity]) -> Iterator[MutableEntity]: + """Infer all entities at once (or in batches), to avoid multiple gpu calls when it isn't nessescary. + Args: + doc (BaseDocument): The document look in. + name_keys (list[str]): list of all names2info + cui_keys (list[str]): list of all cuis2info + context_matrix: Tensor of context matrix we're planning to use could be all names from name2info, + or prefered names from cui2info[cui]["preferred_name"] + Yields: + entity (MutableEntity): Entity with a relevant cui prediction - or skip if it's not suitable.""" + detected_context_vectors = self._get_context_vectors(doc, entities, self.cnf_l.context_window_size) + + # score all detected contexts vs all names, handle in the loop each individual case + names_scores = detected_context_vectors @ self.names_context_matrix.T + cui_scores = detected_context_vectors @ self.cui_context_matrix.T + sorted_indices = torch.argsort(names_scores, dim=1, descending=True) + + for i, entity in enumerate(entities): + link_candidates = [cui for cui in entity.link_candidates if self.cnf_l.filters.check_filters(cui)] + if self.cnf_l.use_ner_link_candidates and len(link_candidates) == 1: + best_idx = self._cui_to_idx[link_candidates[0]] + predicted_cui = link_candidates[0] + similarity = names_scores[i, best_idx].item() + elif self.cnf_l.use_ner_link_candidates and len(link_candidates) > 1: + name_to_cuis = defaultdict(list) + for cui in link_candidates: + for name in self.cdb.cui2info[cui]["names"]: + name_to_cuis[name].append(cui) + + name_idxs = [self._name_to_idx[name] for name in name_to_cuis] + indexed_scores = names_scores[i, name_idxs] + + best_local_pos = torch.argmax(indexed_scores).item() + best_global_idx = name_idxs[best_local_pos] + similarity = names_scores[i, best_global_idx].item() + best_name = self.name_keys[best_global_idx] + best_cuis = name_to_cuis[best_name] + if (len(best_cuis) == 1): + predicted_cui = best_cuis[0] + else: + predicted_cui, _ = self._disambiguate_by_cui( + best_cuis, + cui_scores[i,:] + ) + else: + row_sorted = sorted_indices[i] # sorted candidate indices for entity i + + # Find the first candidate in this row with CUIs + first_true_pos = torch.nonzero(self._valid_names[row_sorted], as_tuple=True)[0][0].item() + + # Get global index + name + top_name_idx = row_sorted[first_true_pos].item() + similarity = names_scores[i, top_name_idx].item() + detected_name = self.name_keys[top_name_idx] + cuis = self.cdb.name2info[detected_name]["per_cui_status"] + + # Disambiguate by CUI + predicted_cui, _ = self._disambiguate_by_cui( + cuis, cui_scores[i,:] + ) + entity.cui = predicted_cui entity.context_similarity = similarity + yield entity + + def _inference_by_cui( + self, + doc: MutableDocument, + entities: list[MutableEntity] + ) -> Iterator[MutableEntity]: + """Infer all entities at once (or in batches), to avoid multiple gpu calls when it isn't nessescary. + Args: + doc (BaseDocument): The document look in. + name_keys (list[str]): list of all names2info + cui_keys (list[str]): list of all cuis2info + context_matrix: Tensor of context matrix we're planning to use + embedded from names in cui2info[cui]["preferred_name"] + Yields: + entity (MutableEntity): Entity with a relevant cui prediction - or skip if it's not suitable.""" + # 14 is a nice average between contexts in the context based linker + detected_context_vectors = self._get_context_vectors(doc, entities, self.cnf_l.context_window_size) + cui_to_idx = {cui: idx for idx, cui in enumerate(self.cui_keys)} + # score all detected contexts vs all cui preferred names, handle in the loop each individual case + scores = detected_context_vectors @ self.cui_context_matrix.T + sorted_indices = torch.argsort(scores, dim=1, descending=True) + for i, entity in enumerate(entities): + # might as well filter here rather than later + link_candidates = [cui for cui in entity.link_candidates if self.cnf_l.filters.check_filters(cui)] + if self.cnf_l.use_ner_link_candidates and len(link_candidates) == 1: + best_idx = cui_to_idx[link_candidates[0]] + entity.cui = link_candidates[0] + + similarity = scores[i, best_idx].item() + entity.context_similarity = similarity + elif self.cnf_l.use_ner_link_candidates and len(link_candidates) > 1: + predicted_cui, similarity = self._disambiguate_by_cui( + link_candidates, + scores[i,:] + ) + entity.cui = predicted_cui + entity.context_similarity = similarity + else: + # no link candidates -> i.e. filtered or none from NER + # therefore: score vs all cui preffered names! + top_cui_idx = sorted_indices[i, 0].item() + entity.cui = self.cui_keys[top_cui_idx] + entity.context_similarity = scores[i, top_cui_idx].item() - def _inference(self, doc: MutableDocument) -> Iterator[MutableEntity]: - # doing this here so it isn't done on each entity - self.names = list(self.cdb.name2info.keys()) - self.context_matrix = torch.stack([self.cdb.name2info[name]["context_vectors"] for name in self.cdb.name2info]).to(self.device) - for entities in self._batch_data(doc.ner_ents): - logger.debug("Linker started with entities: %s", len(entities)) - yield from self._process_entity_inference(entities) + yield entity def _check_similarity(self, cui: str, context_similarity: float) -> bool: - th_type = self.config.components.linking.similarity_threshold_type - threshold = self.config.components.linking.similarity_threshold + th_type = self.cnf_l.similarity_threshold_type + threshold = self.cnf_l.similarity_threshold if th_type == 'static': return context_similarity >= threshold if th_type == 'dynamic': @@ -257,26 +399,65 @@ def _last_token_pool(self, last_hidden_states: Tensor, sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] + + def _build_context_matrices(self): + self._name_keys = list(self.cdb.name2info) + self._cui_keys = list(self.cdb.cui2info) + + if "name_embeddings" in self.cdb.addl_info: + self._names_context_matrix = self.cdb.addl_info["name_embeddings"].half().to(self.device) + if "cui_embeddings" in self.cdb.addl_info: + self._cui_context_matrix = self.cdb.addl_info["cui_embeddings"].half().to(self.device) def __call__(self, doc: MutableDocument) -> MutableDocument: # Reset main entities, will be recreated later doc.linked_ents.clear() + self._load_transformers(self.embedding_model_name) - - cnf_l = self.config.components.linking - - if cnf_l.train: + if self.cnf_l.train: logger.warning("Attemping to train an embedding linker. This is not required.") - linked_entities = self._inference(doc) - # evaluating generator here because the `all_ents` list gets - # cleared afterwards otherwise - le = list(linked_entities) + + inference = self._inference_by_cui + if self.cnf_l.linking_strategy == "names": + inference = self._inference_by_names + # filters are only done this way when infering by names + self._set_filters() + + all_ents = doc.ner_ents + le = [] + with torch.no_grad(): + for entities in self._batch_data(all_ents, self.cnf_l.linking_batch_size): + le.extend(list(inference(doc, entities))) doc.ner_ents.clear() doc.ner_ents.extend(le) create_main_ann(doc) return doc + + @property + def name_keys(self): + if self._name_keys is None: + self._build_context_matrices() + return self._name_keys + + @property + def cui_keys(self): + if self._cui_keys is None: + self._build_context_matrices() + return self._cui_keys + + @property + def names_context_matrix(self): + if self._names_context_matrix is None: + self._build_context_matrices() + return self._names_context_matrix + + @property + def cui_context_matrix(self): + if self._cui_context_matrix is None: + self._build_context_matrices() + return self._cui_context_matrix @classmethod def create_new_component( diff --git a/medcat-v2/medcat/config/config.py b/medcat-v2/medcat/config/config.py index 6f21bad0c..d9fbf4a9d 100644 --- a/medcat-v2/medcat/config/config.py +++ b/medcat-v2/medcat/config/config.py @@ -376,6 +376,35 @@ class Linking(ComponentConfig): class Config: extra = 'allow' +class EmbeddingLinking(Linking): + + """The embedding linker never needs to be trained in its + current implementation.""" + train: bool = False + """Similarity between context bert-like vector and names or + cui preferred names""" + similarity_threshold: float = 0.25 + """Name of the embedding model. It must be downloadable from + huggingface linked from an appropriate file directory""" + embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2" + """Max number of tokens to be embedded from a name.""" + max_token_length: int = 64 + """How many pieces names can be embedded at once, useful when + embedding name2info names, cui2info names""" + embedding_batch_size: int = 4096 + """How many entities to be linked at once""" + linking_batch_size: int = 512 + """Choose the linking method, via all names or a single name + representing a cui. Defaults to cuis if this is changed""" + linking_strategy: str = "names" + """Choose a device for the linking model to be stored. If None + then an appropriate GPU device that is available will be chosen""" + gpu_device: Optional[Any] = None + """Choose the window size to get context vectors.""" + context_window_size: int = 11 + """Link candidates are provided by some NER steps. This will flag if + you want to trust them or not.""" + use_ner_link_candidates: bool = True class Preprocessing(SerialisableBaseModel): """The preprocessing part of the config""" From 3b0827689742e1da4ad4015330a5874173271059 Mon Sep 17 00:00:00 2001 From: Adam Sutton Date: Thu, 21 Aug 2025 23:18:39 +0100 Subject: [PATCH 03/12] handling no link candidates along with fixes --- .../components/linking/embedding_linker.py | 234 +++++++++--------- medcat-v2/medcat/config/config.py | 6 +- 2 files changed, 118 insertions(+), 122 deletions(-) diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py index a5a6559dd..ea986d593 100644 --- a/medcat-v2/medcat/components/linking/embedding_linker.py +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -38,20 +38,21 @@ def __init__(self, self.embedding_model_name = self.cnf_l.embedding_model_name self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self._name_keys = list(self.cdb.name2info) + self._cui_keys = list(self.cdb.cui2info) + # these only need to be populated when called for embedding or inference - self._name_keys = None - self._cui_keys = None self._names_context_matrix = None self._cui_context_matrix = None - # used for filters, and if the name contains a valid cui see: _set_filters + # used for filters and name embedding, and if the name contains a valid cui see: _set_filters self._last_include_set = None self._last_exclude_set = None self._allowed_mask = None self._name_has_allowed_cui = None - self._cui_to_idx = {cui: idx for idx, cui in enumerate(self.cui_keys)} - self._name_to_idx = {name: idx for idx, name in enumerate(self.name_keys)} + self._cui_to_idx = {cui: idx for idx, cui in enumerate(self._cui_keys)} + self._name_to_idx = {name: idx for idx, name in enumerate(self._name_keys)} self._name_to_cui_idxs = [ [ self._cui_to_idx[cui] for cui in self.cdb.name2info[name].get("per_cui_status", {}).keys() @@ -73,8 +74,7 @@ def embed_cui_names(self, else: self.embedding_model_name = embedding_model_name self._load_transformers(embedding_model_name) - # use the preferred name, if not take the longest name - # cui_names = [self.cdb.cui2info[cui]["preferred_name"] for cui in self._name_keys] + # Use the longest name cui_names = [max(self.cdb.cui2info[cui]["names"], key=len) for cui in self._cui_keys] # embed each name in batches. Because there can be 3+ million names total_batches = math.ceil(len(cui_names) / self.cnf_l.embedding_batch_size) @@ -88,7 +88,6 @@ def embed_cui_names(self, # cat all batches into one tensor all_embeddings = torch.cat(all_embeddings, dim=0) self.cdb.addl_info["cui_embeddings"] = all_embeddings - self.cdb.addl_info["cui_to_idx"] = {cui: idx for idx, cui in enumerate(self._cui_keys)} logger.debug("Embedding cui names done, total: %d", len(names)) def embed_names(self, @@ -111,12 +110,11 @@ def embed_names(self, for names in tqdm(self._batch_data(names, self.cnf_l.embedding_batch_size), total=total_batches, desc="Embedding names"): with torch.no_grad(): # removing ~ from names, as it is used to indicate a space in the CDB - names_to_embed = [name.replace("~", " ") for name in names] + names_to_embed = [name.replace(self.config.general.separator, " ") for name in names] embeddings = self._embed(names_to_embed, self.device) all_embeddings.append(embeddings.cpu()) all_embeddings = torch.cat(all_embeddings, dim=0) self.cdb.addl_info["name_embeddings"] = all_embeddings - self.cdb.addl_info["name_to_idx"] = {name: idx for idx, name in enumerate(self.name_keys)} logger.debug("Embedding names done, total: %d", len(names)) @@ -155,13 +153,11 @@ def _embed(self, outputs = F.normalize(outputs, p=2, dim=1) return outputs.half() - def _get_context_tokens(self, + def _get_context(self, entity: MutableEntity, doc: MutableDocument, size: int - ) -> tuple[list[MutableToken], - list[MutableToken], - list[MutableToken]]: + ) -> str: """Get context tokens for an entity Args: @@ -176,14 +172,14 @@ def _get_context_tokens(self, start_ind = entity.base.start_index end_ind = entity.base.end_index - _left_tokens = doc[max(0, start_ind - size):start_ind] - tokens_left = [tkn for tkn in _left_tokens] - tokens_center: list[MutableToken] = list( - cast(Iterable[MutableToken], entity)) - _right_tokens = doc[end_ind + 1:end_ind + 1 + size] - tokens_right = [tkn for tkn in _right_tokens] - - return tokens_left, tokens_center, tokens_right + left_most_token = doc[max(0, start_ind - size)] + left_index = left_most_token.base.char_index + + right_most_token = doc[min(len(doc) - 1, end_ind + size)] + right_index = right_most_token.base.char_index + len(right_most_token.base.text) + + snippet = doc.base.text[left_index:right_index] + return snippet def _get_context_vectors(self, doc: MutableDocument, @@ -199,9 +195,7 @@ def _get_context_vectors(self, The tokens on the left, centre, and right.""" texts = [] for entity in entities: - tokens_left, tokens_center, tokens_right = self._get_context_tokens(entity, doc, size) - tokens = tokens_left + tokens_center + tokens_right - text = " ".join(token.base.text for token in tokens) + text = self._get_context(entity, doc, size) texts.append(text) return self._embed(texts, self.device) @@ -237,7 +231,7 @@ def _set_filters(self) -> None: # checking if a name has at least 1 cui related to it. Might as well do this cheeck here. _has_cuis_all = torch.tensor( - [bool(self.cdb.name2info[name]["per_cui_status"]) for name in self.name_keys], + [bool(self.cdb.name2info[name]["per_cui_status"]) for name in self._name_keys], device=self.device ) self._valid_names = (_has_cuis_all & allowed_mask) @@ -266,7 +260,7 @@ def _disambiguate_by_cui(self, similarity = candidate_scores[candidate_idx].item() return predicted_cui, similarity - def _inference_by_names( + def _inference( self, doc: MutableDocument, entities: list[MutableEntity]) -> Iterator[MutableEntity]: @@ -288,11 +282,11 @@ def _inference_by_names( for i, entity in enumerate(entities): link_candidates = [cui for cui in entity.link_candidates if self.cnf_l.filters.check_filters(cui)] - if self.cnf_l.use_ner_link_candidates and len(link_candidates) == 1: + if len(link_candidates) == 1: best_idx = self._cui_to_idx[link_candidates[0]] predicted_cui = link_candidates[0] similarity = names_scores[i, best_idx].item() - elif self.cnf_l.use_ner_link_candidates and len(link_candidates) > 1: + elif len(link_candidates) > 1: name_to_cuis = defaultdict(list) for cui in link_candidates: for name in self.cdb.cui2info[cui]["names"]: @@ -304,13 +298,13 @@ def _inference_by_names( best_local_pos = torch.argmax(indexed_scores).item() best_global_idx = name_idxs[best_local_pos] similarity = names_scores[i, best_global_idx].item() - best_name = self.name_keys[best_global_idx] - best_cuis = name_to_cuis[best_name] - if (len(best_cuis) == 1): - predicted_cui = best_cuis[0] + best_name = self._name_keys[best_global_idx] + cuis = name_to_cuis[best_name] + if (len(cuis) == 1): + predicted_cui = cuis[0] else: predicted_cui, _ = self._disambiguate_by_cui( - best_cuis, + cuis, cui_scores[i,:] ) else: @@ -322,73 +316,25 @@ def _inference_by_names( # Get global index + name top_name_idx = row_sorted[first_true_pos].item() similarity = names_scores[i, top_name_idx].item() - detected_name = self.name_keys[top_name_idx] - cuis = self.cdb.name2info[detected_name]["per_cui_status"] + detected_name = self._name_keys[top_name_idx] + cuis = list(self.cdb.name2info[detected_name]["per_cui_status"].keys()) - # Disambiguate by CUI predicted_cui, _ = self._disambiguate_by_cui( - cuis, cui_scores[i,:] + cuis, + cui_scores[i,:] ) - entity.cui = predicted_cui - entity.context_similarity = similarity - - yield entity - - def _inference_by_cui( - self, - doc: MutableDocument, - entities: list[MutableEntity] - ) -> Iterator[MutableEntity]: - """Infer all entities at once (or in batches), to avoid multiple gpu calls when it isn't nessescary. - Args: - doc (BaseDocument): The document look in. - name_keys (list[str]): list of all names2info - cui_keys (list[str]): list of all cuis2info - context_matrix: Tensor of context matrix we're planning to use - embedded from names in cui2info[cui]["preferred_name"] - Yields: - entity (MutableEntity): Entity with a relevant cui prediction - or skip if it's not suitable.""" - # 14 is a nice average between contexts in the context based linker - detected_context_vectors = self._get_context_vectors(doc, entities, self.cnf_l.context_window_size) - cui_to_idx = {cui: idx for idx, cui in enumerate(self.cui_keys)} - # score all detected contexts vs all cui preferred names, handle in the loop each individual case - scores = detected_context_vectors @ self.cui_context_matrix.T - sorted_indices = torch.argsort(scores, dim=1, descending=True) - for i, entity in enumerate(entities): - # might as well filter here rather than later - link_candidates = [cui for cui in entity.link_candidates if self.cnf_l.filters.check_filters(cui)] - if self.cnf_l.use_ner_link_candidates and len(link_candidates) == 1: - best_idx = cui_to_idx[link_candidates[0]] - entity.cui = link_candidates[0] - - similarity = scores[i, best_idx].item() - entity.context_similarity = similarity - elif self.cnf_l.use_ner_link_candidates and len(link_candidates) > 1: - predicted_cui, similarity = self._disambiguate_by_cui( - link_candidates, - scores[i,:] - ) - entity.cui = predicted_cui + if self.cnf_l.use_similarity_threshold and self._check_similarity(similarity): + entity.cui = predicted_cui entity.context_similarity = similarity - else: - # no link candidates -> i.e. filtered or none from NER - # therefore: score vs all cui preffered names! - top_cui_idx = sorted_indices[i, 0].item() - entity.cui = self.cui_keys[top_cui_idx] - entity.context_similarity = scores[i, top_cui_idx].item() - - yield entity + yield entity - def _check_similarity(self, cui: str, context_similarity: float) -> bool: - th_type = self.cnf_l.similarity_threshold_type - threshold = self.cnf_l.similarity_threshold - if th_type == 'static': + def _check_similarity(self, context_similarity: float) -> bool: + if self.cnf_l.use_similarity_threshold: + threshold = self.cnf_l.similarity_threshold return context_similarity >= threshold - if th_type == 'dynamic': - conf = self.cdb.cui2info[cui]['average_confidence'] - return context_similarity >= conf * threshold - return False + else: + return True def _last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: @@ -400,14 +346,79 @@ def _last_token_pool(self, last_hidden_states: Tensor, batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] - def _build_context_matrices(self): - self._name_keys = list(self.cdb.name2info) - self._cui_keys = list(self.cdb.cui2info) - + def _build_context_matrices(self) -> None: if "name_embeddings" in self.cdb.addl_info: self._names_context_matrix = self.cdb.addl_info["name_embeddings"].half().to(self.device) if "cui_embeddings" in self.cdb.addl_info: self._cui_context_matrix = self.cdb.addl_info["cui_embeddings"].half().to(self.device) + + + def _generate_link_candidates(self, + doc: MutableDocument, + entities: list[MutableEntity] + ) -> None: + """Generate link candidates for each detected entity based on context vectors with size 0. + Compare to names to get the most similar name in the cdb to the detected concept.""" + detected_context_vectors = self._get_context_vectors(doc, entities, 0) + + # score all detected contexts vs all names, handle in the loop each individual case + names_scores = detected_context_vectors @ self.names_context_matrix.T + sorted_indices = torch.argsort(names_scores, dim=1, descending=True) + + for i, entity in enumerate(entities): + row_sorted = sorted_indices[i] # sorted candidate indices for entity i + + valid_mask = self._valid_names[row_sorted] + # TODO: potentially choose multiple names that are all within a certain range of the top scoring. + # for now just choose the highest scoring name + valid_positions = torch.nonzero(valid_mask, as_tuple=True)[0][:1] + + cuis = set() + for pos in valid_positions.tolist(): + top_name_idx = row_sorted[pos].item() + detected_name = self._name_keys[top_name_idx] + cuis.update(self.cdb.name2info[detected_name]["per_cui_status"].keys()) + + entity.link_candidates = list(cuis) + + + def _pre_inference(self, + doc: MutableDocument) -> tuple[list, list]: + """Checking all entities for entites with only a single link candidate and to avoid full inference step. + If we want to calculate similarities, or not use link candidates then just return the entities""" + all_ents = doc.ner_ents + if not self.cnf_l.use_ner_link_candidates: + to_generate_link_candidates = all_ents + else: + to_generate_link_candidates = [entity for entity in all_ents if not entity.link_candidates] + + # generate our own link candidates if it's required, or wanted + for entities in self._batch_data(to_generate_link_candidates, self.cnf_l.linking_batch_size): + self._generate_link_candidates(doc, entities) + + if self.cnf_l.always_calculate_similarity: + return [], all_ents + + le = [] + to_infer = [] + for entity in all_ents: + if len(entity.link_candidates) == 1: + # if the include filter exists and the only cui is in it + if self.cnf_l.filters.cuis and entity.link_candidates[0] in self.cnf_l.filters.cuis: + entity.cui = entity.link_candidates[0] + entity.context_similarity = 1 + le.append(entity) + continue + # if only the exclude filter exists and the only cui is NOT in it + elif self.cnf_l.filters.cuis_exclude and entity.link_candidates[0] not in self.cnf_l.filters.cuis_exclude: + entity.cui = entity.link_candidates[0] + entity.context_similarity = 1 + le.append(entity) + continue + # if it has to be inferred due to filters or number of link candidates then add it to the infer list + to_infer.append(entity) + return le, to_infer + def __call__(self, doc: MutableDocument) -> MutableDocument: # Reset main entities, will be recreated later @@ -416,36 +427,21 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: self._load_transformers(self.embedding_model_name) if self.cnf_l.train: logger.warning("Attemping to train an embedding linker. This is not required.") + if self.cnf_l.filters.cuis and self.cnf_l.filters.cuis_exclude: + logger.warning("You have both include and exclude filters for CUIs set. This will result in only include CUIs being filtered.") - inference = self._inference_by_cui - if self.cnf_l.linking_strategy == "names": - inference = self._inference_by_names - # filters are only done this way when infering by names - self._set_filters() + self._set_filters() - all_ents = doc.ner_ents - le = [] with torch.no_grad(): - for entities in self._batch_data(all_ents, self.cnf_l.linking_batch_size): - le.extend(list(inference(doc, entities))) + le, to_infer = self._pre_inference(doc) + for entities in self._batch_data(to_infer, self.cnf_l.linking_batch_size): + le.extend(list(self._inference(doc, entities))) doc.ner_ents.clear() doc.ner_ents.extend(le) create_main_ann(doc) return doc - - @property - def name_keys(self): - if self._name_keys is None: - self._build_context_matrices() - return self._name_keys - - @property - def cui_keys(self): - if self._cui_keys is None: - self._build_context_matrices() - return self._cui_keys @property def names_context_matrix(self): diff --git a/medcat-v2/medcat/config/config.py b/medcat-v2/medcat/config/config.py index d9fbf4a9d..89a342a5e 100644 --- a/medcat-v2/medcat/config/config.py +++ b/medcat-v2/medcat/config/config.py @@ -394,9 +394,6 @@ class EmbeddingLinking(Linking): embedding_batch_size: int = 4096 """How many entities to be linked at once""" linking_batch_size: int = 512 - """Choose the linking method, via all names or a single name - representing a cui. Defaults to cuis if this is changed""" - linking_strategy: str = "names" """Choose a device for the linking model to be stored. If None then an appropriate GPU device that is available will be chosen""" gpu_device: Optional[Any] = None @@ -405,6 +402,9 @@ class EmbeddingLinking(Linking): """Link candidates are provided by some NER steps. This will flag if you want to trust them or not.""" use_ner_link_candidates: bool = True + """""" + """Do we have a similarity threshold we care about?""" + use_similarity_threshold: bool = True class Preprocessing(SerialisableBaseModel): """The preprocessing part of the config""" From 5a57b9af189be99bafa7c5608ddeb057e103a74c Mon Sep 17 00:00:00 2001 From: Adam Sutton Date: Thu, 28 Aug 2025 19:00:55 +0100 Subject: [PATCH 04/12] added testing and fixes from testing --- .../components/linking/embedding_linker.py | 73 ++++++++++++------- medcat-v2/medcat/components/types.py | 5 +- medcat-v2/medcat/config/config.py | 4 +- .../linking/test_embedding_linker.py | 66 +++++++++++++++++ 4 files changed, 116 insertions(+), 32 deletions(-) create mode 100644 medcat-v2/tests/components/linking/test_embedding_linker.py diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py index ea986d593..ed4aa3084 100644 --- a/medcat-v2/medcat/components/linking/embedding_linker.py +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -1,9 +1,9 @@ from medcat.cdb import CDB from medcat.config.config import Config, ComponentConfig, EmbeddingLinking from medcat.components.types import CoreComponentType, AbstractCoreComponent -from medcat.tokenizing.tokens import MutableEntity, MutableDocument, MutableToken +from medcat.tokenizing.tokens import MutableEntity, MutableDocument from medcat.tokenizing.tokenizers import BaseTokenizer -from typing import Optional, Iterator, cast, Iterable +from typing import Optional, Iterator from medcat.vocab import Vocab from torch import Tensor from transformers import AutoTokenizer, AutoModel @@ -13,7 +13,12 @@ import torch.nn.functional as F import torch import logging -import math +import math +import re +import string +import re +from nltk.corpus import stopwords +stop_words = set(stopwords.words('english')) logger = logging.getLogger(__name__) class Linker(AbstractCoreComponent): @@ -35,7 +40,6 @@ def __init__(self, raise TypeError("Linking config must be an EmbeddingLinking instance") self.cnf_l: EmbeddingLinking = config.components.linking self.max_length = self.cnf_l.max_token_length - self.embedding_model_name = self.cnf_l.embedding_model_name self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._name_keys = list(self.cdb.name2info) @@ -51,6 +55,12 @@ def __init__(self, self._allowed_mask = None self._name_has_allowed_cui = None + # checking for config settings that aren't used in this linker + if self.cnf_l.prefer_frequent_concepts: + logger.warning(f"linker_config.prefer_frequent_concepts is not used in the embedding linker. But it is currently set to {self.cnf_l.prefer_frequent_concepts}.") + if self.cnf_l.prefer_primary_name: + logger.warning(f"linker_config.prefer_primary_name is not used in the embedding linker. But it is currently set to {self.cnf_l.prefer_primary_name}.") + self._cui_to_idx = {cui: idx for idx, cui in enumerate(self._cui_keys)} self._name_to_idx = {name: idx for idx, name in enumerate(self._name_keys)} self._name_to_cui_idxs = [ @@ -60,7 +70,17 @@ def __init__(self, for name in self._name_keys ] - def embed_cui_names(self, + def create_embeddings(self, + embedding_model_name: str = None): + """"Create embeddings for names and cuis longest names in the CDB.""" + if embedding_model_name == self.cnf_l.embedding_model_name and "cui_embeddings" in self.cdb.addl_info: + logger.warning("Using the same model for embedding names.") + else: + self.cnf_l.embedding_model_name = embedding_model_name + self._load_transformers(embedding_model_name) + + + def _embed_cui_names(self, embedding_model_name: str, ) -> None: """Obtain embeddings for all prefered_names in the CDB using the specified @@ -69,11 +89,11 @@ def embed_cui_names(self, embedding_model_name (str): The name of the embedding model to use. batch_size (int): The size of the batches to use when embedding names. Default 4096 """ - if embedding_model_name == self.embedding_model_name and "cui_embeddings" in self.cdb.addl_info: - logger.warning("Using the same model for embedding names.") + if embedding_model_name == self.cnf_l.embedding_model_name and "cui_embeddings" in self.cdb.addl_info and "name_embeddings" in self.cdb.addl_info: + logger.warning("Using the same model for embedding.") else: - self.embedding_model_name = embedding_model_name - self._load_transformers(embedding_model_name) + self.cnf_l.embedding_model_name = embedding_model_name + # Use the longest name cui_names = [max(self.cdb.cui2info[cui]["names"], key=len) for cui in self._cui_keys] # embed each name in batches. Because there can be 3+ million names @@ -82,7 +102,7 @@ def embed_cui_names(self, for names in tqdm(self._batch_data(cui_names, self.cnf_l.embedding_batch_size), total=total_batches, desc="Embedding cuis' preferred names"): with torch.no_grad(): # removing ~ from names, as it is used to indicate a space in the CDB - names_to_embed = [name.replace("~", " ") for name in names] + names_to_embed = [name.replace(self.config.general.separator, " ") for name in names] embeddings= self._embed(names_to_embed, self.device) all_embeddings.append(embeddings.cpu()) # cat all batches into one tensor @@ -90,7 +110,7 @@ def embed_cui_names(self, self.cdb.addl_info["cui_embeddings"] = all_embeddings logger.debug("Embedding cui names done, total: %d", len(names)) - def embed_names(self, + def _embed_names(self, embedding_model_name: str) -> None: """Obtain embeddings for all names in the CDB using the specified embedding model and store them in the name2info.context_vectors @@ -98,11 +118,10 @@ def embed_names(self, embedding_model_name (str): The name of the embedding model to use. batch_size (int): The size of the batches to use when embedding names. Default 4096 """ - if embedding_model_name == self.embedding_model_name: + if embedding_model_name == self.cnf_l.embedding_model_name: logger.debug("Using the same model for embedding names.") else: - self.embedding_model_name = embedding_model_name - self._load_transformers(embedding_model_name) + self.cnf_l.embedding_model_name = embedding_model_name names = list(self.cdb.name2info.keys()) # embed each name in batches. Because there can be 3+ million names total_batches = math.ceil(len(names) / self.cnf_l.embedding_batch_size) @@ -154,10 +173,10 @@ def _embed(self, return outputs.half() def _get_context(self, - entity: MutableEntity, - doc: MutableDocument, - size: int - ) -> str: + entity: MutableEntity, + doc: MutableDocument, + size: int + ) -> str: """Get context tokens for an entity Args: @@ -177,10 +196,10 @@ def _get_context(self, right_most_token = doc[min(len(doc) - 1, end_ind + size)] right_index = right_most_token.base.char_index + len(right_most_token.base.text) - + snippet = doc.base.text[left_index:right_index] return snippet - + def _get_context_vectors(self, doc: MutableDocument, entities: list[MutableEntity], @@ -404,13 +423,7 @@ def _pre_inference(self, for entity in all_ents: if len(entity.link_candidates) == 1: # if the include filter exists and the only cui is in it - if self.cnf_l.filters.cuis and entity.link_candidates[0] in self.cnf_l.filters.cuis: - entity.cui = entity.link_candidates[0] - entity.context_similarity = 1 - le.append(entity) - continue - # if only the exclude filter exists and the only cui is NOT in it - elif self.cnf_l.filters.cuis_exclude and entity.link_candidates[0] not in self.cnf_l.filters.cuis_exclude: + if self.cnf_l.filters.check_filters(entity.link_candidates[0]): entity.cui = entity.link_candidates[0] entity.context_similarity = 1 le.append(entity) @@ -423,8 +436,12 @@ def _pre_inference(self, def __call__(self, doc: MutableDocument) -> MutableDocument: # Reset main entities, will be recreated later doc.linked_ents.clear() + + if self.cdb.is_dirty: + logging.warning("CDB has been modified since last save/load. This might significantly affect linking performance.") + logging.warning("If you have added new concepts or changes, please re-embed the CDB names and cuis before linking.") - self._load_transformers(self.embedding_model_name) + self._load_transformers(self.cnf_l.embedding_model_name) if self.cnf_l.train: logger.warning("Attemping to train an embedding linker. This is not required.") if self.cnf_l.filters.cuis and self.cnf_l.filters.cuis_exclude: diff --git a/medcat-v2/medcat/components/types.py b/medcat-v2/medcat/components/types.py index 77c53c8ed..e2c28706d 100644 --- a/medcat-v2/medcat/components/types.py +++ b/medcat-v2/medcat/components/types.py @@ -135,7 +135,10 @@ def train(self, cui: str, "NoActionLinker.create_new_component"), "medcat2_two_step_linker": ( "medcat.components.linking.two_step_context_based_linker", - "TwoStepLinker.create_new_component") + "TwoStepLinker.create_new_component"), + "medcat2_embedding_linker": ( + "medcat.components.linking.embedding_linker", + "Linker.create_new_component"), } diff --git a/medcat-v2/medcat/config/config.py b/medcat-v2/medcat/config/config.py index 89a342a5e..01048a564 100644 --- a/medcat-v2/medcat/config/config.py +++ b/medcat-v2/medcat/config/config.py @@ -377,7 +377,6 @@ class Config: extra = 'allow' class EmbeddingLinking(Linking): - """The embedding linker never needs to be trained in its current implementation.""" train: bool = False @@ -398,11 +397,10 @@ class EmbeddingLinking(Linking): then an appropriate GPU device that is available will be chosen""" gpu_device: Optional[Any] = None """Choose the window size to get context vectors.""" - context_window_size: int = 11 + context_window_size: int = 14 """Link candidates are provided by some NER steps. This will flag if you want to trust them or not.""" use_ner_link_candidates: bool = True - """""" """Do we have a similarity threshold we care about?""" use_similarity_threshold: bool = True diff --git a/medcat-v2/tests/components/linking/test_embedding_linker.py b/medcat-v2/tests/components/linking/test_embedding_linker.py new file mode 100644 index 000000000..049e31e7d --- /dev/null +++ b/medcat-v2/tests/components/linking/test_embedding_linker.py @@ -0,0 +1,66 @@ +from medcat.components.linking import embedding_linker +from medcat.components import types +from medcat.config import Config +from medcat.vocab import Vocab +from medcat.cdb.concepts import CUIInfo, NameInfo +from medcat.components.types import TrainableComponent +import unittest +from ..helper import ComponentInitTests + +class FakeDocument: + def __init__(self, text): + self.text = text + +class FakeTokenizer: + def __call__(self, text: str) -> FakeDocument: + return FakeDocument(text) + +class FakeCDB: + def __init__(self, config: Config): + self.config = config + self.cui2info: dict[str, CUIInfo] = dict() + self.name2info: dict[str, NameInfo] = dict() + self.name_separator: str + + def weighted_average_function(self, nr: int) -> float: + return nr // 2.0 + + +class EmbeddingLinkerInitTests(ComponentInitTests, unittest.TestCase): + expected_def_components = 4 + comp_type = types.CoreComponentType.linking + default_cls = embedding_linker.Linker + default_creator = embedding_linker.Linker.create_new_component + module = embedding_linker + + @classmethod + def setUpClass(cls): + cls.cnf = Config() + cls.cnf.components.linking = embedding_linker.EmbeddingLinking() + cls.cnf.components.linking.comp_name = embedding_linker.Linker.name + cls.fcdb = FakeCDB(cls.cnf) + cls.fvocab = Vocab() + cls.vtokenizer = FakeTokenizer() + cls.comp_cnf = getattr(cls.cnf.components, cls.comp_type.name) + + def test_can_create_def_component(self): + component = types.create_core_component( + self.comp_type, + "medcat2_embedding_linker", # explicitly request embedding linker + self.cnf, self.vtokenizer, self.fcdb, self.fvocab, None + ) + self.assertIsInstance(component, self.default_cls) + + def test_has_default(self): + avail_components = types.get_registered_components(self.comp_type) + registered_names = [name for name, _ in avail_components] + self.assertIn("medcat2_embedding_linker", registered_names) + +class TrainableEmbeddingLinkerTests(unittest.TestCase): + cnf = Config() + cnf.components.linking = embedding_linker.EmbeddingLinking() + cnf.components.linking.comp_name = embedding_linker.Linker.name + linker = embedding_linker.Linker(FakeCDB(cnf), cnf) + + def test_linker_is_trainable(self): + self.assertNotIsInstance(self.linker, TrainableComponent) From 5ae1604a093149e288dbe8145cdf66cec2b6ca9e Mon Sep 17 00:00:00 2001 From: Adam Sutton Date: Mon, 1 Sep 2025 14:47:26 +0100 Subject: [PATCH 05/12] mypy fixes --- .../components/linking/embedding_linker.py | 36 ++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py index ed4aa3084..3346eb8f0 100644 --- a/medcat-v2/medcat/components/linking/embedding_linker.py +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -50,8 +50,8 @@ def __init__(self, self._cui_context_matrix = None # used for filters and name embedding, and if the name contains a valid cui see: _set_filters - self._last_include_set = None - self._last_exclude_set = None + self._last_include_set: set[str] = set() + self._last_exclude_set: set[str] = set() self._allowed_mask = None self._name_has_allowed_cui = None @@ -71,13 +71,17 @@ def __init__(self, ] def create_embeddings(self, - embedding_model_name: str = None): + embedding_model_name: Optional[str] = None): + if embedding_model_name is None: + embedding_model_name = self.cnf_l.embedding_model_name # fallback """"Create embeddings for names and cuis longest names in the CDB.""" if embedding_model_name == self.cnf_l.embedding_model_name and "cui_embeddings" in self.cdb.addl_info: logger.warning("Using the same model for embedding names.") else: self.cnf_l.embedding_model_name = embedding_model_name self._load_transformers(embedding_model_name) + self._embed_cui_names(embedding_model_name) + self._embed_names(embedding_model_name) def _embed_cui_names(self, @@ -106,8 +110,8 @@ def _embed_cui_names(self, embeddings= self._embed(names_to_embed, self.device) all_embeddings.append(embeddings.cpu()) # cat all batches into one tensor - all_embeddings = torch.cat(all_embeddings, dim=0) - self.cdb.addl_info["cui_embeddings"] = all_embeddings + all_embeddings_matrix = torch.cat(all_embeddings, dim=0) + self.cdb.addl_info["cui_embeddings"] = all_embeddings_matrix logger.debug("Embedding cui names done, total: %d", len(names)) def _embed_names(self, @@ -132,8 +136,8 @@ def _embed_names(self, names_to_embed = [name.replace(self.config.general.separator, " ") for name in names] embeddings = self._embed(names_to_embed, self.device) all_embeddings.append(embeddings.cpu()) - all_embeddings = torch.cat(all_embeddings, dim=0) - self.cdb.addl_info["name_embeddings"] = all_embeddings + all_embeddings_matrix = torch.cat(all_embeddings, dim=0) + self.cdb.addl_info["name_embeddings"] = all_embeddings_matrix logger.debug("Embedding names done, total: %d", len(names)) @@ -259,7 +263,7 @@ def _set_filters(self) -> None: def _disambiguate_by_cui(self, cui_candidates: list[str], - scores: Tensor): + scores: Tensor) -> tuple[str, float]: """Disambiguate a detected concept by a list of potential cuis Args: cuis (list[str]): Potential cuis @@ -267,16 +271,16 @@ def _disambiguate_by_cui(self, scores (Tensor): Scores for the detected cui2info concepts similarity cui_keys (list[str]): idx_to_cui inverse Returns: - tuple[str, int]: + tuple[str, float]: The CUI and its similarity """ cui_idxs = [self._cui_to_idx[cui] for cui in cui_candidates] candidate_scores = scores[cui_idxs] - candidate_idx = torch.argmax(candidate_scores).item() + candidate_idx = int(torch.argmax(candidate_scores).item()) best_idx = cui_idxs[candidate_idx] predicted_cui = self._cui_keys[best_idx] - similarity = candidate_scores[candidate_idx].item() + similarity = float(candidate_scores[candidate_idx].item()) return predicted_cui, similarity def _inference( @@ -314,7 +318,7 @@ def _inference( name_idxs = [self._name_to_idx[name] for name in name_to_cuis] indexed_scores = names_scores[i, name_idxs] - best_local_pos = torch.argmax(indexed_scores).item() + best_local_pos = int(torch.argmax(indexed_scores).item()) best_global_idx = name_idxs[best_local_pos] similarity = names_scores[i, best_global_idx].item() best_name = self._name_keys[best_global_idx] @@ -330,10 +334,10 @@ def _inference( row_sorted = sorted_indices[i] # sorted candidate indices for entity i # Find the first candidate in this row with CUIs - first_true_pos = torch.nonzero(self._valid_names[row_sorted], as_tuple=True)[0][0].item() + first_true_pos = int(torch.nonzero(self._valid_names[row_sorted], as_tuple=True)[0][0].item()) # Get global index + name - top_name_idx = row_sorted[first_true_pos].item() + top_name_idx = int(row_sorted[first_true_pos].item()) similarity = names_scores[i, top_name_idx].item() detected_name = self._name_keys[top_name_idx] cuis = list(self.cdb.name2info[detected_name]["per_cui_status"].keys()) @@ -392,9 +396,9 @@ def _generate_link_candidates(self, # for now just choose the highest scoring name valid_positions = torch.nonzero(valid_mask, as_tuple=True)[0][:1] - cuis = set() + cuis: set[str] = set() for pos in valid_positions.tolist(): - top_name_idx = row_sorted[pos].item() + top_name_idx = int(row_sorted[pos].item()) detected_name = self._name_keys[top_name_idx] cuis.update(self.cdb.name2info[detected_name]["per_cui_status"].keys()) From 8bf0c3b4ecdd53aa8c3cbbf353ce2bf92702b3c6 Mon Sep 17 00:00:00 2001 From: Adam Sutton Date: Mon, 1 Sep 2025 15:29:07 +0100 Subject: [PATCH 06/12] fixed linting (hopefully) --- .../components/linking/embedding_linker.py | 373 +++++++++++------- 1 file changed, 229 insertions(+), 144 deletions(-) diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py index 3346eb8f0..d66871db6 100644 --- a/medcat-v2/medcat/components/linking/embedding_linker.py +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -14,32 +14,25 @@ import torch import logging import math -import re -import string -import re -from nltk.corpus import stopwords -stop_words = set(stopwords.words('english')) + logger = logging.getLogger(__name__) + class Linker(AbstractCoreComponent): name = "embedding_linker" - def __init__(self, - cdb: CDB, - config: Config) -> None: + def __init__(self, cdb: CDB, config: Config) -> None: """Initializes the embedding linker with a CDB and configuration. Args: cdb (CDB): The concept database to use. config (Config): The base config. - embedding_model_name (Optional[str]): The name of the embedding model to use. Default is "sentence-transformers/all-MiniLM-L6-v2" - max_length (int): The maximum length of the input sequences for the embedding model. Default is 64. """ self.cdb = cdb self.config = config if not isinstance(config.components.linking, EmbeddingLinking): raise TypeError("Linking config must be an EmbeddingLinking instance") self.cnf_l: EmbeddingLinking = config.components.linking - self.max_length = self.cnf_l.max_token_length + self.max_length = self.cnf_l.max_token_length self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._name_keys = list(self.cdb.name2info) @@ -49,7 +42,8 @@ def __init__(self, self._names_context_matrix = None self._cui_context_matrix = None - # used for filters and name embedding, and if the name contains a valid cui see: _set_filters + # used for filters and name embedding, and if the name contains a valid cui + # see: _set_filters self._last_include_set: set[str] = set() self._last_exclude_set: set[str] = set() self._allowed_mask = None @@ -57,25 +51,38 @@ def __init__(self, # checking for config settings that aren't used in this linker if self.cnf_l.prefer_frequent_concepts: - logger.warning(f"linker_config.prefer_frequent_concepts is not used in the embedding linker. But it is currently set to {self.cnf_l.prefer_frequent_concepts}.") + logger.warning( + "linker_config.prefer_frequent_concepts is not used " + "in the embedding linker. It is currently set to " + f"{self.cnf_l.prefer_frequent_concepts}." + ) + if self.cnf_l.prefer_primary_name: - logger.warning(f"linker_config.prefer_primary_name is not used in the embedding linker. But it is currently set to {self.cnf_l.prefer_primary_name}.") + logger.warning( + "linker_config.prefer_primary_name is not used " + "in the embedding linker. It is currently set to " + f"{self.cnf_l.prefer_primary_name}." + ) self._cui_to_idx = {cui: idx for idx, cui in enumerate(self._cui_keys)} self._name_to_idx = {name: idx for idx, name in enumerate(self._name_keys)} self._name_to_cui_idxs = [ - [ self._cui_to_idx[cui] - for cui in self.cdb.name2info[name].get("per_cui_status", {}).keys() - if cui in self._cui_to_idx ] + [ + self._cui_to_idx[cui] + for cui in self.cdb.name2info[name].get("per_cui_status", {}).keys() + if cui in self._cui_to_idx + ] for name in self._name_keys ] - def create_embeddings(self, - embedding_model_name: Optional[str] = None): + def create_embeddings(self, embedding_model_name: Optional[str] = None): if embedding_model_name is None: embedding_model_name = self.cnf_l.embedding_model_name # fallback """"Create embeddings for names and cuis longest names in the CDB.""" - if embedding_model_name == self.cnf_l.embedding_model_name and "cui_embeddings" in self.cdb.addl_info: + if ( + embedding_model_name == self.cnf_l.embedding_model_name + and "cui_embeddings" in self.cdb.addl_info + ): logger.warning("Using the same model for embedding names.") else: self.cnf_l.embedding_model_name = embedding_model_name @@ -83,44 +90,57 @@ def create_embeddings(self, self._embed_cui_names(embedding_model_name) self._embed_names(embedding_model_name) - - def _embed_cui_names(self, - embedding_model_name: str, - ) -> None: + def _embed_cui_names( + self, + embedding_model_name: str, + ) -> None: """Obtain embeddings for all prefered_names in the CDB using the specified embedding model and store them in the name2info.context_vectors Args: embedding_model_name (str): The name of the embedding model to use. - batch_size (int): The size of the batches to use when embedding names. Default 4096 + batch_size (int): The size of the batches to use when embedding names. + Default 4096 """ - if embedding_model_name == self.cnf_l.embedding_model_name and "cui_embeddings" in self.cdb.addl_info and "name_embeddings" in self.cdb.addl_info: + if ( + embedding_model_name == self.cnf_l.embedding_model_name + and "cui_embeddings" in self.cdb.addl_info + and "name_embeddings" in self.cdb.addl_info + ): logger.warning("Using the same model for embedding.") else: self.cnf_l.embedding_model_name = embedding_model_name - + # Use the longest name - cui_names = [max(self.cdb.cui2info[cui]["names"], key=len) for cui in self._cui_keys] + cui_names = [ + max(self.cdb.cui2info[cui]["names"], key=len) for cui in self._cui_keys + ] # embed each name in batches. Because there can be 3+ million names total_batches = math.ceil(len(cui_names) / self.cnf_l.embedding_batch_size) all_embeddings = [] - for names in tqdm(self._batch_data(cui_names, self.cnf_l.embedding_batch_size), total=total_batches, desc="Embedding cuis' preferred names"): + for names in tqdm( + self._batch_data(cui_names, self.cnf_l.embedding_batch_size), + total=total_batches, + desc="Embedding cuis' preferred names", + ): with torch.no_grad(): # removing ~ from names, as it is used to indicate a space in the CDB - names_to_embed = [name.replace(self.config.general.separator, " ") for name in names] - embeddings= self._embed(names_to_embed, self.device) + names_to_embed = [ + name.replace(self.config.general.separator, " ") for name in names + ] + embeddings = self._embed(names_to_embed, self.device) all_embeddings.append(embeddings.cpu()) # cat all batches into one tensor all_embeddings_matrix = torch.cat(all_embeddings, dim=0) self.cdb.addl_info["cui_embeddings"] = all_embeddings_matrix logger.debug("Embedding cui names done, total: %d", len(names)) - def _embed_names(self, - embedding_model_name: str) -> None: + def _embed_names(self, embedding_model_name: str) -> None: """Obtain embeddings for all names in the CDB using the specified embedding model and store them in the name2info.context_vectors Args: embedding_model_name (str): The name of the embedding model to use. - batch_size (int): The size of the batches to use when embedding names. Default 4096 + batch_size (int): The size of the batches to use when embedding names + Default 4096 """ if embedding_model_name == self.cnf_l.embedding_model_name: logger.debug("Using the same model for embedding names.") @@ -130,57 +150,74 @@ def _embed_names(self, # embed each name in batches. Because there can be 3+ million names total_batches = math.ceil(len(names) / self.cnf_l.embedding_batch_size) all_embeddings = [] - for names in tqdm(self._batch_data(names, self.cnf_l.embedding_batch_size), total=total_batches, desc="Embedding names"): + for names in tqdm( + self._batch_data(names, self.cnf_l.embedding_batch_size), + total=total_batches, + desc="Embedding names", + ): with torch.no_grad(): # removing ~ from names, as it is used to indicate a space in the CDB - names_to_embed = [name.replace(self.config.general.separator, " ") for name in names] + names_to_embed = [ + name.replace(self.config.general.separator, " ") for name in names + ] embeddings = self._embed(names_to_embed, self.device) all_embeddings.append(embeddings.cpu()) all_embeddings_matrix = torch.cat(all_embeddings, dim=0) self.cdb.addl_info["name_embeddings"] = all_embeddings_matrix logger.debug("Embedding names done, total: %d", len(names)) - def get_type(self) -> CoreComponentType: return CoreComponentType.linking - + def _batch_data(self, data, batch_size=512) -> Iterator[list]: for i in range(0, len(data), batch_size): - yield data[i:i + batch_size] + yield data[i : i + batch_size] - def _load_transformers(self, - embedding_model_name) -> None: + def _load_transformers(self, embedding_model_name: str) -> None: """Load the transformers model and tokenizer. No need to load a transformer model until it's required. Args: - embedding_model_name (str): The name of the embedding model to load. Default is "sentence-transformers/all-MiniLM-L6-v2" + embedding_model_name (str): The name of the embedding model to load. + Default is "sentence-transformers/all-MiniLM-L6-v2" """ - if not hasattr(self, "model") or not hasattr(self, "tokenizer") or embedding_model_name != self.cnf_l.embedding_model_name: + if ( + not hasattr(self, "model") + or not hasattr(self, "tokenizer") + or embedding_model_name != self.cnf_l.embedding_model_name + ): self.cnf_l.embedding_model_name = embedding_model_name self.tokenizer = AutoTokenizer.from_pretrained(embedding_model_name) self.model = AutoModel.from_pretrained(embedding_model_name) self.model.eval() gpu_device = self.cnf_l.gpu_device - self.device = torch.device(gpu_device or ("cuda" if torch.cuda.is_available() else "cpu")) + self.device = torch.device( + gpu_device or ("cuda" if torch.cuda.is_available() else "cpu") + ) self.model.to(self.device) - logger.debug(f"Loaded embedding model: {embedding_model_name} on device: {self.device}") - - def _embed(self, - to_embed: list[str], - device) -> Tensor: - """Embeds a list of strings - """ - batch_dict = self.tokenizer(to_embed, max_length=self.max_length, padding=True, truncation=True, return_tensors='pt').to(device) + logger.debug( + f"""Loaded embedding model: {embedding_model_name} + on device: {self.device}""" + ) + + def _embed(self, to_embed: list[str], device) -> Tensor: + """Embeds a list of strings""" + batch_dict = self.tokenizer( + to_embed, + max_length=self.max_length, + padding=True, + truncation=True, + return_tensors="pt", + ).to(device) outputs = self.model(**batch_dict) - outputs = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) + outputs = self._last_token_pool( + outputs.last_hidden_state, batch_dict["attention_mask"] + ) outputs = F.normalize(outputs, p=2, dim=1) return outputs.half() - def _get_context(self, - entity: MutableEntity, - doc: MutableDocument, - size: int - ) -> str: + def _get_context( + self, entity: MutableEntity, doc: MutableDocument, size: int + ) -> str: """Get context tokens for an entity Args: @@ -200,15 +237,15 @@ def _get_context(self, right_most_token = doc[min(len(doc) - 1, end_ind + size)] right_index = right_most_token.base.char_index + len(right_most_token.base.text) - + snippet = doc.base.text[left_index:right_index] return snippet - - def _get_context_vectors(self, - doc: MutableDocument, - entities: list[MutableEntity], - size: int) -> Tensor: - """Get context vectors for all detected concepts based on their raw text or detected names. + + def _get_context_vectors( + self, doc: MutableDocument, entities: list[MutableEntity], size: int + ) -> Tensor: + """Get context vectors for all detected concepts based on their + surrounding text. Args: doc (BaseDocument): The document look in. @@ -221,15 +258,17 @@ def _get_context_vectors(self, text = self._get_context(entity, doc, size) texts.append(text) return self._embed(texts, self.device) - + def _set_filters(self) -> None: include_set = self.cnf_l.filters.cuis exclude_set = self.cnf_l.filters.cuis_exclude # Check if sets changed (avoid recomputation if same) - if (include_set == self._last_include_set and - exclude_set == self._last_exclude_set): - return + if ( + include_set == self._last_include_set + and exclude_set == self._last_exclude_set + ): + return n = len(self._name_keys) allowed_mask = torch.empty(n, dtype=torch.bool, device=self.device) @@ -237,33 +276,50 @@ def _set_filters(self) -> None: if include_set: # if in include set, ignore exclude set. allowed_mask[:] = False - include_cui_idxs = {self._cui_to_idx[cui] for cui in include_set if cui in self._cui_to_idx} + include_cui_idxs = { + self._cui_to_idx[cui] for cui in include_set if cui in self._cui_to_idx + } include_idxs = [ name_idx for name_idx, name_cui_idxs in enumerate(self._name_to_cui_idxs) if any(cui in include_cui_idxs for cui in name_cui_idxs) ] - allowed_mask[torch.tensor(include_idxs, dtype=torch.long, device=self.device)] = True + allowed_mask[ + torch.tensor(include_idxs, dtype=torch.long, device=self.device) + ] = True else: # only look at exclude if there's no include set allowed_mask[:] = True if exclude_set: - exclude_cui_idxs = {self._cui_to_idx[cui] for cui in exclude_set if cui in self._cui_to_idx} - exclude_idxs = [i for i, name_cui_idxs in enumerate(self._name_to_cui_idxs) if any(ci in exclude_cui_idxs for ci in name_cui_idxs)] - allowed_mask[torch.tensor(exclude_idxs, dtype=torch.long, device=self.device)] = False - - # checking if a name has at least 1 cui related to it. Might as well do this cheeck here. + exclude_cui_idxs = { + self._cui_to_idx[cui] + for cui in exclude_set + if cui in self._cui_to_idx + } + exclude_idxs = [ + i + for i, name_cui_idxs in enumerate(self._name_to_cui_idxs) + if any(ci in exclude_cui_idxs for ci in name_cui_idxs) + ] + allowed_mask[ + torch.tensor(exclude_idxs, dtype=torch.long, device=self.device) + ] = False + + # checking if a name has at least 1 cui related to it. _has_cuis_all = torch.tensor( - [bool(self.cdb.name2info[name]["per_cui_status"]) for name in self._name_keys], - device=self.device + [ + bool(self.cdb.name2info[name]["per_cui_status"]) + for name in self._name_keys + ], + device=self.device, ) - self._valid_names = (_has_cuis_all & allowed_mask) + self._valid_names = _has_cuis_all & allowed_mask self._last_include_set = include_set self._last_exclude_set = exclude_set - def _disambiguate_by_cui(self, - cui_candidates: list[str], - scores: Tensor) -> tuple[str, float]: + def _disambiguate_by_cui( + self, cui_candidates: list[str], scores: Tensor + ) -> tuple[str, float]: """Disambiguate a detected concept by a list of potential cuis Args: cuis (list[str]): Potential cuis @@ -284,27 +340,31 @@ def _disambiguate_by_cui(self, return predicted_cui, similarity def _inference( - self, - doc: MutableDocument, - entities: list[MutableEntity]) -> Iterator[MutableEntity]: - """Infer all entities at once (or in batches), to avoid multiple gpu calls when it isn't nessescary. + self, doc: MutableDocument, entities: list[MutableEntity] + ) -> Iterator[MutableEntity]: + """Infer all entities at once (or in batches), to avoid multiple gpu calls + when it isn't nessescary. Args: doc (BaseDocument): The document look in. - name_keys (list[str]): list of all names2info - cui_keys (list[str]): list of all cuis2info - context_matrix: Tensor of context matrix we're planning to use could be all names from name2info, - or prefered names from cui2info[cui]["preferred_name"] + entities (list[BaseEntity]): The entities to infer. Yields: - entity (MutableEntity): Entity with a relevant cui prediction - or skip if it's not suitable.""" - detected_context_vectors = self._get_context_vectors(doc, entities, self.cnf_l.context_window_size) + entity (MutableEntity): Entity with a relevant cui prediction - + or skip if it's not suitable.""" + detected_context_vectors = self._get_context_vectors( + doc, entities, self.cnf_l.context_window_size + ) - # score all detected contexts vs all names, handle in the loop each individual case + # score all detected contexts vs all names names_scores = detected_context_vectors @ self.names_context_matrix.T cui_scores = detected_context_vectors @ self.cui_context_matrix.T sorted_indices = torch.argsort(names_scores, dim=1, descending=True) for i, entity in enumerate(entities): - link_candidates = [cui for cui in entity.link_candidates if self.cnf_l.filters.check_filters(cui)] + link_candidates = [ + cui + for cui in entity.link_candidates + if self.cnf_l.filters.check_filters(cui) + ] if len(link_candidates) == 1: best_idx = self._cui_to_idx[link_candidates[0]] predicted_cui = link_candidates[0] @@ -323,18 +383,19 @@ def _inference( similarity = names_scores[i, best_global_idx].item() best_name = self._name_keys[best_global_idx] cuis = name_to_cuis[best_name] - if (len(cuis) == 1): + if len(cuis) == 1: predicted_cui = cuis[0] else: - predicted_cui, _ = self._disambiguate_by_cui( - cuis, - cui_scores[i,:] - ) + predicted_cui, _ = self._disambiguate_by_cui(cuis, cui_scores[i, :]) else: row_sorted = sorted_indices[i] # sorted candidate indices for entity i # Find the first candidate in this row with CUIs - first_true_pos = int(torch.nonzero(self._valid_names[row_sorted], as_tuple=True)[0][0].item()) + first_true_pos = int( + torch.nonzero(self._valid_names[row_sorted], as_tuple=True)[0][ + 0 + ].item() + ) # Get global index + name top_name_idx = int(row_sorted[first_true_pos].item()) @@ -342,49 +403,55 @@ def _inference( detected_name = self._name_keys[top_name_idx] cuis = list(self.cdb.name2info[detected_name]["per_cui_status"].keys()) - predicted_cui, _ = self._disambiguate_by_cui( - cuis, - cui_scores[i,:] - ) + predicted_cui, _ = self._disambiguate_by_cui(cuis, cui_scores[i, :]) - if self.cnf_l.use_similarity_threshold and self._check_similarity(similarity): + if self.cnf_l.use_similarity_threshold and self._check_similarity( + similarity + ): entity.cui = predicted_cui entity.context_similarity = similarity yield entity - + def _check_similarity(self, context_similarity: float) -> bool: if self.cnf_l.use_similarity_threshold: threshold = self.cnf_l.similarity_threshold return context_similarity >= threshold else: return True - - def _last_token_pool(self, last_hidden_states: Tensor, - attention_mask: Tensor) -> Tensor: - left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) + + def _last_token_pool( + self, last_hidden_states: Tensor, attention_mask: Tensor + ) -> Tensor: + left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] if left_padding: return last_hidden_states[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] - return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] - + return last_hidden_states[ + torch.arange(batch_size, device=last_hidden_states.device), + sequence_lengths, + ] + def _build_context_matrices(self) -> None: if "name_embeddings" in self.cdb.addl_info: - self._names_context_matrix = self.cdb.addl_info["name_embeddings"].half().to(self.device) + self._names_context_matrix = ( + self.cdb.addl_info["name_embeddings"].half().to(self.device) + ) if "cui_embeddings" in self.cdb.addl_info: - self._cui_context_matrix = self.cdb.addl_info["cui_embeddings"].half().to(self.device) - - - def _generate_link_candidates(self, - doc: MutableDocument, - entities: list[MutableEntity] - ) -> None: - """Generate link candidates for each detected entity based on context vectors with size 0. - Compare to names to get the most similar name in the cdb to the detected concept.""" + self._cui_context_matrix = ( + self.cdb.addl_info["cui_embeddings"].half().to(self.device) + ) + + def _generate_link_candidates( + self, doc: MutableDocument, entities: list[MutableEntity] + ) -> None: + """Generate link candidates for each detected entity based + on context vectors with size 0. Compare to names to get the most + similar name in the cdb to the detected concept.""" detected_context_vectors = self._get_context_vectors(doc, entities, 0) - # score all detected contexts vs all names, handle in the loop each individual case + # score all detected contexts vs all names names_scores = detected_context_vectors @ self.names_context_matrix.T sorted_indices = torch.argsort(names_scores, dim=1, descending=True) @@ -392,7 +459,8 @@ def _generate_link_candidates(self, row_sorted = sorted_indices[i] # sorted candidate indices for entity i valid_mask = self._valid_names[row_sorted] - # TODO: potentially choose multiple names that are all within a certain range of the top scoring. + # TODO: potentially choose multiple names that + # are all within a certain range of the top scoring. # for now just choose the highest scoring name valid_positions = torch.nonzero(valid_mask, as_tuple=True)[0][:1] @@ -404,19 +472,22 @@ def _generate_link_candidates(self, entity.link_candidates = list(cuis) - - def _pre_inference(self, - doc: MutableDocument) -> tuple[list, list]: - """Checking all entities for entites with only a single link candidate and to avoid full inference step. - If we want to calculate similarities, or not use link candidates then just return the entities""" + def _pre_inference(self, doc: MutableDocument) -> tuple[list, list]: + """Checking all entities for entites with only a single link candidate and to + avoid full inference step. If we want to calculate similarities, or not use + link candidates then just return the entities""" all_ents = doc.ner_ents if not self.cnf_l.use_ner_link_candidates: to_generate_link_candidates = all_ents else: - to_generate_link_candidates = [entity for entity in all_ents if not entity.link_candidates] + to_generate_link_candidates = [ + entity for entity in all_ents if not entity.link_candidates + ] # generate our own link candidates if it's required, or wanted - for entities in self._batch_data(to_generate_link_candidates, self.cnf_l.linking_batch_size): + for entities in self._batch_data( + to_generate_link_candidates, self.cnf_l.linking_batch_size + ): self._generate_link_candidates(doc, entities) if self.cnf_l.always_calculate_similarity: @@ -432,25 +503,35 @@ def _pre_inference(self, entity.context_similarity = 1 le.append(entity) continue - # if it has to be inferred due to filters or number of link candidates then add it to the infer list + # it has to be inferred due to filters or number of link candidates to_infer.append(entity) return le, to_infer - def __call__(self, doc: MutableDocument) -> MutableDocument: # Reset main entities, will be recreated later doc.linked_ents.clear() if self.cdb.is_dirty: - logging.warning("CDB has been modified since last save/load. This might significantly affect linking performance.") - logging.warning("If you have added new concepts or changes, please re-embed the CDB names and cuis before linking.") - + logging.warning( + "CDB has been modified since last save/load. " + "This might significantly affect linking performance." + ) + logging.warning( + "If you have added new concepts or changes, " + "please re-embed the CDB names and cuis before linking." + ) + self._load_transformers(self.cnf_l.embedding_model_name) if self.cnf_l.train: - logger.warning("Attemping to train an embedding linker. This is not required.") + logger.warning( + "Attemping to train an embedding linker. This is not required." + ) if self.cnf_l.filters.cuis and self.cnf_l.filters.cuis_exclude: - logger.warning("You have both include and exclude filters for CUIs set. This will result in only include CUIs being filtered.") - + logger.warning( + "You have both include and exclude filters for CUIs set. " + "This will result in only include CUIs being filtered." + ) + self._set_filters() with torch.no_grad(): @@ -475,10 +556,14 @@ def cui_context_matrix(self): if self._cui_context_matrix is None: self._build_context_matrices() return self._cui_context_matrix - + @classmethod def create_new_component( - cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, - cdb: CDB, vocab: Vocab, model_load_path: Optional[str] - ) -> 'Linker': - return cls(cdb, cdb.config) \ No newline at end of file + cls, + cnf: ComponentConfig, + tokenizer: BaseTokenizer, + cdb: CDB, + vocab: Vocab, + model_load_path: Optional[str], + ) -> "Linker": + return cls(cdb, cdb.config) From 87a709bf4f94f40d913ef0e6eb7efbad3b6bbe41 Mon Sep 17 00:00:00 2001 From: Adam Sutton Date: Mon, 8 Sep 2025 16:08:07 +0100 Subject: [PATCH 07/12] added thresholds for short and long contexts --- .../components/linking/embedding_linker.py | 40 +++++++++++++------ medcat-v2/medcat/config/config.py | 14 +++++-- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py index d66871db6..79d2d54e4 100644 --- a/medcat-v2/medcat/components/linking/embedding_linker.py +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -44,8 +44,8 @@ def __init__(self, cdb: CDB, config: Config) -> None: # used for filters and name embedding, and if the name contains a valid cui # see: _set_filters - self._last_include_set: set[str] = set() - self._last_exclude_set: set[str] = set() + self._last_include_set: set[str] | None = None + self._last_exclude_set: set[str] | None = None self._allowed_mask = None self._name_has_allowed_cui = None @@ -265,7 +265,9 @@ def _set_filters(self) -> None: # Check if sets changed (avoid recomputation if same) if ( - include_set == self._last_include_set + self._last_include_set is not None + and self._last_exclude_set is not None + and include_set == self._last_include_set and exclude_set == self._last_exclude_set ): return @@ -405,7 +407,7 @@ def _inference( predicted_cui, _ = self._disambiguate_by_cui(cuis, cui_scores[i, :]) - if self.cnf_l.use_similarity_threshold and self._check_similarity( + if self._check_similarity( similarity ): entity.cui = predicted_cui @@ -413,8 +415,8 @@ def _inference( yield entity def _check_similarity(self, context_similarity: float) -> bool: - if self.cnf_l.use_similarity_threshold: - threshold = self.cnf_l.similarity_threshold + if self.cnf_l.long_similarity_threshold: + threshold = self.cnf_l.long_similarity_threshold return context_similarity >= threshold else: return True @@ -456,15 +458,24 @@ def _generate_link_candidates( sorted_indices = torch.argsort(names_scores, dim=1, descending=True) for i, entity in enumerate(entities): - row_sorted = sorted_indices[i] # sorted candidate indices for entity i + row_sorted = sorted_indices[i] + cuis: set[str] = set() + # scores for this entity row + row_scores = names_scores[i, row_sorted] + # valid names via filtering and contain at least 1 cui valid_mask = self._valid_names[row_sorted] - # TODO: potentially choose multiple names that - # are all within a certain range of the top scoring. - # for now just choose the highest scoring name - valid_positions = torch.nonzero(valid_mask, as_tuple=True)[0][:1] - cuis: set[str] = set() + if self.cnf_l.short_similarity_threshold > 0: + # thresholded selection + above_thresh_mask = row_scores >= self.cnf_l.short_similarity_threshold + selected_mask = valid_mask & above_thresh_mask + valid_positions = torch.nonzero(selected_mask, as_tuple=True)[0] + else: + # just take the single best valid candidate + first_valid = torch.nonzero(valid_mask, as_tuple=True)[0][:1] + valid_positions = first_valid + for pos in valid_positions.tolist(): top_name_idx = int(row_sorted[pos].item()) detected_name = self._name_keys[top_name_idx] @@ -490,8 +501,11 @@ def _pre_inference(self, doc: MutableDocument) -> tuple[list, list]: ): self._generate_link_candidates(doc, entities) + # filter out entities with no link candidates after thresholding + filtered_ents = [ent for ent in all_ents if ent.link_candidates] + if self.cnf_l.always_calculate_similarity: - return [], all_ents + return [], filtered_ents le = [] to_infer = [] diff --git a/medcat-v2/medcat/config/config.py b/medcat-v2/medcat/config/config.py index 01048a564..00722bf68 100644 --- a/medcat-v2/medcat/config/config.py +++ b/medcat-v2/medcat/config/config.py @@ -380,9 +380,17 @@ class EmbeddingLinking(Linking): """The embedding linker never needs to be trained in its current implementation.""" train: bool = False - """Similarity between context bert-like vector and names or - cui preferred names""" - similarity_threshold: float = 0.25 + """Used in the inference step to choose the best CUI given the + link candidates. Testing shows a threshold of 0.7 increases precision + with minimal impact on recall. Default is 0.0 which assumes + all entities detected by the NER step are true.""" + long_similarity_threshold: float = 0.0 + """Used for generating cui candidates. If a threshold of 0.0 + is selected then only the highest scoring name will provide cuis + to be link candidates. Use a threshold of 0.95 or higher, as this is + essentailly string matching and account for spelling errors. Lower + thresholds will provide too many candidates and slow down the inference.""" + short_similarity_threshold: float = 0.0 """Name of the embedding model. It must be downloadable from huggingface linked from an appropriate file directory""" embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2" From ba520022764d848ba094c776979c8e4e2abd4f24 Mon Sep 17 00:00:00 2001 From: Adam Sutton Date: Wed, 17 Sep 2025 11:59:55 +0100 Subject: [PATCH 08/12] fixed mypy issues --- medcat-v2/medcat/components/linking/embedding_linker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py index 79d2d54e4..f09c60d2a 100644 --- a/medcat-v2/medcat/components/linking/embedding_linker.py +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -3,7 +3,7 @@ from medcat.components.types import CoreComponentType, AbstractCoreComponent from medcat.tokenizing.tokens import MutableEntity, MutableDocument from medcat.tokenizing.tokenizers import BaseTokenizer -from typing import Optional, Iterator +from typing import Optional, Iterator, Set from medcat.vocab import Vocab from torch import Tensor from transformers import AutoTokenizer, AutoModel @@ -44,8 +44,8 @@ def __init__(self, cdb: CDB, config: Config) -> None: # used for filters and name embedding, and if the name contains a valid cui # see: _set_filters - self._last_include_set: set[str] | None = None - self._last_exclude_set: set[str] | None = None + self._last_include_set: Optional[Set[str]] = None + self._last_exclude_set: Optional[Set[str]] = None self._allowed_mask = None self._name_has_allowed_cui = None From 8109d21f6373bb17435bbd74ada626b618bfa5cd Mon Sep 17 00:00:00 2001 From: Adam Sutton Date: Thu, 25 Sep 2025 23:31:29 +0100 Subject: [PATCH 09/12] Added filter before disambig, and various tests --- .../components/linking/embedding_linker.py | 22 +++++++++++-------- medcat-v2/medcat/config/config.py | 4 ++++ .../linking/test_embedding_linker.py | 14 ++++-------- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py index f09c60d2a..600cc50fa 100644 --- a/medcat-v2/medcat/components/linking/embedding_linker.py +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -146,7 +146,7 @@ def _embed_names(self, embedding_model_name: str) -> None: logger.debug("Using the same model for embedding names.") else: self.cnf_l.embedding_model_name = embedding_model_name - names = list(self.cdb.name2info.keys()) + names = self._name_keys # embed each name in batches. Because there can be 3+ million names total_batches = math.ceil(len(names) / self.cnf_l.embedding_batch_size) all_embeddings = [] @@ -316,8 +316,9 @@ def _set_filters(self) -> None: device=self.device, ) self._valid_names = _has_cuis_all & allowed_mask - self._last_include_set = include_set - self._last_exclude_set = exclude_set + self._last_include_set = set(include_set) if include_set is not None else None + self._last_exclude_set = set(exclude_set) if exclude_set is not None else None + def _disambiguate_by_cui( self, cui_candidates: list[str], scores: Tensor @@ -362,11 +363,13 @@ def _inference( sorted_indices = torch.argsort(names_scores, dim=1, descending=True) for i, entity in enumerate(entities): - link_candidates = [ - cui - for cui in entity.link_candidates - if self.cnf_l.filters.check_filters(cui) - ] + link_candidates = entity.link_candidates + if self.config.components.linking.filter_before_disamb: + link_candidates = [ + cui + for cui in link_candidates + if self.cnf_l.filters.check_filters(cui) + ] if len(link_candidates) == 1: best_idx = self._cui_to_idx[link_candidates[0]] predicted_cui = link_candidates[0] @@ -406,7 +409,8 @@ def _inference( cuis = list(self.cdb.name2info[detected_name]["per_cui_status"].keys()) predicted_cui, _ = self._disambiguate_by_cui(cuis, cui_scores[i, :]) - + if not self.cnf_l.filters.check_filters(predicted_cui): + continue if self._check_similarity( similarity ): diff --git a/medcat-v2/medcat/config/config.py b/medcat-v2/medcat/config/config.py index 00722bf68..a5e54089e 100644 --- a/medcat-v2/medcat/config/config.py +++ b/medcat-v2/medcat/config/config.py @@ -377,6 +377,10 @@ class Config: extra = 'allow' class EmbeddingLinking(Linking): + """Changing compoenent name""" + comp_name: str = "embedding_linker" + """All concepts below this will always be disambiguated""" + filter_before_disamb: bool = True """The embedding linker never needs to be trained in its current implementation.""" train: bool = False diff --git a/medcat-v2/tests/components/linking/test_embedding_linker.py b/medcat-v2/tests/components/linking/test_embedding_linker.py index 049e31e7d..b35dcc746 100644 --- a/medcat-v2/tests/components/linking/test_embedding_linker.py +++ b/medcat-v2/tests/components/linking/test_embedding_linker.py @@ -4,6 +4,7 @@ from medcat.vocab import Vocab from medcat.cdb.concepts import CUIInfo, NameInfo from medcat.components.types import TrainableComponent +from medcat.components.types import _DEFAULT_LINKING as DEF_LINKING import unittest from ..helper import ComponentInitTests @@ -27,8 +28,9 @@ def weighted_average_function(self, nr: int) -> float: class EmbeddingLinkerInitTests(ComponentInitTests, unittest.TestCase): - expected_def_components = 4 + expected_def_components = len(DEF_LINKING) comp_type = types.CoreComponentType.linking + default = 'medcat2_embedding_linker' default_cls = embedding_linker.Linker default_creator = embedding_linker.Linker.create_new_component module = embedding_linker @@ -43,14 +45,6 @@ def setUpClass(cls): cls.vtokenizer = FakeTokenizer() cls.comp_cnf = getattr(cls.cnf.components, cls.comp_type.name) - def test_can_create_def_component(self): - component = types.create_core_component( - self.comp_type, - "medcat2_embedding_linker", # explicitly request embedding linker - self.cnf, self.vtokenizer, self.fcdb, self.fvocab, None - ) - self.assertIsInstance(component, self.default_cls) - def test_has_default(self): avail_components = types.get_registered_components(self.comp_type) registered_names = [name for name, _ in avail_components] @@ -62,5 +56,5 @@ class TrainableEmbeddingLinkerTests(unittest.TestCase): cnf.components.linking.comp_name = embedding_linker.Linker.name linker = embedding_linker.Linker(FakeCDB(cnf), cnf) - def test_linker_is_trainable(self): + def test_linker_is_not_trainable(self): self.assertNotIsInstance(self.linker, TrainableComponent) From b871eef6565e1092f1fdd19cff93eb6e69b902cb Mon Sep 17 00:00:00 2001 From: Adam Sutton Date: Thu, 25 Sep 2025 23:34:20 +0100 Subject: [PATCH 10/12] handling cases with 1 candidate that's filtered out --- medcat-v2/medcat/components/linking/embedding_linker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py index 600cc50fa..958473e65 100644 --- a/medcat-v2/medcat/components/linking/embedding_linker.py +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -521,6 +521,8 @@ def _pre_inference(self, doc: MutableDocument) -> tuple[list, list]: entity.context_similarity = 1 le.append(entity) continue + elif self.cnf_l.use_ner_link_candidates: + continue # it has to be inferred due to filters or number of link candidates to_infer.append(entity) return le, to_infer From d9bf801f6a17552c3a0c2f44be281b6badc917e1 Mon Sep 17 00:00:00 2001 From: Adam Sutton Date: Mon, 29 Sep 2025 14:38:15 +0100 Subject: [PATCH 11/12] added max length logic and finals suggested changes --- .../components/linking/embedding_linker.py | 18 +++++++++++++++--- medcat-v2/medcat/config/config.py | 5 ++++- .../linking/test_embedding_linker.py | 9 ++++++++- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py index 958473e65..1d67f0faa 100644 --- a/medcat-v2/medcat/components/linking/embedding_linker.py +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -75,10 +75,21 @@ def __init__(self, cdb: CDB, config: Config) -> None: for name in self._name_keys ] - def create_embeddings(self, embedding_model_name: Optional[str] = None): + def create_embeddings(self, + embedding_model_name: Optional[str] = None, + max_length: Optional[int] = None, + ): + """Create embeddings for all names and cuis longest names in the CDB + using the chosen embedding model.""" if embedding_model_name is None: embedding_model_name = self.cnf_l.embedding_model_name # fallback - """"Create embeddings for names and cuis longest names in the CDB.""" + + if max_length is not None and max_length != self.max_length: + logger.info( + "Updating max_length from %s to %s", self.max_length, max_length + ) + self.max_length = max_length + self.cnf_l.max_token_length = max_length if ( embedding_model_name == self.cnf_l.embedding_model_name and "cui_embeddings" in self.cdb.addl_info @@ -94,7 +105,7 @@ def _embed_cui_names( self, embedding_model_name: str, ) -> None: - """Obtain embeddings for all prefered_names in the CDB using the specified + """Obtain embeddings for all cuis longest names in the CDB using the specified embedding model and store them in the name2info.context_vectors Args: embedding_model_name (str): The name of the embedding model to use. @@ -314,6 +325,7 @@ def _set_filters(self) -> None: for name in self._name_keys ], device=self.device, + dtype=torch.bool, ) self._valid_names = _has_cuis_all & allowed_mask self._last_include_set = set(include_set) if include_set is not None else None diff --git a/medcat-v2/medcat/config/config.py b/medcat-v2/medcat/config/config.py index a5e54089e..b362db35a 100644 --- a/medcat-v2/medcat/config/config.py +++ b/medcat-v2/medcat/config/config.py @@ -398,7 +398,10 @@ class EmbeddingLinking(Linking): """Name of the embedding model. It must be downloadable from huggingface linked from an appropriate file directory""" embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2" - """Max number of tokens to be embedded from a name.""" + """Max number of tokens to be embedded from a name. + If the max token length is changed then the linker will need to be created + with a new config. + """ max_token_length: int = 64 """How many pieces names can be embedded at once, useful when embedding name2info names, cui2info names""" diff --git a/medcat-v2/tests/components/linking/test_embedding_linker.py b/medcat-v2/tests/components/linking/test_embedding_linker.py index b35dcc746..187bc2189 100644 --- a/medcat-v2/tests/components/linking/test_embedding_linker.py +++ b/medcat-v2/tests/components/linking/test_embedding_linker.py @@ -9,6 +9,8 @@ from ..helper import ComponentInitTests class FakeDocument: + linked_ents = [] + ner_ents = [] def __init__(self, text): self.text = text @@ -18,6 +20,7 @@ def __call__(self, text: str) -> FakeDocument: class FakeCDB: def __init__(self, config: Config): + self.is_dirty = False self.config = config self.cui2info: dict[str, CUIInfo] = dict() self.name2info: dict[str, NameInfo] = dict() @@ -50,7 +53,7 @@ def test_has_default(self): registered_names = [name for name, _ in avail_components] self.assertIn("medcat2_embedding_linker", registered_names) -class TrainableEmbeddingLinkerTests(unittest.TestCase): +class NonTrainableEmbeddingLinkerTests(unittest.TestCase): cnf = Config() cnf.components.linking = embedding_linker.EmbeddingLinking() cnf.components.linking.comp_name = embedding_linker.Linker.name @@ -58,3 +61,7 @@ class TrainableEmbeddingLinkerTests(unittest.TestCase): def test_linker_is_not_trainable(self): self.assertNotIsInstance(self.linker, TrainableComponent) + + def test_linker_processes_document(self): + doc = FakeDocument("Test Document") + self.linker(doc) \ No newline at end of file From a03a62f76c4bf8c3facd9b9cd5c6a4f637610c6c Mon Sep 17 00:00:00 2001 From: Adam Sutton Date: Wed, 1 Oct 2025 12:49:34 +0100 Subject: [PATCH 12/12] changes to config documentation and filter before disambig --- medcat-v2/medcat/config/config.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/medcat-v2/medcat/config/config.py b/medcat-v2/medcat/config/config.py index b362db35a..2f494d1b5 100644 --- a/medcat-v2/medcat/config/config.py +++ b/medcat-v2/medcat/config/config.py @@ -377,47 +377,48 @@ class Config: extra = 'allow' class EmbeddingLinking(Linking): - """Changing compoenent name""" + """The config exclusively used for the embedding linker""" comp_name: str = "embedding_linker" - """All concepts below this will always be disambiguated""" - filter_before_disamb: bool = True + """Changing compoenent name""" + filter_before_disamb: bool = False + """Filtering CUIs before disambiguation""" + train: bool = False """The embedding linker never needs to be trained in its current implementation.""" - train: bool = False + long_similarity_threshold: float = 0.0 """Used in the inference step to choose the best CUI given the link candidates. Testing shows a threshold of 0.7 increases precision with minimal impact on recall. Default is 0.0 which assumes all entities detected by the NER step are true.""" - long_similarity_threshold: float = 0.0 + short_similarity_threshold: float = 0.0 """Used for generating cui candidates. If a threshold of 0.0 is selected then only the highest scoring name will provide cuis to be link candidates. Use a threshold of 0.95 or higher, as this is essentailly string matching and account for spelling errors. Lower thresholds will provide too many candidates and slow down the inference.""" - short_similarity_threshold: float = 0.0 + embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2" """Name of the embedding model. It must be downloadable from huggingface linked from an appropriate file directory""" - embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2" + max_token_length: int = 64 """Max number of tokens to be embedded from a name. If the max token length is changed then the linker will need to be created with a new config. """ - max_token_length: int = 64 + embedding_batch_size: int = 4096 """How many pieces names can be embedded at once, useful when embedding name2info names, cui2info names""" - embedding_batch_size: int = 4096 - """How many entities to be linked at once""" linking_batch_size: int = 512 + """How many entities to be linked at once""" + gpu_device: Optional[Any] = None """Choose a device for the linking model to be stored. If None then an appropriate GPU device that is available will be chosen""" - gpu_device: Optional[Any] = None - """Choose the window size to get context vectors.""" context_window_size: int = 14 + """Choose the window size to get context vectors.""" + use_ner_link_candidates: bool = True """Link candidates are provided by some NER steps. This will flag if you want to trust them or not.""" - use_ner_link_candidates: bool = True - """Do we have a similarity threshold we care about?""" use_similarity_threshold: bool = True + """Do we have a similarity threshold we care about?""" class Preprocessing(SerialisableBaseModel): """The preprocessing part of the config"""