Skip to content

Commit af49ba2

Browse files
committed
add: enable mid-auto save in searcher
1 parent 8143fff commit af49ba2

File tree

3 files changed

+120
-5
lines changed

3 files changed

+120
-5
lines changed

graphgen/graphgen.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import hashlib
12
import os
23
import time
34
from typing import Dict
@@ -173,20 +174,52 @@ async def search(self, search_config: Dict):
173174
if len(seeds) == 0:
174175
logger.warning("All documents are already been searched")
175176
return
177+
178+
# Get save_interval from config (default: 1000, 0 to disable)
179+
save_interval = search_config.get("save_interval", 1000)
180+
176181
search_results = await search_all(
177182
seed_data=seeds,
178183
search_config=search_config,
184+
search_storage=self.search_storage if save_interval > 0 else None,
185+
save_interval=save_interval,
179186
)
180187

181-
_add_search_keys = self.search_storage.filter_keys(list(search_results.keys()))
188+
# Convert search_results from {data_source: [results]} to {key: result}
189+
# This maintains backward compatibility
190+
flattened_results = {}
191+
for data_source, result_list in search_results.items():
192+
if not isinstance(result_list, list):
193+
continue
194+
for result in result_list:
195+
if result is None:
196+
continue
197+
# Use _search_query as key if available, otherwise generate a key
198+
if isinstance(result, dict) and "_search_query" in result:
199+
query = result["_search_query"]
200+
key = f"{data_source}:{query}"
201+
else:
202+
# Generate a unique key
203+
result_str = str(result)
204+
key_hash = hashlib.md5(result_str.encode()).hexdigest()[:8]
205+
key = f"{data_source}:{key_hash}"
206+
flattened_results[key] = result
207+
208+
_add_search_keys = self.search_storage.filter_keys(list(flattened_results.keys()))
182209
search_results = {
183-
k: v for k, v in search_results.items() if k in _add_search_keys
210+
k: v for k, v in flattened_results.items() if k in _add_search_keys
184211
}
185212
if len(search_results) == 0:
186213
logger.warning("All search results are already in the storage")
187214
return
188-
self.search_storage.upsert(search_results)
189-
self.search_storage.index_done_callback()
215+
216+
# Only save if not using periodic saving (to avoid duplicate saves)
217+
if save_interval == 0:
218+
self.search_storage.upsert(search_results)
219+
self.search_storage.index_done_callback()
220+
else:
221+
# Results were already saved periodically, just update index
222+
self.search_storage.index_done_callback()
190223

191224
@async_to_sync_method
192225
async def quiz_and_judge(self, quiz_and_judge_config: Dict):

graphgen/operators/search/search_all.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@
1515
async def search_all(
1616
seed_data: dict,
1717
search_config: dict,
18+
search_storage=None,
19+
save_interval: int = 1000,
1820
) -> dict:
1921
"""
2022
Perform searches across multiple search types and aggregate the results.
2123
:param seed_data: A dictionary containing seed data with entity names.
2224
:param search_config: A dictionary specifying which data sources to use for searching.
23-
:return: A dictionary with
25+
:param search_storage: Optional storage instance for periodic saving of results.
26+
:param save_interval: Number of search results to accumulate before saving (default: 1000, 0 to disable).
27+
:return: A dictionary with search results
2428
"""
2529

