Skip to content

Commit 8143fff

Browse files
committed
feat: enable faster search
1 parent aa76650 commit 8143fff

3 files changed

Lines changed: 274 additions & 29 deletions

File tree

graphgen/models/searcher/db/ncbi_searcher.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
@lru_cache(maxsize=None)
2626
def _get_pool():
27-
return ThreadPoolExecutor(max_workers=10)
27+
return ThreadPoolExecutor(max_workers=20) # NOTE:can increase for better parallelism
2828

2929

3030
# ensure only one NCBI request at a time
@@ -432,16 +432,29 @@ async def search(self, query: str, threshold: float = 0.01, **kwargs) -> Optiona
432432

433433
loop = asyncio.get_running_loop()
434434

435-
# limit concurrent requests (NCBI rate limit: max 3 requests per second)
436-
async with _ncbi_lock:
437-
# Auto-detect query type and execute in thread pool
438-
if query.startswith(">") or re.fullmatch(r"[ATCGN\s]+", query, re.I):
435+
# Auto-detect query type and execute in thread pool
436+
# Only use lock for network API calls (NCBI rate limit: max 3 requests per second)
437+
# Local BLAST can run in parallel
438+
if query.startswith(">") or re.fullmatch(r"[ATCGN\s]+", query, re.I):
439+
# FASTA sequence: use lock only if using network BLAST
440+
if self.use_local_blast:
441+
# Local BLAST can run in parallel, no lock needed
439442
result = await loop.run_in_executor(_get_pool(), self.get_by_fasta, query, threshold)
440-
elif re.fullmatch(r"^\d+$", query):
443+
else:
444+
# Network BLAST needs lock to respect rate limits
445+
async with _ncbi_lock:
446+
result = await loop.run_in_executor(_get_pool(), self.get_by_fasta, query, threshold)
447+
elif re.fullmatch(r"^\d+$", query):
448+
# Gene ID: always use lock (network API call)
449+
async with _ncbi_lock:
441450
result = await loop.run_in_executor(_get_pool(), self.get_by_gene_id, query)
442-
elif re.fullmatch(r"[A-Z]{2}_\d+\.?\d*", query, re.I):
451+
elif re.fullmatch(r"[A-Z]{2}_\d+\.?\d*", query, re.I):
452+
# Accession: always use lock (network API call)
453+
async with _ncbi_lock:
443454
result = await loop.run_in_executor(_get_pool(), self.get_by_accession, query)
444-
else:
455+
else:
456+
# Keyword: always use lock (network API call)
457+
async with _ncbi_lock:
445458
result = await loop.run_in_executor(_get_pool(), self.get_best_hit, query)
446459

447460
if result:

graphgen/models/searcher/db/rnacentral_searcher.py

Lines changed: 111 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
)
1919

2020
from graphgen.bases import BaseSearcher
21-
from graphgen.utils import logger
21+
from graphgen.utils import logger, load_json
2222

2323

2424
@lru_cache(maxsize=None)
2525
def _get_pool():
26-
return ThreadPoolExecutor(max_workers=10)
26+
return ThreadPoolExecutor(max_workers=20) # NOTE:can increase for better parallelism
2727

2828
class RNACentralSearch(BaseSearcher):
2929
"""
@@ -35,12 +35,28 @@ class RNACentralSearch(BaseSearcher):
3535
API Documentation: https://rnacentral.org/api/v1
3636
"""
3737

