Skip to content

Commit b1e7ef8

Browse files
Merge pull request #88 from open-sciencelab/refactor/refactor-search
feat: add search_any in uniprot_search
2 parents 08bf2f0 + ed2f0a0 commit b1e7ef8

File tree

22 files changed

+163
-348
lines changed

22 files changed

+163
-348
lines changed

graphgen/bases/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .base_llm_wrapper import BaseLLMWrapper
55
from .base_partitioner import BasePartitioner
66
from .base_reader import BaseReader
7+
from .base_searcher import BaseSearcher
78
from .base_splitter import BaseSplitter
89
from .base_storage import (
910
BaseGraphStorage,

graphgen/bases/base_searcher.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Dict, List
3+
4+
5+
class BaseSearcher(ABC):
6+
"""
7+
Abstract base class for searching and retrieving data.
8+
"""
9+
10+
@abstractmethod
11+
async def search(self, query: str, **kwargs) -> List[Dict[str, Any]]:
12+
"""
13+
Search for data based on the given query.
14+
15+
:param query: The searcher query.
16+
:param kwargs: Additional keyword arguments for the searcher.
17+
:return: List of dictionaries containing the searcher results.
18+
"""
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
pipeline:
2+
- name: read
3+
params:
4+
input_file: resources/input_examples/search_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5+
6+
- name: search
7+
params:
8+
data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot

graphgen/graphgen.py

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def __init__(
5858
self.meta_storage: MetaJsonKVStorage = MetaJsonKVStorage(
5959
self.working_dir, namespace="_meta"
6060
)
61-
6261
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
6362
self.working_dir, namespace="full_docs"
6463
)
@@ -69,9 +68,8 @@ def __init__(
6968
self.working_dir, namespace="graph"
7069
)
7170
self.search_storage: JsonKVStorage = JsonKVStorage(
72-
self.working_dir, namespace="search"
71+
self.working_dir, namespace="searcher"
7372
)
74-
7573
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
7674
self.working_dir, namespace="rephrase"
7775
)
@@ -181,41 +179,33 @@ async def build_kg(self):
181179

182180
return _add_entities_and_relations
183181

184-
@op("search", deps=["chunk"])
182+
@op("search", deps=["read"])
185183
@async_to_sync_method
186184
async def search(self, search_config: Dict):
187-
logger.info(
188-
"Search is %s", "enabled" if search_config["enabled"] else "disabled"
185+
logger.info("[Search] %s ...", ", ".join(search_config["data_sources"]))
186+
187+
seeds = await self.meta_storage.get_new_data(self.full_docs_storage)
188+
if len(seeds) == 0:
189+
logger.warning("All documents are already been searched")
190+
return
191+
search_results = await search_all(
192+
seed_data=seeds,
193+
**search_config,
194+
)
195+
196+
_add_search_keys = await self.search_storage.filter_keys(
197+
list(search_results.keys())
189198
)
190-
if search_config["enabled"]:
191-
logger.info("[Search] %s ...", ", ".join(search_config["search_types"]))
192-
all_nodes = await self.graph_storage.get_all_nodes()
193-
all_nodes_names = [node[0] for node in all_nodes]
194-
new_search_entities = await self.full_docs_storage.filter_keys(
195-
all_nodes_names
196-
)
197-
logger.info(
198-
"[Search] Found %d entities to search", len(new_search_entities)
199-
)
200-
_add_search_data = await search_all(
201-
search_types=search_config["search_types"],
202-
search_entities=new_search_entities,
203-
)
204-
if _add_search_data:
205-
await self.search_storage.upsert(_add_search_data)
206-
logger.info("[Search] %d entities searched", len(_add_search_data))
207-
208-
# Format search results for inserting
209-
search_results = []
210-
for _, search_data in _add_search_data.items():
211-
search_results.extend(
212-
[
213-
{"content": search_data[key]}
214-
for key in list(search_data.keys())
215-
]
216-
)
217-
# TODO: fix insert after search
218-
# await self.insert()
199+
search_results = {
200+
k: v for k, v in search_results.items() if k in _add_search_keys
201+
}
202+
if len(search_results) == 0:
203+
logger.warning("All search results are already in the storage")
204+
return
205+
await self.search_storage.upsert(search_results)
206+
await self.search_storage.index_done_callback()
207+
await self.meta_storage.mark_done(self.full_docs_storage)
208+
await self.meta_storage.index_done_callback()
219209

220210
@op("quiz_and_judge", deps=["build_kg"])
221211
@async_to_sync_method

graphgen/models/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
RDFReader,
2626
TXTReader,
2727
)
28-
from .search.db.uniprot_search import UniProtSearch
29-
from .search.kg.wiki_search import WikiSearch
30-
from .search.web.bing_search import BingSearch
31-
from .search.web.google_search import GoogleSearch
28+
from .searcher.db.uniprot_searcher import UniProtSearch
29+
from .searcher.kg.wiki_search import WikiSearch
30+
from .searcher.web.bing_search import BingSearch
31+
from .searcher.web.google_search import GoogleSearch
3232
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
3333
from .storage import JsonKVStorage, JsonListStorage, MetaJsonKVStorage, NetworkXStorage
3434
from .tokenizer import Tokenizer

