Skip to content

Commit ef270b8

Browse files
committed
refactor: unify searcher interfaces and improve error handling
- Extract utility functions (_gene_record_to_dict, _accession_to_dict, _rna_data_to_dict) - Unify method naming: search_by_keyword -> get_best_hit - Add threshold parameter to NCBI and RNAcentral searchers for interface consistency - Improve error handling with network error detection and fallback strategies - Fix RNAcentral sequence search to prioritize exact matches - Add search_rna_demo.jsonl example file
1 parent ea2214c commit ef270b8

3 files changed

Lines changed: 150 additions & 60 deletions

File tree

graphgen/models/searcher/db/ncbi_searcher.py

Lines changed: 69 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,38 @@ def __init__(self, email: str = "test@example.com", tool: str = "GraphGen"):
4141
Entrez.tool = tool
4242
Entrez.timeout = 60 # 60 seconds timeout
4343

44+
@staticmethod
45+
def _gene_record_to_dict(gene_record, gene_id: str) -> dict:
46+
"""
47+
Convert an Entrez gene record to a dictionary.
48+
:param gene_record: The Entrez gene record (list from Entrez.read).
49+
:param gene_id: The gene ID.
50+
:return: A dictionary containing gene information.
51+
"""
52+
if not gene_record:
53+
raise ValueError("Empty gene record")
54+
55+
gene_data = gene_record[0]
56+
gene_ref = gene_data.get("Entrezgene_gene", {}).get("Gene-ref", {})
57+
58+
organism = (
59+
gene_data.get("Entrezgene_source", {})
60+
.get("BioSource", {})
61+
.get("BioSource_org", {})
62+
.get("Org-ref", {})
63+
.get("Org-ref_taxname", "N/A")
64+
)
65+
66+
return {
67+
"molecule_type": "DNA",
68+
"database": "NCBI",
69+
"id": gene_id,
70+
"gene_name": gene_ref.get("Gene-ref_locus", "N/A"),
71+
"gene_description": gene_ref.get("Gene-ref_desc", "N/A"),
72+
"organism": organism,
73+
"url": f"https://www.ncbi.nlm.nih.gov/gene/{gene_id}",
74+
}
75+
4476
def get_by_gene_id(self, gene_id: str) -> Optional[dict]:
4577
"""
4678
Get gene information by Gene ID.
@@ -54,26 +86,7 @@ def get_by_gene_id(self, gene_id: str) -> Optional[dict]:
5486
gene_record = Entrez.read(handle)
5587
if not gene_record:
5688
return None
57-
58-
gene_data = gene_record[0]
59-
gene_ref = gene_data.get("Entrezgene_gene", {}).get("Gene-ref", {})
60-
61-
organism = (
62-
gene_data.get("Entrezgene_source", {})
63-
.get("BioSource", {})
64-
.get("BioSource_org", {})
65-
.get("Org-ref", {})
66-
.get("Org-ref_taxname", "N/A")
67-
)
68-
return {
69-
"molecule_type": "DNA",
70-
"database": "NCBI",
71-
"id": gene_id,
72-
"gene_name": gene_ref.get("Gene-ref_locus", "N/A"),
73-
"gene_description": gene_ref.get("Gene-ref_desc", "N/A"),
74-
"organism": organism,
75-
"url": f"https://www.ncbi.nlm.nih.gov/gene/{gene_id}",
76-
}
89+
return self._gene_record_to_dict(gene_record, gene_id)
7790
finally:
7891
handle.close()
7992
except RequestException:
@@ -82,14 +95,36 @@ def get_by_gene_id(self, gene_id: str) -> Optional[dict]:
8295
logger.error("Gene ID %s not found: %s", gene_id, exc)
8396
return None
8497

98+
@staticmethod
99+
def _accession_to_dict(accession: str, sequence: str, header: str, title: str, organism: str) -> dict:
100+
"""
101+
Convert accession information to a dictionary.
102+
:param accession: NCBI accession number.
103+
:param sequence: DNA sequence.
104+
:param header: FASTA header.
105+
:param title: Sequence title.
106+
:param organism: Organism name.
107+
:return: A dictionary containing sequence information.
108+
"""
109+
return {
110+
"molecule_type": "DNA",
111+
"database": "NCBI",
112+
"id": accession,
113+
"title": title,
114+
"organism": organism,
115+
"sequence": sequence,
116+
"sequence_length": len(sequence),
117+
"url": f"https://www.ncbi.nlm.nih.gov/nuccore/{accession}",
118+
}
119+
85120
def get_by_accession(self, accession: str) -> Optional[dict]:
86121
"""
87122
Get sequence information by accession number.
88123
:param accession: NCBI accession number (e.g., NM_000546).
89124
:return: A dictionary containing sequence information or None if not found.
90125
"""
91126
try:
92-
time.sleep(0.35) # 遵守速率限制
127+
time.sleep(0.35) # Comply with rate limit
93128
handle = Entrez.efetch(
94129
db="nuccore",
95130
id=accession,
@@ -120,16 +155,7 @@ def get_by_accession(self, accession: str) -> Optional[dict]:
120155
finally:
121156
summary_handle.close()
122157

123-
return {
124-
"molecule_type": "DNA",
125-
"database": "NCBI",
126-
"id": accession,
127-
"title": title,
128-
"organism": organism,
129-
"sequence": sequence,
130-
"sequence_length": len(sequence),
131-
"url": f"https://www.ncbi.nlm.nih.gov/nuccore/{accession}",
132-
}
158+
return self._accession_to_dict(accession, sequence, header, title, organism)
133159
finally:
134160
handle.close()
135161
except RequestException:
@@ -138,7 +164,7 @@ def get_by_accession(self, accession: str) -> Optional[dict]:
138164
logger.error("Accession %s not found: %s", accession, exc)
139165
return None
140166

141-
def search_by_keyword(self, keyword: str) -> Optional[dict]:
167+
def get_best_hit(self, keyword: str) -> Optional[dict]:
142168
"""
143169
Search NCBI Gene database with a keyword and return the best hit.
144170
:param keyword: The search keyword (e.g., gene name).
@@ -148,7 +174,7 @@ def search_by_keyword(self, keyword: str) -> Optional[dict]:
148174
return None
149175

