Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
88f4a6b
Increase CPU verify_timeout default from 600s to 1200s.
Xreki May 15, 2026
fb11dd8
Add llm_timeout parameter to GraphNetAgent with 600s default.
Xreki May 15, 2026
e8c26a1
Increase LLM timeout and skip forward verify on CPU timeout.
Xreki May 15, 2026
76cb7dd
Track verify-timeout success and expose in progress/summary logs.
Xreki May 15, 2026
1042f00
Merge branch 'develop' into opt_extract_agent
Xreki May 15, 2026
69d3826
Improve prompt.
Xreki May 18, 2026
42b27d5
feat(agent): add orphan worker cleanup to prevent GPU leak on SIGKILL
Xreki May 18, 2026
612392a
feat(agent): add error classification and smart LLM retry
Xreki May 18, 2026
1f159c9
refactor(agent): move error category enum into exceptions.py
Xreki May 18, 2026
26bed31
refactor(agent): replace markdown_report with plain list output
Xreki May 18, 2026
9ba299f
fix(agent): change llm_timeout default back to 600
Xreki May 18, 2026
b46b9b4
feat(agent): expose error_category to parallel_extract results
Xreki May 18, 2026
ca59ba9
docs(agent): translate Chinese comments to English
Xreki May 19, 2026
99d1089
fix(agent): skip non-fixable LLM retry errors
Xreki May 19, 2026
2722a94
fix(agent): restore CPU verify timeout default
Xreki May 19, 2026
d93dfdb
feat(agent): precheck HuggingFace model accessibility
Xreki May 19, 2026
95e4163
refactor(agent): constrain LLM fixer output
Xreki May 19, 2026
6744a79
fix(agent): restore LLM timeout default
Xreki May 19, 2026
307ca47
fix(agent): support current HuggingFace model_info API
Xreki May 19, 2026
25a0746
refactor(agent): generate minimal extraction scripts
Xreki May 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graph_net/agent/code_generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ def generate(
Path to generated script file

Raises:
CodeGenError: If code generation fails
CodeGenerationError: If code generation fails
"""
pass
77 changes: 58 additions & 19 deletions graph_net/agent/code_generator/llm_code_fixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from pathlib import Path
from typing import Optional

from graph_net.agent.utils.exceptions import CodeGenError
from graph_net.agent.utils.exceptions import (
CodeGenerationError,
GraphExtractionErrorCategory,
)

# Candidate binary names / paths to search for ducc CLI
_DUCC_CANDIDATES = [
Expand All @@ -21,7 +24,7 @@

_SYSTEM_PROMPT = """\
你是 PyTorch / HuggingFace 模型计算图抽取专家。
任务:修复一段失败的图抽取脚本,输出完整、可直接运行的 Python 脚本。
任务:修复一段失败的图抽取脚本,输出完整、可直接运行但最小化的 Python 脚本。

## 【硬性约束 - 违反即输出无效】
1. 抽取调用格式固定为:
Expand All @@ -33,6 +36,8 @@
3. 设备选择固定写法:device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4. 只允许使用 torch、transformers、graph_net 及 Python 标准库(os/pathlib/json 等)
5. 只输出代码块,格式:```python\\n...代码...\\n```,禁止输出任何说明文字
6. 必须输出完整但最小化的脚本,只保留:必要 import、模型/config 加载、输入 tensor 构造、graph_net.torch.extract(...)、一次 forward 调用
7. 禁止添加注释、helper 函数、错误处理、try/except、fallback 逻辑、重试逻辑、文件系统遍历、额外校验或无关打印。只修复导致报错的输入构造或调用方式,保持行数尽可能少

## 【输入构造规范 - 按 model_type 选择对应方案】

Expand Down Expand Up @@ -141,11 +146,11 @@ def __init__(
):
"""
Args:
timeout: Max seconds to wait for ducc response.
timeout: Max seconds to wait for ducc response (default 360s).
model: Override the LLM model (e.g. 'sonnet', 'haiku').
If None, uses whatever ducc default is configured.
"""
self.timeout = timeout
self.timeout = timeout if timeout is not None else 360
self.model = model
self.logger = logging.getLogger(self.__class__.__name__)
self._ducc_bin = _find_ducc()
Expand Down Expand Up @@ -176,7 +181,7 @@ def fix(

Args:
script_path: Path to the (failed) script to fix
error_msg: Captured stderr / ExtractionError message
error_msg: Captured stderr / GraphExtractionError message
model_dir: Local model directory (contains config.json)
model_id: HuggingFace model ID (e.g. 'prajjwal1/bert-tiny')
output_dir: Directory where the fixed script should be written
Expand All @@ -186,10 +191,10 @@ def fix(
Path to the fixed script (run_model_llm_1.py / run_model_llm_2.py)

Raises:
CodeGenError: If LLM call fails or returns no valid code
CodeGenerationError: If LLM call fails or returns no valid code
"""
if not self.available:
raise CodeGenError(
raise CodeGenerationError(
"ducc/claude binary not available; cannot perform LLM fix."
)

Expand All @@ -213,7 +218,7 @@ def fix(

code = _extract_code_block(llm_output)
if not code:
raise CodeGenError(
raise CodeGenerationError(
f"LLM response contained no Python code block.\n"
f"Response (first 500 chars):\n{llm_output[:500]}"
)
Expand All @@ -228,6 +233,26 @@ def fix(
# Internal helpers
# ------------------------------------------------------------------

@staticmethod
def _compact_script(script: str) -> str:
"""Remove blank lines and pure comment lines to shrink prompt size."""
lines = script.splitlines()
compacted = []
for line in lines:
stripped = line.strip()
if stripped == "" or stripped.startswith("#"):
continue
compacted.append(line.rstrip())
return "\n".join(compacted)

@staticmethod
def _truncate_error(error_msg: str, max_chars: int = 1200) -> str:
if len(error_msg) <= max_chars:
return error_msg
# Keep tail (usually contains the actual error) + head for context
half = max_chars // 2
return error_msg[:half] + "\n... (truncated) ...\n" + error_msg[-half:]

def _build_prompt(
self,
original_script: str,
Expand All @@ -240,19 +265,29 @@ def _build_prompt(
model_dir_str = str(model_dir).replace("\\", "/")
system = _SYSTEM_PROMPT.format(name=safe_name)
key_fields = self._extract_key_fields(model_dir)

# Compact script to reduce prompt bloat (keep structure, drop empty/comment lines)
compact_script = self._compact_script(original_script)
# If still very long, fall back to raw script so we don't lose critical logic
if len(compact_script) < len(original_script) * 0.3:
compact_script = original_script

truncated_error = self._truncate_error(error_msg)

return (
f"{system}\n\n"
f"---\n\n"
f"## 当前任务\n\n"
f"### 模型信息\n"
f"- model_id: `{model_id}`\n"
f"- config_dir: `{model_dir_str}`\n"
f"- 关键配置字段(优先以此为准):\n```json\n{key_fields}\n```\n\n"
f"### config.json(完整参考)\n```json\n{config_json}\n```\n\n"
f"### 失败脚本\n```python\n{original_script}\n```\n\n"
f"### 错误信息\n```\n{error_msg}\n```\n\n"
f"- 关键配置字段:\n```json\n{key_fields}\n```\n\n"
f"### 失败脚本\n```python\n{compact_script}\n```\n\n"
f"### 错误信息\n```\n{truncated_error}\n```\n\n"
f"### 输出要求\n"
f"直接输出修复后的完整脚本,用 ```python\\n...\\n``` 包裹,不附加任何说明:"
f"直接输出修复后的完整最小脚本,用 ```python\\n...\\n``` 包裹,不附加任何说明。"
f"只保留必要 import、模型/config 加载、输入 tensor 构造、extract 调用和一次 forward。"
f"禁止注释、helper、try/except、fallback、重试、文件遍历、额外校验或无关打印。"
)

def _call_ducc(self, prompt: str) -> str:
Expand Down Expand Up @@ -281,17 +316,21 @@ def _call_ducc(self, prompt: str) -> str:
timeout=self.timeout,
)
except subprocess.TimeoutExpired:
raise CodeGenError(f"ducc -p timed out after {self.timeout}s")
raise CodeGenerationError(
f"ducc -p timed out after {self.timeout}s",
error_category=GraphExtractionErrorCategory.LLM_TIMEOUT,
)

if result.returncode != 0:
raise CodeGenError(
raise CodeGenerationError(
f"ducc -p exited with code {result.returncode}.\n"
f"stderr: {result.stderr[:500]}"
f"stderr: {result.stderr[:500]}",
error_category=GraphExtractionErrorCategory.LLM_EXIT_ERROR,
)

output = result.stdout.strip()
if not output:
raise CodeGenError("ducc -p returned empty output.")
raise CodeGenerationError("ducc -p returned empty output.")

return output

Expand All @@ -312,7 +351,7 @@ def _read_config(model_dir: Path) -> str:

@staticmethod
def _extract_key_fields(model_dir: Path) -> str:
"""config.json 提取对输入构造最关键的字段,方便 LLM 直接读取。"""
"""Extract the most important input-construction fields from config.json for the LLM."""
config_path = model_dir / "config.json"
if not config_path.exists():
return "{}"
Expand Down Expand Up @@ -358,7 +397,7 @@ def _extract_key_fields(model_dir: Path) -> str:
"sample_rate",
]
result = {k: cfg[k] for k in keys if k in cfg}
# 对嵌套 config 只取关键字段
# Keep only key fields from nested configs.
for nested in ("audio_config", "vision_config", "text_config"):
if isinstance(result.get(nested), dict):
result[nested] = {
Expand Down
90 changes: 24 additions & 66 deletions graph_net/agent/code_generator/template_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from graph_net.agent.metadata_analyzer.model_metadata import ModelMetadata
from graph_net.agent.code_generator.base import BaseCodeGenerator
from graph_net.agent.utils.exceptions import CodeGenError
from graph_net.agent.utils.exceptions import CodeGenerationError

# Constants for safe vocab size calculation
DEFAULT_VOCAB_SIZE = 30522
Expand Down Expand Up @@ -57,7 +57,7 @@ def generate(

return script_path
except Exception as e:
raise CodeGenError(f"Failed to generate code: {e}") from e
raise CodeGenerationError(f"Failed to generate code: {e}") from e

@staticmethod
def _model_short_name(model_id: str) -> str:
Expand All @@ -79,34 +79,15 @@ def _generate_standard_code(
short_name = self._model_short_name(model_metadata.model_id)

code = f"""import torch
try:
from transformers import AutoModel
except ImportError:
raise ImportError("transformers is required. Install with: pip install transformers")

import graph_net

def main():
# Load model
{self._indent(load_code, 4)}

# Prepare inputs
{self._indent(input_code, 4)}

# Extract graph
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()

# Move inputs to same device as model
inputs = {{k: v.to(device) for k, v in inputs.items()}}
{load_code}

wrapped = graph_net.torch.extract(name="{short_name}", dynamic=False)(model).eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()
{input_code}

with torch.no_grad():
wrapped(**inputs)

if __name__ == "__main__":
main()
graph_net.torch.extract(name="{short_name}", dynamic=False)(model).eval()(**inputs)
"""
return code

Expand All @@ -118,37 +99,19 @@ def _generate_diffusion_code(
input_code = self._generate_input_code(model_metadata)
short_name = self._model_short_name(model_metadata.model_id)

# Diffusion model forward takes positional args, not **inputs dict
code = f"""import torch
try:
from diffusers import UNet2DConditionModel
except ImportError:
raise ImportError("diffusers is required. Install with: pip install diffusers")

import graph_net

def main():
# Load model
{self._indent(load_code, 4)}

# Prepare inputs
{self._indent(input_code, 4)}
{load_code}

# Extract graph
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()
{input_code}

sample = inputs["sample"].to(device)
timestep = inputs["timestep"].to(device)
encoder_hidden_states = inputs["encoder_hidden_states"].to(device)

wrapped = graph_net.torch.extract(name="{short_name}", dynamic=False)(model).eval()

with torch.no_grad():
wrapped(sample, timestep, encoder_hidden_states)

if __name__ == "__main__":
main()
sample = inputs["sample"]
timestep = inputs["timestep"]
encoder_hidden_states = inputs["encoder_hidden_states"]
graph_net.torch.extract(name="{short_name}", dynamic=False)(model).eval()(sample, timestep, encoder_hidden_states)
"""
return code

Expand Down Expand Up @@ -199,7 +162,7 @@ def _generate_model_loader(

def _generate_input_code(self, model_metadata: ModelMetadata) -> str:
"""Generate input tensor construction code based on model metadata"""
lines = ["inputs = {}"]
lines = ["inputs = {"]

for name, shape in model_metadata.input_shapes.items():
dtype = model_metadata.input_dtypes.get(name, "int64")
Expand All @@ -209,18 +172,18 @@ def _generate_input_code(self, model_metadata: ModelMetadata) -> str:
if dtype == "int64":
if "input_ids" in name.lower() or "decoder_input_ids" in name.lower():
safe_vocab_size = self._calculate_safe_vocab_size(model_metadata)
lines.append(
f'inputs["{name}"] = torch.randint(0, {safe_vocab_size}, {shape_tuple}, dtype={torch_dtype})'
value = (
f"torch.randint(0, {safe_vocab_size}, {shape_tuple}, "
f"dtype={torch_dtype}).to(device)"
)
else:
lines.append(
f'inputs["{name}"] = torch.ones({shape_tuple}, dtype={torch_dtype})'
)
value = f"torch.ones({shape_tuple}, dtype={torch_dtype}).to(device)"
else:
lines.append(
f'inputs["{name}"] = torch.randn({shape_tuple}, dtype={torch_dtype})'
)
value = f"torch.randn({shape_tuple}, dtype={torch_dtype}).to(device)"

lines.append(f' "{name}": {value},')

lines.append("}")
return "\n".join(lines)

def _get_torch_dtype(self, dtype: str) -> str:
Expand Down Expand Up @@ -261,8 +224,3 @@ def _is_large_vocab_model_type(self, model_type: str) -> bool:
or "xlm_roberta" in model_type
or "roberta" in model_type
)

def _indent(self, text: str, spaces: int) -> str:
"""Indent text by specified spaces"""
indent = " " * spaces
return "\n".join(indent + line for line in text.split("\n"))
2 changes: 1 addition & 1 deletion graph_net/agent/graph_extractor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ def extract(self, code_path: Path, model_id: str) -> Path:
Path to extracted sample directory

Raises:
ExtractionError: If extraction fails
GraphExtractionError: If extraction fails
"""
pass
Loading
Loading