Skip to content

Commit 9382660

Browse files
committed
feat: add DNA RNA local blast
1 parent 40ef49e commit 9382660

File tree

8 files changed

+529
-10
lines changed

8 files changed

+529
-10
lines changed

graphgen/configs/search_dna_config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@ pipeline:
1212
ncbi_params:
1313
email: [email protected] # NCBI requires an email address
1414
tool: GraphGen # tool name for NCBI API
15+
use_local_blast: true # whether to use local blast for DNA search
16+
local_blast_db: /your_path/refseq_241 # path to local BLAST database (without .nhr extension)
1517

graphgen/configs/search_protein_config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ pipeline:
1111
data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot
1212
uniprot_params:
1313
use_local_blast: true # whether to use local blast for uniprot search
14-
local_blast_db: /your_path/uniprot_sprot
14+
local_blast_db: /your_path/2024_01/uniprot_sprot # format: /path/to/${RELEASE}/uniprot_sprot
15+
# options: uniprot_sprot (recommended, high quality), uniprot_trembl, or uniprot_${RELEASE} (merged database)

graphgen/configs/search_rna_config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,7 @@ pipeline:
1010
params:
1111
data_sources: [rnacentral] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral
1212
rnacentral_params:
13-
{} # RNAcentral doesn't require additional parameters currently
13+
use_local_blast: true # whether to use local blast for RNA search
14+
local_blast_db: /your_path/refseq_rna_241 # format: /path/to/refseq_rna_${RELEASE}
15+
# can also use DNA database with RNA sequences (if already built)
1416

graphgen/models/searcher/db/ncbi_searcher.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import asyncio
22
import logging
3+
import os
34
import re
5+
import subprocess
6+
import tempfile
47
import time
58
from concurrent.futures import ThreadPoolExecutor
69
from functools import lru_cache
@@ -38,11 +41,22 @@ class NCBISearch(BaseSearcher):
3841
Note: NCBI has rate limits (max 3 requests per second), delays are required between requests.
3942
"""
4043

41-
def __init__(self, email: str = "[email protected]", tool: str = "GraphGen"):
44+
def __init__(
45+
self,
46+
email: str = "[email protected]",
47+
tool: str = "GraphGen",
48+
use_local_blast: bool = False,
49+
local_blast_db: str = "nt_db",
50+
):
4251
super().__init__()
4352
Entrez.email = email
4453
Entrez.tool = tool
4554
Entrez.timeout = 60 # 60 seconds timeout
55+
self.use_local_blast = use_local_blast
56+
self.local_blast_db = local_blast_db
57+
if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.nhr"):
58+
logger.error("Local BLAST database files not found. Please check the path.")
59+
self.use_local_blast = False
4660

4761
@staticmethod
4862
def _safe_get(obj, key, default=None):
@@ -518,10 +532,47 @@ def get_best_hit(self, keyword: str) -> Optional[dict]:
518532
logger.error("Keyword %s not found: %s", keyword, e)
519533
return None
520534

535+
def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
536+
"""
537+
Perform local BLAST search using local BLAST database.
538+
:param seq: The DNA sequence.
539+
:param threshold: E-value threshold for BLAST search.
540+
:return: The accession number of the best hit or None if not found.
541+
"""
542+
try:
543+
with tempfile.NamedTemporaryFile(
544+
mode="w+", suffix=".fa", delete=False
545+
) as tmp:
546+
tmp.write(f">query\n{seq}\n")
547+
tmp_name = tmp.name
548+
549+
cmd = [
550+
"blastn",
551+
"-db",
552+
self.local_blast_db,
553+
"-query",
554+
tmp_name,
555+
"-evalue",
556+
str(threshold),
557+
"-max_target_seqs",
558+
"1",
559+
"-outfmt",
560+
"6 sacc", # only return accession
561+
]
562+
logger.debug("Running local blastn: %s", " ".join(cmd))
563+
out = subprocess.check_output(cmd, text=True).strip()
564+
os.remove(tmp_name)
565+
if out:
566+
return out.split("\n", maxsplit=1)[0]
567+
return None
568+
except Exception as exc: # pylint: disable=broad-except
569+
logger.error("Local blastn failed: %s", exc)
570+
return None
571+
521572
def search_by_sequence(self, sequence: str, threshold: float = 0.01) -> Optional[dict]:
522573
"""
523574
Search NCBI with a DNA sequence using BLAST.
524-
Note: This is a simplified version. For production, consider using local BLAST.
575+
Tries local BLAST first if enabled, falls back to network BLAST.
525576
:param sequence: DNA sequence (FASTA format or raw sequence).
526577
:param threshold: E-value threshold for BLAST search.
527578
:return: A dictionary containing the best hit information or None if not found.
@@ -542,7 +593,16 @@ def search_by_sequence(self, sequence: str, threshold: float = 0.01) -> Optional
542593
logger.error("Invalid DNA sequence provided.")
543594
return None
544595

545-
# Use BLAST search (Note: requires network connection, may be slow)
596+
# Try local BLAST first if enabled
597+
accession = None
598+
if self.use_local_blast:
599+
accession = self._local_blast(seq, threshold)
600+
if accession:
601+
logger.debug("Local BLAST found accession: %s", accession)
602+
return self.get_by_accession(accession)
603+
604+
# Fall back to network BLAST
605+
logger.debug("Falling back to NCBIWWW.qblast.")
546606
logger.debug("Performing BLAST search for DNA sequence...")
547607
time.sleep(0.35)
548608

