Skip to content

Commit ed2f0a0

Browse files
refactor: refactor search_all to support uniprot_search
1 parent 6a93d3c commit ed2f0a0

File tree

11 files changed

+143
-125
lines changed

11 files changed

+143
-125
lines changed

graphgen/bases/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .base_llm_wrapper import BaseLLMWrapper
55
from .base_partitioner import BasePartitioner
66
from .base_reader import BaseReader
7+
from .base_searcher import BaseSearcher
78
from .base_splitter import BaseSplitter
89
from .base_storage import (
910
BaseGraphStorage,

graphgen/bases/base_searcher.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Dict, List
3+
4+
5+
class BaseSearcher(ABC):
6+
"""
7+
Abstract base class for searching and retrieving data.
8+
"""
9+
10+
@abstractmethod
11+
async def search(self, query: str, **kwargs) -> List[Dict[str, Any]]:
12+
"""
13+
Search for data based on the given query.
14+
15+
:param query: The searcher query.
16+
:param kwargs: Additional keyword arguments for the searcher.
17+
:return: List of dictionaries containing the searcher results.
18+
"""
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
pipeline:
2+
- name: read
3+
params:
4+
input_file: resources/input_examples/search_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5+
6+
- name: search
7+
params:
8+
data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot

graphgen/graphgen.py

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def __init__(
5858
self.meta_storage: MetaJsonKVStorage = MetaJsonKVStorage(
5959
self.working_dir, namespace="_meta"
6060
)
61-
6261
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
6362
self.working_dir, namespace="full_docs"
6463
)
@@ -69,9 +68,8 @@ def __init__(
6968
self.working_dir, namespace="graph"
7069
)
7170
self.search_storage: JsonKVStorage = JsonKVStorage(
72-
self.working_dir, namespace="search"
71+
self.working_dir, namespace="searcher"
7372
)
74-
7573
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
7674
self.working_dir, namespace="rephrase"
7775
)
@@ -181,41 +179,33 @@ async def build_kg(self):
181179

182180
return _add_entities_and_relations
183181

184-
@op("search", deps=["chunk"])
182+
@op("search", deps=["read"])
185183
@async_to_sync_method
186184
async def search(self, search_config: Dict):
187-
logger.info(
188-
"Search is %s", "enabled" if search_config["enabled"] else "disabled"
185+
logger.info("[Search] %s ...", ", ".join(search_config["data_sources"]))
186+
187+
seeds = await self.meta_storage.get_new_data(self.full_docs_storage)
188+
if len(seeds) == 0:
189+
logger.warning("All documents are already been searched")
190+
return
191+
search_results = await search_all(
192+
seed_data=seeds,
193+
**search_config,
194+
)
195+
196+
_add_search_keys = await self.search_storage.filter_keys(
197+
list(search_results.keys())
189198
)
190-
if search_config["enabled"]:
191-
logger.info("[Search] %s ...", ", ".join(search_config["search_types"]))
192-
all_nodes = await self.graph_storage.get_all_nodes()
193-
all_nodes_names = [node[0] for node in all_nodes]
194-
new_search_entities = await self.full_docs_storage.filter_keys(
195-
all_nodes_names
196-
)
197-
logger.info(
198-
"[Search] Found %d entities to search", len(new_search_entities)
199-
)
200-
_add_search_data = await search_all(
201-
search_types=search_config["search_types"],
202-
search_entities=new_search_entities,
203-
)
204-
if _add_search_data:
205-
await self.search_storage.upsert(_add_search_data)
206-
logger.info("[Search] %d entities searched", len(_add_search_data))
207-
208-
# Format search results for inserting
209-
search_results = []
210-
for _, search_data in _add_search_data.items():
211-
search_results.extend(
212-
[
213-
{"content": search_data[key]}
214-
for key in list(search_data.keys())
215-
]
216-
)
217-
# TODO: fix insert after search
218-
# await self.insert()
199+
search_results = {
200+
k: v for k, v in search_results.items() if k in _add_search_keys
201+
}
202+
if len(search_results) == 0:
203+
logger.warning("All search results are already in the storage")
204+
return
205+
await self.search_storage.upsert(search_results)
206+
await self.search_storage.index_done_callback()
207+
await self.meta_storage.mark_done(self.full_docs_storage)
208+
await self.meta_storage.index_done_callback()
219209

220210
@op("quiz_and_judge", deps=["build_kg"])
221211
@async_to_sync_method

graphgen/models/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
RDFReader,
2626
TXTReader,
2727
)
28-
from .search.db.uniprot_search import UniProtSearch
29-
from .search.kg.wiki_search import WikiSearch
30-
from .search.web.bing_search import BingSearch
31-
from .search.web.google_search import GoogleSearch
28+
from .searcher.db.uniprot_searcher import UniProtSearch
29+
from .searcher.kg.wiki_search import WikiSearch
30+
from .searcher.web.bing_search import BingSearch
31+
from .searcher.web.google_search import GoogleSearch
3232
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
3333
from .storage import JsonKVStorage, JsonListStorage, MetaJsonKVStorage, NetworkXStorage
3434
from .tokenizer import Tokenizer

