Skip to content

Commit de4da5f

Browse files
refactor: use xml format prompt in Generators
1 parent e49c34d commit de4da5f

File tree

11 files changed

+196
-125
lines changed

11 files changed

+196
-125
lines changed

graphgen/models/generator/aggregated_generator.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any
1+
import re
2+
from typing import Any, Optional
23

34
from graphgen.bases import BaseGenerator
45
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
@@ -56,19 +57,21 @@ def build_prompt(
5657
return prompt
5758

5859
@staticmethod
59-
def parse_rephrased_text(response: str) -> str:
60+
def parse_rephrased_text(response: str) -> Optional[str]:
6061
"""
6162
Parse the rephrased text from the response.
6263
:param response:
6364
:return: rephrased text
6465
"""
65-
if "Rephrased Text:" in response:
66-
rephrased_text = response.split("Rephrased Text:")[1].strip()
67-
elif "重述文本:" in response:
68-
rephrased_text = response.split("重述文本:")[1].strip()
66+
rephrased_match = re.search(
67+
r"<rephrased_text>(.*?)</rephrased_text>", response, re.DOTALL
68+
)
69+
if rephrased_match:
70+
rephrased_text = rephrased_match.group(1).strip()
6971
else:
70-
rephrased_text = response.strip()
71-
return rephrased_text.strip('"')
72+
logger.warning("Failed to parse rephrased text from response: %s", response)
73+
return None
74+
return rephrased_text.strip('"').strip("'")
7275

7376
@staticmethod
7477
def _build_prompt_for_question_generation(answer: str) -> str:
@@ -85,15 +88,13 @@ def _build_prompt_for_question_generation(answer: str) -> str:
8588

8689
@staticmethod
8790
def parse_response(response: str) -> dict:
88-
if response.startswith("Question:"):
89-
question = response[len("Question:") :].strip()
90-
elif response.startswith("问题:"):
91-
question = response[len("问题:") :].strip()
91+
question_match = re.search(r"<question>(.*?)</question>", response, re.DOTALL)
92+
if question_match:
93+
question = question_match.group(1).strip()
9294
else:
93-
question = response.strip()
94-
return {
95-
"question": question,
96-
}
95+
logger.warning("Failed to parse question from response: %s", response)
96+
return {"question": ""}
97+
return {"question": question.strip('"').strip("'")}
9798

9899
async def generate(
99100
self,
@@ -110,9 +111,13 @@ async def generate(
110111
rephrasing_prompt = self.build_prompt(batch)
111112
response = await self.llm_client.generate_answer(rephrasing_prompt)
112113
context = self.parse_rephrased_text(response)
114+
if not context:
115+
return result
113116
question_generation_prompt = self._build_prompt_for_question_generation(context)
114117
response = await self.llm_client.generate_answer(question_generation_prompt)
115118
question = self.parse_response(response)["question"]
119+
if not question:
120+
return result
116121
logger.debug("Question: %s", question)
117122
logger.debug("Answer: %s", context)
118123
qa_pairs = {

graphgen/models/generator/atomic_generator.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import Any
23

34
from graphgen.bases import BaseGenerator
@@ -29,17 +30,18 @@ def parse_response(response: str) -> dict:
2930
:param response:
3031
:return:
3132
"""
32-
if "Question:" in response and "Answer:" in response:
33-
question = response.split("Question:")[1].split("Answer:")[0].strip()
34-
answer = response.split("Answer:")[1].strip()
35-
elif "问题:" in response and "答案:" in response:
36-
question = response.split("问题:")[1].split("答案:")[0].strip()
37-
answer = response.split("答案:")[1].strip()
33+
question_match = re.search(r"<question>(.*?)</question>", response, re.DOTALL)
34+
answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
35+
36+
if question_match and answer_match:
37+
question = question_match.group(1).strip()
38+
answer = answer_match.group(1).strip()
3839
else:
3940
logger.warning("Failed to parse response: %s", response)
4041
return {}
41-
question = question.strip('"')
42-
answer = answer.strip('"')
42+
43+
question = question.strip('"').strip("'")
44+
answer = answer.strip('"').strip("'")
4345
logger.debug("Question: %s", question)
4446
logger.debug("Answer: %s", answer)
4547
return {

graphgen/models/generator/cot_generator.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import Any
23

34
from graphgen.bases import BaseGenerator
@@ -67,22 +68,26 @@ def build_prompt_for_cot_generation(
6768

6869
@staticmethod
6970
def parse_response(response: str) -> dict:
70-
if "Question:" in response and "Reasoning-Path Design:" in response:
71-
question = (
72-
response.split("Question:")[1]
73-
.split("Reasoning-Path Design:")[0]
74-
.strip()
75-
)
76-
reasoning_path = response.split("Reasoning-Path Design:")[1].strip()
77-
elif "问题:" in response and "推理路径设计:" in response:
78-
question = response.split("问题:")[1].split("推理路径设计:")[0].strip()
79-
reasoning_path = response.split("推理路径设计:")[1].strip()
71+
"""
72+
Parse CoT template from response.
73+
:param response:
74+
:return: dict with question and reasoning_path
75+
"""
76+
question_match = re.search(r"<question>(.*?)</question>", response, re.DOTALL)
77+
reasoning_path_match = re.search(
78+
r"<reasoning_path>(.*?)</reasoning_path>", response, re.DOTALL
79+
)
80+
81+
if question_match and reasoning_path_match:
82+
question = question_match.group(1).strip()
83+
reasoning_path = reasoning_path_match.group(1).strip()
8084
else:
81-
logger.warning("Failed to parse CoT template: %s", response)
85+
logger.warning("Failed to parse response: %s", response)
8286
return {}
8387

84-
question = question.strip('"')
85-
reasoning_path = reasoning_path.strip('"')
88+
question = question.strip('"').strip("'")
89+
reasoning_path = reasoning_path.strip('"').strip("'")
90+
8691
logger.debug("CoT Question: %s", question)
8792
logger.debug("CoT Reasoning Path: %s", reasoning_path)
8893
return {
@@ -105,6 +110,8 @@ async def generate(
105110
prompt = self.build_prompt(batch)
106111
response = await self.llm_client.generate_answer(prompt)
107112
response = self.parse_response(response)
113+
if not response:
114+
return result
108115
question, reasoning_path = response["question"], response["reasoning_path"]
109116
prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path)
110117
cot_answer = await self.llm_client.generate_answer(prompt)

graphgen/models/generator/multi_hop_generator.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import Any
23

34
from graphgen.bases import BaseGenerator
@@ -32,17 +33,18 @@ def build_prompt(
3233

3334
@staticmethod
3435
def parse_response(response: str) -> dict:
35-
if "Question:" in response and "Answer:" in response:
36-
question = response.split("Question:")[1].split("Answer:")[0].strip()
37-
answer = response.split("Answer:")[1].strip()
38-
elif "问题:" in response and "答案:" in response:
39-
question = response.split("问题:")[1].split("答案:")[0].strip()
40-
answer = response.split("答案:")[1].strip()
36+
question_match = re.search(r"<question>(.*?)</question>", response, re.DOTALL)
37+
answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
38+
39+
if question_match and answer_match:
40+
question = question_match.group(1).strip()
41+
answer = answer_match.group(1).strip()
4142
else:
4243
logger.warning("Failed to parse response: %s", response)
4344
return {}
44-
question = question.strip('"')
45-
answer = answer.strip('"')
45+
46+
question = question.strip('"').strip("'")
47+
answer = answer.strip('"').strip("'")
4648
logger.debug("Question: %s", question)
4749
logger.debug("Answer: %s", answer)
4850
return {

graphgen/models/generator/vqa_generator.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import Any
23

34
from graphgen.bases import BaseGenerator
@@ -38,25 +39,21 @@ def parse_response(response: str) -> Any:
3839
:return: QA pairs
3940
"""
4041
qa_pairs = {}
41-
qa_list = response.strip().split("\n\n")
42-
for qa in qa_list:
43-
if "Question:" in qa and "Answer:" in qa:
44-
question = qa.split("Question:")[1].split("Answer:")[0].strip()
45-
answer = qa.split("Answer:")[1].strip()
46-
elif "问题:" in qa and "答案:" in qa:
47-
question = qa.split("问题:")[1].split("答案:")[0].strip()
48-
answer = qa.split("答案:")[1].strip()
49-
else:
50-
logger.error("Failed to parse QA pair: %s", qa)
51-
continue
52-
question = question.strip('"')
53-
answer = answer.strip('"')
54-
logger.debug("Question: %s", question)
55-
logger.debug("Answer: %s", answer)
56-
qa_pairs[compute_content_hash(question)] = {
57-
"question": question,
58-
"answer": answer,
59-
}
42+
pattern = r"<question>(.*?)</question>\s*<answer>(.*?)</answer>"
43+
matches = re.findall(pattern, response, re.DOTALL)
44+
45+
if matches:
46+
for question, answer in matches:
47+
question = question.strip().strip('"').strip("'")
48+
answer = answer.strip().strip('"').strip("'")
49+
logger.debug("Question: %s", question)
50+
logger.debug("Answer: %s", answer)
51+
qa_pairs[compute_content_hash(question)] = {
52+
"question": question,
53+
"answer": answer,
54+
}
55+
return qa_pairs
56+
logger.warning("Error parsing the response %s", response)
6057
return qa_pairs
6158

6259
async def generate(

graphgen/operators/generate/generate_service.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def generate(self, items: list[dict]) -> list[dict]:
6161
unit="batch",
6262
)
6363

64+
# Filter out empty results
65+
results = [res for res in results if res]
66+
6467
results = self.generator.format_generation_results(
6568
results, output_data_format=self.data_format
6669
)

graphgen/templates/generation/aggregated_generation.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@
132132
- Logical consistency throughout
133133
- Clear cause-and-effect relationships
134134
135+
**Attention: Please directly provide the rephrased text without any additional content or analysis.**
136+
135137
################
136138
-ENTITIES-
137139
################
@@ -175,6 +177,8 @@
175177
- 整体逻辑一致性
176178
- 清晰的因果关系
177179
180+
**注意: 请你直接给出重述文本,不要输出任何额外的内容,也不要进行任何分析。**
181+
178182
################
179183
-实体-
180184
################
@@ -191,32 +195,52 @@
191195
################
192196
请在下方直接输出连贯的重述文本,不要输出任何额外的内容。
193197
198+
输出格式:
199+
<rephrased_text>rephrased_text_here</rephrased_text>
200+
194201
重述文本:
195202
"""
196203

197204
REQUIREMENT_EN = """
198205
################
199206
Please directly output the coherent rephrased text below, without any additional content.
200207
208+
Output format:
209+
<rephrased_text>rephrased_text_here</rephrased_text>
210+
201211
Rephrased Text:
202212
"""
203213

204214
QUESTION_GENERATION_EN: str = """The answer to a question is provided. Please generate a question that corresponds to the answer.
205215
206-
################
207-
Answer:
208-
{answer}
209-
################
216+
The answer for which a question needs to be generated is as follows:
217+
<answer>{answer}</answer>
218+
219+
Please note the following requirements:
220+
1. Only output one question text without any additional explanations or analysis.
221+
2. Do not repeat the content of the answer or any fragments of it.
222+
3. The question must be independently understandable and fully match the answer.
223+
224+
Output format:
225+
<question>question_text</question>
226+
210227
Question:
211228
"""
212229

213230
QUESTION_GENERATION_ZH: str = """下面提供了一个问题的答案,请生成一个与答案对应的问题。
214231
215-
################
216-
答案:
217-
{answer}
218-
################
219-
问题:
232+
需要生成问题的答案如下:
233+
<answer>{answer}</answer>
234+
235+
请注意下列要求:
236+
1. 仅输出一个问题文本,不得包含任何额外说明或分析
237+
2. 不得重复答案内容或其中任何片段
238+
3. 问题必须可独立理解且与答案完全匹配
239+
240+
输出格式:
241+
<question>question_text</question>
242+
243+
问题:
220244
"""
221245

222246
AGGREGATED_GENERATION_PROMPT = {

graphgen/templates/generation/atomic_generation.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,44 @@
11
# pylint: disable=C0301
22
TEMPLATE_EN: str = """You are given a text passage. Your task is to generate a question and answer (QA) pair based on the content of that text.
3-
The answer should be accurate and directly derived from the text. Make sure the QA pair is relevant to the main theme or important details of the given text.
4-
For example:
5-
Question: What is the effect of overexpressing the BG1 gene on grain size and development?
6-
Answer: Overexpression of the BG1 gene leads to significantly increased grain size, demonstrating its role in grain development.
73
8-
Question: What role does TAC4 play in the gravitropism of rice shoots?
9-
Answer: TAC4 is a key regulator of gravitropism in rice shoots, promoting the bending of shoots towards the gravity vector.
4+
Please note the following requirements:
5+
1. Output only one QA pair without any additional explanations or analysis.
6+
2. Do not repeat the content of the answer or any part of it.
7+
3. The answer should be accurate and directly derived from the text. Make sure the QA pair is relevant to the main theme or important details of the given text.
8+
9+
Output format:
10+
<question>question_text</question>
11+
<answer>answer_text</answer>
12+
13+
For example:
14+
<question>What is the effect of overexpressing the BG1 gene on grain size and development?</question>
15+
<answer>Overexpression of the BG1 gene leads to significantly increased grain size, demonstrating its role in grain development.</answer>
1016
1117
Here is the text passage you need to generate a QA pair for:
1218
{context}
19+
20+
Output:
1321
"""
1422

1523
TEMPLATE_ZH: str = """给定一个文本段落。你的任务是根据该文本的内容生成一个问答(QA)对。
16-
答案应准确且直接从文本中得出。确保QA对与给定文本的主题或重要细节相关。
17-
例如:
18-
问题:过表达BG1基因对谷粒大小和发育有什么影响?
19-
答案:BG1基因的过表达显著增加了谷粒大小,表明其在谷物发育中的作用。
2024
21-
问题:TAC4在水稻茎的重力性状中扮演什么角色?
22-
答案:TAC4是水稻茎重力性状的关键调节因子,促进茎向重力矢量弯曲。
25+
请注意下列要求:
26+
1. 仅输出一个问答(QA)对,不得包含任何额外说明或分析
27+
2. 不得重复答案内容或其中任何片段
28+
3. 答案应准确且直接从文本中得出。确保QA对与给定文本的主题或重要细节相关。
29+
30+
输出格式如下:
31+
<question>question_text</question>
32+
<answer>answer_text</answer>
33+
34+
例如:
35+
<question>过表达BG1基因对谷粒大小和发育有什么影响?</question>
36+
<answer>BG1基因的过表达显著增加了谷粒大小,表明其在谷物发育中的作用。</answer>
2337
2438
以下是你需要为其生成QA对的文本段落:
2539
{context}
40+
41+
输出:
2642
"""
2743

2844

0 commit comments

Comments
 (0)