Skip to content

Commit 2a715de

Browse files
committed
style: reduce return statements and branches in searcher methods
- Refactor search_by_sequence in ncbi_searcher.py to reduce return statements from 7 to 1 - Refactor search_by_sequence in rnacentral_searcher.py to reduce return statements from 8 to 1 and branches from 16 to 12 - Extract helper methods to improve code readability and maintainability - Fix pylint errors R0911 (too-many-return-statements) and R0912 (too-many-branches)
1 parent 9382660 commit 2a715de

File tree

2 files changed

+131
-125
lines changed

2 files changed

+131
-125
lines changed

graphgen/models/searcher/db/ncbi_searcher.py

Lines changed: 61 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,45 @@ def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
569569
logger.error("Local blastn failed: %s", exc)
570570
return None
571571

572+
def _extract_and_normalize_sequence(self, sequence: str) -> Optional[str]:
573+
"""Extract and normalize DNA sequence from input."""
574+
if sequence.startswith(">"):
575+
seq_lines = sequence.strip().split("\n")
576+
seq = "".join(seq_lines[1:])
577+
else:
578+
seq = sequence.strip().replace(" ", "").replace("\n", "")
579+
return seq if seq and re.fullmatch(r"[ATCGN\s]+", seq, re.I) else None
580+
581+
def _process_network_blast_result(self, blast_record, seq: str, threshold: float) -> Optional[dict]:
582+
"""Process network BLAST result and return dictionary or None."""
583+
if not blast_record.alignments:
584+
logger.info("No BLAST hits found for the given sequence.")
585+
return None
586+
587+
best_alignment = blast_record.alignments[0]
588+
best_hsp = best_alignment.hsps[0]
589+
if best_hsp.expect > threshold:
590+
logger.info("No BLAST hits below the threshold E-value.")
591+
return None
592+
593+
hit_id = best_alignment.hit_id
594+
accession_match = re.search(r"ref\|([^|]+)", hit_id)
595+
if accession_match:
596+
accession = accession_match.group(1).split(".")[0]
597+
return self.get_by_accession(accession)
598+
599+
# If unable to extract accession, return basic information
600+
return {
601+
"molecule_type": "DNA",
602+
"database": "NCBI",
603+
"id": hit_id,
604+
"title": best_alignment.title,
605+
"sequence_length": len(seq),
606+
"e_value": best_hsp.expect,
607+
"identity": best_hsp.identities / best_hsp.align_length if best_hsp.align_length > 0 else 0,
608+
"url": f"https://www.ncbi.nlm.nih.gov/nuccore/{hit_id}",
609+
}
610+
572611
def search_by_sequence(self, sequence: str, threshold: float = 0.01) -> Optional[dict]:
573612
"""
574613
Search NCBI with a DNA sequence using BLAST.
@@ -577,77 +616,40 @@ def search_by_sequence(self, sequence: str, threshold: float = 0.01) -> Optional
577616
:param threshold: E-value threshold for BLAST search.
578617
:return: A dictionary containing the best hit information or None if not found.
579618
"""
619+
result = None
580620
try:
581-
# Extract sequence (if in FASTA format)
582-
if sequence.startswith(">"):
583-
seq_lines = sequence.strip().split("\n")
584-
seq = "".join(seq_lines[1:])
585-
else:
586-
seq = sequence.strip().replace(" ", "").replace("\n", "")
587-
588-
# Validate sequence
589-
if not seq or not re.fullmatch(r"[ATCGN\s]+", seq, re.I):
590-
if not seq:
591-
logger.error("Empty DNA sequence provided.")
592-
else:
593-
logger.error("Invalid DNA sequence provided.")
621+
seq = self._extract_and_normalize_sequence(sequence)
622+
if not seq:
623+
logger.error("Empty or invalid DNA sequence provided.")
594624
return None
595625