150176
try:
151-
time.sleep(0.35) # 遵守速率限制
177+
time.sleep(0.35) # Comply with rate limit
152178
# Search gene database
153179
search_handle = Entrez.esearch(
154180
db="gene",
@@ -181,11 +207,12 @@ def search_by_keyword(self, keyword: str) -> Optional[dict]:
181207
logger.error("Keyword %s not found: %s", keyword, e)
182208
return None
183209

184-
def search_by_sequence(self, sequence: str) -> Optional[dict]:
210+
def search_by_sequence(self, sequence: str, threshold: float = 0.01) -> Optional[dict]:
185211
"""
186212
Search NCBI with a DNA sequence using BLAST.
187213
Note: This is a simplified version. For production, consider using local BLAST.
188214
:param sequence: DNA sequence (FASTA format or raw sequence).
215+
:param threshold: E-value threshold for BLAST search.
189216
:return: A dictionary containing the best hit information or None if not found.
190217
"""
191218
try:
@@ -215,7 +242,7 @@ def search_by_sequence(self, sequence: str) -> Optional[dict]:
215242
database="nr",
216243
sequence=seq,
217244
hitlist_size=1,
218-
expect=0.001,
245+
expect=threshold,
219246
)
220247
blast_record = NCBIXML.read(result_handle)
221248

@@ -225,6 +252,9 @@ def search_by_sequence(self, sequence: str) -> Optional[dict]:
225252

226253
best_alignment = blast_record.alignments[0]
227254
best_hsp = best_alignment.hsps[0]
255+
if best_hsp.expect > threshold:
256+
logger.info("No BLAST hits below the threshold E-value.")
257+
return None
228258
hit_id = best_alignment.hit_id
229259

