diff --git a/graphgen/models/search/db/uniprot_search.py b/graphgen/models/search/db/uniprot_search.py index daf42246..6e7d2bfb 100644 --- a/graphgen/models/search/db/uniprot_search.py +++ b/graphgen/models/search/db/uniprot_search.py @@ -1,61 +1,118 @@ -import requests -from fastapi import HTTPException +from io import StringIO +from typing import Dict, Optional -from graphgen.utils import logger +from Bio import ExPASy, SeqIO, SwissProt, UniProt +from Bio.Blast import NCBIWWW, NCBIXML -UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search" +from graphgen.utils import logger class UniProtSearch: """ UniProt Search client to search with UniProt. 1) Get the protein by accession number. - 2) Search with keywords or protein names. + 2) Search with keywords or protein names (fuzzy search). + 3) Search with FASTA sequence (BLAST search). """ - def get_entry(self, accession: str) -> dict: + def get_by_accession(self, accession: str) -> Optional[dict]: + try: + handle = ExPASy.get_sprot_raw(accession) + record = SwissProt.read(handle) + handle.close() + return self._swissprot_to_dict(record) + except Exception as exc: # pylint: disable=broad-except + logger.error("Accession %s not found: %s", accession, exc) + return None + + @staticmethod + def _swissprot_to_dict(record: SwissProt.Record) -> dict: + """error + Convert a SwissProt.Record to a dictionary. """ - Get the UniProt entry by accession number(e.g., P04637). + functions = [] + for line in record.comments: + if line.startswith("FUNCTION:"): + functions.append(line[9:].strip()) + + return { + "molecule_type": "protein", + "database": "UniProt", + "id": record.accessions[0], + "entry_name": record.entry_name, + "gene_names": record.gene_name, + "protein_name": record.description.split(";")[0].split("=")[-1], + "organism": record.organism.split(" (")[0], + "sequence": str(record.sequence), + "function": functions, + "url": f"https://www.uniprot.org/uniprot/{record.accessions[0]}", + } + + def get_best_hit(self, keyword: str) -> Optional[Dict]: """ - url = f"{UNIPROT_BASE}/{accession}.json" - return self._safe_get(url).json() - - def search( - self, - query: str, - *, - size: int = 10, - cursor: str = None, - fields: list[str] = None, - ) -> dict: + Search UniProt with a keyword and return the best hit. + :param keyword: The search keyword. + :return: A dictionary containing the best hit information or None if not found. """ - Search UniProt with a query string. - :param query: The search query. - :param size: The number of results to return. - :param cursor: The cursor for pagination. - :param fields: The fields to return in the response. - :return: A dictionary containing the search results. + if not keyword.strip(): + return None + + try: + iterator = UniProt.search(keyword, fields=None, batch_size=1) + hit = next(iterator, None) + if hit is None: + return None + return self.get_by_accession(hit["primaryAccession"]) + + except Exception as e: # pylint: disable=broad-except + logger.error("Keyword %s not found: %s", keyword, e) + return None + + def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]: """ - params = { - "query": query, - "size": size, - } - if cursor: - params["cursor"] = cursor - if fields: - params["fields"] = ",".join(fields) - url = UNIPROT_BASE - return self._safe_get(url, params=params).json() + Search UniProt with a FASTA sequence and return the best hit. + :param fasta_sequence: The FASTA sequence. + :param threshold: E-value threshold for BLAST search. + :return: A dictionary containing the best hit information or None if not found. + """ + try: + if fasta_sequence.startswith(">"): + seq = str(list(SeqIO.parse(StringIO(fasta_sequence), "fasta"))[0].seq) + else: + seq = fasta_sequence.strip() + except Exception as e: # pylint: disable=broad-except + logger.error("Invalid FASTA sequence: %s", e) + return None - @staticmethod - def _safe_get(url: str, params: dict = None) -> requests.Response: - r = requests.get( - url, - params=params, - headers={"Accept": "application/json"}, - timeout=10, - ) - if not r.ok: - logger.error("Search engine error: %s", r.text) - raise HTTPException(r.status_code, "Search engine error.") - return r + if not seq: + logger.error("Empty FASTA sequence provided.") + return None + + # UniProtKB/Swiss-Prot BLAST API + try: + result_handle = NCBIWWW.qblast( + program="blastp", + database="swissprot", + sequence=seq, + hitlist_size=1, + expect=threshold, + ) + blast_record = NCBIXML.read(result_handle) + except Exception as e: # pylint: disable=broad-except + logger.error("BLAST search failed: %s", e) + return None + + if not blast_record.alignments: + logger.info("No BLAST hits found for the given sequence.") + return None + + best_alignment = blast_record.alignments[0] + best_hsp = best_alignment.hsps[0] + if best_hsp.expect > threshold: + logger.info("No BLAST hits below the threshold E-value.") + return None + hit_id = best_alignment.hit_id + + # like sp|P01308.1|INS_HUMAN + accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id + return self.get_by_accession(accession) diff --git a/graphgen/operators/search/db/__init__.py b/graphgen/operators/search/db/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/graphgen/operators/search/db/search_uniprot.py b/graphgen/operators/search/db/search_uniprot.py deleted file mode 100644 index e69de29b..00000000 diff --git a/requirements.txt b/requirements.txt index 82740f03..ce223306 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,5 +24,8 @@ leidenalg igraph python-louvain +# Bioinformatics +biopython + # For visualization matplotlib