596626
# Try local BLAST first if enabled
597-
accession = None
598627
if self.use_local_blast:
599628
accession = self._local_blast(seq, threshold)
600629
if accession:
601630
logger.debug("Local BLAST found accession: %s", accession)
602-
return self.get_by_accession(accession)
603-
604-
# Fall back to network BLAST
605-
logger.debug("Falling back to NCBIWWW.qblast.")
606-
logger.debug("Performing BLAST search for DNA sequence...")
607-
time.sleep(0.35)
608-
609-
result_handle = NCBIWWW.qblast(
610-
program="blastn",
611-
database="nr",
612-
sequence=seq,
613-
hitlist_size=1,
614-
expect=threshold,
615-
)
616-
blast_record = NCBIXML.read(result_handle)
617-
618-
if not blast_record.alignments:
619-
logger.info("No BLAST hits found for the given sequence.")
620-
return None
621-
622-
best_alignment = blast_record.alignments[0]
623-
best_hsp = best_alignment.hsps[0]
624-
if best_hsp.expect > threshold:
625-
logger.info("No BLAST hits below the threshold E-value.")
626-
return None
627-
hit_id = best_alignment.hit_id
628-
629-
# Extract accession number
630-
# Format may be: gi|123456|ref|NM_000546.5|
631-
accession_match = re.search(r"ref\|([^|]+)", hit_id)
632-
if accession_match:
633-
accession = accession_match.group(1).split(".")[0]
634-
return self.get_by_accession(accession)
635-
# If unable to extract accession, return basic information
636-
return {
637-
"molecule_type": "DNA",
638-
"database": "NCBI",
639-
"id": hit_id,
640-
"title": best_alignment.title,
641-
"sequence_length": len(seq),
642-
"e_value": best_hsp.expect,
643-
"identity": best_hsp.identities / best_hsp.align_length if best_hsp.align_length > 0 else 0,
644-
"url": f"https://www.ncbi.nlm.nih.gov/nuccore/{hit_id}",
645-
}
631+
result = self.get_by_accession(accession)
632+
633+
# Fall back to network BLAST if local BLAST didn't find result
634+
if not result:
635+
logger.debug("Falling back to NCBIWWW.qblast.")
636+
logger.debug("Performing BLAST search for DNA sequence...")
637+
time.sleep(0.35)
638+
639+
result_handle = NCBIWWW.qblast(
640+
program="blastn",
641+
database="nr",
642+
sequence=seq,
643+
hitlist_size=1,
644+
expect=threshold,
645+
)
646+
blast_record = NCBIXML.read(result_handle)
647+
result = self._process_network_blast_result(blast_record, seq, threshold)
646648
except RequestException:
647649
raise
648650
except Exception as e: # pylint: disable=broad-except
649651
logger.error("BLAST search failed: %s", e)
650-
return None
652+
return result
651653

