diff --git a/graphgen/configs/search_config.yaml b/graphgen/configs/search_config.yaml index 69b3b9c0..37e65818 100644 --- a/graphgen/configs/search_config.yaml +++ b/graphgen/configs/search_config.yaml @@ -1,8 +1,11 @@ pipeline: - name: read params: - input_file: resources/input_examples/search_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples + input_file: resources/input_examples/search_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - name: search params: data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot + uniprot_params: + use_local_blast: true # whether to use local blast for uniprot search + local_blast_db: /your_path/uniprot_sprot diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 827f57fe..d63ca555 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -68,7 +68,7 @@ def __init__( self.working_dir, namespace="graph" ) self.search_storage: JsonKVStorage = JsonKVStorage( - self.working_dir, namespace="searcher" + self.working_dir, namespace="search" ) self.rephrase_storage: JsonKVStorage = JsonKVStorage( self.working_dir, namespace="rephrase" @@ -190,7 +190,7 @@ async def search(self, search_config: Dict): return search_results = await search_all( seed_data=seeds, - **search_config, + search_config=search_config, ) _add_search_keys = await self.search_storage.filter_keys( diff --git a/graphgen/models/searcher/db/uniprot_searcher.py b/graphgen/models/searcher/db/uniprot_searcher.py index 4856ea90..a74b623e 100644 --- a/graphgen/models/searcher/db/uniprot_searcher.py +++ b/graphgen/models/searcher/db/uniprot_searcher.py @@ -1,4 +1,10 @@ +import asyncio +import os import re +import subprocess +import tempfile +from concurrent.futures import ThreadPoolExecutor +from functools import lru_cache from io import StringIO from typing import Dict, Optional @@ -16,6 +22,11 @@ from graphgen.utils import logger +@lru_cache(maxsize=None) +def _get_pool(): + return ThreadPoolExecutor(max_workers=10) + + class UniProtSearch(BaseSearcher): """ UniProt Search client to searcher with UniProt. @@ -24,6 +35,14 @@ class UniProtSearch(BaseSearcher): 3) Search with FASTA sequence (BLAST searcher). """ + def __init__(self, use_local_blast: bool = False, local_blast_db: str = "sp_db"): + super().__init__() + self.use_local_blast = use_local_blast + self.local_blast_db = local_blast_db + if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.phr"): + logger.error("Local BLAST database files not found. Please check the path.") + self.use_local_blast = False + def get_by_accession(self, accession: str) -> Optional[dict]: try: handle = ExPASy.get_sprot_raw(accession) @@ -101,38 +120,86 @@ def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]: logger.error("Empty FASTA sequence provided.") return None - # UniProtKB/Swiss-Prot BLAST API - try: - logger.debug("Performing BLAST searcher for the given sequence: %s", seq) - result_handle = NCBIWWW.qblast( - program="blastp", - database="swissprot", - sequence=seq, - hitlist_size=1, - expect=threshold, - ) - blast_record = NCBIXML.read(result_handle) - except RequestException: - raise - except Exception as e: # pylint: disable=broad-except - logger.error("BLAST searcher failed: %s", e) - return None + accession = None + if self.use_local_blast: + accession = self._local_blast(seq, threshold) + if accession: + logger.debug("Local BLAST found accession: %s", accession) + + if not accession: + logger.debug("Falling back to NCBIWWW.qblast.") + + # UniProtKB/Swiss-Prot BLAST API + try: + logger.debug( + "Performing BLAST searcher for the given sequence: %s", seq + ) + result_handle = NCBIWWW.qblast( + program="blastp", + database="swissprot", + sequence=seq, + hitlist_size=1, + expect=threshold, + ) + blast_record = NCBIXML.read(result_handle) + except RequestException: + raise + except Exception as e: # pylint: disable=broad-except + logger.error("BLAST searcher failed: %s", e) + return None - if not blast_record.alignments: - logger.info("No BLAST hits found for the given sequence.") - return None + if not blast_record.alignments: + logger.info("No BLAST hits found for the given sequence.") + return None - best_alignment = blast_record.alignments[0] - best_hsp = best_alignment.hsps[0] - if best_hsp.expect > threshold: - logger.info("No BLAST hits below the threshold E-value.") - return None - hit_id = best_alignment.hit_id + best_alignment = blast_record.alignments[0] + best_hsp = best_alignment.hsps[0] + if best_hsp.expect > threshold: + logger.info("No BLAST hits below the threshold E-value.") + return None + hit_id = best_alignment.hit_id - # like sp|P01308.1|INS_HUMAN - accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id + # like sp|P01308.1|INS_HUMAN + accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id return self.get_by_accession(accession) + def _local_blast(self, seq: str, threshold: float) -> Optional[str]: + """ + Perform local BLAST search using local BLAST database. + :param seq: The protein sequence. + :param threshold: E-value threshold for BLAST searcher. + :return: The accession number of the best hit or None if not found. + """ + try: + with tempfile.NamedTemporaryFile( + mode="w+", suffix=".fa", delete=False + ) as tmp: + tmp.write(f">query\n{seq}\n") + tmp_name = tmp.name + + cmd = [ + "blastp", + "-db", + self.local_blast_db, + "-query", + tmp_name, + "-evalue", + str(threshold), + "-max_target_seqs", + "1", + "-outfmt", + "6 sacc", # only return accession + ] + logger.debug("Running local blastp: %s", " ".join(cmd)) + out = subprocess.check_output(cmd, text=True).strip() + os.remove(tmp_name) + if out: + return out.split("\n", maxsplit=1)[0] + return None + except Exception as exc: # pylint: disable=broad-except + logger.error("Local blastp failed: %s", exc) + return None + @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -156,20 +223,29 @@ async def search( query = query.strip() logger.debug("UniProt searcher query: %s", query) + + loop = asyncio.get_running_loop() + # check if fasta sequence if query.startswith(">") or re.fullmatch( r"[ACDEFGHIKLMNPQRSTVWY\s]+", query, re.I ): - result = self.get_by_fasta(query, threshold) + coro = loop.run_in_executor( + _get_pool(), self.get_by_fasta, query, threshold + ) # check if accession number elif re.fullmatch(r"[A-NR-Z0-9]{6,10}", query, re.I): - result = self.get_by_accession(query) + coro = loop.run_in_executor(_get_pool(), self.get_by_accession, query) else: # otherwise treat as keyword - result = self.get_best_hit(query) + coro = loop.run_in_executor(_get_pool(), self.get_best_hit, query) + result = await coro if result: result["_search_query"] = query return result + + +# TODO: use local UniProt database for large-scale searchs diff --git a/graphgen/operators/search/search_all.py b/graphgen/operators/search/search_all.py index 99c71a79..6c543dbf 100644 --- a/graphgen/operators/search/search_all.py +++ b/graphgen/operators/search/search_all.py @@ -14,22 +14,25 @@ async def search_all( seed_data: dict, - data_sources: list[str], + search_config: dict, ) -> dict: """ Perform searches across multiple search types and aggregate the results. :param seed_data: A dictionary containing seed data with entity names. - :param data_sources: A list of search types to perform (e.g., "wikipedia", "google", "bing", "uniprot"). + :param search_config: A dictionary specifying which data sources to use for searching. :return: A dictionary with """ results = {} + data_sources = search_config.get("data_sources", []) for data_source in data_sources: if data_source == "uniprot": from graphgen.models import UniProtSearch - uniprot_search_client = UniProtSearch() + uniprot_search_client = UniProtSearch( + **search_config.get("uniprot_params", {}) + ) data = list(seed_data.values()) data = [d["content"] for d in data if "content" in d]