graphgen/models/searcher/db/rnacentral_searcher.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import asyncio
2+
import os
23
import re
4+
import subprocess
5+
import tempfile
36
from typing import Dict, Optional, List, Any
47

58
import aiohttp
@@ -23,10 +26,15 @@ class RNACentralSearch(BaseSearcher):
2326
API Documentation: https://rnacentral.org/api/v1
2427
"""
2528

26-
def __init__(self):
29+
def __init__(self, use_local_blast: bool = False, local_blast_db: str = "rna_db"):
2730
super().__init__()
2831
self.base_url = "https://rnacentral.org/api/v1"
2932
self.headers = {"Accept": "application/json"}
33+
self.use_local_blast = use_local_blast
34+
self.local_blast_db = local_blast_db
35+
if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.nhr"):
36+
logger.error("Local BLAST database files not found. Please check the path.")
37+
self.use_local_blast = False
3038

3139
async def _fetch_all_xrefs(self, xrefs_url: str, session: aiohttp.ClientSession) -> List[Dict]:
3240
"""
@@ -294,11 +302,50 @@ async def get_best_hit(self, keyword: str) -> Optional[dict]:
294302
logger.error("Keyword %s not found: %s", keyword, e)
295303
return None
296304

297-
async def search_by_sequence(self, sequence: str) -> Optional[dict]:
305+
def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
306+
"""
307+
Perform local BLAST search using local BLAST database.
308+
:param seq: The RNA sequence.
309+
:param threshold: E-value threshold for BLAST search.
310+
:return: The accession/ID of the best hit or None if not found.
311+
"""
312+
try:
313+
with tempfile.NamedTemporaryFile(
314+
mode="w+", suffix=".fa", delete=False
315+
) as tmp:
316+
tmp.write(f">query\n{seq}\n")
317+
tmp_name = tmp.name
318+
319+
cmd = [
320+
"blastn",
321+
"-db",
322+
self.local_blast_db,
323+
"-query",
324+
tmp_name,
325+
"-evalue",
326+
str(threshold),
327+
"-max_target_seqs",
328+
"1",
329+
"-outfmt",
330+
"6 sacc", # only return accession
331+
]
332+
logger.debug("Running local blastn for RNA: %s", " ".join(cmd))
333+
out = subprocess.check_output(cmd, text=True).strip()
334+
os.remove(tmp_name)
335+
if out:
336+
return out.split("\n", maxsplit=1)[0]
337+
return None
338+
except Exception as exc: # pylint: disable=broad-except
339+
logger.error("Local blastn failed: %s", exc)
340+
return None
341+
342+
async def search_by_sequence(self, sequence: str, threshold: float = 0.01) -> Optional[dict]:
298343
"""
299344
Search RNAcentral with an RNA sequence.
345+
Tries local BLAST first if enabled, falls back to RNAcentral API.
300346
Unified approach: Find RNA ID from sequence search, then call get_by_rna_id() for complete information.
301347
:param sequence: RNA sequence (FASTA format or raw sequence).
348+
:param threshold: E-value threshold for BLAST search.
302349
:return: A dictionary containing complete RNA information or None if not found.
303350
"""
304351
try:
@@ -318,7 +365,23 @@ async def search_by_sequence(self, sequence: str) -> Optional[dict]:
318365
logger.error("Empty RNA sequence provided.")
319366
return None
320367

321-
# RNAcentral API supports sequence search
368+
# Try local BLAST first if enabled
369+
if self.use_local_blast:
370+
accession = self._local_blast(seq, threshold)
371+
if accession:
372+
logger.debug("Local BLAST found accession: %s", accession)
373+
# Try to get RNA ID from accession (may need conversion)
374+
# For now, try using accession as RNA ID or search by it
375+
result = await self.get_by_rna_id(accession)
376+
if result:
377+
return result
378+
# If not found by ID, try keyword search
379+
result = await self.get_best_hit(accession)
380+
if result:
381+
return result
382+
383+
# Fall back to RNAcentral API
384+
logger.debug("Falling back to RNAcentral API.")
322385
async with aiohttp.ClientSession() as session:
323386
search_url = f"{self.base_url}/rna"
324387
params = {"sequence": seq, "format": "json"}
@@ -373,7 +436,7 @@ async def search_by_sequence(self, sequence: str) -> Optional[dict]:
373436
reraise=True,
374437
)
375438
async def search(
376-
self, query: str, threshold: float = 0.7, **kwargs
439+
self, query: str, threshold: float = 0.1, **kwargs
377440
) -> Optional[Dict]:
378441
"""
379442
Search RNAcentral with either an RNAcentral ID, keyword, or RNA sequence.
@@ -395,7 +458,7 @@ async def search(
395458
if query.startswith(">") or (
396459
re.fullmatch(r"[AUCGN\s]+", query, re.I) and "U" in query.upper()
397460
):
398-
result = await self.search_by_sequence(query)
461+
result = await self.search_by_sequence(query, threshold)
399462
# check if RNAcentral ID (typically starts with URS)
400463
elif re.fullmatch(r"URS\d+", query, re.I):
401464
result = await self.get_by_rna_id(query)

0 commit comments

Comments
 (0)