652654
@retry(
653655
stop=stop_after_attempt(5),

graphgen/models/searcher/db/rnacentral_searcher.py

Lines changed: 70 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,49 @@ def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
339339
logger.error("Local blastn failed: %s", exc)
340340
return None
341341

342+
@staticmethod
343+
def _extract_and_normalize_sequence(sequence: str) -> Optional[str]:
344+
"""Extract and normalize RNA sequence from input."""
345+
if sequence.startswith(">"):
346+
seq_lines = sequence.strip().split("\n")
347+
seq = "".join(seq_lines[1:])
348+
else:
349+
seq = sequence.strip().replace(" ", "").replace("\n", "")
350+
return seq if seq and re.fullmatch(r"[AUCGN\s]+", seq, re.I) else None
351+
352+
def _find_best_match_from_results(self, results: List[Dict], seq: str) -> Optional[Dict]:
353+
"""Find best match from search results, preferring exact match."""
354+
exact_match = None
355+
for result_item in results:
356+
result_seq = result_item.get("sequence", "")
357+
if result_seq == seq:
358+
exact_match = result_item
359+
break
360+
return exact_match if exact_match else (results[0] if results else None)
361+
362+
async def _process_api_search_results(
363+
self, results: List[Dict], seq: str
364+
) -> Optional[dict]:
365+
"""Process API search results and return dictionary or None."""
366+
if not results:
367+
logger.info("No results found for sequence.")
368+
return None
369+
370+
target_result = self._find_best_match_from_results(results, seq)
371+
if not target_result:
372+
return None
373+
374+
rna_id = target_result.get("rnacentral_id")
375+
if not rna_id:
376+
return None
377+
378+
# Try to get complete information
379+
result = await self.get_by_rna_id(rna_id)
380+
if not result:
381+
logger.debug("get_by_rna_id() failed for %s, using search result data", rna_id)
382+
result = self._rna_data_to_dict(rna_id, target_result)
383+
return result
384+
342385
async def search_by_sequence(self, sequence: str, threshold: float = 0.01) -> Optional[dict]:
343386
"""
344387
Search RNAcentral with an RNA sequence.
@@ -348,86 +391,47 @@ async def search_by_sequence(self, sequence: str, threshold: float = 0.01) -> Op
348391
:param threshold: E-value threshold for BLAST search.
349392
:return: A dictionary containing complete RNA information or None if not found.
350393
"""
394+
result = None
351395
try:
352-
# Extract sequence (if in FASTA format)
353-
if sequence.startswith(">"):
354-
seq_lines = sequence.strip().split("\n")
355-
seq = "".join(seq_lines[1:])
356-
else:
357-
seq = sequence.strip().replace(" ", "").replace("\n", "")
358-
359-
# Validate if it's an RNA sequence (contains U instead of T)
360-
if not re.fullmatch(r"[AUCGN\s]+", seq, re.I):
361-
logger.error("Invalid RNA sequence provided.")
362-
return None
363-
396+
seq = self._extract_and_normalize_sequence(sequence)
364397
if not seq:
365-
logger.error("Empty RNA sequence provided.")
398+
logger.error("Empty or invalid RNA sequence provided.")
366399
return None
367400

368401
# Try local BLAST first if enabled
369402
if self.use_local_blast:
370403
accession = self._local_blast(seq, threshold)
371404
if accession:
372405
logger.debug("Local BLAST found accession: %s", accession)
373-
# Try to get RNA ID from accession (may need conversion)
374-
# For now, try using accession as RNA ID or search by it
375406
result = await self.get_by_rna_id(accession)
376-
if result:
377-
return result
378-
# If not found by ID, try keyword search
379-
result = await self.get_best_hit(accession)
380-
if result:
381-
return result
382-
383-
# Fall back to RNAcentral API
384-
logger.debug("Falling back to RNAcentral API.")
385-
async with aiohttp.ClientSession() as session:
386-
search_url = f"{self.base_url}/rna"
387-
params = {"sequence": seq, "format": "json"}
388-
async with session.get(
389-
search_url,
390-
params=params,
391-
headers=self.headers,
392-
timeout=aiohttp.ClientTimeout(total=60), # Sequence search may take longer
393-
) as resp:
394-
if resp.status == 200:
395-
search_results = await resp.json()
396-
results = search_results.get("results", [])
397-
if results:
398-
# Step 1: Find best match (prefer exact match)
399-
exact_match = None
400-
for result in results:
401-
result_seq = result.get("sequence", "")
402-
if result_seq == seq:
403-
exact_match = result
404-
break
405-
406-
# Use exact match if found, otherwise use first result
407-
target_result = exact_match if exact_match else results[0]
408-
rna_id = target_result.get("rnacentral_id")
409-
410-
if rna_id:
411-
# Step 2: Unified call to get_by_rna_id() for complete information
412-
result = await self.get_by_rna_id(rna_id)
413-
414-
# Step 3: If get_by_rna_id() failed, use search result data as fallback
415-
if not result:
416-
logger.debug("get_by_rna_id() failed for %s, using search result data", rna_id)
417-
result = self._rna_data_to_dict(rna_id, target_result)
418-
419-
return result
420-
logger.info("No results found for sequence.")
421-
return None
422-
error_text = await resp.text()
423-
logger.error("HTTP %d error for sequence search: %s", resp.status, error_text[:200])
424-
raise Exception(f"HTTP {resp.status}: {error_text}")
407+
if not result:
408+
result = await self.get_best_hit(accession)
409+
410+
# Fall back to RNAcentral API if local BLAST didn't find result
411+
if not result:
412+
logger.debug("Falling back to RNAcentral API.")
413+
async with aiohttp.ClientSession() as session:
414+
search_url = f"{self.base_url}/rna"
415+
params = {"sequence": seq, "format": "json"}
416+
async with session.get(
417+
search_url,
418+
params=params,
419+
headers=self.headers,
420+
timeout=aiohttp.ClientTimeout(total=60), # Sequence search may take longer
421+
) as resp:
422+
if resp.status == 200:
423+
search_results = await resp.json()
424+
results = search_results.get("results", [])
425+
result = await self._process_api_search_results(results, seq)
426+
else:
427+
error_text = await resp.text()
428+
logger.error("HTTP %d error for sequence search: %s", resp.status, error_text[:200])
429+
raise Exception(f"HTTP {resp.status}: {error_text}")
425430
except aiohttp.ClientError as e:
426431
logger.error("Network error searching for sequence: %s", e)
427-
return None
428432
except Exception as e: # pylint: disable=broad-except
429433
logger.error("Sequence search failed: %s", e)
430-
return None
434+
return result
431435

432436
@retry(
433437
stop=stop_after_attempt(3),

0 commit comments

Comments
 (0)