|
1 | 1 | import json |
2 | | -import re |
3 | 2 | from typing import List |
4 | 3 |
|
5 | 4 | import gradio as gr |
6 | 5 |
|
| 6 | +from graphgen.bases import BaseLLMWrapper |
7 | 7 | from graphgen.bases.base_storage import BaseGraphStorage |
8 | 8 | from graphgen.bases.datatypes import Chunk |
9 | | -from graphgen.models import OpenAIClient |
10 | 9 | from graphgen.templates import PROTEIN_ANCHOR_PROMPT, PROTEIN_KG_EXTRACTION_PROMPT |
11 | | -from graphgen.utils import ( |
12 | | - detect_main_language, |
13 | | - handle_single_entity_extraction, |
14 | | - handle_single_relationship_extraction, |
15 | | - logger, |
16 | | - run_concurrent, |
17 | | - split_string_by_multi_markers, |
18 | | -) |
| 10 | +from graphgen.utils import detect_main_language, logger, run_concurrent |
19 | 11 |
|
20 | 12 |
|
21 | 13 | async def build_mo_kg( |
22 | | - llm_client: OpenAIClient, |
| 14 | + llm_client: BaseLLMWrapper, |
23 | 15 | kg_instance: BaseGraphStorage, |
24 | 16 | chunks: List[Chunk], |
25 | 17 | progress_bar: gr.Progress = None, |
@@ -73,48 +65,13 @@ async def extract_mo_info(chunk: Chunk): |
73 | 65 | # logger.warning("Failed to search for protein info: %s", e) |
74 | 66 | # search_results = {} |
75 | 67 |
|
76 | | - # 组织成文本 |
77 | 68 | mo_text = "\n".join([f"{k}: {v}" for k, v in merged.items()]) |
78 | 69 | lang = detect_main_language(mo_text) |
79 | 70 | prompt = PROTEIN_KG_EXTRACTION_PROMPT[lang].format( |
80 | 71 | input_text=mo_text, |
81 | 72 | **PROTEIN_KG_EXTRACTION_PROMPT["FORMAT"], |
82 | 73 | ) |
83 | 74 | kg_output = await llm_client.generate_answer(prompt) |
84 | | - |
85 | | - logger.debug("Image chunk extraction result: %s", kg_output) |
86 | | - |
87 | | - # parse the result |
88 | | - records = split_string_by_multi_markers( |
89 | | - kg_output, |
90 | | - [ |
91 | | - PROTEIN_KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"], |
92 | | - PROTEIN_KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"], |
93 | | - ], |
94 | | - ) |
95 | | - |
96 | | - print(records) |
97 | | - raise NotImplementedError |
98 | | - |
99 | | - nodes = defaultdict(list) |
100 | | - edges = defaultdict(list) |
101 | | - |
102 | | - for record in records: |
103 | | - match = re.search(r"\((.*)\)", record) |
104 | | - if not match: |
105 | | - continue |
106 | | - inner = match.group(1) |
107 | | - |
108 | | - attributes = split_string_by_multi_markers( |
109 | | - inner, [PROTEIN_KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]] |
110 | | - ) |
111 | | - |
112 | | - entity = await handle_single_entity_extraction(attributes, "temp") |
113 | | - if entity is not None: |
114 | | - nodes[entity["entity_name"]].append(entity) |
115 | | - continue |
116 | | - |
117 | | - relation = await handle_single_relationship_extraction(attributes, "temp") |
118 | | - if relation is not None: |
119 | | - key = (relation["src_id"], relation["tgt_id"]) |
120 | | - edges[key].append(relation) |
| 75 | + print(kg_output) |
| 76 | + # TODO: parse kg_output and insert into kg_instance |
| 77 | + return kg_instance |
0 commit comments