2630
results = {}
@@ -31,6 +35,41 @@ async def search_all(
3135
data = [d["content"] for d in data if "content" in d]
3236
data = list(set(data)) # Remove duplicates
3337

38+
# Prepare save callback for this data source
39+
def make_save_callback(source_name):
40+
def save_callback(intermediate_results, completed_count):
41+
"""Save intermediate search results."""
42+
if search_storage is None:
43+
return
44+
45+
# Convert results list to dict format
46+
# Results are tuples of (query, result_dict) or just result_dict
47+
batch_results = {}
48+
for result in intermediate_results:
49+
if result is None:
50+
continue
51+
# Check if result is a dict with _search_query key
52+
if isinstance(result, dict) and "_search_query" in result:
53+
query = result["_search_query"]
54+
# Create a key for the result (using query as key)
55+
key = f"{source_name}:{query}"
56+
batch_results[key] = result
57+
elif isinstance(result, dict):
58+
# If no _search_query, use a generated key
59+
key = f"{source_name}:{completed_count}"
60+
batch_results[key] = result
61+
62+
if batch_results:
63+
# Filter out already existing keys
64+
new_keys = search_storage.filter_keys(list(batch_results.keys()))
65+
new_results = {k: v for k, v in batch_results.items() if k in new_keys}
66+
if new_results:
67+
search_storage.upsert(new_results)
68+
search_storage.index_done_callback()
69+
logger.debug("Saved %d intermediate results for %s", len(new_results), source_name)
70+
71+
return save_callback
72+
3473
if data_source == "uniprot":
3574
from graphgen.models import UniProtSearch
3675

@@ -43,6 +82,8 @@ async def search_all(
4382
data,
4483
desc="Searching UniProt database",
4584
unit="keyword",
85+
save_interval=save_interval if save_interval > 0 else 0,
86+
save_callback=make_save_callback("uniprot") if search_storage and save_interval > 0 else None,
4687
)
4788
results[data_source] = uniprot_results
4889

@@ -58,6 +99,8 @@ async def search_all(
5899
data,
59100
desc="Searching NCBI database",
60101
unit="keyword",
102+
save_interval=save_interval if save_interval > 0 else 0,
103+
save_callback=make_save_callback("ncbi") if search_storage and save_interval > 0 else None,
61104
)
62105
results[data_source] = ncbi_results
63106

@@ -73,6 +116,8 @@ async def search_all(
73116
data,
74117
desc="Searching RNAcentral database",
75118
unit="keyword",
119+
save_interval=save_interval if save_interval > 0 else 0,
120+
save_callback=make_save_callback("rnacentral") if search_storage and save_interval > 0 else None,
76121
)
77122
results[data_source] = rnacentral_results
78123

graphgen/utils/run_concurrent.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,26 @@ async def run_concurrent(
1717
desc: str = "processing",
1818
unit: str = "item",
1919
progress_bar: Optional[gr.Progress] = None,
20+
save_interval: int = 0,
21+
save_callback: Optional[Callable[[List[R], int], None]] = None,
2022
) -> List[R]:
23+
"""
24+
Run coroutines concurrently with optional periodic saving.
25+
26+
:param coro_fn: Coroutine function to run for each item
27+
:param items: List of items to process
28+
:param desc: Description for progress bar
29+
:param unit: Unit name for progress bar
30+
:param progress_bar: Optional Gradio progress bar
31+
:param save_interval: Number of completed tasks before calling save_callback (0 to disable)
32+
:param save_callback: Callback function to save intermediate results (results, completed_count)
33+
:return: List of results
34+
"""
2135
tasks = [asyncio.create_task(coro_fn(it)) for it in items]
2236

2337
completed_count = 0
2438
results = []
39+
pending_save_results = []
2540

2641
pbar = tqdm_async(total=len(items), desc=desc, unit=unit)
2742

@@ -32,6 +47,8 @@ async def run_concurrent(
3247
try:
3348
result = await future
3449
results.append(result)
50+
if save_interval > 0 and save_callback is not None:
51+
pending_save_results.append(result)
3552
except Exception as e: # pylint: disable=broad-except
3653
logger.exception("Task failed: %s", e)
3754
# even if failed, record it to keep results consistent with tasks
@@ -44,11 +61,31 @@ async def run_concurrent(
4461
progress = completed_count / len(items)
4562
progress_bar(progress, desc=f"{desc} ({completed_count}/{len(items)})")
4663

64+
# Periodic save
65+
if save_interval > 0 and save_callback is not None and completed_count % save_interval == 0:
66+
try:
67+
# Filter out exceptions before saving
68+
valid_results = [res for res in pending_save_results if not isinstance(res, Exception)]
69+
save_callback(valid_results, completed_count)
70+
pending_save_results = [] # Clear after saving
71+
logger.info("Saved intermediate results: %d/%d completed", completed_count, len(items))
72+
except Exception as e:
73+
logger.warning("Failed to save intermediate results: %s", e)
74+
4775
pbar.close()
4876

4977
if progress_bar is not None:
5078
progress_bar(1.0, desc=f"{desc} (completed)")
5179

80+
# Save remaining results if any
81+
if save_interval > 0 and save_callback is not None and pending_save_results:
82+
try:
83+
valid_results = [res for res in pending_save_results if not isinstance(res, Exception)]
84+
save_callback(valid_results, completed_count)
85+
logger.info("Saved final intermediate results: %d completed", completed_count)
86+
except Exception as e:
87+
logger.warning("Failed to save final intermediate results: %s", e)
88+
5289
# filter out exceptions
5390
results = [res for res in results if not isinstance(res, Exception)]
5491

0 commit comments

Comments
 (0)