graphgen/models/searcher/db/uniprot_searcher.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,24 @@
44

55
from Bio import ExPASy, SeqIO, SwissProt, UniProt
66
from Bio.Blast import NCBIWWW, NCBIXML
7-
7+
from requests.exceptions import RequestException
8+
from tenacity import (
9+
retry,
10+
retry_if_exception_type,
11+
stop_after_attempt,
12+
wait_exponential,
13+
)
14+
15+
from graphgen.bases import BaseSearcher
816
from graphgen.utils import logger
917

1018

11-
class UniProtSearch:
19+
class UniProtSearch(BaseSearcher):
1220
"""
13-
UniProt Search client to search with UniProt.
21+
UniProt Search client to searcher with UniProt.
1422
1) Get the protein by accession number.
15-
2) Search with keywords or protein names (fuzzy search).
16-
3) Search with FASTA sequence (BLAST search).
23+
2) Search with keywords or protein names (fuzzy searcher).
24+
3) Search with FASTA sequence (BLAST searcher).
1725
"""
1826

1927
def get_by_accession(self, accession: str) -> Optional[dict]:
@@ -22,6 +30,8 @@ def get_by_accession(self, accession: str) -> Optional[dict]:
2230
record = SwissProt.read(handle)
2331
handle.close()
2432
return self._swissprot_to_dict(record)
33+
except RequestException: # network-related errors
34+
raise
2535
except Exception as exc: # pylint: disable=broad-except
2636
logger.error("Accession %s not found: %s", accession, exc)
2737
return None
@@ -52,7 +62,7 @@ def _swissprot_to_dict(record: SwissProt.Record) -> dict:
5262
def get_best_hit(self, keyword: str) -> Optional[Dict]:
5363
"""
5464
Search UniProt with a keyword and return the best hit.
55-
:param keyword: The search keyword.
65+
:param keyword: The searcher keyword.
5666
:return: A dictionary containing the best hit information or None if not found.
5767
"""
5868
if not keyword.strip():
@@ -65,15 +75,17 @@ def get_best_hit(self, keyword: str) -> Optional[Dict]:
6575
return None
6676
return self.get_by_accession(hit["primaryAccession"])
6777

78+
except RequestException:
79+
raise
6880
except Exception as e: # pylint: disable=broad-except
6981
logger.error("Keyword %s not found: %s", keyword, e)
70-
return None
82+
return None
7183

7284
def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:
7385
"""
7486
Search UniProt with a FASTA sequence and return the best hit.
7587
:param fasta_sequence: The FASTA sequence.
76-
:param threshold: E-value threshold for BLAST search.
88+
:param threshold: E-value threshold for BLAST searcher.
7789
:return: A dictionary containing the best hit information or None if not found.
7890
"""
7991
try:
@@ -91,6 +103,7 @@ def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:
91103

92104
# UniProtKB/Swiss-Prot BLAST API
93105
try:
106+
logger.debug("Performing BLAST searcher for the given sequence: %s", seq)
94107
result_handle = NCBIWWW.qblast(
95108
program="blastp",
96109
database="swissprot",
@@ -99,8 +112,10 @@ def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:
99112
expect=threshold,
100113
)
101114
blast_record = NCBIXML.read(result_handle)
115+
except RequestException:
116+
raise
102117
except Exception as e: # pylint: disable=broad-except
103-
logger.error("BLAST search failed: %s", e)
118+
logger.error("BLAST searcher failed: %s", e)
104119
return None
105120

