Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion graphgen/configs/search_config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
pipeline:
- name: read
params:
input_file: resources/input_examples/search_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
input_file: resources/input_examples/search_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples

- name: search
params:
data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot
uniprot_params:
use_local_blast: true # whether to use local blast for uniprot search
local_blast_db: /your_path/uniprot_sprot
4 changes: 2 additions & 2 deletions graphgen/graphgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
self.working_dir, namespace="graph"
)
self.search_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="searcher"
self.working_dir, namespace="search"
)
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="rephrase"
Expand Down Expand Up @@ -190,7 +190,7 @@ async def search(self, search_config: Dict):
return
search_results = await search_all(
seed_data=seeds,
**search_config,
search_config=search_config,
)

_add_search_keys = await self.search_storage.filter_keys(
Expand Down
136 changes: 106 additions & 30 deletions graphgen/models/searcher/db/uniprot_searcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import asyncio
import os
import re
import subprocess
import tempfile
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from io import StringIO
from typing import Dict, Optional

Expand All @@ -16,6 +22,11 @@
from graphgen.utils import logger


@lru_cache(maxsize=None)
def _get_pool():
return ThreadPoolExecutor(max_workers=10)


class UniProtSearch(BaseSearcher):
"""
UniProt Search client to searcher with UniProt.
Expand All @@ -24,6 +35,14 @@ class UniProtSearch(BaseSearcher):
3) Search with FASTA sequence (BLAST searcher).
"""

def __init__(self, use_local_blast: bool = False, local_blast_db: str = "sp_db"):
super().__init__()
self.use_local_blast = use_local_blast
self.local_blast_db = local_blast_db
if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.phr"):
logger.error("Local BLAST database files not found. Please check the path.")
self.use_local_blast = False

def get_by_accession(self, accession: str) -> Optional[dict]:
try:
handle = ExPASy.get_sprot_raw(accession)
Expand Down Expand Up @@ -101,38 +120,86 @@ def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:
logger.error("Empty FASTA sequence provided.")
return None

# UniProtKB/Swiss-Prot BLAST API
try:
logger.debug("Performing BLAST searcher for the given sequence: %s", seq)
result_handle = NCBIWWW.qblast(
program="blastp",
database="swissprot",
sequence=seq,
hitlist_size=1,
expect=threshold,
)
blast_record = NCBIXML.read(result_handle)
except RequestException:
raise
except Exception as e: # pylint: disable=broad-except
logger.error("BLAST searcher failed: %s", e)
return None
accession = None
if self.use_local_blast:
accession = self._local_blast(seq, threshold)
if accession:
logger.debug("Local BLAST found accession: %s", accession)

if not accession:
logger.debug("Falling back to NCBIWWW.qblast.")

# UniProtKB/Swiss-Prot BLAST API
try:
logger.debug(
"Performing BLAST searcher for the given sequence: %s", seq
)
result_handle = NCBIWWW.qblast(
program="blastp",
database="swissprot",
sequence=seq,
hitlist_size=1,
expect=threshold,
)
blast_record = NCBIXML.read(result_handle)
except RequestException:
raise
except Exception as e: # pylint: disable=broad-except
logger.error("BLAST searcher failed: %s", e)
return None

if not blast_record.alignments:
logger.info("No BLAST hits found for the given sequence.")
return None
if not blast_record.alignments:
logger.info("No BLAST hits found for the given sequence.")
return None

best_alignment = blast_record.alignments[0]
best_hsp = best_alignment.hsps[0]
if best_hsp.expect > threshold:
logger.info("No BLAST hits below the threshold E-value.")
return None
hit_id = best_alignment.hit_id
best_alignment = blast_record.alignments[0]
best_hsp = best_alignment.hsps[0]
if best_hsp.expect > threshold:
logger.info("No BLAST hits below the threshold E-value.")
return None
hit_id = best_alignment.hit_id

# like sp|P01308.1|INS_HUMAN
accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id
# like sp|P01308.1|INS_HUMAN
accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id
return self.get_by_accession(accession)

def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
"""
Perform local BLAST search using local BLAST database.
:param seq: The protein sequence.
:param threshold: E-value threshold for BLAST searcher.
:return: The accession number of the best hit or None if not found.
"""
try:
with tempfile.NamedTemporaryFile(
mode="w+", suffix=".fa", delete=False
) as tmp:
tmp.write(f">query\n{seq}\n")
tmp_name = tmp.name

cmd = [
"blastp",
"-db",
self.local_blast_db,
"-query",
tmp_name,
"-evalue",
str(threshold),
"-max_target_seqs",
"1",
"-outfmt",
"6 sacc", # only return accession
]
logger.debug("Running local blastp: %s", " ".join(cmd))
out = subprocess.check_output(cmd, text=True).strip()
os.remove(tmp_name)
if out:
return out.split("\n", maxsplit=1)[0]
return None
except Exception as exc: # pylint: disable=broad-except
logger.error("Local blastp failed: %s", exc)
return None

@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
Expand All @@ -156,20 +223,29 @@ async def search(
query = query.strip()

logger.debug("UniProt searcher query: %s", query)

loop = asyncio.get_running_loop()

# check if fasta sequence
if query.startswith(">") or re.fullmatch(
r"[ACDEFGHIKLMNPQRSTVWY\s]+", query, re.I
):
result = self.get_by_fasta(query, threshold)
coro = loop.run_in_executor(
_get_pool(), self.get_by_fasta, query, threshold
)

# check if accession number
elif re.fullmatch(r"[A-NR-Z0-9]{6,10}", query, re.I):
result = self.get_by_accession(query)
coro = loop.run_in_executor(_get_pool(), self.get_by_accession, query)

else:
# otherwise treat as keyword
result = self.get_best_hit(query)
coro = loop.run_in_executor(_get_pool(), self.get_best_hit, query)

result = await coro
if result:
result["_search_query"] = query
return result


# TODO: use local UniProt database for large-scale searchs
9 changes: 6 additions & 3 deletions graphgen/operators/search/search_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,25 @@

async def search_all(
seed_data: dict,
data_sources: list[str],
search_config: dict,
) -> dict:
"""
Perform searches across multiple search types and aggregate the results.
:param seed_data: A dictionary containing seed data with entity names.
:param data_sources: A list of search types to perform (e.g., "wikipedia", "google", "bing", "uniprot").
:param search_config: A dictionary specifying which data sources to use for searching.
:return: A dictionary with
"""

results = {}
data_sources = search_config.get("data_sources", [])

for data_source in data_sources:
if data_source == "uniprot":
from graphgen.models import UniProtSearch

uniprot_search_client = UniProtSearch()
uniprot_search_client = UniProtSearch(
**search_config.get("uniprot_params", {})
)

data = list(seed_data.values())
data = [d["content"] for d in data if "content" in d]
Expand Down