Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions graphgen/bases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .base_llm_wrapper import BaseLLMWrapper
from .base_partitioner import BasePartitioner
from .base_reader import BaseReader
from .base_searcher import BaseSearcher
from .base_splitter import BaseSplitter
from .base_storage import (
BaseGraphStorage,
Expand Down
18 changes: 18 additions & 0 deletions graphgen/bases/base_searcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List


class BaseSearcher(ABC):
"""
Abstract base class for searching and retrieving data.
"""

@abstractmethod
async def search(self, query: str, **kwargs) -> List[Dict[str, Any]]:
"""
Search for data based on the given query.

:param query: The searcher query.
:param kwargs: Additional keyword arguments for the searcher.
:return: List of dictionaries containing the searcher results.
"""
8 changes: 8 additions & 0 deletions graphgen/configs/search_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
pipeline:
- name: read
params:
input_file: resources/input_examples/search_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples

- name: search
params:
data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot
60 changes: 25 additions & 35 deletions graphgen/graphgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(
self.meta_storage: MetaJsonKVStorage = MetaJsonKVStorage(
self.working_dir, namespace="_meta"
)

self.full_docs_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="full_docs"
)
Expand All @@ -69,9 +68,8 @@ def __init__(
self.working_dir, namespace="graph"
)
self.search_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="search"
self.working_dir, namespace="searcher"
)

self.rephrase_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="rephrase"
)
Expand Down Expand Up @@ -181,41 +179,33 @@ async def build_kg(self):

return _add_entities_and_relations

@op("search", deps=["chunk"])
@op("search", deps=["read"])
@async_to_sync_method
async def search(self, search_config: Dict):
logger.info(
"Search is %s", "enabled" if search_config["enabled"] else "disabled"
logger.info("[Search] %s ...", ", ".join(search_config["data_sources"]))

seeds = await self.meta_storage.get_new_data(self.full_docs_storage)
if len(seeds) == 0:
logger.warning("All documents are already been searched")
return
search_results = await search_all(
seed_data=seeds,
**search_config,
)

_add_search_keys = await self.search_storage.filter_keys(
list(search_results.keys())
)
if search_config["enabled"]:
logger.info("[Search] %s ...", ", ".join(search_config["search_types"]))
all_nodes = await self.graph_storage.get_all_nodes()
all_nodes_names = [node[0] for node in all_nodes]
new_search_entities = await self.full_docs_storage.filter_keys(
all_nodes_names
)
logger.info(
"[Search] Found %d entities to search", len(new_search_entities)
)
_add_search_data = await search_all(
search_types=search_config["search_types"],
search_entities=new_search_entities,
)
if _add_search_data:
await self.search_storage.upsert(_add_search_data)
logger.info("[Search] %d entities searched", len(_add_search_data))

# Format search results for inserting
search_results = []
for _, search_data in _add_search_data.items():
search_results.extend(
[
{"content": search_data[key]}
for key in list(search_data.keys())
]
)
# TODO: fix insert after search
# await self.insert()
search_results = {
k: v for k, v in search_results.items() if k in _add_search_keys
}
if len(search_results) == 0:
logger.warning("All search results are already in the storage")
return
await self.search_storage.upsert(search_results)
await self.search_storage.index_done_callback()
await self.meta_storage.mark_done(self.full_docs_storage)
await self.meta_storage.index_done_callback()

@op("quiz_and_judge", deps=["build_kg"])
@async_to_sync_method
Expand Down
8 changes: 4 additions & 4 deletions graphgen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
RDFReader,
TXTReader,
)
from .search.db.uniprot_search import UniProtSearch
from .search.kg.wiki_search import WikiSearch
from .search.web.bing_search import BingSearch
from .search.web.google_search import GoogleSearch
from .searcher.db.uniprot_searcher import UniProtSearch
from .searcher.kg.wiki_search import WikiSearch
from .searcher.web.bing_search import BingSearch
from .searcher.web.google_search import GoogleSearch
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
from .storage import JsonKVStorage, JsonListStorage, MetaJsonKVStorage, NetworkXStorage
from .tokenizer import Tokenizer
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
import re
from io import StringIO
from typing import Dict, Optional

from Bio import ExPASy, SeqIO, SwissProt, UniProt
from Bio.Blast import NCBIWWW, NCBIXML

from requests.exceptions import RequestException
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)

from graphgen.bases import BaseSearcher
from graphgen.utils import logger


class UniProtSearch:
class UniProtSearch(BaseSearcher):
"""
UniProt Search client to search with UniProt.
UniProt Search client to searcher with UniProt.
1) Get the protein by accession number.
2) Search with keywords or protein names (fuzzy search).
3) Search with FASTA sequence (BLAST search).
2) Search with keywords or protein names (fuzzy searcher).
3) Search with FASTA sequence (BLAST searcher).
"""

