diff --git a/graph_net/agent/code_generator/base.py b/graph_net/agent/code_generator/base.py index d574a5170c..2771c95e4a 100644 --- a/graph_net/agent/code_generator/base.py +++ b/graph_net/agent/code_generator/base.py @@ -28,6 +28,6 @@ def generate( Path to generated script file Raises: - CodeGenError: If code generation fails + CodeGenerationError: If code generation fails """ pass diff --git a/graph_net/agent/code_generator/llm_code_fixer.py b/graph_net/agent/code_generator/llm_code_fixer.py index 5c6242f25e..2e0fcef2c9 100644 --- a/graph_net/agent/code_generator/llm_code_fixer.py +++ b/graph_net/agent/code_generator/llm_code_fixer.py @@ -9,7 +9,10 @@ from pathlib import Path from typing import Optional -from graph_net.agent.utils.exceptions import CodeGenError +from graph_net.agent.utils.exceptions import ( + CodeGenerationError, + GraphExtractionErrorCategory, +) # Candidate binary names / paths to search for ducc CLI _DUCC_CANDIDATES = [ @@ -21,7 +24,7 @@ _SYSTEM_PROMPT = """\ 你是 PyTorch / HuggingFace 模型计算图抽取专家。 -任务:修复一段失败的图抽取脚本,输出完整、可直接运行的 Python 脚本。 +任务:修复一段失败的图抽取脚本,输出完整、可直接运行但最小化的 Python 脚本。 ## 【硬性约束 - 违反即输出无效】 1. 抽取调用格式固定为: @@ -33,6 +36,8 @@ 3. 设备选择固定写法:device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4. 只允许使用 torch、transformers、graph_net 及 Python 标准库(os/pathlib/json 等) 5. 只输出代码块,格式:```python\\n...代码...\\n```,禁止输出任何说明文字 +6. 必须输出完整但最小化的脚本,只保留:必要 import、模型/config 加载、输入 tensor 构造、graph_net.torch.extract(...)、一次 forward 调用 +7. 禁止添加注释、helper 函数、错误处理、try/except、fallback 逻辑、重试逻辑、文件系统遍历、额外校验或无关打印。只修复导致报错的输入构造或调用方式,保持行数尽可能少 ## 【输入构造规范 - 按 model_type 选择对应方案】 @@ -141,11 +146,11 @@ def __init__( ): """ Args: - timeout: Max seconds to wait for ducc response. + timeout: Max seconds to wait for ducc response (default 360s). model: Override the LLM model (e.g. 'sonnet', 'haiku'). If None, uses whatever ducc default is configured. """ - self.timeout = timeout + self.timeout = timeout if timeout is not None else 360 self.model = model self.logger = logging.getLogger(self.__class__.__name__) self._ducc_bin = _find_ducc() @@ -176,7 +181,7 @@ def fix( Args: script_path: Path to the (failed) script to fix - error_msg: Captured stderr / ExtractionError message + error_msg: Captured stderr / GraphExtractionError message model_dir: Local model directory (contains config.json) model_id: HuggingFace model ID (e.g. 'prajjwal1/bert-tiny') output_dir: Directory where the fixed script should be written @@ -186,10 +191,10 @@ def fix( Path to the fixed script (run_model_llm_1.py / run_model_llm_2.py) Raises: - CodeGenError: If LLM call fails or returns no valid code + CodeGenerationError: If LLM call fails or returns no valid code """ if not self.available: - raise CodeGenError( + raise CodeGenerationError( "ducc/claude binary not available; cannot perform LLM fix." ) @@ -213,7 +218,7 @@ def fix( code = _extract_code_block(llm_output) if not code: - raise CodeGenError( + raise CodeGenerationError( f"LLM response contained no Python code block.\n" f"Response (first 500 chars):\n{llm_output[:500]}" ) @@ -228,6 +233,26 @@ def fix( # Internal helpers # ------------------------------------------------------------------ + @staticmethod + def _compact_script(script: str) -> str: + """Remove blank lines and pure comment lines to shrink prompt size.""" + lines = script.splitlines() + compacted = [] + for line in lines: + stripped = line.strip() + if stripped == "" or stripped.startswith("#"): + continue + compacted.append(line.rstrip()) + return "\n".join(compacted) + + @staticmethod + def _truncate_error(error_msg: str, max_chars: int = 1200) -> str: + if len(error_msg) <= max_chars: + return error_msg + # Keep tail (usually contains the actual error) + head for context + half = max_chars // 2 + return error_msg[:half] + "\n... (truncated) ...\n" + error_msg[-half:] + def _build_prompt( self, original_script: str, @@ -240,6 +265,15 @@ def _build_prompt( model_dir_str = str(model_dir).replace("\\", "/") system = _SYSTEM_PROMPT.format(name=safe_name) key_fields = self._extract_key_fields(model_dir) + + # Compact script to reduce prompt bloat (keep structure, drop empty/comment lines) + compact_script = self._compact_script(original_script) + # If still very long, fall back to raw script so we don't lose critical logic + if len(compact_script) < len(original_script) * 0.3: + compact_script = original_script + + truncated_error = self._truncate_error(error_msg) + return ( f"{system}\n\n" f"---\n\n" @@ -247,12 +281,13 @@ def _build_prompt( f"### 模型信息\n" f"- model_id: `{model_id}`\n" f"- config_dir: `{model_dir_str}`\n" - f"- 关键配置字段(优先以此为准):\n```json\n{key_fields}\n```\n\n" - f"### config.json(完整参考)\n```json\n{config_json}\n```\n\n" - f"### 失败脚本\n```python\n{original_script}\n```\n\n" - f"### 错误信息\n```\n{error_msg}\n```\n\n" + f"- 关键配置字段:\n```json\n{key_fields}\n```\n\n" + f"### 失败脚本\n```python\n{compact_script}\n```\n\n" + f"### 错误信息\n```\n{truncated_error}\n```\n\n" f"### 输出要求\n" - f"直接输出修复后的完整脚本,用 ```python\\n...\\n``` 包裹,不附加任何说明:" + f"直接输出修复后的完整最小脚本,用 ```python\\n...\\n``` 包裹,不附加任何说明。" + f"只保留必要 import、模型/config 加载、输入 tensor 构造、extract 调用和一次 forward。" + f"禁止注释、helper、try/except、fallback、重试、文件遍历、额外校验或无关打印。" ) def _call_ducc(self, prompt: str) -> str: @@ -281,17 +316,21 @@ def _call_ducc(self, prompt: str) -> str: timeout=self.timeout, ) except subprocess.TimeoutExpired: - raise CodeGenError(f"ducc -p timed out after {self.timeout}s") + raise CodeGenerationError( + f"ducc -p timed out after {self.timeout}s", + error_category=GraphExtractionErrorCategory.LLM_TIMEOUT, + ) if result.returncode != 0: - raise CodeGenError( + raise CodeGenerationError( f"ducc -p exited with code {result.returncode}.\n" - f"stderr: {result.stderr[:500]}" + f"stderr: {result.stderr[:500]}", + error_category=GraphExtractionErrorCategory.LLM_EXIT_ERROR, ) output = result.stdout.strip() if not output: - raise CodeGenError("ducc -p returned empty output.") + raise CodeGenerationError("ducc -p returned empty output.") return output @@ -312,7 +351,7 @@ def _read_config(model_dir: Path) -> str: @staticmethod def _extract_key_fields(model_dir: Path) -> str: - """从 config.json 提取对输入构造最关键的字段,方便 LLM 直接读取。""" + """Extract the most important input-construction fields from config.json for the LLM.""" config_path = model_dir / "config.json" if not config_path.exists(): return "{}" @@ -358,7 +397,7 @@ def _extract_key_fields(model_dir: Path) -> str: "sample_rate", ] result = {k: cfg[k] for k in keys if k in cfg} - # 对嵌套 config 只取关键字段 + # Keep only key fields from nested configs. for nested in ("audio_config", "vision_config", "text_config"): if isinstance(result.get(nested), dict): result[nested] = { diff --git a/graph_net/agent/code_generator/template_generator.py b/graph_net/agent/code_generator/template_generator.py index a3332f695b..e2e051b318 100644 --- a/graph_net/agent/code_generator/template_generator.py +++ b/graph_net/agent/code_generator/template_generator.py @@ -5,7 +5,7 @@ from graph_net.agent.metadata_analyzer.model_metadata import ModelMetadata from graph_net.agent.code_generator.base import BaseCodeGenerator -from graph_net.agent.utils.exceptions import CodeGenError +from graph_net.agent.utils.exceptions import CodeGenerationError # Constants for safe vocab size calculation DEFAULT_VOCAB_SIZE = 30522 @@ -57,7 +57,7 @@ def generate( return script_path except Exception as e: - raise CodeGenError(f"Failed to generate code: {e}") from e + raise CodeGenerationError(f"Failed to generate code: {e}") from e @staticmethod def _model_short_name(model_id: str) -> str: @@ -79,34 +79,15 @@ def _generate_standard_code( short_name = self._model_short_name(model_metadata.model_id) code = f"""import torch -try: - from transformers import AutoModel -except ImportError: - raise ImportError("transformers is required. Install with: pip install transformers") - import graph_net -def main(): - # Load model -{self._indent(load_code, 4)} - - # Prepare inputs -{self._indent(input_code, 4)} - - # Extract graph - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = model.to(device).eval() - - # Move inputs to same device as model - inputs = {{k: v.to(device) for k, v in inputs.items()}} +{load_code} - wrapped = graph_net.torch.extract(name="{short_name}", dynamic=False)(model).eval() +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = model.to(device).eval() +{input_code} - with torch.no_grad(): - wrapped(**inputs) - -if __name__ == "__main__": - main() +graph_net.torch.extract(name="{short_name}", dynamic=False)(model).eval()(**inputs) """ return code @@ -118,37 +99,19 @@ def _generate_diffusion_code( input_code = self._generate_input_code(model_metadata) short_name = self._model_short_name(model_metadata.model_id) - # Diffusion model forward takes positional args, not **inputs dict code = f"""import torch -try: - from diffusers import UNet2DConditionModel -except ImportError: - raise ImportError("diffusers is required. Install with: pip install diffusers") - import graph_net -def main(): - # Load model -{self._indent(load_code, 4)} - - # Prepare inputs -{self._indent(input_code, 4)} +{load_code} - # Extract graph - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = model.to(device).eval() +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = model.to(device).eval() +{input_code} - sample = inputs["sample"].to(device) - timestep = inputs["timestep"].to(device) - encoder_hidden_states = inputs["encoder_hidden_states"].to(device) - - wrapped = graph_net.torch.extract(name="{short_name}", dynamic=False)(model).eval() - - with torch.no_grad(): - wrapped(sample, timestep, encoder_hidden_states) - -if __name__ == "__main__": - main() +sample = inputs["sample"] +timestep = inputs["timestep"] +encoder_hidden_states = inputs["encoder_hidden_states"] +graph_net.torch.extract(name="{short_name}", dynamic=False)(model).eval()(sample, timestep, encoder_hidden_states) """ return code @@ -199,7 +162,7 @@ def _generate_model_loader( def _generate_input_code(self, model_metadata: ModelMetadata) -> str: """Generate input tensor construction code based on model metadata""" - lines = ["inputs = {}"] + lines = ["inputs = {"] for name, shape in model_metadata.input_shapes.items(): dtype = model_metadata.input_dtypes.get(name, "int64") @@ -209,18 +172,18 @@ def _generate_input_code(self, model_metadata: ModelMetadata) -> str: if dtype == "int64": if "input_ids" in name.lower() or "decoder_input_ids" in name.lower(): safe_vocab_size = self._calculate_safe_vocab_size(model_metadata) - lines.append( - f'inputs["{name}"] = torch.randint(0, {safe_vocab_size}, {shape_tuple}, dtype={torch_dtype})' + value = ( + f"torch.randint(0, {safe_vocab_size}, {shape_tuple}, " + f"dtype={torch_dtype}).to(device)" ) else: - lines.append( - f'inputs["{name}"] = torch.ones({shape_tuple}, dtype={torch_dtype})' - ) + value = f"torch.ones({shape_tuple}, dtype={torch_dtype}).to(device)" else: - lines.append( - f'inputs["{name}"] = torch.randn({shape_tuple}, dtype={torch_dtype})' - ) + value = f"torch.randn({shape_tuple}, dtype={torch_dtype}).to(device)" + + lines.append(f' "{name}": {value},') + lines.append("}") return "\n".join(lines) def _get_torch_dtype(self, dtype: str) -> str: @@ -261,8 +224,3 @@ def _is_large_vocab_model_type(self, model_type: str) -> bool: or "xlm_roberta" in model_type or "roberta" in model_type ) - - def _indent(self, text: str, spaces: int) -> str: - """Indent text by specified spaces""" - indent = " " * spaces - return "\n".join(indent + line for line in text.split("\n")) diff --git a/graph_net/agent/graph_extractor/base.py b/graph_net/agent/graph_extractor/base.py index 362451cc7d..798112bcbf 100644 --- a/graph_net/agent/graph_extractor/base.py +++ b/graph_net/agent/graph_extractor/base.py @@ -20,6 +20,6 @@ def extract(self, code_path: Path, model_id: str) -> Path: Path to extracted sample directory Raises: - ExtractionError: If extraction fails + GraphExtractionError: If extraction fails """ pass diff --git a/graph_net/agent/graph_extractor/subprocess_graph_extractor.py b/graph_net/agent/graph_extractor/subprocess_graph_extractor.py index 3b001693bb..b579d7140b 100644 --- a/graph_net/agent/graph_extractor/subprocess_graph_extractor.py +++ b/graph_net/agent/graph_extractor/subprocess_graph_extractor.py @@ -5,12 +5,16 @@ import signal import subprocess import sys +import threading import time from pathlib import Path from typing import Optional from graph_net.agent.graph_extractor.base import BaseGraphExtractor -from graph_net.agent.utils.exceptions import ExtractionError +from graph_net.agent.utils.exceptions import ( + GraphExtractionError, + GraphExtractionErrorCategory, +) # Constants DEFAULT_TIMEOUT = 1000 # ~17 minutes for large models @@ -18,6 +22,54 @@ HASH_DIR_LENGTH = 40 # SHA1 hash length ERROR_MSG_MAX_LINES = 20 # Keep first and last N lines of error messages +# --------------------------------------------------------------------------- +# Active child process group tracking (for orphan worker cleanup) +# --------------------------------------------------------------------------- + + +class ProcessGroupTracker: + """Track and manage active child process groups for clean orphan worker teardown. + + Uses class-level storage so any code path (extractor, orphan watcher, etc.) + can register/unregister/kill without passing instances around. + """ + + _pgids: set[int] = set() + _lock = threading.Lock() + + @classmethod + def register(cls, pgid: int) -> None: + with cls._lock: + cls._pgids.add(pgid) + + @classmethod + def unregister(cls, pgid: int) -> None: + with cls._lock: + cls._pgids.discard(pgid) + + @classmethod + def kill_all(cls, sig: int = signal.SIGKILL) -> None: + """Kill all tracked process groups and clear the registry.""" + with cls._lock: + pgids = list(cls._pgids) + for pgid in pgids: + try: + os.killpg(pgid, sig) + except (ProcessLookupError, PermissionError, OSError): + pass + with cls._lock: + cls._pgids.clear() + + @classmethod + def is_empty(cls) -> bool: + with cls._lock: + return len(cls._pgids) == 0 + + +def kill_all_active_children() -> None: + """Convenience alias for backward compatibility.""" + ProcessGroupTracker.kill_all() + class SubprocessGraphExtractor(BaseGraphExtractor): """Extractor that runs script in subprocess""" @@ -44,7 +96,7 @@ def extract(self, code_path: Path, model_id: str) -> Path: Path to extracted sample directory Raises: - ExtractionError: If extraction fails + GraphExtractionError: If extraction fails """ try: # Get GraphNet root directory for PYTHONPATH @@ -72,45 +124,52 @@ def extract(self, code_path: Path, model_id: str) -> Path: stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - # 用新进程组,方便整组 kill(避免遗留孙进程占显存) + # Start a new process group so the whole group can be killed, avoiding orphaned child processes holding GPU memory. start_new_session=True, ) + pgid = os.getpgid(proc.pid) + ProcessGroupTracker.register(pgid) try: stdout, stderr = proc.communicate(timeout=self.timeout) except subprocess.TimeoutExpired: - # 先 kill 整个进程组,确保 GPU 显存释放 + # Kill the entire process group first to ensure GPU memory is released. try: - os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + os.killpg(pgid, signal.SIGKILL) except ProcessLookupError: proc.kill() - proc.communicate() # 回收僵尸进程 - raise ExtractionError( - f"Script execution timed out after {self.timeout} seconds" + proc.communicate() # Reap the zombie process + raise GraphExtractionError( + f"Script execution timed out after {self.timeout} seconds", + error_category=GraphExtractionErrorCategory.SCRIPT_TIMEOUT, ) + finally: + ProcessGroupTracker.unregister(pgid) if proc.returncode != 0: error_msg = self._format_error_message(stderr or stdout) - raise ExtractionError( + raise GraphExtractionError( f"Script execution failed with return code {proc.returncode}.\n" f"Command: {sys.executable} {code_path}\n" - f"Error output:\n{error_msg}" + f"Error output:\n{error_msg}", + error_category=GraphExtractionErrorCategory.SCRIPT_EXECUTION_FAILED, ) # Find output directory using multiple strategies output_dir = self._find_output_dir_robust(model_id) if not output_dir or not output_dir.exists(): - raise ExtractionError( + raise GraphExtractionError( f"Output directory not found for model: {model_id}.\n" f"Searched in workspace: {self.workspace}\n" - f"Please check if the extraction script executed successfully." + f"Please check if the extraction script executed successfully.", + error_category=GraphExtractionErrorCategory.OUTPUT_DIR_NOT_FOUND, ) return output_dir - except ExtractionError: + except GraphExtractionError: raise except Exception as e: - raise ExtractionError(f"Failed to extract graph: {e}") from e + raise GraphExtractionError(f"Failed to extract graph: {e}") from e def _format_error_message(self, error_msg: str) -> str: """Format error message, truncating if too long""" @@ -208,8 +267,8 @@ def _find_hash_named_dir(self, workspace_path: Path) -> Optional[Path]: def _is_valid_sample_dir(self, dir_path: Path) -> bool: """Check if a directory is a valid sample directory""" required_files = ["model.py", "graph_net.json"] - # 单图:根目录下有文件 + # Single graph: files exist in the root directory. if all((dir_path / f).exists() for f in required_files): return True - # 多子图:subgraph_* 子目录下有文件 + # Multiple subgraphs: files exist under subgraph_* directories. return any(dir_path.glob("subgraph_*/model.py")) diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index 2a2b234eb8..0cf9718079 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -13,11 +13,14 @@ from graph_net.agent.code_generator.llm_code_fixer import LLMCodeFixer from graph_net.agent.graph_extractor import SubprocessGraphExtractor from graph_net.agent.model_fetcher import HFFetcher +from graph_net.agent.utils.error_classifier import GraphExtractionErrorClassifier from graph_net.agent.utils.exceptions import ( - AnalysisError, - CodeGenError, - ExtractionError, - VerificationError, + GraphExtractionErrorCategory, + MetadataAnalysisError, + CodeGenerationError, + GraphExtractionError, + ModelFetchError, + SampleVerificationError, ) from graph_net.agent.utils.logger import setup_logger from graph_net.agent.utils.workspace_manager import WorkspaceManager @@ -43,20 +46,22 @@ def __init__( llm_retry: bool = True, extract_timeout: Optional[int] = None, verify_timeout: Optional[int] = None, + llm_timeout: int = 360, ): """ Initialize GraphNet Agent Args: - workspace: Workspace root directory. Defaults to - $GRAPH_NET_EXTRACT_WORKSPACE or ~/graphnet_workspace. - hf_token: HuggingFace API token (optional) - llm_retry: If True and ducc/claude CLI is available, retry failed - extractions up to 2 times with LLM-fixed scripts. - extract_timeout: Timeout in seconds for graph extraction subprocess - (default None -> 1000s). - verify_timeout: Timeout in seconds for forward verification subprocess - (default None -> 300s). + workspace: Workspace root directory. Defaults to + $GRAPH_NET_EXTRACT_WORKSPACE or ~/graphnet_workspace. + hf_token: HuggingFace API token (optional) + llm_retry: If True and ducc/claude CLI is available, retry failed + extractions up to 2 times with LLM-fixed scripts. + extract_timeout: Timeout in seconds for graph extraction subprocess + (default None -> 1000s). + verify_timeout: Timeout in seconds for forward verification subprocess + (default None -> 300s). + llm_timeout: Timeout in seconds for LLM script fix (default: 360). """ if workspace is None: workspace = os.environ.get( @@ -85,7 +90,15 @@ def __init__( self.sample_verifier = ForwardVerifier(timeout=verify_timeout) # LLM fixer — only created when llm_retry is requested - self.llm_fixer: Optional[LLMCodeFixer] = LLMCodeFixer() if llm_retry else None + self.llm_fixer: Optional[LLMCodeFixer] = ( + LLMCodeFixer(timeout=llm_timeout) if llm_retry else None + ) + + # Track whether the last verify succeeded only because of timeout skip + self.last_timeout_success = False + + # Error classifier for post-run reporting + self.error_classifier = GraphExtractionErrorClassifier() def extract_sample(self, model_id: str) -> ExtractionStatus: """ @@ -104,6 +117,7 @@ def extract_sample(self, model_id: str) -> ExtractionStatus: ExtractionStatus.EXTRACT_FAILED – extraction (or pre-extraction) failed ExtractionStatus.ERROR – unexpected error """ + self.last_timeout_success = False try: self.logger.info(f"Starting extraction for model: {model_id}") @@ -115,7 +129,12 @@ def extract_sample(self, model_id: str) -> ExtractionStatus: # ── First attempt (template script) ────────────────────────── try: sample_dir = self._extract_graph(script_path, model_id) - except ExtractionError as first_err: + except GraphExtractionError as first_err: + if not self._is_llm_fixable_error(first_err): + self.logger.warning( + f"Extraction error is not fixable by LLM, skipping retry: {first_err}" + ) + raise first_err sample_dir = self._llm_retry( first_err, script_path, model_dir, model_id ) @@ -129,24 +148,53 @@ def extract_sample(self, model_id: str) -> ExtractionStatus: if not self.sample_verifier.verify(sample_dir): self.logger.error("Sample verification failed") + self.error_classifier.classify_and_record( + model_id, + Exception("Sample verification failed"), + ) return ExtractionStatus.VERIFY_FAILED + if getattr(self.sample_verifier, "last_timeout_success", False): + self.last_timeout_success = True + self.logger.info( + f"Sample verification for {model_id} passed via timeout skip" + ) + self.logger.info(f"Successfully extracted sample for {model_id}") return ExtractionStatus.OK - except VerificationError as e: + except SampleVerificationError as e: self.logger.error(f"Extraction failed for {model_id}: {e}") + self.error_classifier.classify_and_record(model_id, e) return ExtractionStatus.VERIFY_FAILED - except (AnalysisError, CodeGenError, ExtractionError) as e: + except ( + ModelFetchError, + MetadataAnalysisError, + CodeGenerationError, + GraphExtractionError, + ) as e: self.logger.error(f"Extraction failed for {model_id}: {e}") + self.error_classifier.classify_and_record(model_id, e) return ExtractionStatus.EXTRACT_FAILED except Exception as e: self.logger.error(f"Unexpected error for {model_id}: {e}", exc_info=True) + self.error_classifier.classify_and_record(model_id, e) return ExtractionStatus.ERROR + @staticmethod + def _is_llm_fixable_error(err: GraphExtractionError) -> bool: + """Decide whether an extraction error is worth retrying with LLM. + + Only allow LLM retry for script logic errors (non-zero return code). + All other categories (timeout, infrastructure, missing model, etc.) + are not fixable by rewriting the script. + """ + category = GraphExtractionErrorClassifier.classify_from_exception(err) + return category == GraphExtractionErrorCategory.SCRIPT_EXECUTION_FAILED + def _llm_retry( self, - first_err: ExtractionError, + first_err: GraphExtractionError, script_path: Path, model_dir: Path, model_id: str, @@ -158,7 +206,7 @@ def _llm_retry( Returns: (sample_dir, successful_script_path) - Raises ExtractionError if LLM fix is unavailable or both attempts fail. + Raises GraphExtractionError if LLM fix is unavailable or both attempts fail. """ if self.llm_fixer is None or not self.llm_fixer.available: self.logger.warning( @@ -187,14 +235,33 @@ def _llm_retry( try: sample_dir = self._extract_graph(fixed_path, model_id) return sample_dir - except ExtractionError as retry_err: + except GraphExtractionError as retry_err: + if not self._is_llm_fixable_error(retry_err): + self.logger.warning( + "LLM-fixed script failed with non-fixable error, " + f"skipping remaining retries: {retry_err}" + ) + raise err = retry_err - current_script = fixed_path # 第二次把上一次修复的脚本+新报错再喂给 LLM + current_script = fixed_path # On the second attempt, feed the previous fixed script and new error back to the LLM raise err def _fetch_model(self, model_id: str) -> Path: """Download model from HuggingFace Hub""" + self.logger.info(f"Checking model repo accessibility: {model_id}") + try: + self.model_fetcher.check_accessible(model_id) + except ModelFetchError as e: + if e.error_category in ( + GraphExtractionErrorCategory.MODEL_NOT_FOUND, + GraphExtractionErrorCategory.MODEL_FORBIDDEN, + ): + raise + self.logger.warning( + f"Model repo precheck failed for {model_id}, continuing to download: {e}" + ) + self.logger.info(f"Fetching model: {model_id}") model_dir = self.model_fetcher.download(model_id) self.logger.info(f"Model downloaded to: {model_dir}") @@ -255,7 +322,7 @@ def _extract_graph(self, script_path: Path, model_id: str) -> Path: return sample_dir def _fix_model_name(self, sample_dir: Path, model_id: str) -> None: - """将 graph_net.json 中的 model_name 修正为原始 HuggingFace model_id(org/model)""" + """Update model_name in graph_net.json to the original HuggingFace model_id (org/model).""" for json_path in [ sample_dir / "graph_net.json", *sample_dir.glob("subgraph_*/graph_net.json"), diff --git a/graph_net/agent/metadata_analyzer/base.py b/graph_net/agent/metadata_analyzer/base.py index 8e0a9955dd..dd39ec3042 100644 --- a/graph_net/agent/metadata_analyzer/base.py +++ b/graph_net/agent/metadata_analyzer/base.py @@ -21,6 +21,6 @@ def analyze(self, model_dir: Path) -> ModelMetadata: ModelMetadata object containing model information Raises: - AnalysisError: If analysis fails + MetadataAnalysisError: If analysis fails """ pass diff --git a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py index b15df28b0a..3e62133063 100644 --- a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py +++ b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py @@ -6,7 +6,10 @@ from graph_net.agent.metadata_analyzer.base import BaseMetadataAnalyzer from graph_net.agent.metadata_analyzer.model_metadata import ModelMetadata -from graph_net.agent.utils.exceptions import AnalysisError +from graph_net.agent.utils.exceptions import ( + GraphExtractionErrorCategory, + MetadataAnalysisError, +) # Cap sequence length to avoid OOM: attention is O(n²), graph extraction @@ -47,11 +50,14 @@ def analyze(self, model_dir: Path) -> ModelMetadata: ModelMetadata object Raises: - AnalysisError: If analysis fails + MetadataAnalysisError: If analysis fails """ config_path = model_dir / "config.json" if not config_path.exists(): - raise AnalysisError(f"config.json not found in {model_dir}") + raise MetadataAnalysisError( + f"config.json not found in {model_dir}", + error_category=GraphExtractionErrorCategory.CONFIG_NOT_FOUND, + ) try: # Primary path: load via AutoConfig to get a rich PretrainedConfig object @@ -101,11 +107,17 @@ def analyze(self, model_dir: Path) -> ModelMetadata: architecture_type=arch_type, ) except json.JSONDecodeError as e: - raise AnalysisError(f"Failed to parse config.json: {e}") from e - except AnalysisError: + raise MetadataAnalysisError( + f"Failed to parse config.json: {e}", + error_category=GraphExtractionErrorCategory.CONFIG_PARSE_ERROR, + ) from e + except MetadataAnalysisError: raise except Exception as e: - raise AnalysisError(f"Failed to analyze model: {e}") from e + raise MetadataAnalysisError( + f"Failed to analyze model: {e}", + error_category=GraphExtractionErrorCategory.METADATA_ANALYSIS_FAILED, + ) from e # ------------------------------------------------------------------ # Architecture classification diff --git a/graph_net/agent/model_fetcher/huggingface_fetcher.py b/graph_net/agent/model_fetcher/huggingface_fetcher.py index 903e4984b6..b53b484525 100644 --- a/graph_net/agent/model_fetcher/huggingface_fetcher.py +++ b/graph_net/agent/model_fetcher/huggingface_fetcher.py @@ -6,12 +6,16 @@ from typing import Optional try: - from huggingface_hub import snapshot_download + from huggingface_hub import HfApi, snapshot_download except ImportError: + HfApi = None snapshot_download = None from graph_net.agent.model_fetcher.base import BaseModelFetcher -from graph_net.agent.utils.exceptions import ModelFetchError +from graph_net.agent.utils.exceptions import ( + GraphExtractionErrorCategory, + ModelFetchError, +) # Network-related exceptions that are worth retrying _RETRYABLE_ERRORS = ( @@ -33,7 +37,18 @@ _RETRYABLE_ERRORS = _RETRYABLE_ERRORS + (LocalEntryNotFoundError,) except ImportError: - pass + LocalEntryNotFoundError = None + +try: + from huggingface_hub.errors import ( + GatedRepoError, + HfHubHTTPError, + RepositoryNotFoundError, + ) +except ImportError: + GatedRepoError = None + HfHubHTTPError = None + RepositoryNotFoundError = None class HFFetcher(BaseModelFetcher): @@ -67,6 +82,53 @@ def __init__( # Resolve endpoint: explicit param > env var self.endpoint = endpoint or os.environ.get("HF_ENDPOINT") + def check_accessible(self, model_id: str) -> None: + """Check whether a HuggingFace model repo is reachable without downloading files.""" + if HfApi is None: + raise ModelFetchError( + "huggingface_hub is not installed. " + "Please install it with: pip install huggingface_hub" + ) + + try: + if self.endpoint: + os.environ["HF_ENDPOINT"] = self.endpoint + api = HfApi(endpoint=self.endpoint) + api.model_info( + repo_id=model_id, + token=self.token, + files_metadata=False, + ) + except Exception as e: + error_category = self._classify_hf_error(e) + raise ModelFetchError( + f"Model repo is not accessible for {model_id}: {e}", + error_category=error_category, + ) from e + + @staticmethod + def _classify_hf_error(error: Exception) -> GraphExtractionErrorCategory: + """Classify HuggingFace API/download errors into extraction categories.""" + if RepositoryNotFoundError is not None and isinstance( + error, RepositoryNotFoundError + ): + return GraphExtractionErrorCategory.MODEL_NOT_FOUND + if GatedRepoError is not None and isinstance(error, GatedRepoError): + return GraphExtractionErrorCategory.MODEL_FORBIDDEN + if HfHubHTTPError is not None and isinstance(error, HfHubHTTPError): + status_code = getattr(getattr(error, "response", None), "status_code", None) + if status_code == 404: + return GraphExtractionErrorCategory.MODEL_NOT_FOUND + if status_code in (401, 403): + return GraphExtractionErrorCategory.MODEL_FORBIDDEN + + err_text = str(error) + if "404 Client Error" in err_text: + return GraphExtractionErrorCategory.MODEL_NOT_FOUND + if "401 Client Error" in err_text or "403 Client Error" in err_text: + return GraphExtractionErrorCategory.MODEL_FORBIDDEN + return GraphExtractionErrorCategory.MODEL_DOWNLOAD_ERROR + def download(self, model_id: str) -> Path: """ Download model from HuggingFace Hub with retry on network errors. @@ -143,7 +205,8 @@ def download(self, model_id: str) -> Path: ) from e except Exception as e: raise ModelFetchError( - f"Failed to download model {model_id}: {e}" + f"Failed to download model {model_id}: {e}", + error_category=self._classify_hf_error(e), ) from e # Should not reach here, but just in case diff --git a/graph_net/agent/parallel_extract.py b/graph_net/agent/parallel_extract.py index 834e68cde1..0f93162386 100644 --- a/graph_net/agent/parallel_extract.py +++ b/graph_net/agent/parallel_extract.py @@ -171,6 +171,31 @@ def worker_fn( flush=True, ) + # Orphan watcher: if main process is killed with SIGKILL, worker becomes + # orphaned (ppid == 1). Detect this and kill all child process groups to + # prevent GPU memory leaks from run_model.py subprocesses. + import threading + + def _orphan_watcher(): + while True: + time.sleep(5) + if os.getppid() == 1: + print( + f"{prefix} Parent died (orphaned), cleaning up child processes...", + flush=True, + ) + # Multiple rounds to catch any late-starting children + for _ in range(5): + from graph_net.agent.graph_extractor.subprocess_graph_extractor import ( + kill_all_active_children, + ) + + kill_all_active_children() + time.sleep(1) + os._exit(1) + + threading.Thread(target=_orphan_watcher, daemon=True).start() + try: agent = GraphNetAgent( workspace=workspace, @@ -215,16 +240,29 @@ def worker_fn( status = agent.extract_sample(model_id) elapsed = time.time() - t0 ok = status == ExtractionStatus.OK + timeout_success = getattr(agent, "last_timeout_success", False) label = "OK" if ok else status.name.replace("_", " ") + if ok and timeout_success: + label = "OK(timeout)" print(f"{prefix} {label} {model_id} ({elapsed:.1f}s)", flush=True) result_dict["success"] = ok result_dict["status"] = status.value + result_dict["timeout_success"] = timeout_success + # Expose error category so the main process can decide policy + rec = agent.error_classifier.get_record(model_id) + if rec is not None: + result_dict["error_category"] = rec.category.value + result_dict["error_message"] = rec.message except Exception as e: elapsed = time.time() - t0 print(f"{prefix} ERROR {model_id}: {e} ({elapsed:.1f}s)", flush=True) result_dict["success"] = False result_dict["status"] = ExtractionStatus.ERROR.value result_dict["error"] = str(e) + result_dict["timeout_success"] = False + raw_cat = getattr(e, "error_category", None) + if raw_cat is not None: + result_dict["error_category"] = str(raw_cat) result_dict["elapsed"] = round(elapsed, 2) result_dict["timestamp"] = datetime.now().isoformat() @@ -249,6 +287,7 @@ def _print_summary(results: Dict) -> None: details = results.get("details", []) total = len(details) success = sum(1 for d in details if d.get("success")) + timeout_success = sum(1 for d in details if d.get("timeout_success")) extract_success = sum( 1 for d in details @@ -257,25 +296,29 @@ def _print_summary(results: Dict) -> None: ) failed = total - success rate = (success / total * 100) if total else 0.0 + timeout_rate = (timeout_success / total * 100) if total else 0.0 extract_rate = (extract_success / total * 100) if total else 0.0 print("\n" + "=" * 60) print("[SUMMARY] Parallel Extraction Summary") print("=" * 60) print(f" Total : {total}") print(f" Success : {success} (verify ok)") + print(f" Timeout : {timeout_success} (verify skipped by timeout)") print(f" Extract : {extract_success} (graph extracted)") print(f" Failed : {failed}") - print(f" Rate : {rate:.2f}% (overall)") + print(f" Rate : {rate:.2f}% (overall, timeout_success={timeout_rate:.2f}%)") print(f" Extract : {extract_rate:.2f}% (extraction only)") # Per-GPU breakdown gpu_stats: Dict[int, Dict] = {} for d in details: g = d.get("gpu", -1) if g not in gpu_stats: - gpu_stats[g] = {"total": 0, "success": 0, "extract": 0} + gpu_stats[g] = {"total": 0, "success": 0, "extract": 0, "timeout": 0} gpu_stats[g]["total"] += 1 if d.get("success"): gpu_stats[g]["success"] += 1 + if d.get("timeout_success"): + gpu_stats[g]["timeout"] += 1 if d.get("status") in ( ExtractionStatus.OK.value, ExtractionStatus.VERIFY_FAILED.value, @@ -288,9 +331,11 @@ def _print_summary(results: Dict) -> None: gs = gpu_stats[g] gr = (gs["success"] / gs["total"] * 100) if gs["total"] else 0.0 er = (gs["extract"] / gs["total"] * 100) if gs["total"] else 0.0 + tr = (gs["timeout"] / gs["total"] * 100) if gs["total"] else 0.0 print( f" {label} {g}: success={gs['success']}/{gs['total']} ({gr:.1f}%), " - f"extract={gs['extract']}/{gs['total']} ({er:.1f}%)" + f"extract={gs['extract']}/{gs['total']} ({er:.1f}%), " + f"timeout={gs['timeout']}/{gs['total']} ({tr:.1f}%)" ) print("=" * 60) @@ -470,6 +515,7 @@ def main() -> int: details.append(entry) done = len(details) ok_so_far = sum(1 for d in details if d.get("success")) + timeout_so_far = sum(1 for d in details if d.get("timeout_success")) extract_ok_so_far = sum( 1 for d in details @@ -478,7 +524,7 @@ def main() -> int: ) print( f"[PROGRESS] {done}/{len(model_ids)} done, " - f"success={ok_so_far/done*100:.1f}%, " + f"success={ok_so_far/done*100:.1f}%(timeout_success={timeout_so_far/done*100:.1f}%), " f"extract={extract_ok_so_far/done*100:.1f}%", flush=True, ) @@ -494,6 +540,7 @@ def main() -> int: end_time = datetime.now() success_count = sum(1 for d in details if d.get("success")) + timeout_success_count = sum(1 for d in details if d.get("timeout_success")) extract_success_count = sum( 1 for d in details @@ -508,14 +555,19 @@ def main() -> int: "workspace": workspace, "total": len(details), "success": success_count, + "timeout_success": timeout_success_count, "extract_success": extract_success_count, "failed": len(details) - success_count, "success_rate": 0.0, + "timeout_success_rate": 0.0, "extract_success_rate": 0.0, "details": details, } if results["total"] > 0: results["success_rate"] = round(results["success"] / results["total"] * 100, 2) + results["timeout_success_rate"] = round( + results["timeout_success"] / results["total"] * 100, 2 + ) results["extract_success_rate"] = round( results["extract_success"] / results["total"] * 100, 2 ) diff --git a/graph_net/agent/sample_verifier/base.py b/graph_net/agent/sample_verifier/base.py index 8e3ff87a04..d1cf99f03c 100644 --- a/graph_net/agent/sample_verifier/base.py +++ b/graph_net/agent/sample_verifier/base.py @@ -19,6 +19,6 @@ def verify(self, sample_dir: Path) -> bool: True if sample is valid, False otherwise Raises: - VerificationError: If verification process fails + SampleVerificationError: If verification process fails """ pass diff --git a/graph_net/agent/sample_verifier/basic_sample_verifier.py b/graph_net/agent/sample_verifier/basic_sample_verifier.py index 70e50e20d0..fa00800b54 100644 --- a/graph_net/agent/sample_verifier/basic_sample_verifier.py +++ b/graph_net/agent/sample_verifier/basic_sample_verifier.py @@ -4,7 +4,10 @@ from pathlib import Path from graph_net.agent.sample_verifier.base import BaseSampleVerifier -from graph_net.agent.utils.exceptions import VerificationError +from graph_net.agent.utils.exceptions import ( + GraphExtractionErrorCategory, + SampleVerificationError, +) class BasicSampleVerifier(BaseSampleVerifier): @@ -38,4 +41,7 @@ def verify(self, sample_dir: Path) -> bool: return True except Exception as e: - raise VerificationError(f"Verification failed: {e}") from e + raise SampleVerificationError( + f"Verification failed: {e}", + error_category=GraphExtractionErrorCategory.SAMPLE_INCOMPLETE, + ) from e diff --git a/graph_net/agent/sample_verifier/forward_verifier.py b/graph_net/agent/sample_verifier/forward_verifier.py index c7849eac76..b1b9ee3fc1 100644 --- a/graph_net/agent/sample_verifier/forward_verifier.py +++ b/graph_net/agent/sample_verifier/forward_verifier.py @@ -7,7 +7,10 @@ from graph_net.agent.sample_verifier.base import BaseSampleVerifier from graph_net.agent.sample_verifier.basic_sample_verifier import BasicSampleVerifier -from graph_net.agent.utils.exceptions import VerificationError +from graph_net.agent.utils.exceptions import ( + GraphExtractionErrorCategory, + SampleVerificationError, +) # Inline eager runner — executed in a subprocess to isolate CUDA state. # Loads GraphModule from model.py, reconstructs tensors from weight_meta.py, @@ -50,6 +53,7 @@ def __init__(self, timeout: int = 300): self._basic = BasicSampleVerifier() self.timeout = timeout if timeout is not None else 300 self.logger = logging.getLogger(self.__class__.__name__) + self.last_timeout_success = False def verify(self, sample_dir: Path) -> bool: """ @@ -61,6 +65,7 @@ def verify(self, sample_dir: Path) -> bool: Returns: True if all checks pass, False otherwise """ + self.last_timeout_success = False try: # Stage 1: file structure check if not self._basic.verify(sample_dir): @@ -72,16 +77,28 @@ def verify(self, sample_dir: Path) -> bool: targets = subgraph_dirs if subgraph_dirs else [sample_dir] for target in targets: - if not self._run_forward(target): + ok, is_timeout = self._run_forward(target) + if not ok: return False + if is_timeout: + self.last_timeout_success = True return True except Exception as e: - raise VerificationError(f"Forward verification failed: {e}") from e + raise SampleVerificationError( + f"Forward verification failed: {e}", + error_category=GraphExtractionErrorCategory.FORWARD_VERIFY_FAILED, + ) from e - def _run_forward(self, model_path: Path) -> bool: - """Run an eager forward pass on one model directory in a subprocess.""" + def _run_forward(self, model_path: Path) -> tuple[bool, bool]: + """Run an eager forward pass on one model directory in a subprocess. + + Returns: + (success, is_timeout): success=True means the check passed; + is_timeout=True means it passed only because + the subprocess timed out (treated as skip). + """ self.logger.info(f"Forward verify (eager): {model_path.name}") try: result = subprocess.run( @@ -92,14 +109,15 @@ def _run_forward(self, model_path: Path) -> bool: ) if result.returncode == 0: self.logger.info(f"Forward verify OK: {model_path.name}") - return True + return True, False else: self.logger.warning( f"Forward verify FAIL: {model_path.name}\n{result.stderr[-2000:]}" ) - return False + return False, False except subprocess.TimeoutExpired: self.logger.warning( - f"Forward verify TIMEOUT ({self.timeout}s): {model_path.name}" + f"Forward verify TIMEOUT ({self.timeout}s): {model_path.name}, " + "treating as pass (skip verification for large models on CPU)" ) - return False + return True, True diff --git a/graph_net/agent/utils/__init__.py b/graph_net/agent/utils/__init__.py index e8ebbd410c..fb94a5f6c2 100644 --- a/graph_net/agent/utils/__init__.py +++ b/graph_net/agent/utils/__init__.py @@ -3,17 +3,17 @@ from graph_net.agent.utils.exceptions import ( AgentError, ModelFetchError, - AnalysisError, - CodeGenError, - ExtractionError, - VerificationError, + MetadataAnalysisError, + CodeGenerationError, + GraphExtractionError, + SampleVerificationError, ) __all__ = [ "AgentError", "ModelFetchError", - "AnalysisError", - "CodeGenError", - "ExtractionError", - "VerificationError", + "MetadataAnalysisError", + "CodeGenerationError", + "GraphExtractionError", + "SampleVerificationError", ] diff --git a/graph_net/agent/utils/error_classifier.py b/graph_net/agent/utils/error_classifier.py new file mode 100644 index 0000000000..9b5b5b62d3 --- /dev/null +++ b/graph_net/agent/utils/error_classifier.py @@ -0,0 +1,129 @@ +"""Error classification for extraction failures. + +Classification is driven entirely by the exception's `error_category` +attribute (set at the raise-site). No string keyword matching is +performed here — keywords belong in the code that raises the exception. +""" + +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, List, Optional + +from graph_net.agent.utils.exceptions import GraphExtractionErrorCategory + + +@dataclass +class ErrorRecord: + """Single error occurrence.""" + + model_id: str + category: GraphExtractionErrorCategory + message: str + + +class GraphExtractionErrorClassifier: + """Classify extraction errors and keep per-model records. + + Usage: + classifier = GraphExtractionErrorClassifier() + category = classifier.classify_from_exception(exc) + classifier.record(model_id, category, str(exc)) + + # After run + report = classifier.summary() + """ + + def __init__(self): + self.records: List[ErrorRecord] = [] + self._by_model: Dict[str, ErrorRecord] = {} + + @staticmethod + def classify_from_exception(exc: Exception) -> GraphExtractionErrorCategory: + """Classify from the exception's ``error_category`` attribute. + + Falls back to UNKNOWN when the attribute is missing or invalid. + """ + raw = getattr(exc, "error_category", None) + if raw is not None: + try: + return GraphExtractionErrorCategory(raw) + except ValueError: + pass + return GraphExtractionErrorCategory.UNKNOWN + + def record( + self, + model_id: str, + category: GraphExtractionErrorCategory, + message: str, + ) -> None: + """Store one error record.""" + rec = ErrorRecord(model_id=model_id, category=category, message=message) + self.records.append(rec) + self._by_model[model_id] = rec + + def classify_and_record( + self, + model_id: str, + exception: Exception, + ) -> GraphExtractionErrorCategory: + """Convenience: classify from exception and store.""" + category = self.classify_from_exception(exception) + self.record(model_id, category, str(exception)) + return category + + def get_record(self, model_id: str) -> Optional[ErrorRecord]: + return self._by_model.get(model_id) + + def get_models_by_category( + self, category: GraphExtractionErrorCategory + ) -> List[str]: + return [rec.model_id for rec in self.records if rec.category == category] + + def summary(self) -> Dict[str, object]: + counts: Dict[str, int] = defaultdict(int) + per_category: Dict[str, List[str]] = defaultdict(list) + for rec in self.records: + cat_name = rec.category.value + counts[cat_name] += 1 + per_category[cat_name].append(rec.model_id) + return { + "total_errors": len(self.records), + "category_counts": dict(counts), + "models_per_category": dict(per_category), + } + + def report_lines(self) -> List[str]: + """Plain-text report as a list of lines (no markdown).""" + lines: List[str] = [] + lines.append("Extraction Error Report") + lines.append("") + lines.append(f"Total errors: {len(self.records)}") + lines.append("") + + counts: Dict[GraphExtractionErrorCategory, int] = defaultdict(int) + per_cat: Dict[GraphExtractionErrorCategory, List[ErrorRecord]] = defaultdict( + list + ) + for rec in self.records: + counts[rec.category] += 1 + per_cat[rec.category].append(rec) + + lines.append("Summary by Category:") + for cat, cnt in sorted(counts.items(), key=lambda x: -x[1]): + lines.append(f" {cat.value}: {cnt}") + lines.append("") + + lines.append("Details:") + for cat, recs in sorted(per_cat.items(), key=lambda x: -len(x[1])): + lines.append(f" {cat.value} ({len(recs)}):") + for rec in recs[:10]: + msg = ( + rec.message[:120] + "..." if len(rec.message) > 120 else rec.message + ) + lines.append(f" - {rec.model_id}: {msg}") + if len(recs) > 10: + lines.append(f" - ... and {len(recs) - 10} more") + lines.append("") + + return lines diff --git a/graph_net/agent/utils/exceptions.py b/graph_net/agent/utils/exceptions.py index 95f4f88f31..d274338f7b 100644 --- a/graph_net/agent/utils/exceptions.py +++ b/graph_net/agent/utils/exceptions.py @@ -1,37 +1,114 @@ -"""Custom exception classes for Agent""" +"""Custom exception classes for Agent. + +Each exception may carry an `error_category` so that +error_classifier.py can route without string matching. +""" + +from enum import Enum +from typing import Optional + + +class GraphExtractionErrorCategory(str, Enum): + """Known categories of extraction failure.""" + + # Pre-extraction failures + MODEL_NOT_FOUND = "model_not_found" + MODEL_FORBIDDEN = "model_forbidden" + MODEL_DOWNLOAD_ERROR = "model_download_error" + + # Config / metadata analysis failures + CONFIG_NOT_FOUND = "config_not_found" + CONFIG_PARSE_ERROR = "config_parse_error" + METADATA_ANALYSIS_FAILED = "metadata_analysis_failed" + + # Script generation failures + CODE_GEN_ERROR = "code_gen_error" + + # Script execution failures + SCRIPT_EXECUTION_FAILED = "script_execution_failed" + SCRIPT_TIMEOUT = "script_timeout" + OUTPUT_DIR_NOT_FOUND = "output_dir_not_found" + + # LLM retry failures + LLM_TIMEOUT = "llm_timeout" + LLM_EXIT_ERROR = "llm_exit_error" + + # Post-extraction failures + SAMPLE_INCOMPLETE = "sample_incomplete" + FORWARD_VERIFY_FAILED = "forward_verify_failed" + VERIFICATION_TIMEOUT = "verification_timeout" + VERIFICATION_FAILED = "verification_failed" + + # Catch-all + UNKNOWN = "unknown" class AgentError(Exception): - """Base exception for Agent errors""" + """Base exception for Agent errors. + + Subclasses can set `default_category` so that raise-sites do not + need to repeat the category when the default is sufficient. + """ - pass + default_category: Optional[GraphExtractionErrorCategory] = None + + def __init__( + self, + message: str, + error_category: Optional[GraphExtractionErrorCategory] = None, + ): + super().__init__(message) + self.error_category = error_category or self.default_category class ModelFetchError(AgentError): - """Raised when model fetching fails""" + """Raised when model fetching fails. + + Default: MODEL_DOWNLOAD_ERROR. + Raise-sites should override for 404 (MODEL_NOT_FOUND) + or 403 (MODEL_FORBIDDEN). + """ + + default_category = GraphExtractionErrorCategory.MODEL_DOWNLOAD_ERROR + + +class MetadataAnalysisError(AgentError): + """Raised when model metadata/config analysis fails. + + Covers config missing, JSON parse errors, and unsupported architectures. + """ - pass + default_category = GraphExtractionErrorCategory.METADATA_ANALYSIS_FAILED -class AnalysisError(AgentError): - """Raised when model analysis fails""" +class CodeGenerationError(AgentError): + """Raised when code generation fails. - pass + Default: CODE_GEN_ERROR. + Raise-sites should override for LLM-specific failures + (LLM_TIMEOUT / LLM_EXIT_ERROR). + """ + default_category = GraphExtractionErrorCategory.CODE_GEN_ERROR -class CodeGenError(AgentError): - """Raised when code generation fails""" - pass +class GraphExtractionError(AgentError): + """Raised when graph extraction fails. + Default: UNKNOWN — raise-sites MUST override with one of: + - SCRIPT_EXECUTION_FAILED + - SCRIPT_TIMEOUT + - OUTPUT_DIR_NOT_FOUND + """ -class ExtractionError(AgentError): - """Raised when graph extraction fails""" + default_category = GraphExtractionErrorCategory.UNKNOWN - pass +class SampleVerificationError(AgentError): + """Raised when sample verification fails. -class VerificationError(AgentError): - """Raised when sample verification fails""" + Default: VERIFICATION_FAILED. + Raise-sites may override with VERIFICATION_TIMEOUT. + """ - pass + default_category = GraphExtractionErrorCategory.VERIFICATION_FAILED