38-
def __init__(self, use_local_blast: bool = False, local_blast_db: str = "rna_db"):
38+
def __init__(
39+
self,
40+
use_local_blast: bool = False,
41+
local_blast_db: str = "rna_db",
42+
api_timeout: int = 5,
43+
metadata_db_file: Optional[str] = None,
44+
blast_num_threads: int = 4
45+
):
3946
super().__init__()
4047
self.base_url = "https://rnacentral.org/api/v1"
4148
self.headers = {"Accept": "application/json"}
4249
self.use_local_blast = use_local_blast
4350
self.local_blast_db = local_blast_db
51+
self.api_timeout = api_timeout
52+
self.metadata_db_file = metadata_db_file
53+
self.blast_num_threads = blast_num_threads # Number of threads for BLAST search
54+
55+
# Load pre-built metadata database if provided
56+
self._metadata_db: Optional[Dict[str, Optional[dict]]] = None
57+
if self.metadata_db_file:
58+
self._load_metadata_db()
59+
4460
if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.nhr"):
4561
logger.error("Local BLAST database files not found. Please check the path.")
4662
self.use_local_blast = False
@@ -142,22 +158,60 @@ def _calculate_md5(sequence: str) -> str:
142158

143159
return hashlib.md5(normalized_seq.encode("ascii")).hexdigest()
144160

161+
def _load_metadata_db(self) -> None:
162+
"""Load pre-built metadata database from file."""
163+
if not self.metadata_db_file:
164+
return
165+
166+
try:
167+
if os.path.isfile(self.metadata_db_file):
168+
self._metadata_db = load_json(self.metadata_db_file)
169+
if self._metadata_db and isinstance(self._metadata_db, dict):
170+
logger.info("Loaded %d RNA ID entries from metadata database: %s",
171+
len(self._metadata_db), self.metadata_db_file)
172+
else:
173+
logger.warning("Metadata database file %s exists but contains invalid data",
174+
self.metadata_db_file)
175+
self._metadata_db = None
176+
else:
177+
logger.warning("Metadata database file not found: %s", self.metadata_db_file)
178+
logger.info("To build the database, run: python -m graphgen.models.searcher.db.build_rna_metadata_db")
179+
except Exception as e:
180+
logger.warning("Failed to load metadata database from %s: %s", self.metadata_db_file, e)
181+
self._metadata_db = None
182+
145183
def get_by_rna_id(self, rna_id: str) -> Optional[dict]:
146184
"""
147185
Get RNA information by RNAcentral ID.
186+
First checks pre-built metadata database if available, then falls back to API.
148187
:param rna_id: RNAcentral ID (e.g., URS0000000001).
149188
:return: A dictionary containing RNA information or None if not found.
150189
"""
190+
# Check pre-built metadata database first
191+
if self._metadata_db is not None:
192+
if rna_id in self._metadata_db:
193+
result = self._metadata_db[rna_id]
194+
logger.debug("Found RNA ID %s in metadata database", rna_id)
195+
return result
196+
else:
197+
logger.debug("RNA ID %s not found in metadata database, skipping API call", rna_id)
198+
return None
199+
200+
# Fall back to API if metadata database not available
151201
try:
152202
url = f"{self.base_url}/rna/{rna_id}"
153203
url += "?flat=true"
154204

155-
resp = requests.get(url, headers=self.headers, timeout=30)
205+
resp = requests.get(url, headers=self.headers, timeout=self.api_timeout)
156206
resp.raise_for_status()
157207

158208
rna_data = resp.json()
159209
xrefs_data = rna_data.get("xrefs", [])
160-
return self._rna_data_to_dict(rna_id, rna_data, xrefs_data)
210+
result = self._rna_data_to_dict(rna_id, rna_data, xrefs_data)
211+
return result
212+
except requests.Timeout as e:
213+
logger.warning("Timeout getting RNA ID %s (timeout=%ds): %s", rna_id, self.api_timeout, e)
214+
return None
161215
except requests.RequestException as e:
162216
logger.error("Network error getting RNA ID %s: %s", rna_id, e)
163217
return None
@@ -179,7 +233,7 @@ def get_best_hit(self, keyword: str) -> Optional[dict]:
179233
try:
180234
url = f"{self.base_url}/rna"
181235
params = {"search": keyword, "format": "json"}
182-
resp = requests.get(url, params=params, headers=self.headers, timeout=30)
236+
resp = requests.get(url, params=params, headers=self.headers, timeout=self.api_timeout)
183237
resp.raise_for_status()
184238