106121
if not blast_record.alignments:
@@ -118,11 +133,19 @@ def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:
118133
accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id
119134
return self.get_by_accession(accession)
120135

121-
def get_any(self, query: str, threshold: float = 1e-5) -> Optional[Dict]:
136+
@retry(
137+
stop=stop_after_attempt(5),
138+
wait=wait_exponential(multiplier=1, min=4, max=10),
139+
retry=retry_if_exception_type(RequestException),
140+
reraise=True,
141+
)
142+
async def search(
143+
self, query: str, threshold: float = 0.7, **kwargs
144+
) -> Optional[Dict]:
122145
"""
123146
Search UniProt with either an accession number, keyword, or FASTA sequence.
124-
:param query: The search query (accession number, keyword, or FASTA sequence).
125-
:param threshold: E-value threshold for BLAST search.
147+
:param query: The searcher query (accession number, keyword, or FASTA sequence).
148+
:param threshold: E-value threshold for BLAST searcher.
126149
:return: A dictionary containing the best hit information or None if not found.
127150
"""
128151

@@ -132,15 +155,21 @@ def get_any(self, query: str, threshold: float = 1e-5) -> Optional[Dict]:
132155
return None
133156
query = query.strip()
134157

158+
logger.debug("UniProt searcher query: %s", query)
135159
# check if fasta sequence
136160
if query.startswith(">") or re.fullmatch(
137161
r"[ACDEFGHIKLMNPQRSTVWY\s]+", query, re.I
138162
):
139-
return self.get_by_fasta(query, threshold)
163+
result = self.get_by_fasta(query, threshold)
140164

141165
# check if accession number
142-
if re.fullmatch(r"[A-NR-Z0-9]{6,10}", query, re.I):
143-
return self.get_by_accession(query)
166+
elif re.fullmatch(r"[A-NR-Z0-9]{6,10}", query, re.I):
167+
result = self.get_by_accession(query)
168+
169+
else:
170+
# otherwise treat as keyword
171+
result = self.get_best_hit(query)
144172

145-
# otherwise treat as keyword
146-
return self.get_best_hit(query)
173+
if result:
174+
result["_search_query"] = query
175+
return result

graphgen/models/searcher/web/bing_search.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
class BingSearch:
1111
"""
12-
Bing Search client to search with Bing.
12+
Bing Search client to searcher with Bing.
1313
"""
1414

1515
def __init__(self, subscription_key: str):
@@ -18,9 +18,9 @@ def __init__(self, subscription_key: str):
1818
def search(self, query: str, num_results: int = 1):
1919
"""
2020
Search with Bing and return the contexts.
21-
:param query: The search query.
21+
:param query: The searcher query.
2222
:param num_results: The number of results to return.
23-
:return: A list of search results.
23+
:return: A list of searcher results.
2424
"""
2525
params = {"q": query, "mkt": BING_MKT, "count": num_results}
2626
response = requests.get(

graphgen/models/searcher/web/google_search.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,19 @@
99
class GoogleSearch:
1010
def __init__(self, subscription_key: str, cx: str):
1111
"""
12-
Initialize the Google Search client with the subscription key and custom search engine ID.
12+
Initialize the Google Search client with the subscription key and custom searcher engine ID.
1313
:param subscription_key: Your Google API subscription key.
14-
:param cx: Your custom search engine ID.
14+
:param cx: Your custom searcher engine ID.
1515
"""
1616
self.subscription_key = subscription_key
1717
self.cx = cx
1818

1919
def search(self, query: str, num_results: int = 1):
2020
"""
2121
Search with Google and return the contexts.
22-
:param query: The search query.
22+
:param query: The searcher query.
2323
:param num_results: The number of results to return.
24-
:return: A list of search results.
24+
:return: A list of searcher results.
2525
"""
2626
params = {
2727
"key": self.subscription_key,

0 commit comments

Comments
 (0)