Skip to content

Commit 256acc1

Browse files
feat: add mo_kg_builder
1 parent 96be73a commit 256acc1

File tree

13 files changed

+347
-179
lines changed

13 files changed

+347
-179
lines changed

graphgen/configs/protein_qa_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
read:
2-
input_file: resources/input_examples/protein_qa_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
2+
input_file: resources/input_examples/protein_qa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
33
anchor_type: protein # get protein information from chunks
44
split:
55
chunk_size: 1024 # chunk size for text splitting

graphgen/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
MultiHopGenerator,
77
VQAGenerator,
88
)
9-
from .kg_builder import LightRAGKGBuilder, MMKGBuilder
9+
from .kg_builder import LightRAGKGBuilder, MMKGBuilder, MOKGBuilder
1010
from .llm import HTTPClient, OllamaClient, OpenAIClient
1111
from .partitioner import (
1212
AnchorBFSPartitioner,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .light_rag_kg_builder import LightRAGKGBuilder
22
from .mm_kg_builder import MMKGBuilder
3+
from .mo_kg_builder import MOKGBuilder

graphgen/models/kg_builder/mo_kg_builder.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,36 @@
1+
import re
2+
from collections import defaultdict
13
from typing import Dict, List, Tuple
24

35
from graphgen.bases import Chunk
6+
from graphgen.templates import PROTEIN_KG_EXTRACTION_PROMPT
7+
from graphgen.utils import (
8+
detect_main_language,
9+
handle_single_entity_extraction,
10+
handle_single_relationship_extraction,
11+
logger,
12+
split_string_by_multi_markers,
13+
)
414

515
from .light_rag_kg_builder import LightRAGKGBuilder
616

717

818
class MOKGBuilder(LightRAGKGBuilder):
19+
@staticmethod
20+
async def scan_document_for_schema(
21+
chunk: Chunk, schema: Dict[str, List[str]]
22+
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
23+
"""
24+
Scan the document chunk to extract entities and relationships based on the provided schema.
25+
:param chunk: The document chunk to be scanned.
26+
:param schema: A dictionary defining the entities and relationships to be extracted.
27+
:return: A tuple containing two dictionaries - one for entities and one for relationships.
28+
"""
29+
# TODO: use hard-coded PROTEIN_KG_EXTRACTION_PROMPT for protein chunks,
30+
# support schema for other chunk types later
31+
print(chunk.id, schema)
32+
return {}, {}
33+
934
async def extract(
1035
self, chunk: Chunk
1136
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
@@ -19,5 +44,57 @@ async def extract(
1944
:return: Tuple containing entities and relationships.
2045
"""
2146
# TODO: Implement the multi-omics KG extraction logic here
22-
print(chunk)
23-
return {}, {}
47+
chunk_id = chunk.id
48+
chunk_type = chunk.type # genome | protein | ...
49+
metadata = chunk.metadata
50+
51+
# choose different extraction strategies based on chunk type
52+
if chunk_type == "protein":
53+
protein_caption = ""
54+
for key, value in metadata["protein_caption"].items():
55+
protein_caption += f"{key}: {value}\n"
56+
logger.debug("Protein chunk caption: %s", protein_caption)
57+
58+
language = detect_main_language(protein_caption)
59+
prompt_template = PROTEIN_KG_EXTRACTION_PROMPT[language].format(
60+
**PROTEIN_KG_EXTRACTION_PROMPT["FORMAT"],
61+
input_text=protein_caption,
62+
)
63+
result = await self.llm_client.generate_answer(prompt_template)
64+
logger.debug("Protein chunk extraction result: %s", result)
65+
66+
# parse the result
67+
records = split_string_by_multi_markers(
68+
result,
69+
[
70+
PROTEIN_KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
71+
PROTEIN_KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
72+
],
73+
)
74+
75+
nodes = defaultdict(list)
76+
edges = defaultdict(list)
77+
78+
for record in records:
79+
match = re.search(r"\((.*)\)", record)
80+
if not match:
81+
continue
82+
inner = match.group(1)
83+
84+
attributes = split_string_by_multi_markers(
85+
inner, [PROTEIN_KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
86+
)
87+
88+
entity = await handle_single_entity_extraction(attributes, chunk_id)
89+
if entity is not None:
90+
nodes[entity["entity_name"]].append(entity)
91+
continue
92+
93+
relation = await handle_single_relationship_extraction(
94+
attributes, chunk_id
95+
)
96+
if relation is not None:
97+
key = (relation["src_id"], relation["tgt_id"])
98+
edges[key].append(relation)
99+
100+
return dict(nodes), dict(edges)
Lines changed: 102 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,117 @@
1-
import requests
2-
from fastapi import HTTPException
1+
from io import StringIO
2+
from typing import Dict, Optional
33

4-
from graphgen.utils import logger
4+
from Bio import ExPASy, SeqIO, SwissProt, UniProt
5+
from Bio.Blast import NCBIWWW, NCBIXML
56

6-
UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search"
7+
from graphgen.utils import logger
78

89

910
class UniProtSearch:
1011
"""
1112
UniProt Search client to search with UniProt.
1213
1) Get the protein by accession number.
13-
2) Search with keywords or protein names.
14+
2) Search with keywords or protein names (fuzzy search).
1415
"""
1516

16-
def get_entry(self, accession: str) -> dict:
17+
def get_by_accession(self, accession: str) -> Optional[dict]:
18+
try:
19+
handle = ExPASy.get_sprot_raw(accession)
20+
record = SwissProt.read(handle)
21+
handle.close()
22+
return self._swissprot_to_dict(record)
23+
except Exception as exc: # pylint: disable=broad-except
24+
logger.error("Accession %s not found: %s", accession, exc)
25+
return None
26+
27+
@staticmethod
28+
def _swissprot_to_dict(record: SwissProt.Record) -> dict:
29+
"""error
30+
Convert a SwissProt.Record to a dictionary.
1731
"""
18-
Get the UniProt entry by accession number(e.g., P04637).
32+
functions = []
33+
for line in record.comments:
34+
if line.startswith("FUNCTION:"):
35+
functions.append(line[9:].strip())
36+
37+
return {
38+
"molecule_type": "protein",
39+
"database": "UniProt",
40+
"id": record.accessions[0],
41+
"entry_name": record.entry_name,
42+
"gene_names": record.gene_name,
43+
"protein_name": record.description.split(";")[0].split("=")[-1],
44+
"organism": record.organism.split(" (")[0],
45+
"sequence": str(record.sequence),
46+
"function": functions,
47+
"url": f"https://www.uniprot.org/uniprot/{record.accessions[0]}",
48+
}
49+
50+
def get_best_hit(self, keyword: str) -> Optional[Dict]:
1951
"""
20-
url = f"{UNIPROT_BASE}/{accession}.json"
21-
return self._safe_get(url).json()
22-
23-
def search(
24-
self,
25-
query: str,
26-
*,
27-
size: int = 10,
28-
cursor: str = None,
29-
fields: list[str] = None,
30-
) -> dict:
52+
Search UniProt with a keyword and return the best hit.
53+
:param keyword: The search keyword.
54+
:return: A dictionary containing the best hit information or None if not found.
3155
"""
32-
Search UniProt with a query string.
33-
:param query: The search query.
34-
:param size: The number of results to return.
35-
:param cursor: The cursor for pagination.
36-
:param fields: The fields to return in the response.
37-
:return: A dictionary containing the search results.
56+
if not keyword.strip():
57+
return None
58+
59+
try:
60+
iterator = UniProt.search(keyword, fields=None, batch_size=1)
61+
hit = next(iterator, None)
62+
if hit is None:
63+
return None
64+
return self.get_by_accession(hit["primaryAccession"])
65+
66+
except Exception as e: # pylint: disable=broad-except
67+
logger.error("Keyword %s not found: %s", keyword, e)
68+
return None
69+
70+
def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:
3871
"""
39-
params = {
40-
"query": query,
41-
"size": size,
42-
}
43-
if cursor:
44-
params["cursor"] = cursor
45-
if fields:
46-
params["fields"] = ",".join(fields)
47-
url = UNIPROT_BASE
48-
return self._safe_get(url, params=params).json()
72+
Search UniProt with a FASTA sequence and return the best hit.
73+
:param fasta_sequence: The FASTA sequence.
74+
:param threshold: E-value threshold for BLAST search.
75+
:return: A dictionary containing the best hit information or None if not found.
76+
"""
77+
try:
78+
if fasta_sequence.startswith(">"):
79+
seq = str(list(SeqIO.parse(StringIO(fasta_sequence), "fasta"))[0].seq)
80+
else:
81+
seq = fasta_sequence.strip()
82+
except Exception as e: # pylint: disable=broad-except
83+
logger.error("Invalid FASTA sequence: %s", e)
84+
return None
4985

50-
@staticmethod
51-
def _safe_get(url: str, params: dict = None) -> requests.Response:
52-
r = requests.get(
53-
url,
54-
params=params,
55-
headers={"Accept": "application/json"},
56-
timeout=10,
57-
)
58-
if not r.ok:
59-
logger.error("Search engine error: %s", r.text)
60-
raise HTTPException(r.status_code, "Search engine error.")
61-
return r
86+
if not seq:
87+
logger.error("Empty FASTA sequence provided.")
88+
return None
89+
90+
# UniProtKB/Swiss-Prot BLAST API
91+
try:
92+
result_handle = NCBIWWW.qblast(
93+
program="blastp",
94+
database="swissprot",
95+
sequence=seq,
96+
hitlist_size=1,
97+
expect=threshold,
98+
)
99+
blast_record = NCBIXML.read(result_handle)
100+
except Exception as e: # pylint: disable=broad-except
101+
logger.error("BLAST search failed: %s", e)
102+
return None
103+
104+
if not blast_record.alignments:
105+
logger.info("No BLAST hits found for the given sequence.")
106+
return None
107+
108+
best_alignment = blast_record.alignments[0]
109+
best_hsp = best_alignment.hsps[0]
110+
if best_hsp.expect > threshold:
111+
logger.info("No BLAST hits below the threshold E-value.")
112+
return None
113+
hit_id = best_alignment.hit_id
114+
115+
# like sp|P01308.1|INS_HUMAN
116+
accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id
117+
return self.get_by_accession(accession)

graphgen/operators/build_kg/build_kg.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,12 @@ async def build_kg(
3030
"""
3131

3232
text_chunks = [chunk for chunk in chunks if chunk.type == "text"]
33-
mm_chunks = [chunk for chunk in chunks if chunk.type != "text"]
33+
mm_chunks = [
34+
chunk
35+
for chunk in chunks
36+
if chunk.type in ("image", "video", "table", "formula")
37+
]
38+
mo_chunks = [chunk for chunk in chunks if chunk.type in ("genome", "protein")]
3439

3540
if len(text_chunks) == 0:
3641
logger.info("All text chunks are already in the storage")
@@ -42,6 +47,7 @@ async def build_kg(
4247
chunks=text_chunks,
4348
progress_bar=progress_bar,
4449
)
50+
4551
if len(mm_chunks) == 0:
4652
logger.info("All multi-modal chunks are already in the storage")
4753
else:
@@ -53,16 +59,15 @@ async def build_kg(
5359
progress_bar=progress_bar,
5460
)
5561

56-
if anchor_type is not None:
57-
logger.info("Anchoring data based on %s ...", anchor_type)
58-
if anchor_type == "protein":
59-
await build_mo_kg(
60-
llm_client=llm_client,
61-
kg_instance=kg_instance,
62-
chunks=text_chunks,
63-
progress_bar=progress_bar,
64-
)
65-
else:
66-
logger.error("Anchor type %s is not supported yet.", anchor_type)
62+
if len(mo_chunks) == 0:
63+
logger.info("All multi-omics chunks are already in the storage")
64+
else:
65+
logger.info("[Multi-omics Entity and Relation Extraction] processing ...")
66+
await build_mo_kg(
67+
llm_client=llm_client,
68+
kg_instance=kg_instance,
69+
chunks=mo_chunks,
70+
progress_bar=progress_bar,
71+
)
6772

6873
return kg_instance

0 commit comments

Comments
 (0)