def get_by_accession(self, accession: str) -> Optional[dict]:
Expand All @@ -21,6 +30,8 @@ def get_by_accession(self, accession: str) -> Optional[dict]:
record = SwissProt.read(handle)
handle.close()
return self._swissprot_to_dict(record)
except RequestException: # network-related errors
raise
except Exception as exc: # pylint: disable=broad-except
logger.error("Accession %s not found: %s", accession, exc)
return None
Expand Down Expand Up @@ -51,7 +62,7 @@ def _swissprot_to_dict(record: SwissProt.Record) -> dict:
def get_best_hit(self, keyword: str) -> Optional[Dict]:
"""
Search UniProt with a keyword and return the best hit.
:param keyword: The search keyword.
:param keyword: The searcher keyword.
:return: A dictionary containing the best hit information or None if not found.
"""
if not keyword.strip():
Expand All @@ -64,15 +75,17 @@ def get_best_hit(self, keyword: str) -> Optional[Dict]:
return None
return self.get_by_accession(hit["primaryAccession"])

except RequestException:
raise
except Exception as e: # pylint: disable=broad-except
logger.error("Keyword %s not found: %s", keyword, e)
return None
return None

def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:
"""
Search UniProt with a FASTA sequence and return the best hit.
:param fasta_sequence: The FASTA sequence.
:param threshold: E-value threshold for BLAST search.
:param threshold: E-value threshold for BLAST searcher.
:return: A dictionary containing the best hit information or None if not found.
"""
try:
Expand All @@ -90,6 +103,7 @@ def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:

# UniProtKB/Swiss-Prot BLAST API
try:
logger.debug("Performing BLAST searcher for the given sequence: %s", seq)
result_handle = NCBIWWW.qblast(
program="blastp",
database="swissprot",
Expand All @@ -98,8 +112,10 @@ def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:
expect=threshold,
)
blast_record = NCBIXML.read(result_handle)
except RequestException:
raise
except Exception as e: # pylint: disable=broad-except
logger.error("BLAST search failed: %s", e)
logger.error("BLAST searcher failed: %s", e)
return None

if not blast_record.alignments:
Expand All @@ -116,3 +132,44 @@ def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:
# like sp|P01308.1|INS_HUMAN
accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id
return self.get_by_accession(accession)

@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(RequestException),
reraise=True,
)
async def search(
self, query: str, threshold: float = 0.7, **kwargs
) -> Optional[Dict]:
"""
Search UniProt with either an accession number, keyword, or FASTA sequence.
:param query: The searcher query (accession number, keyword, or FASTA sequence).
:param threshold: E-value threshold for BLAST searcher.
:return: A dictionary containing the best hit information or None if not found.
"""

# auto detect query type
if not query or not isinstance(query, str):
logger.error("Empty or non-string input.")
return None
query = query.strip()

logger.debug("UniProt searcher query: %s", query)
# check if fasta sequence
if query.startswith(">") or re.fullmatch(
r"[ACDEFGHIKLMNPQRSTVWY\s]+", query, re.I
):
result = self.get_by_fasta(query, threshold)

# check if accession number
elif re.fullmatch(r"[A-NR-Z0-9]{6,10}", query, re.I):
result = self.get_by_accession(query)

else:
# otherwise treat as keyword
result = self.get_best_hit(query)

if result:
result["_search_query"] = query
return result
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class BingSearch:
"""
Bing Search client to search with Bing.
Bing Search client to searcher with Bing.
"""

def __init__(self, subscription_key: str):
Expand All @@ -18,9 +18,9 @@ def __init__(self, subscription_key: str):
def search(self, query: str, num_results: int = 1):
"""
Search with Bing and return the contexts.
:param query: The search query.
:param query: The searcher query.
:param num_results: The number of results to return.
:return: A list of search results.
:return: A list of searcher results.
"""
params = {"q": query, "mkt": BING_MKT, "count": num_results}
response = requests.get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@
class GoogleSearch:
def __init__(self, subscription_key: str, cx: str):
"""
Initialize the Google Search client with the subscription key and custom search engine ID.
Initialize the Google Search client with the subscription key and custom searcher engine ID.
:param subscription_key: Your Google API subscription key.
:param cx: Your custom search engine ID.
:param cx: Your custom searcher engine ID.
"""
self.subscription_key = subscription_key
self.cx = cx

def search(self, query: str, num_results: int = 1):
"""
Search with Google and return the contexts.
:param query: The search query.
:param query: The searcher query.
:param num_results: The number of results to return.
:return: A list of search results.
:return: A list of searcher results.
"""
params = {
"key": self.subscription_key,
Expand Down
Empty file.
58 changes: 0 additions & 58 deletions graphgen/operators/search/kg/search_wikipedia.py

This file was deleted.

Loading