|
1 | | -import requests |
2 | | -from fastapi import HTTPException |
| 1 | +from io import StringIO |
| 2 | +from typing import Dict, Optional |
3 | 3 |
|
4 | | -from graphgen.utils import logger |
| 4 | +from Bio import ExPASy, SeqIO, SwissProt, UniProt |
| 5 | +from Bio.Blast import NCBIWWW, NCBIXML |
5 | 6 |
|
6 | | -UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search" |
| 7 | +from graphgen.utils import logger |
7 | 8 |
|
8 | 9 |
|
9 | 10 | class UniProtSearch: |
10 | 11 | """ |
11 | 12 | UniProt Search client to search with UniProt. |
12 | 13 | 1) Get the protein by accession number. |
13 | | - 2) Search with keywords or protein names. |
| 14 | + 2) Search with keywords or protein names (fuzzy search). |
14 | 15 | """ |
15 | 16 |
|
16 | | - def get_entry(self, accession: str) -> dict: |
| 17 | + def get_by_accession(self, accession: str) -> Optional[dict]: |
| 18 | + try: |
| 19 | + handle = ExPASy.get_sprot_raw(accession) |
| 20 | + record = SwissProt.read(handle) |
| 21 | + handle.close() |
| 22 | + return self._swissprot_to_dict(record) |
| 23 | + except Exception as exc: # pylint: disable=broad-except |
| 24 | + logger.error("Accession %s not found: %s", accession, exc) |
| 25 | + return None |
| 26 | + |
| 27 | + @staticmethod |
| 28 | + def _swissprot_to_dict(record: SwissProt.Record) -> dict: |
| 29 | + """error |
| 30 | + Convert a SwissProt.Record to a dictionary. |
17 | 31 | """ |
18 | | - Get the UniProt entry by accession number(e.g., P04637). |
| 32 | + functions = [] |
| 33 | + for line in record.comments: |
| 34 | + if line.startswith("FUNCTION:"): |
| 35 | + functions.append(line[9:].strip()) |
| 36 | + |
| 37 | + return { |
| 38 | + "molecule_type": "protein", |
| 39 | + "database": "UniProt", |
| 40 | + "id": record.accessions[0], |
| 41 | + "entry_name": record.entry_name, |
| 42 | + "gene_names": record.gene_name, |
| 43 | + "protein_name": record.description.split(";")[0].split("=")[-1], |
| 44 | + "organism": record.organism.split(" (")[0], |
| 45 | + "sequence": str(record.sequence), |
| 46 | + "function": functions, |
| 47 | + "url": f"https://www.uniprot.org/uniprot/{record.accessions[0]}", |
| 48 | + } |
| 49 | + |
| 50 | + def get_best_hit(self, keyword: str) -> Optional[Dict]: |
19 | 51 | """ |
20 | | - url = f"{UNIPROT_BASE}/{accession}.json" |
21 | | - return self._safe_get(url).json() |
22 | | - |
23 | | - def search( |
24 | | - self, |
25 | | - query: str, |
26 | | - *, |
27 | | - size: int = 10, |
28 | | - cursor: str = None, |
29 | | - fields: list[str] = None, |
30 | | - ) -> dict: |
| 52 | + Search UniProt with a keyword and return the best hit. |
| 53 | + :param keyword: The search keyword. |
| 54 | + :return: A dictionary containing the best hit information or None if not found. |
31 | 55 | """ |
32 | | - Search UniProt with a query string. |
33 | | - :param query: The search query. |
34 | | - :param size: The number of results to return. |
35 | | - :param cursor: The cursor for pagination. |
36 | | - :param fields: The fields to return in the response. |
37 | | - :return: A dictionary containing the search results. |
| 56 | + if not keyword.strip(): |
| 57 | + return None |
| 58 | + |
| 59 | + try: |
| 60 | + iterator = UniProt.search(keyword, fields=None, batch_size=1) |
| 61 | + hit = next(iterator, None) |
| 62 | + if hit is None: |
| 63 | + return None |
| 64 | + return self.get_by_accession(hit["primaryAccession"]) |
| 65 | + |
| 66 | + except Exception as e: # pylint: disable=broad-except |
| 67 | + logger.error("Keyword %s not found: %s", keyword, e) |
| 68 | + return None |
| 69 | + |
| 70 | + def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]: |
38 | 71 | """ |
39 | | - params = { |
40 | | - "query": query, |
41 | | - "size": size, |
42 | | - } |
43 | | - if cursor: |
44 | | - params["cursor"] = cursor |
45 | | - if fields: |
46 | | - params["fields"] = ",".join(fields) |
47 | | - url = UNIPROT_BASE |
48 | | - return self._safe_get(url, params=params).json() |
| 72 | + Search UniProt with a FASTA sequence and return the best hit. |
| 73 | + :param fasta_sequence: The FASTA sequence. |
| 74 | + :param threshold: E-value threshold for BLAST search. |
| 75 | + :return: A dictionary containing the best hit information or None if not found. |
| 76 | + """ |
| 77 | + try: |
| 78 | + if fasta_sequence.startswith(">"): |
| 79 | + seq = str(list(SeqIO.parse(StringIO(fasta_sequence), "fasta"))[0].seq) |
| 80 | + else: |
| 81 | + seq = fasta_sequence.strip() |
| 82 | + except Exception as e: # pylint: disable=broad-except |
| 83 | + logger.error("Invalid FASTA sequence: %s", e) |
| 84 | + return None |
49 | 85 |
|
50 | | - @staticmethod |
51 | | - def _safe_get(url: str, params: dict = None) -> requests.Response: |
52 | | - r = requests.get( |
53 | | - url, |
54 | | - params=params, |
55 | | - headers={"Accept": "application/json"}, |
56 | | - timeout=10, |
57 | | - ) |
58 | | - if not r.ok: |
59 | | - logger.error("Search engine error: %s", r.text) |
60 | | - raise HTTPException(r.status_code, "Search engine error.") |
61 | | - return r |
| 86 | + if not seq: |
| 87 | + logger.error("Empty FASTA sequence provided.") |
| 88 | + return None |
| 89 | + |
| 90 | + # UniProtKB/Swiss-Prot BLAST API |
| 91 | + try: |
| 92 | + result_handle = NCBIWWW.qblast( |
| 93 | + program="blastp", |
| 94 | + database="swissprot", |
| 95 | + sequence=seq, |
| 96 | + hitlist_size=1, |
| 97 | + expect=threshold, |
| 98 | + ) |
| 99 | + blast_record = NCBIXML.read(result_handle) |
| 100 | + except Exception as e: # pylint: disable=broad-except |
| 101 | + logger.error("BLAST search failed: %s", e) |
| 102 | + return None |
| 103 | + |
| 104 | + if not blast_record.alignments: |
| 105 | + logger.info("No BLAST hits found for the given sequence.") |
| 106 | + return None |
| 107 | + |
| 108 | + best_alignment = blast_record.alignments[0] |
| 109 | + best_hsp = best_alignment.hsps[0] |
| 110 | + if best_hsp.expect > threshold: |
| 111 | + logger.info("No BLAST hits below the threshold E-value.") |
| 112 | + return None |
| 113 | + hit_id = best_alignment.hit_id |
| 114 | + |
| 115 | + # like sp|P01308.1|INS_HUMAN |
| 116 | + accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id |
| 117 | + return self.get_by_accession(accession) |
0 commit comments