230260
# Extract accession number
@@ -257,11 +287,12 @@ def search_by_sequence(self, sequence: str) -> Optional[dict]:
257287
reraise=True,
258288
)
259289
async def search(
260-
self, query: str, **kwargs
290+
self, query: str, threshold: float = 0.01, **kwargs
261291
) -> Optional[Dict]:
262292
"""
263293
Search NCBI with either a gene ID, accession number, keyword, or DNA sequence.
264294
:param query: The search query (gene ID, accession, keyword, or DNA sequence).
295+
:param threshold: E-value threshold for BLAST search.
265296
:param kwargs: Additional keyword arguments (not used currently).
266297
:return: A dictionary containing the search results or None if not found.
267298
"""
@@ -278,7 +309,7 @@ async def search(
278309
# check if DNA sequence (ATCG characters)
279310
if query.startswith(">") or re.fullmatch(r"[ATCGN\s]+", query, re.I):
280311
result = await loop.run_in_executor(
281-
_get_pool(), self.search_by_sequence, query
312+
_get_pool(), self.search_by_sequence, query, threshold
282313
)
283314
# check if gene ID (numeric)
284315
elif re.fullmatch(r"^\d+$", query):
@@ -293,7 +324,7 @@ async def search(
293324
else:
294325
# otherwise treat as keyword
295326
result = await loop.run_in_executor(
296-
_get_pool(), self.search_by_keyword, query
327+
_get_pool(), self.get_best_hit, query
297328
)
298329

299330
if result:

graphgen/models/searcher/db/rnacentral_searcher.py

Lines changed: 77 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,27 @@ def __init__(self):
3636
self.base_url = "https://rnacentral.org/api/v1"
3737
self.headers = {"Accept": "application/json"}
3838

39+
@staticmethod
40+
def _rna_data_to_dict(rna_id: str, rna_data: dict) -> dict:
41+
"""
42+
Convert RNAcentral API response to a dictionary.
43+
:param rna_id: RNAcentral ID.
44+
:param rna_data: API response data (dict or dict-like from search results).
45+
:return: A dictionary containing RNA information.
46+
"""
47+
sequence = rna_data.get("sequence", "")
48+
return {
49+
"molecule_type": "RNA",
50+
"database": "RNAcentral",
51+
"id": rna_id,
52+
"rnacentral_id": rna_data.get("rnacentral_id", rna_id),
53+
"sequence": sequence,
54+
"sequence_length": rna_data.get("length", len(sequence)),
55+
"rna_type": rna_data.get("rna_type", "N/A"),
56+
"description": rna_data.get("description", "N/A"),
57+
"url": f"https://rnacentral.org/rna/{rna_id}",
58+
}
59+
3960
async def get_by_rna_id(self, rna_id: str) -> Optional[dict]:
4061
"""
4162
Get RNA information by RNAcentral ID.
@@ -50,26 +71,19 @@ async def get_by_rna_id(self, rna_id: str) -> Optional[dict]:
5071
) as resp:
5172
if resp.status == 200:
5273
rna_data = await resp.json()
53-
return {
54-
"molecule_type": "RNA",
55-
"database": "RNAcentral",
56-
"id": rna_id,
57-
"rnacentral_id": rna_data.get("rnacentral_id", "N/A"),
58-
"sequence": rna_data.get("sequence", ""),
59-
"sequence_length": len(rna_data.get("sequence", "")),
60-
"rna_type": rna_data.get("rna_type", "N/A"),
61-
"description": rna_data.get("description", "N/A"),
62-
"url": f"https://rnacentral.org/rna/{rna_id}",
63-
}
74+
return self._rna_data_to_dict(rna_id, rna_data)
6475
if resp.status == 404:
6576
logger.error("RNA ID %s not found", rna_id)
6677
return None
6778
raise Exception(f"HTTP {resp.status}: {await resp.text()}")
79+
except aiohttp.ClientError as e:
80+
logger.error("Network error getting RNA ID %s: %s", rna_id, e)
81+
return None
6882
except Exception as exc: # pylint: disable=broad-except
6983
logger.error("RNA ID %s not found: %s", rna_id, exc)
7084
return None
7185

72-
async def search_by_keyword(self, keyword: str) -> Optional[dict]:
86+
async def get_best_hit(self, keyword: str) -> Optional[dict]:
7387
"""
7488
Search RNAcentral with a keyword and return the best hit.
7589
:param keyword: The search keyword (e.g., miRNA name, RNA name).
@@ -90,13 +104,26 @@ async def search_by_keyword(self, keyword: str) -> Optional[dict]:
90104
) as resp:
91105
if resp.status == 200:
92106
search_results = await resp.json()
93-
if search_results.get("results"):
94-
rna_id = search_results["results"][0].get("rnacentral_id")
107+
results = search_results.get("results", [])
108+
if results:
109+
# Use the first result directly (search API already returns enough info)
110+
first_result = results[0]
111+
rna_id = first_result.get("rnacentral_id")
95112
if rna_id:
96-
return await self.get_by_rna_id(rna_id)
113+
# Try to get detailed info, but fall back to search result if it fails
114+
detailed_info = await self.get_by_rna_id(rna_id)
115+
if detailed_info:
116+
return detailed_info
117+
# Fall back to using search result data
118+
return self._rna_data_to_dict(rna_id, first_result)
97119
logger.info("No results found for keyword: %s", keyword)
98120
return None
99-
raise Exception(f"HTTP {resp.status}: {await resp.text()}")
121+
error_text = await resp.text()
122+
logger.error("HTTP %d error for keyword %s: %s", resp.status, keyword, error_text[:200])
123+
raise Exception(f"HTTP {resp.status}: {error_text}")
124+
except aiohttp.ClientError as e:
125+
logger.error("Network error searching for keyword %s: %s", keyword, e)
126+
return None
100127
except Exception as e: # pylint: disable=broad-except
101128
logger.error("Keyword %s not found: %s", keyword, e)
102129
return None
@@ -136,13 +163,39 @@ async def search_by_sequence(self, sequence: str) -> Optional[dict]:
136163
) as resp:
137164
if resp.status == 200:
138165
search_results = await resp.json()
139-
if search_results.get("results"):
140-
rna_id = search_results["results"][0].get("rnacentral_id")
166+
results = search_results.get("results", [])
167+
if results:
168+
# First, try to find an exact sequence match
169+
exact_match = None
170+
for result in results:
171+
result_seq = result.get("sequence", "")
172+
if result_seq == seq:
173+
exact_match = result
174+
break
175+
176+
# Use exact match if found, otherwise use first result
177+
target_result = exact_match if exact_match else results[0]
178+
rna_id = target_result.get("rnacentral_id")
179+
141180
if rna_id:
142-
return await self.get_by_rna_id(rna_id)
181+
# Try to get detailed info, but fall back to search result if it fails
182+
try:
183+
detailed_info = await self.get_by_rna_id(rna_id)
184+
if detailed_info:
185+
return detailed_info
186+
except Exception as e:
187+
logger.debug("Failed to get detailed info for %s: %s, using search result", rna_id, e)
188+
189+
# Fall back to using search result data
190+
return self._rna_data_to_dict(rna_id, target_result)
143191
logger.info("No results found for sequence.")
144192
return None
145-
raise Exception(f"HTTP {resp.status}: {await resp.text()}")
193+
error_text = await resp.text()
194+
logger.error("HTTP %d error for sequence search: %s", resp.status, error_text[:200])
195+
raise Exception(f"HTTP {resp.status}: {error_text}")
196+
except aiohttp.ClientError as e:
197+
logger.error("Network error searching for sequence: %s", e)
198+
return None
146199
except Exception as e: # pylint: disable=broad-except
147200
logger.error("Sequence search failed: %s", e)
148201
return None
@@ -154,11 +207,13 @@ async def search_by_sequence(self, sequence: str) -> Optional[dict]:
154207
reraise=True,
155208
)
156209
async def search(
157-
self, query: str, **kwargs
210+
self, query: str, threshold: float = 0.7, **kwargs
158211
) -> Optional[Dict]:
159212
"""
160213
Search RNAcentral with either an RNAcentral ID, keyword, or RNA sequence.
161214
:param query: The search query (RNAcentral ID, keyword, or RNA sequence).
215+
:param threshold: E-value threshold for sequence search.
216+
Note: RNAcentral API uses its own similarity matching, this parameter is for interface consistency.
162217
:param kwargs: Additional keyword arguments (not used currently).
163218
:return: A dictionary containing the search results or None if not found.
164219
"""
@@ -180,7 +235,7 @@ async def search(
180235
result = await self.get_by_rna_id(query)
181236
else:
182237
# otherwise treat as keyword
183-
result = await self.search_by_keyword(query)
238+
result = await self.get_best_hit(query)
184239

185240
if result:
186241
result["_search_query"] = query
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{"type": "text", "content": "hsa-let-7a-1"}
2+
{"type": "text", "content": "URS0000123456"}
3+
{"type": "text", "content": "URS0000000001"}
4+
{"type": "text", "content": "CUCCUUUGACGUUAGCGGCGGACGGGUUAGUAACACGUGGGUAACCUACCUAUAAGACUGGGAUAACUUCGGGAAACCGGAGCUAAUACCGGAUAAUAUUUCGAACCGCAUGGUUCGAUAGUGAAAGAUGGUUUUGCUAUCACUUAUAGAUGGACCCGCGCCGUAUUAGCUAGUUGGUAAGGUAACGGCUUACCAAGGCGACGAUACGUAGCCGACCUGAGAGGGUGAUCGGCCACACUGGAACUGAGACACGGUCCAGACUCCUACGGGAGGCAGCAGGGG"}

0 commit comments

Comments
 (0)