Skip to content

Commit ccaa726

Browse files
fix: fix async search
1 parent b5b5450 commit ccaa726

File tree

4 files changed

+29
-9
lines changed

4 files changed

+29
-9
lines changed
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
pipeline:
22
- name: read
33
params:
4-
input_file: resources/input_examples/search_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
4+
input_file: resources/input_examples/search_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
55

66
- name: search
77
params:
88
data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot
9+
uniprot_params:
10+
use_local_blast: true # whether to use local blast for uniprot search
11+
local_blast_db: /your_path/uniprot_sprot

graphgen/graphgen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
self.working_dir, namespace="graph"
6969
)
7070
self.search_storage: JsonKVStorage = JsonKVStorage(
71-
self.working_dir, namespace="searcher"
71+
self.working_dir, namespace="search"
7272
)
7373
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
7474
self.working_dir, namespace="rephrase"
@@ -190,7 +190,7 @@ async def search(self, search_config: Dict):
190190
return
191191
search_results = await search_all(
192192
seed_data=seeds,
193-
**search_config,
193+
search_config=search_config,
194194
)
195195

196196
_add_search_keys = await self.search_storage.filter_keys(

graphgen/models/searcher/db/uniprot_searcher.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import asyncio
12
import os
23
import re
34
import subprocess
45
import tempfile
6+
from concurrent.futures import ThreadPoolExecutor
7+
from functools import lru_cache
58
from io import StringIO
69
from typing import Dict, Optional
710

@@ -19,6 +22,11 @@
1922
from graphgen.utils import logger
2023

2124

25+
@lru_cache(maxsize=None)
26+
def _get_pool():
27+
return ThreadPoolExecutor(max_workers=10)
28+
29+
2230
class UniProtSearch(BaseSearcher):
2331
"""
2432
UniProt Search client to searcher with UniProt.
@@ -215,20 +223,26 @@ async def search(
215223
query = query.strip()
216224

217225
logger.debug("UniProt searcher query: %s", query)
226+
227+
loop = asyncio.get_running_loop()
228+
218229
# check if fasta sequence
219230
if query.startswith(">") or re.fullmatch(
220231
r"[ACDEFGHIKLMNPQRSTVWY\s]+", query, re.I
221232
):
222-
result = self.get_by_fasta(query, threshold)
233+
coro = loop.run_in_executor(
234+
_get_pool(), self.get_by_fasta, query, threshold
235+
)
223236

224237
# check if accession number
225238
elif re.fullmatch(r"[A-NR-Z0-9]{6,10}", query, re.I):
226-
result = self.get_by_accession(query)
239+
coro = loop.run_in_executor(_get_pool(), self.get_by_accession, query)
227240

228241
else:
229242
# otherwise treat as keyword
230-
result = self.get_best_hit(query)
243+
coro = loop.run_in_executor(_get_pool(), self.get_best_hit, query)
231244

245+
result = await coro
232246
if result:
233247
result["_search_query"] = query
234248
return result

graphgen/operators/search/search_all.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,25 @@
1414

1515
async def search_all(
1616
seed_data: dict,
17-
data_sources: list[str],
17+
search_config: dict,
1818
) -> dict:
1919
"""
2020
Perform searches across multiple search types and aggregate the results.
2121
:param seed_data: A dictionary containing seed data with entity names.
22-
:param data_sources: A list of search types to perform (e.g., "wikipedia", "google", "bing", "uniprot").
22+
:param search_config: A dictionary specifying which data sources to use for searching.
2323
:return: A dictionary with
2424
"""
2525

2626
results = {}
27+
data_sources = search_config.get("data_sources", [])
2728

2829
for data_source in data_sources:
2930
if data_source == "uniprot":
3031
from graphgen.models import UniProtSearch
3132

32-
uniprot_search_client = UniProtSearch()
33+
uniprot_search_client = UniProtSearch(
34+
**search_config.get("uniprot_params", {})
35+
)
3336

3437
data = list(seed_data.values())
3538
data = [d["content"] for d in data if "content" in d]

0 commit comments

Comments
 (0)