185239
data = resp.json()
@@ -207,22 +261,54 @@ def get_best_hit(self, keyword: str) -> Optional[dict]:
207261
return None
208262

209263
def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
210-
"""Perform local BLAST search using local BLAST database."""
264+
"""
265+
Perform local BLAST search using local BLAST database.
266+
Optimized with multi-threading and faster output format.
267+
"""
211268
try:
269+
# Use temporary file for query sequence
212270
with tempfile.NamedTemporaryFile(mode="w+", suffix=".fa", delete=False) as tmp:
213271
tmp.write(f">query\n{seq}\n")
214272
tmp_name = tmp.name
215273

274+
# Optimized BLAST command with:
275+
# - num_threads: Use multiple threads for faster search
276+
# - outfmt 6 sacc: Only return accession (minimal output)
277+
# - max_target_seqs 1: Only need the best hit
278+
# - evalue: Threshold for significance
216279
cmd = [
217280
"blastn", "-db", self.local_blast_db, "-query", tmp_name,
218-
"-evalue", str(threshold), "-max_target_seqs", "1", "-outfmt", "6 sacc"
281+
"-evalue", str(threshold),
282+
"-max_target_seqs", "1",
283+
"-num_threads", str(self.blast_num_threads),
284+
"-outfmt", "6 sacc" # Only accession, tab-separated
219285
]
220-
logger.debug("Running local blastn for RNA: %s", " ".join(cmd))
221-
out = subprocess.check_output(cmd, text=True).strip()
286+
logger.debug("Running local blastn for RNA (threads=%d): %s",
287+
self.blast_num_threads, " ".join(cmd))
288+
289+
# Run BLAST with timeout to avoid hanging
290+
try:
291+
out = subprocess.check_output(
292+
cmd,
293+
text=True,
294+
timeout=300, # 5 minute timeout for BLAST search
295+
stderr=subprocess.DEVNULL # Suppress BLAST warnings to reduce I/O
296+
).strip()
297+
except subprocess.TimeoutExpired:
298+
logger.warning("BLAST search timed out after 5 minutes for sequence")
299+
os.remove(tmp_name)
300+
return None
301+
222302
os.remove(tmp_name)
223303
return out.split("\n", maxsplit=1)[0] if out else None
224304
except Exception as exc:
225305
logger.error("Local blastn failed: %s", exc)
306+
# Clean up temp file if it still exists
307+
try:
308+
if 'tmp_name' in locals():
309+
os.remove(tmp_name)
310+
except Exception:
311+
pass
226312
return None
227313

228314
def get_by_fasta(self, sequence: str, threshold: float = 0.01) -> Optional[dict]:
@@ -254,7 +340,15 @@ def _extract_sequence(sequence: str) -> Optional[str]:
254340
accession = self._local_blast(seq, threshold)
255341
if accession:
256342
logger.debug("Local BLAST found accession: %s", accession)
257-
return self.get_by_rna_id(accession)
343+
detailed = self.get_by_rna_id(accession)
344+
if detailed:
345+
return detailed
346+
logger.info(
347+
"Local BLAST found accession %s but metadata not available in database. "
348+
"API fallback disabled when using local database.",
349+
accession
350+
)
351+
return None
258352
logger.info(
259353
"Local BLAST found no match for sequence. "
260354
"API fallback disabled when using local database."
@@ -280,7 +374,12 @@ def _extract_sequence(sequence: str) -> Optional[str]:
280374

281375
rna_id = results[0].get("rnacentral_id")
282376
if rna_id:
283-
return self.get_by_rna_id(rna_id)
377+
detailed = self.get_by_rna_id(rna_id)
378+
if detailed:
379+
return detailed
380+
# Fallback: use search result data if get_by_rna_id returns None
381+
logger.debug("Using search result data for %s (get_by_rna_id returned None)", rna_id)
382+
return self._rna_data_to_dict(rna_id, results[0])
284383

285384
logger.error("No RNAcentral ID found in search results.")
286385
return None

0 commit comments

Comments
 (0)