1818)
1919
2020from graphgen .bases import BaseSearcher
21- from graphgen .utils import logger
21+ from graphgen .utils import logger , load_json
2222
2323
2424@lru_cache (maxsize = None )
2525def _get_pool ():
26- return ThreadPoolExecutor (max_workers = 10 )
26+ return ThreadPoolExecutor (max_workers = 20 ) # NOTE:can increase for better parallelism
2727
2828class RNACentralSearch (BaseSearcher ):
2929 """
@@ -35,12 +35,28 @@ class RNACentralSearch(BaseSearcher):
3535 API Documentation: https://rnacentral.org/api/v1
3636 """
3737
38- def __init__ (self , use_local_blast : bool = False , local_blast_db : str = "rna_db" ):
38+ def __init__ (
39+ self ,
40+ use_local_blast : bool = False ,
41+ local_blast_db : str = "rna_db" ,
42+ api_timeout : int = 5 ,
43+ metadata_db_file : Optional [str ] = None ,
44+ blast_num_threads : int = 4
45+ ):
3946 super ().__init__ ()
4047 self .base_url = "https://rnacentral.org/api/v1"
4148 self .headers = {"Accept" : "application/json" }
4249 self .use_local_blast = use_local_blast
4350 self .local_blast_db = local_blast_db
51+ self .api_timeout = api_timeout
52+ self .metadata_db_file = metadata_db_file
53+ self .blast_num_threads = blast_num_threads # Number of threads for BLAST search
54+
55+ # Load pre-built metadata database if provided
56+ self ._metadata_db : Optional [Dict [str , Optional [dict ]]] = None
57+ if self .metadata_db_file :
58+ self ._load_metadata_db ()
59+
4460 if self .use_local_blast and not os .path .isfile (f"{ self .local_blast_db } .nhr" ):
4561 logger .error ("Local BLAST database files not found. Please check the path." )
4662 self .use_local_blast = False
@@ -142,22 +158,60 @@ def _calculate_md5(sequence: str) -> str:
142158
143159 return hashlib .md5 (normalized_seq .encode ("ascii" )).hexdigest ()
144160
161+ def _load_metadata_db (self ) -> None :
162+ """Load pre-built metadata database from file."""
163+ if not self .metadata_db_file :
164+ return
165+
166+ try :
167+ if os .path .isfile (self .metadata_db_file ):
168+ self ._metadata_db = load_json (self .metadata_db_file )
169+ if self ._metadata_db and isinstance (self ._metadata_db , dict ):
170+ logger .info ("Loaded %d RNA ID entries from metadata database: %s" ,
171+ len (self ._metadata_db ), self .metadata_db_file )
172+ else :
173+ logger .warning ("Metadata database file %s exists but contains invalid data" ,
174+ self .metadata_db_file )
175+ self ._metadata_db = None
176+ else :
177+ logger .warning ("Metadata database file not found: %s" , self .metadata_db_file )
178+ logger .info ("To build the database, run: python -m graphgen.models.searcher.db.build_rna_metadata_db" )
179+ except Exception as e :
180+ logger .warning ("Failed to load metadata database from %s: %s" , self .metadata_db_file , e )
181+ self ._metadata_db = None
182+
145183 def get_by_rna_id (self , rna_id : str ) -> Optional [dict ]:
146184 """
147185 Get RNA information by RNAcentral ID.
186+ First checks pre-built metadata database if available, then falls back to API.
148187 :param rna_id: RNAcentral ID (e.g., URS0000000001).
149188 :return: A dictionary containing RNA information or None if not found.
150189 """
190+ # Check pre-built metadata database first
191+ if self ._metadata_db is not None :
192+ if rna_id in self ._metadata_db :
193+ result = self ._metadata_db [rna_id ]
194+ logger .debug ("Found RNA ID %s in metadata database" , rna_id )
195+ return result
196+ else :
197+ logger .debug ("RNA ID %s not found in metadata database, skipping API call" , rna_id )
198+ return None
199+
200+ # Fall back to API if metadata database not available
151201 try :
152202 url = f"{ self .base_url } /rna/{ rna_id } "
153203 url += "?flat=true"
154204
155- resp = requests .get (url , headers = self .headers , timeout = 30 )
205+ resp = requests .get (url , headers = self .headers , timeout = self . api_timeout )
156206 resp .raise_for_status ()
157207
158208 rna_data = resp .json ()
159209 xrefs_data = rna_data .get ("xrefs" , [])
160- return self ._rna_data_to_dict (rna_id , rna_data , xrefs_data )
210+ result = self ._rna_data_to_dict (rna_id , rna_data , xrefs_data )
211+ return result
212+ except requests .Timeout as e :
213+ logger .warning ("Timeout getting RNA ID %s (timeout=%ds): %s" , rna_id , self .api_timeout , e )
214+ return None
161215 except requests .RequestException as e :
162216 logger .error ("Network error getting RNA ID %s: %s" , rna_id , e )
163217 return None
@@ -179,7 +233,7 @@ def get_best_hit(self, keyword: str) -> Optional[dict]:
179233 try :
180234 url = f"{ self .base_url } /rna"
181235 params = {"search" : keyword , "format" : "json" }
182- resp = requests .get (url , params = params , headers = self .headers , timeout = 30 )
236+ resp = requests .get (url , params = params , headers = self .headers , timeout = self . api_timeout )
183237 resp .raise_for_status ()
184238
185239 data = resp .json ()
@@ -207,22 +261,54 @@ def get_best_hit(self, keyword: str) -> Optional[dict]:
207261 return None
208262
209263 def _local_blast (self , seq : str , threshold : float ) -> Optional [str ]:
210- """Perform local BLAST search using local BLAST database."""
264+ """
265+ Perform local BLAST search using local BLAST database.
266+ Optimized with multi-threading and faster output format.
267+ """
211268 try :
269+ # Use temporary file for query sequence
212270 with tempfile .NamedTemporaryFile (mode = "w+" , suffix = ".fa" , delete = False ) as tmp :
213271 tmp .write (f">query\n { seq } \n " )
214272 tmp_name = tmp .name
215273
274+ # Optimized BLAST command with:
275+ # - num_threads: Use multiple threads for faster search
276+ # - outfmt 6 sacc: Only return accession (minimal output)
277+ # - max_target_seqs 1: Only need the best hit
278+ # - evalue: Threshold for significance
216279 cmd = [
217280 "blastn" , "-db" , self .local_blast_db , "-query" , tmp_name ,
218- "-evalue" , str (threshold ), "-max_target_seqs" , "1" , "-outfmt" , "6 sacc"
281+ "-evalue" , str (threshold ),
282+ "-max_target_seqs" , "1" ,
283+ "-num_threads" , str (self .blast_num_threads ),
284+ "-outfmt" , "6 sacc" # Only accession, tab-separated
219285 ]
220- logger .debug ("Running local blastn for RNA: %s" , " " .join (cmd ))
221- out = subprocess .check_output (cmd , text = True ).strip ()
286+ logger .debug ("Running local blastn for RNA (threads=%d): %s" ,
287+ self .blast_num_threads , " " .join (cmd ))
288+
289+ # Run BLAST with timeout to avoid hanging
290+ try :
291+ out = subprocess .check_output (
292+ cmd ,
293+ text = True ,
294+ timeout = 300 , # 5 minute timeout for BLAST search
295+ stderr = subprocess .DEVNULL # Suppress BLAST warnings to reduce I/O
296+ ).strip ()
297+ except subprocess .TimeoutExpired :
298+ logger .warning ("BLAST search timed out after 5 minutes for sequence" )
299+ os .remove (tmp_name )
300+ return None
301+
222302 os .remove (tmp_name )
223303 return out .split ("\n " , maxsplit = 1 )[0 ] if out else None
224304 except Exception as exc :
225305 logger .error ("Local blastn failed: %s" , exc )
306+ # Clean up temp file if it still exists
307+ try :
308+ if 'tmp_name' in locals ():
309+ os .remove (tmp_name )
310+ except Exception :
311+ pass
226312 return None
227313
228314 def get_by_fasta (self , sequence : str , threshold : float = 0.01 ) -> Optional [dict ]:
@@ -254,7 +340,15 @@ def _extract_sequence(sequence: str) -> Optional[str]:
254340 accession = self ._local_blast (seq , threshold )
255341 if accession :
256342 logger .debug ("Local BLAST found accession: %s" , accession )
257- return self .get_by_rna_id (accession )
343+ detailed = self .get_by_rna_id (accession )
344+ if detailed :
345+ return detailed
346+ logger .info (
347+ "Local BLAST found accession %s but metadata not available in database. "
348+ "API fallback disabled when using local database." ,
349+ accession
350+ )
351+ return None
258352 logger .info (
259353 "Local BLAST found no match for sequence. "
260354 "API fallback disabled when using local database."
@@ -280,7 +374,12 @@ def _extract_sequence(sequence: str) -> Optional[str]:
280374
281375 rna_id = results [0 ].get ("rnacentral_id" )
282376 if rna_id :
283- return self .get_by_rna_id (rna_id )
377+ detailed = self .get_by_rna_id (rna_id )
378+ if detailed :
379+ return detailed
380+ # Fallback: use search result data if get_by_rna_id returns None
381+ logger .debug ("Using search result data for %s (get_by_rna_id returned None)" , rna_id )
382+ return self ._rna_data_to_dict (rna_id , results [0 ])
284383
285384 logger .error ("No RNAcentral ID found in search results." )
286385 return None
0 commit comments