From e49c34d4bf39fb36bfcc49fad0c9444dd7cc6075 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 23 Dec 2025 15:32:43 +0800 Subject: [PATCH 1/4] fix: change cache_dir in read operator to working_dir --- graphgen/operators/read/read.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/graphgen/operators/read/read.py b/graphgen/operators/read/read.py index fbed377e..c55f3d3d 100644 --- a/graphgen/operators/read/read.py +++ b/graphgen/operators/read/read.py @@ -50,7 +50,7 @@ def _build_reader(suffix: str, cache_dir: str | None, **reader_kwargs): def read( input_path: Union[str, List[str]], allowed_suffix: Optional[List[str]] = None, - cache_dir: Optional[str] = "cache", + working_dir: Optional[str] = "cache", parallelism: int = 4, recursive: bool = True, **reader_kwargs: Any, @@ -60,7 +60,7 @@ def read( :param input_path: File or directory path(s) to read from :param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt']) - :param cache_dir: Directory to cache intermediate files (PDF processing) + :param working_dir: Directory to cache intermediate files (PDF processing) :param parallelism: Number of parallel workers :param recursive: Whether to scan directories recursively :param reader_kwargs: Additional kwargs passed to readers @@ -70,7 +70,7 @@ def read( # 1. Scan all paths to discover files logger.info("[READ] Scanning paths: %s", input_path) scanner = ParallelFileScanner( - cache_dir=cache_dir, + cache_dir=working_dir, allowed_suffix=allowed_suffix, rescan=False, max_workers=parallelism if parallelism > 0 else 1, @@ -100,7 +100,7 @@ def read( # 3. Create read tasks read_tasks = [] for suffix, file_paths in files_by_suffix.items(): - reader = _build_reader(suffix, cache_dir, **reader_kwargs) + reader = _build_reader(suffix, working_dir, **reader_kwargs) ds = reader.read(file_paths) read_tasks.append(ds) From de4da5f44caa2aa4726fff2e12f684b34ae3fa40 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 23 Dec 2025 20:39:15 +0800 Subject: [PATCH 2/4] refactor: use xml format prompt in Generators --- .../models/generator/aggregated_generator.py | 37 ++++++------ graphgen/models/generator/atomic_generator.py | 18 +++--- graphgen/models/generator/cot_generator.py | 33 ++++++----- .../models/generator/multi_hop_generator.py | 18 +++--- graphgen/models/generator/vqa_generator.py | 35 ++++++------ .../operators/generate/generate_service.py | 3 + .../generation/aggregated_generation.py | 42 +++++++++++--- .../templates/generation/atomic_generation.py | 40 +++++++++---- .../templates/generation/cot_generation.py | 12 ++-- .../generation/multi_hop_generation.py | 56 +++++++++++-------- .../templates/generation/vqa_generation.py | 27 +++++---- 11 files changed, 196 insertions(+), 125 deletions(-) diff --git a/graphgen/models/generator/aggregated_generator.py b/graphgen/models/generator/aggregated_generator.py index 4bad8e99..3ed2078b 100644 --- a/graphgen/models/generator/aggregated_generator.py +++ b/graphgen/models/generator/aggregated_generator.py @@ -1,4 +1,5 @@ -from typing import Any +import re +from typing import Any, Optional from graphgen.bases import BaseGenerator from graphgen.templates import AGGREGATED_GENERATION_PROMPT @@ -56,19 +57,21 @@ def build_prompt( return prompt @staticmethod - def parse_rephrased_text(response: str) -> str: + def parse_rephrased_text(response: str) -> Optional[str]: """ Parse the rephrased text from the response. :param response: :return: rephrased text """ - if "Rephrased Text:" in response: - rephrased_text = response.split("Rephrased Text:")[1].strip() - elif "重述文本:" in response: - rephrased_text = response.split("重述文本:")[1].strip() + rephrased_match = re.search( + r"(.*?)", response, re.DOTALL + ) + if rephrased_match: + rephrased_text = rephrased_match.group(1).strip() else: - rephrased_text = response.strip() - return rephrased_text.strip('"') + logger.warning("Failed to parse rephrased text from response: %s", response) + return None + return rephrased_text.strip('"').strip("'") @staticmethod def _build_prompt_for_question_generation(answer: str) -> str: @@ -85,15 +88,13 @@ def _build_prompt_for_question_generation(answer: str) -> str: @staticmethod def parse_response(response: str) -> dict: - if response.startswith("Question:"): - question = response[len("Question:") :].strip() - elif response.startswith("问题:"): - question = response[len("问题:") :].strip() + question_match = re.search(r"(.*?)", response, re.DOTALL) + if question_match: + question = question_match.group(1).strip() else: - question = response.strip() - return { - "question": question, - } + logger.warning("Failed to parse question from response: %s", response) + return {"question": ""} + return {"question": question.strip('"').strip("'")} async def generate( self, @@ -110,9 +111,13 @@ async def generate( rephrasing_prompt = self.build_prompt(batch) response = await self.llm_client.generate_answer(rephrasing_prompt) context = self.parse_rephrased_text(response) + if not context: + return result question_generation_prompt = self._build_prompt_for_question_generation(context) response = await self.llm_client.generate_answer(question_generation_prompt) question = self.parse_response(response)["question"] + if not question: + return result logger.debug("Question: %s", question) logger.debug("Answer: %s", context) qa_pairs = { diff --git a/graphgen/models/generator/atomic_generator.py b/graphgen/models/generator/atomic_generator.py index 713140d2..152e6389 100644 --- a/graphgen/models/generator/atomic_generator.py +++ b/graphgen/models/generator/atomic_generator.py @@ -1,3 +1,4 @@ +import re from typing import Any from graphgen.bases import BaseGenerator @@ -29,17 +30,18 @@ def parse_response(response: str) -> dict: :param response: :return: """ - if "Question:" in response and "Answer:" in response: - question = response.split("Question:")[1].split("Answer:")[0].strip() - answer = response.split("Answer:")[1].strip() - elif "问题:" in response and "答案:" in response: - question = response.split("问题:")[1].split("答案:")[0].strip() - answer = response.split("答案:")[1].strip() + question_match = re.search(r"(.*?)", response, re.DOTALL) + answer_match = re.search(r"(.*?)", response, re.DOTALL) + + if question_match and answer_match: + question = question_match.group(1).strip() + answer = answer_match.group(1).strip() else: logger.warning("Failed to parse response: %s", response) return {} - question = question.strip('"') - answer = answer.strip('"') + + question = question.strip('"').strip("'") + answer = answer.strip('"').strip("'") logger.debug("Question: %s", question) logger.debug("Answer: %s", answer) return { diff --git a/graphgen/models/generator/cot_generator.py b/graphgen/models/generator/cot_generator.py index a111a6f6..3a893784 100644 --- a/graphgen/models/generator/cot_generator.py +++ b/graphgen/models/generator/cot_generator.py @@ -1,3 +1,4 @@ +import re from typing import Any from graphgen.bases import BaseGenerator @@ -67,22 +68,26 @@ def build_prompt_for_cot_generation( @staticmethod def parse_response(response: str) -> dict: - if "Question:" in response and "Reasoning-Path Design:" in response: - question = ( - response.split("Question:")[1] - .split("Reasoning-Path Design:")[0] - .strip() - ) - reasoning_path = response.split("Reasoning-Path Design:")[1].strip() - elif "问题:" in response and "推理路径设计:" in response: - question = response.split("问题:")[1].split("推理路径设计:")[0].strip() - reasoning_path = response.split("推理路径设计:")[1].strip() + """ + Parse CoT template from response. + :param response: + :return: dict with question and reasoning_path + """ + question_match = re.search(r"(.*?)", response, re.DOTALL) + reasoning_path_match = re.search( + r"(.*?)", response, re.DOTALL + ) + + if question_match and reasoning_path_match: + question = question_match.group(1).strip() + reasoning_path = reasoning_path_match.group(1).strip() else: - logger.warning("Failed to parse CoT template: %s", response) + logger.warning("Failed to parse response: %s", response) return {} - question = question.strip('"') - reasoning_path = reasoning_path.strip('"') + question = question.strip('"').strip("'") + reasoning_path = reasoning_path.strip('"').strip("'") + logger.debug("CoT Question: %s", question) logger.debug("CoT Reasoning Path: %s", reasoning_path) return { @@ -105,6 +110,8 @@ async def generate( prompt = self.build_prompt(batch) response = await self.llm_client.generate_answer(prompt) response = self.parse_response(response) + if not response: + return result question, reasoning_path = response["question"], response["reasoning_path"] prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path) cot_answer = await self.llm_client.generate_answer(prompt) diff --git a/graphgen/models/generator/multi_hop_generator.py b/graphgen/models/generator/multi_hop_generator.py index 9098b102..896592e8 100644 --- a/graphgen/models/generator/multi_hop_generator.py +++ b/graphgen/models/generator/multi_hop_generator.py @@ -1,3 +1,4 @@ +import re from typing import Any from graphgen.bases import BaseGenerator @@ -32,17 +33,18 @@ def build_prompt( @staticmethod def parse_response(response: str) -> dict: - if "Question:" in response and "Answer:" in response: - question = response.split("Question:")[1].split("Answer:")[0].strip() - answer = response.split("Answer:")[1].strip() - elif "问题:" in response and "答案:" in response: - question = response.split("问题:")[1].split("答案:")[0].strip() - answer = response.split("答案:")[1].strip() + question_match = re.search(r"(.*?)", response, re.DOTALL) + answer_match = re.search(r"(.*?)", response, re.DOTALL) + + if question_match and answer_match: + question = question_match.group(1).strip() + answer = answer_match.group(1).strip() else: logger.warning("Failed to parse response: %s", response) return {} - question = question.strip('"') - answer = answer.strip('"') + + question = question.strip('"').strip("'") + answer = answer.strip('"').strip("'") logger.debug("Question: %s", question) logger.debug("Answer: %s", answer) return { diff --git a/graphgen/models/generator/vqa_generator.py b/graphgen/models/generator/vqa_generator.py index 91b44862..0c432f52 100644 --- a/graphgen/models/generator/vqa_generator.py +++ b/graphgen/models/generator/vqa_generator.py @@ -1,3 +1,4 @@ +import re from typing import Any from graphgen.bases import BaseGenerator @@ -38,25 +39,21 @@ def parse_response(response: str) -> Any: :return: QA pairs """ qa_pairs = {} - qa_list = response.strip().split("\n\n") - for qa in qa_list: - if "Question:" in qa and "Answer:" in qa: - question = qa.split("Question:")[1].split("Answer:")[0].strip() - answer = qa.split("Answer:")[1].strip() - elif "问题:" in qa and "答案:" in qa: - question = qa.split("问题:")[1].split("答案:")[0].strip() - answer = qa.split("答案:")[1].strip() - else: - logger.error("Failed to parse QA pair: %s", qa) - continue - question = question.strip('"') - answer = answer.strip('"') - logger.debug("Question: %s", question) - logger.debug("Answer: %s", answer) - qa_pairs[compute_content_hash(question)] = { - "question": question, - "answer": answer, - } + pattern = r"(.*?)\s*(.*?)" + matches = re.findall(pattern, response, re.DOTALL) + + if matches: + for question, answer in matches: + question = question.strip().strip('"').strip("'") + answer = answer.strip().strip('"').strip("'") + logger.debug("Question: %s", question) + logger.debug("Answer: %s", answer) + qa_pairs[compute_content_hash(question)] = { + "question": question, + "answer": answer, + } + return qa_pairs + logger.warning("Error parsing the response %s", response) return qa_pairs async def generate( diff --git a/graphgen/operators/generate/generate_service.py b/graphgen/operators/generate/generate_service.py index 1ae2f067..db784d08 100644 --- a/graphgen/operators/generate/generate_service.py +++ b/graphgen/operators/generate/generate_service.py @@ -61,6 +61,9 @@ def generate(self, items: list[dict]) -> list[dict]: unit="batch", ) + # Filter out empty results + results = [res for res in results if res] + results = self.generator.format_generation_results( results, output_data_format=self.data_format ) diff --git a/graphgen/templates/generation/aggregated_generation.py b/graphgen/templates/generation/aggregated_generation.py index 305064e7..a2542fac 100644 --- a/graphgen/templates/generation/aggregated_generation.py +++ b/graphgen/templates/generation/aggregated_generation.py @@ -132,6 +132,8 @@ - Logical consistency throughout - Clear cause-and-effect relationships +**Attention: Please directly provide the rephrased text without any additional content or analysis.** + ################ -ENTITIES- ################ @@ -175,6 +177,8 @@ - 整体逻辑一致性 - 清晰的因果关系 +**注意: 请你直接给出重述文本,不要输出任何额外的内容,也不要进行任何分析。** + ################ -实体- ################ @@ -191,6 +195,9 @@ ################ 请在下方直接输出连贯的重述文本,不要输出任何额外的内容。 +输出格式: +rephrased_text_here + 重述文本: """ @@ -198,25 +205,42 @@ ################ Please directly output the coherent rephrased text below, without any additional content. +Output format: +rephrased_text_here + Rephrased Text: """ QUESTION_GENERATION_EN: str = """The answer to a question is provided. Please generate a question that corresponds to the answer. -################ -Answer: -{answer} -################ +The answer for which a question needs to be generated is as follows: +{answer} + +Please note the following requirements: +1. Only output one question text without any additional explanations or analysis. +2. Do not repeat the content of the answer or any fragments of it. +3. The question must be independently understandable and fully match the answer. + +Output format: +question_text + Question: """ QUESTION_GENERATION_ZH: str = """下面提供了一个问题的答案,请生成一个与答案对应的问题。 -################ -答案: -{answer} -################ -问题: +需要生成问题的答案如下: +{answer} + +请注意下列要求: +1. 仅输出一个问题文本,不得包含任何额外说明或分析 +2. 不得重复答案内容或其中任何片段 +3. 问题必须可独立理解且与答案完全匹配 + +输出格式: +question_text + +问题: """ AGGREGATED_GENERATION_PROMPT = { diff --git a/graphgen/templates/generation/atomic_generation.py b/graphgen/templates/generation/atomic_generation.py index 499100f7..c7dac79a 100644 --- a/graphgen/templates/generation/atomic_generation.py +++ b/graphgen/templates/generation/atomic_generation.py @@ -1,28 +1,44 @@ # pylint: disable=C0301 TEMPLATE_EN: str = """You are given a text passage. Your task is to generate a question and answer (QA) pair based on the content of that text. -The answer should be accurate and directly derived from the text. Make sure the QA pair is relevant to the main theme or important details of the given text. -For example: -Question: What is the effect of overexpressing the BG1 gene on grain size and development? -Answer: Overexpression of the BG1 gene leads to significantly increased grain size, demonstrating its role in grain development. -Question: What role does TAC4 play in the gravitropism of rice shoots? -Answer: TAC4 is a key regulator of gravitropism in rice shoots, promoting the bending of shoots towards the gravity vector. +Please note the following requirements: +1. Output only one QA pair without any additional explanations or analysis. +2. Do not repeat the content of the answer or any part of it. +3. The answer should be accurate and directly derived from the text. Make sure the QA pair is relevant to the main theme or important details of the given text. + +Output format: +question_text +answer_text + +For example: +What is the effect of overexpressing the BG1 gene on grain size and development? +Overexpression of the BG1 gene leads to significantly increased grain size, demonstrating its role in grain development. Here is the text passage you need to generate a QA pair for: {context} + +Output: """ TEMPLATE_ZH: str = """给定一个文本段落。你的任务是根据该文本的内容生成一个问答(QA)对。 -答案应准确且直接从文本中得出。确保QA对与给定文本的主题或重要细节相关。 -例如: -问题:过表达BG1基因对谷粒大小和发育有什么影响? -答案:BG1基因的过表达显著增加了谷粒大小,表明其在谷物发育中的作用。 -问题:TAC4在水稻茎的重力性状中扮演什么角色? -答案:TAC4是水稻茎重力性状的关键调节因子,促进茎向重力矢量弯曲。 +请注意下列要求: +1. 仅输出一个问答(QA)对,不得包含任何额外说明或分析 +2. 不得重复答案内容或其中任何片段 +3. 答案应准确且直接从文本中得出。确保QA对与给定文本的主题或重要细节相关。 + +输出格式如下: +question_text +answer_text + +例如: +过表达BG1基因对谷粒大小和发育有什么影响? +BG1基因的过表达显著增加了谷粒大小,表明其在谷物发育中的作用。 以下是你需要为其生成QA对的文本段落: {context} + +输出: """ diff --git a/graphgen/templates/generation/cot_generation.py b/graphgen/templates/generation/cot_generation.py index 849a7c71..9ce242c6 100644 --- a/graphgen/templates/generation/cot_generation.py +++ b/graphgen/templates/generation/cot_generation.py @@ -81,7 +81,7 @@ Output: """ -COT_TEMPLATE_DESIGN_ZH = """你是一位“元推理架构师”。你的任务不是回答问题,\ +COT_TEMPLATE_DESIGN_ZH: str = """你是一位“元推理架构师”。你的任务不是回答问题,\ 而是根据给定的知识图谱中的实体和关系的名称以及描述信息,设计一条可复用、可泛化的 CoT 推理路径模板。\ -步骤- @@ -115,8 +115,8 @@ 4. 不要出现具体数值或结论,不要出现“识别实体”、“识别关系”这类无意义的操作描述。 5. 使用中文作为输出语言。 6. 输出格式为: -问题: -推理路径设计: +问题文本 +推理路径设计文本 -真实数据- 输入: @@ -130,7 +130,7 @@ """ -COT_TEMPLATE_DESIGN_EN = """You are a “meta-reasoning architect”. \ +COT_TEMPLATE_DESIGN_EN: str = """You are a “meta-reasoning architect”. \ Your task is NOT to answer the question, but to design a reusable, generalizable CoT reasoning-path \ template based solely on the names and descriptions of entities and \ relationships in the provided knowledge graph. @@ -168,8 +168,8 @@ and DO NOT describing meaningless operations like "Identify the entity" or "Identify the relationship". 5. Use English as the output language. 6. The output format is: -Question: -Reasoning-Path Design: +question text +reasoning path design text Please summarize the information expressed by the knowledge graph based on the following [Entities:] and [Relationships:] provided. diff --git a/graphgen/templates/generation/multi_hop_generation.py b/graphgen/templates/generation/multi_hop_generation.py index 73857ebb..b077e64c 100644 --- a/graphgen/templates/generation/multi_hop_generation.py +++ b/graphgen/templates/generation/multi_hop_generation.py @@ -1,56 +1,68 @@ # pylint: disable=C0301 -TEMPLATE_ZH: str = """请基于以下知识子图生成多跳推理问题和答案。你将获得一个知识子图,其中包含一系列实体、关系和事实。你的任务是提出一个问题,该问题需要经过多次推理才能回答。问题的答案应该是从给定的知识子图中推断出来的。确保问题的难度适中,需要多步推理才能回答。 +TEMPLATE_ZH: str = """请基于以下知识子图生成多跳推理问题和答案。你将获得一个知识子图,其中包含一系列实体、关系和事实。 +你的任务是生成一个问答对,其中问题需要经过多次推理才能回答。问题的答案应该是从给定的知识子图中推断出来的。确保问题的难度适中,需要多步推理才能回答。 + +请注意下列要求: +1. 仅输出一个问答(QA)对,不得包含任何额外说明或分析 +2. 不得重复答案内容或其中任何片段,不要直接复制示例问题和答案 +3. 答案应准确且直接从文本中得出。确保QA对与给定文本的主题或重要细节相关。 + +输出格式: +question_text +answer_text 例如: -######## +输入为: --实体-- 1. 苹果 2. 水果 3. 维生素C -######## --关系-- 1. 苹果-水果:苹果是一种水果 2. 水果-维生素C:水果中富含维生素C -######## -问题:通过吃苹果补充的什么物质,有助于维持健康? -答案:维生素C -######## -######### +输出: +通过吃苹果补充的什么物质,有助于维持健康? +维生素C + +真实输入如下: --实体-- {entities} -######### --关系-- {relationships} -######### -直接输出生成的问题和答案,请不要直接复制示例问题和答案,不要输出无关内容。 + +输出: """ -TEMPLATE_EN: str = """Please generate a multi-hop reasoning question and answer based on the following knowledge subgraph. You will be provided with a knowledge subgraph that contains a series of entities, relations, and facts. Your task is to generate a question that requires multiple steps of reasoning to answer. The answer to the question should be inferred from the given knowledge subgraph. Ensure that the question is of moderate difficulty and requires multiple steps of reasoning to answer. +TEMPLATE_EN: str = """Please generate a multi-hop reasoning question and answer based on the following knowledge subgraph. You will be provided with a knowledge subgraph that contains a series of entities, relations, and facts. +Your task is to generate a question-answer (QA) pair where the question requires multiple steps of reasoning to answer. The answer to the question should be inferred from the given knowledge subgraph. Ensure that the question is of moderate difficulty and requires multiple steps of reasoning to answer. + +Please note the following requirements: +1. Output only one QA pair without any additional explanations or analysis. +2. Do not repeat the content of the answer or any part of it. Do not directly copy the example question and answer. +3. The answer should be accurate and directly derived from the text. Make sure the QA pair is relevant to the main theme or important details of the given text. For example: -######## +Input: --Entities-- 1. Apple 2. Fruit 3. Vitamin C -######## --Relations-- 1. Apple-Fruit: Apple is a type of fruit 2. Fruit-Vitamin C: Fruits are rich in Vitamin C -######## -Question: What substance, obtained through eating apples, helps maintain health? -Answer: Vitamin C -######## -######## +Output: +What substance, obtained by eating apples, helps maintain health? +Vitamin C + +Real input: --Entities-- {entities} -######## --Relations-- {relationships} -######## -Output the generated question and answer directly, please do not copy the example question and answer directly, and do not provide irrelevant information. + +Output: """ MULTI_HOP_GENERATION_PROMPT = {"en": TEMPLATE_EN, "zh": TEMPLATE_ZH} diff --git a/graphgen/templates/generation/vqa_generation.py b/graphgen/templates/generation/vqa_generation.py index 4826be0e..b8804fc8 100644 --- a/graphgen/templates/generation/vqa_generation.py +++ b/graphgen/templates/generation/vqa_generation.py @@ -39,14 +39,16 @@ ################ {relationships} ################ -Directly output the generated questions and answers, please do not directly copy the example questions and answers, and do not provide irrelevant information. -Here is the response format you should follow: -Question: -Answer: -Question: -Answer: +Please directly output the generated questions and answers, do not directly copy the example questions and answers, and do not provide irrelevant information. + +Here is the response format you should follow: +question1 +answer1 +question2 +answer2 +Output: """ TEMPLATE_ZH: str = """---角色--- @@ -91,14 +93,15 @@ ################ {relationships} ################ -直接输出生成的问题和答案,请不要直接复制示例问题和答案,不要输出无关内容。 -以下是你应该遵循的响应格式: -问题: <问题1> -答案: <答案1> -问题: <问题2> -答案: <答案2> +请直接输出生成的问题和答案,不要直接复制示例问题和答案,也不要提供无关信息。 +以下是你应遵循的响应格式: +question1 +answer1 +question2 +answer2 +输出: """ VQA_GENERATION_PROMPT = {"en": TEMPLATE_EN, "zh": TEMPLATE_ZH} From 467504e5762fa5cfc64a7cc7f83e0e63c4a708ca Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 23 Dec 2025 20:42:30 +0800 Subject: [PATCH 3/4] feat: change temperature & max_token in vllmwrapper --- graphgen/models/llm/local/vllm_wrapper.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py index c6e5feac..fc412b51 100644 --- a/graphgen/models/llm/local/vllm_wrapper.py +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -16,7 +16,7 @@ def __init__( model: str, tensor_parallel_size: int = 1, gpu_memory_utilization: float = 0.9, - temperature: float = 0.0, + temperature: float = 0.6, top_p: float = 1.0, topk: int = 5, **kwargs: Any, @@ -66,7 +66,7 @@ async def generate_answer( sp = self.SamplingParams( temperature=self.temperature if self.temperature > 0 else 1.0, top_p=self.top_p if self.temperature > 0 else 1.0, - max_tokens=extra.get("max_new_tokens", 512), + max_tokens=extra.get("max_new_tokens", 2048), ) result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) @@ -82,7 +82,7 @@ async def generate_answer( async def generate_topk_per_token( self, text: str, history: Optional[List[str]] = None, **extra: Any - ) -> List[Token]: + ) -> List[Token]: full_prompt = self._build_inputs(text, history) request_id = f"graphgen_topk_{uuid.uuid4()}" @@ -110,7 +110,9 @@ async def generate_topk_per_token( candidate_tokens = [] for _, logprob_obj in top_logprobs.items(): - tok_str = logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else "" + tok_str = ( + logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else "" + ) prob = float(math.exp(logprob_obj.logprob)) candidate_tokens.append(Token(tok_str, prob)) @@ -120,7 +122,7 @@ async def generate_topk_per_token( main_token = Token( text=candidate_tokens[0].text, prob=candidate_tokens[0].prob, - top_candidates=candidate_tokens + top_candidates=candidate_tokens, ) return [main_token] return [] From 85ae0d7df0fe836bda19211b8d033ab722359105 Mon Sep 17 00:00:00 2001 From: chenzihong <58508660+ChenZiHong-Gavin@users.noreply.github.com> Date: Tue, 23 Dec 2025 21:00:21 +0800 Subject: [PATCH 4/4] Update graphgen/models/generator/vqa_generator.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- graphgen/models/generator/vqa_generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphgen/models/generator/vqa_generator.py b/graphgen/models/generator/vqa_generator.py index 0c432f52..5839d327 100644 --- a/graphgen/models/generator/vqa_generator.py +++ b/graphgen/models/generator/vqa_generator.py @@ -52,8 +52,8 @@ def parse_response(response: str) -> Any: "question": question, "answer": answer, } - return qa_pairs - logger.warning("Error parsing the response %s", response) + else: + logger.warning("Error parsing the response %s", response) return qa_pairs async def generate(