graphgen/models/search/db/uniprot_search.py renamed to graphgen/models/searcher/db/uniprot_searcher.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
1+
import re
12
from io import StringIO
23
from typing import Dict, Optional
34

45
from Bio import ExPASy, SeqIO, SwissProt, UniProt
56
from Bio.Blast import NCBIWWW, NCBIXML
6-
7+
from requests.exceptions import RequestException
8+
from tenacity import (
9+
retry,
10+
retry_if_exception_type,
11+
stop_after_attempt,
12+
wait_exponential,
13+
)
14+
15+
from graphgen.bases import BaseSearcher
716
from graphgen.utils import logger
817

918

10-
class UniProtSearch:
19+
class UniProtSearch(BaseSearcher):
1120
"""
12-
UniProt Search client to search with UniProt.
21+
UniProt Search client to searcher with UniProt.
1322
1) Get the protein by accession number.
14-
2) Search with keywords or protein names (fuzzy search).
15-
3) Search with FASTA sequence (BLAST search).
23+
2) Search with keywords or protein names (fuzzy searcher).
24+
3) Search with FASTA sequence (BLAST searcher).
1625
"""
1726

1827
def get_by_accession(self, accession: str) -> Optional[dict]:
@@ -21,6 +30,8 @@ def get_by_accession(self, accession: str) -> Optional[dict]:
2130
record = SwissProt.read(handle)
2231
handle.close()
2332
return self._swissprot_to_dict(record)
33+
except RequestException: # network-related errors
34+
raise
2435
except Exception as exc: # pylint: disable=broad-except
2536
logger.error("Accession %s not found: %s", accession, exc)
2637
return None
@@ -51,7 +62,7 @@ def _swissprot_to_dict(record: SwissProt.Record) -> dict:
5162
def get_best_hit(self, keyword: str) -> Optional[Dict]:
5263
"""
5364
Search UniProt with a keyword and return the best hit.
54-
:param keyword: The search keyword.
65+
:param keyword: The searcher keyword.
5566
:return: A dictionary containing the best hit information or None if not found.
5667
"""
5768
if not keyword.strip():
@@ -64,15 +75,17 @@ def get_best_hit(self, keyword: str) -> Optional[Dict]:
6475
return None
6576
return self.get_by_accession(hit["primaryAccession"])
6677

78+
except RequestException:
79+
raise
6780
except Exception as e: # pylint: disable=broad-except
6881
logger.error("Keyword %s not found: %s", keyword, e)
69-
return None
82+
return None
7083

7184
def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:
7285
"""
7386
Search UniProt with a FASTA sequence and return the best hit.
7487
:param fasta_sequence: The FASTA sequence.
75-
:param threshold: E-value threshold for BLAST search.
88+
:param threshold: E-value threshold for BLAST searcher.
7689
:return: A dictionary containing the best hit information or None if not found.
7790
"""
7891
try:
@@ -90,6 +103,7 @@ def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:
90103

91104
# UniProtKB/Swiss-Prot BLAST API
92105
try:
106+
logger.debug("Performing BLAST searcher for the given sequence: %s", seq)
93107
result_handle = NCBIWWW.qblast(
94108
program="blastp",
95109
database="swissprot",
@@ -98,8 +112,10 @@ def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:
98112
expect=threshold,
99113
)
100114
blast_record = NCBIXML.read(result_handle)
115+
except RequestException:
116+
raise
101117
except Exception as e: # pylint: disable=broad-except
102-
logger.error("BLAST search failed: %s", e)
118+
logger.error("BLAST searcher failed: %s", e)
103119
return None
104120

105121
if not blast_record.alignments:
@@ -116,3 +132,44 @@ def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:
116132
# like sp|P01308.1|INS_HUMAN
117133
accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id
118134
return self.get_by_accession(accession)
135+
136+
@retry(
137+
stop=stop_after_attempt(5),
138+
wait=wait_exponential(multiplier=1, min=4, max=10),
139+
retry=retry_if_exception_type(RequestException),
140+
reraise=True,
141+
)
142+
async def search(
143+
self, query: str, threshold: float = 0.7, **kwargs
144+
) -> Optional[Dict]:
145+
"""
146+
Search UniProt with either an accession number, keyword, or FASTA sequence.
147+
:param query: The searcher query (accession number, keyword, or FASTA sequence).
148+
:param threshold: E-value threshold for BLAST searcher.
149+
:return: A dictionary containing the best hit information or None if not found.
150+
"""
151+
152+
# auto detect query type
153+
if not query or not isinstance(query, str):
154+
logger.error("Empty or non-string input.")
155+
return None
156+
query = query.strip()
157+
158+
logger.debug("UniProt searcher query: %s", query)
159+
# check if fasta sequence
160+
if query.startswith(">") or re.fullmatch(
161+
r"[ACDEFGHIKLMNPQRSTVWY\s]+", query, re.I
162+
):
163+
result = self.get_by_fasta(query, threshold)
164+
165+
# check if accession number
166+
elif re.fullmatch(r"[A-NR-Z0-9]{6,10}", query, re.I):
167+
result = self.get_by_accession(query)
168+
169+
else:
170+
# otherwise treat as keyword
171+
result = self.get_best_hit(query)
172+
173+
if result:
174+
result["_search_query"] = query
175+
return result
File renamed without changes.

0 commit comments

Comments
 (0)