diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f034853..d2a2f19 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -76,7 +76,7 @@ jobs: run: uv sync --extra dev - name: Run tests - run: uv run pytest --cov --cov-report=xml --cov-fail-under=79 -v -m "not integration" + run: uv run pytest --cov --cov-report=xml --cov-fail-under=80 -v -m "not integration" - name: Upload coverage uses: codecov/codecov-action@v4 diff --git a/.gitignore b/.gitignore index aa63c69..91b0661 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,4 @@ deriva/adapters/neo4j/logs/* deriva/adapters/database/sql.db deriva/adapters/database/sql.db.wal .coverage -coverage.xml -.export/* -todo/* +coverage.xml \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index afb32b5..0c26a81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,17 +6,47 @@ Deriving ArchiMate models from code using knowledge graphs, heuristics and LLM's # v0.6.x - Deriva (December 2025 - January 2026) -## v0.6.8 - PydanticAI Migration (Unreleased) +## v0.6.8 - Library Migration & Overall Cleanup (Unreleased) + +Big migration replacing 6 custom implementations with off-the-shelf libraries, reducing the amount of code and improving maintainability. ### LLM Adapter Rewrite -- **PydanticAI Integration**: Replaced custom REST provider implementations with PydanticAI library -- **Code Reduction**: Same (or better) llm adapter with way less code, deleted entire `providers.py` +- **PydanticAI Integration**: Replaced custom REST provider implementations with `pydantic-ai` library +- **Model Registry**: New `model_registry.py` maps Deriva config to PydanticAI model identifiers with URL normalization for Azure/LM Studio +- **Code Reduction**: Same (or better) LLM adapter with way less code, deleted entire `providers.py` - **Native Structured Output**: PydanticAI handles validation and retry automatically -- **Removed ClaudeCode Provider**: Use `anthropic` provider directly instead (CLI subprocess no longer supported) ### Configuration +- **Pydantic Settings**: New `config_models.py` with type-safe environment validation using `pydantic-settings` - **Standard API Keys**: Added PydanticAI standard env vars (`OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, `MISTRAL_API_KEY`, `AZURE_OPENAI_*`) -- **Updated .env.example**: Removed `claudecode` provider, added Anthropic direct API configs + +### Caching +- **diskcache Integration**: Replaced custom SQLite-based caching with `diskcache` library +- **Simplified Cache Utils**: Rewrote `cache_utils.py` to wrap diskcache with `BaseDiskCache` class +- **Preserved Features**: Kept `hash_inputs()`, `bench_hash` isolation, and `export_to_json()` functionality +- **LLM & Graph Caches**: Updated both adapters to use new base cache class + +### Retry Logic +- **backoff Library**: Replaced custom retry implementation with `backoff` library +- **New retry.py**: Centralized retry decorator with exponential backoff and jitter +- **Simplified Rate Limiter**: Token bucket rate limiting now separate from retry logic + +### Small CLI refactor +- **Typer Framework**: Replaced argparse-based CLI with `typer` +- **Command Modules**: Split CLI into `deriva/cli/commands/` with separate files for `benchmark.py`, `config.py`, `repo.py`, `run.py` +- **Modern CLI Features**: Auto-completion, better help generation, type hints via `Annotated` +- **Subcommand Groups**: `config`, `repo`, `benchmark` as typer subapps + +### Logging + +- **structlog Integration**: Rewrote `logging.py` with `structlog` for structured logging +- **Preserved API**: Same RunLogger, StepContext, and JSONL output format +- **OCEL Unchanged**: OCEL module kept intact for benchmark process mining + +### Tests & Quality +- **CLI Tests Rewritten**: Updated all 51 CLI tests to use typer's `CliRunner` +- **Tree-sitter Test Consolidation**: Merged per-language test files into single `test_languages.py` +- **Coverage Threshold**: Updated CI coverage threshold to 80% --- diff --git a/deriva/adapters/llm/__init__.py b/deriva/adapters/llm/__init__.py index e56cfc9..a541b45 100644 --- a/deriva/adapters/llm/__init__.py +++ b/deriva/adapters/llm/__init__.py @@ -48,6 +48,7 @@ class Concept(BaseModel): ValidationError, ) from .rate_limiter import RateLimitConfig, RateLimiter, get_default_rate_limit +from .retry import create_retry_decorator, retry_on_rate_limit __all__ = [ # Main service @@ -70,6 +71,9 @@ class Concept(BaseModel): "RateLimitConfig", "RateLimiter", "get_default_rate_limit", + # Retry + "create_retry_decorator", + "retry_on_rate_limit", # Exceptions "LLMError", "ConfigurationError", diff --git a/deriva/adapters/llm/manager.py b/deriva/adapters/llm/manager.py index fa0e801..1432007 100644 --- a/deriva/adapters/llm/manager.py +++ b/deriva/adapters/llm/manager.py @@ -40,11 +40,11 @@ class Concept(BaseModel): from dotenv import load_dotenv from pydantic import BaseModel from pydantic_ai import Agent +from pydantic_ai.settings import ModelSettings from .cache import CacheManager from .model_registry import VALID_PROVIDERS, get_pydantic_ai_model from .models import ( - APIError, BenchmarkModelConfig, CachedResponse, ConfigurationError, @@ -208,7 +208,9 @@ def from_config( load_dotenv(override=True) effective_temperature = ( - temperature if temperature is not None else float(os.getenv("LLM_TEMPERATURE", "0.7")) + temperature + if temperature is not None + else float(os.getenv("LLM_TEMPERATURE", "0.7")) ) instance = object.__new__(cls) @@ -265,8 +267,12 @@ def _load_config_from_env(self) -> dict[str, Any]: if default_model: benchmark_models = load_benchmark_models() if default_model not in benchmark_models: - available = ", ".join(benchmark_models.keys()) if benchmark_models else "none" - raise ConfigurationError(f"LLM_DEFAULT_MODEL '{default_model}' not found. Available: {available}") + available = ( + ", ".join(benchmark_models.keys()) if benchmark_models else "none" + ) + raise ConfigurationError( + f"LLM_DEFAULT_MODEL '{default_model}' not found. Available: {available}" + ) config = benchmark_models[default_model] provider = config.provider api_url = config.get_api_url() @@ -288,7 +294,9 @@ def _load_config_from_env(self) -> dict[str, Any]: api_key = os.getenv("LLM_ANTHROPIC_API_KEY") model = os.getenv("LLM_ANTHROPIC_MODEL", "claude-sonnet-4-20250514") elif provider == "ollama": - api_url = os.getenv("LLM_OLLAMA_API_URL", "http://localhost:11434/api/chat") + api_url = os.getenv( + "LLM_OLLAMA_API_URL", "http://localhost:11434/api/chat" + ) api_key = None model = os.getenv("LLM_OLLAMA_MODEL", "llama3.2") elif provider == "mistral": @@ -296,7 +304,9 @@ def _load_config_from_env(self) -> dict[str, Any]: api_key = os.getenv("LLM_MISTRAL_API_KEY") model = os.getenv("LLM_MISTRAL_MODEL", "mistral-large-latest") elif provider == "lmstudio": - api_url = os.getenv("LLM_LMSTUDIO_API_URL", "http://localhost:1234/v1/chat/completions") + api_url = os.getenv( + "LLM_LMSTUDIO_API_URL", "http://localhost:1234/v1/chat/completions" + ) api_key = None model = os.getenv("LLM_LMSTUDIO_MODEL", "local-model") else: @@ -325,7 +335,9 @@ def _validate_config(self) -> None: """Validate configuration has required fields.""" provider = self.config.get("provider", "") if provider not in VALID_PROVIDERS: - raise ConfigurationError(f"Invalid provider: {provider}. Must be one of {VALID_PROVIDERS}") + raise ConfigurationError( + f"Invalid provider: {provider}. Must be one of {VALID_PROVIDERS}" + ) # Ollama and LM Studio don't require api_key if provider in ("ollama", "lmstudio"): @@ -335,7 +347,9 @@ def _validate_config(self) -> None: missing = [f for f in required_fields if not self.config.get(f)] if missing: - raise ConfigurationError(f"Missing required config fields: {', '.join(missing)}") + raise ConfigurationError( + f"Missing required config fields: {', '.join(missing)}" + ) @overload def query( @@ -394,7 +408,9 @@ def query( If response_model is provided: Validated Pydantic model instance or FailedResponse Otherwise: LiveResponse, CachedResponse, or FailedResponse """ - effective_temperature = temperature if temperature is not None else self.temperature + effective_temperature = ( + temperature if temperature is not None else self.temperature + ) effective_max_tokens = max_tokens if max_tokens is not None else self.max_tokens # Generate cache key @@ -446,12 +462,12 @@ def query( ) # Run query + settings: ModelSettings = {"temperature": effective_temperature} + if effective_max_tokens is not None: + settings["max_tokens"] = effective_max_tokens result = agent.run_sync( prompt, - model_settings={ - "temperature": effective_temperature, - "max_tokens": effective_max_tokens, - }, + model_settings=settings, ) self._rate_limiter.record_success() @@ -461,7 +477,8 @@ def query( if hasattr(result, "usage") and result.usage: usage = { "prompt_tokens": getattr(result.usage, "request_tokens", 0) or 0, - "completion_tokens": getattr(result.usage, "response_tokens", 0) or 0, + "completion_tokens": getattr(result.usage, "response_tokens", 0) + or 0, "total_tokens": getattr(result.usage, "total_tokens", 0) or 0, } @@ -469,13 +486,21 @@ def query( if response_model: # Cache the serialized model if write_cache: - content = result.output.model_dump_json() if hasattr(result.output, "model_dump_json") else str(result.output) - self.cache.set_response(cache_key, content, prompt, self.model, usage) + content = ( + result.output.model_dump_json() + if hasattr(result.output, "model_dump_json") + else str(result.output) + ) + self.cache.set_response( + cache_key, content, prompt, self.model, usage + ) return result.output else: content = str(result.output) if result.output else "" if write_cache: - self.cache.set_response(cache_key, content, prompt, self.model, usage) + self.cache.set_response( + cache_key, content, prompt, self.model, usage + ) return LiveResponse( prompt=prompt, model=self.model, @@ -541,7 +566,9 @@ def get_token_usage_stats(self) -> dict[str, Any]: "total_tokens": total_prompt + total_completion, "total_calls": total_calls, "avg_prompt_tokens": total_prompt / total_calls if total_calls else 0, - "avg_completion_tokens": total_completion / total_calls if total_calls else 0, + "avg_completion_tokens": total_completion / total_calls + if total_calls + else 0, } def __repr__(self) -> str: diff --git a/deriva/adapters/llm/model_registry.py b/deriva/adapters/llm/model_registry.py index 84c11b2..96b86e2 100644 --- a/deriva/adapters/llm/model_registry.py +++ b/deriva/adapters/llm/model_registry.py @@ -12,7 +12,9 @@ from pydantic_ai.models import Model # Valid provider names -VALID_PROVIDERS = frozenset({"azure", "openai", "anthropic", "ollama", "mistral", "lmstudio"}) +VALID_PROVIDERS = frozenset( + {"azure", "openai", "anthropic", "ollama", "mistral", "lmstudio"} +) def get_pydantic_ai_model(config: dict[str, Any]) -> "Model | str": @@ -63,12 +65,16 @@ def get_pydantic_ai_model(config: dict[str, Any]) -> "Model | str": from pydantic_ai.models.openai import OpenAIChatModel from pydantic_ai.providers.openai import OpenAIProvider - base_url = _normalize_openai_url(api_url) if api_url else "http://localhost:1234/v1" + base_url = ( + _normalize_openai_url(api_url) if api_url else "http://localhost:1234/v1" + ) openai_provider = OpenAIProvider(base_url=base_url) return OpenAIChatModel(model, provider=openai_provider) else: - raise ValueError(f"Unknown provider: {provider}. Valid providers: {VALID_PROVIDERS}") + raise ValueError( + f"Unknown provider: {provider}. Valid providers: {VALID_PROVIDERS}" + ) def _normalize_azure_url(url: str) -> str: diff --git a/deriva/adapters/llm/models.py b/deriva/adapters/llm/models.py index ad96d96..9f92907 100644 --- a/deriva/adapters/llm/models.py +++ b/deriva/adapters/llm/models.py @@ -67,7 +67,9 @@ class BenchmarkModelConfig: def __post_init__(self): """Validate provider.""" if self.provider not in VALID_PROVIDERS: - raise ValueError(f"Invalid provider: {self.provider}. Must be one of {VALID_PROVIDERS}") + raise ValueError( + f"Invalid provider: {self.provider}. Must be one of {VALID_PROVIDERS}" + ) def get_api_key(self) -> str | None: """Get API key from direct value or environment variable.""" diff --git a/deriva/adapters/llm/rate_limiter.py b/deriva/adapters/llm/rate_limiter.py index 3b655a5..c1dc3d5 100644 --- a/deriva/adapters/llm/rate_limiter.py +++ b/deriva/adapters/llm/rate_limiter.py @@ -1,10 +1,11 @@ """ Rate limiter for LLM API requests. -Implements token bucket algorithm with support for: +Implements token bucket algorithm for: - Requests per minute (RPM) limits - Minimum delay between requests -- Exponential backoff on rate limit errors + +For retry logic with exponential backoff, use retry.py instead. """ from __future__ import annotations @@ -27,7 +28,6 @@ "mistral": 24, # Mistral: varies by tier "ollama": 0, # Local - no limit "lmstudio": 0, # Local - no limit - "claudecode": 30, # Claude Code CLI: conservative default } @@ -37,15 +37,12 @@ class RateLimitConfig: requests_per_minute: int = 60 # 0 = no limit min_request_delay: float = 0.0 # Minimum seconds between requests - backoff_base: float = 2.0 # Base for exponential backoff - backoff_max: float = 60.0 # Maximum backoff delay in seconds - backoff_jitter: float = 0.1 # Jitter factor (0-1) for randomization @dataclass class RateLimiter: """ - Token bucket rate limiter with exponential backoff support. + Token bucket rate limiter for API requests. Thread-safe implementation that tracks request timestamps and enforces rate limits across concurrent calls. @@ -56,8 +53,8 @@ class RateLimiter: config: RateLimitConfig = field(default_factory=RateLimitConfig) _request_times: deque = field(default_factory=deque) # O(1) popleft _lock: threading.Lock = field(default_factory=threading.Lock) - _consecutive_rate_limits: int = field(default=0) _last_request_time: float = field(default=0.0) + _successful_requests: int = field(default=0) def wait_if_needed(self) -> float: """ @@ -110,54 +107,9 @@ def wait_if_needed(self) -> float: return wait_time def record_success(self) -> None: - """Record a successful request, resetting backoff counter.""" + """Record a successful request.""" with self._lock: - self._consecutive_rate_limits = 0 - - def record_rate_limit(self, retry_after: float | None = None) -> float: - """ - Record a rate limit error and calculate backoff delay. - - Args: - retry_after: Optional retry-after value from API response header (seconds). - If provided, uses this value instead of exponential backoff. - - Returns: - float: Recommended wait time before retry - """ - import random - - with self._lock: - self._consecutive_rate_limits += 1 - - # Use retry-after header if provided - if retry_after is not None and retry_after > 0: - delay = min(retry_after, self.config.backoff_max) - logger.warning( - "Rate limit hit (attempt %d), using retry-after header: %.2fs", - self._consecutive_rate_limits, - delay, - ) - return delay - - # Calculate exponential backoff with jitter - delay = min( - self.config.backoff_base**self._consecutive_rate_limits, - self.config.backoff_max, - ) - - # Add jitter to prevent thundering herd - if self.config.backoff_jitter > 0: - jitter = delay * self.config.backoff_jitter * random.random() - delay += jitter - - logger.warning( - "Rate limit hit (attempt %d), backing off %.2fs", - self._consecutive_rate_limits, - delay, - ) - - return delay + self._successful_requests += 1 def get_stats(self) -> dict[str, float | int]: """Get current rate limiter statistics.""" @@ -172,7 +124,7 @@ def get_stats(self) -> dict[str, float | int]: return { "requests_last_minute": recent_requests, "rpm_limit": self.config.requests_per_minute, - "consecutive_rate_limits": self._consecutive_rate_limits, + "successful_requests": self._successful_requests, "min_request_delay": self.config.min_request_delay, } @@ -216,3 +168,12 @@ def parse_retry_after(headers: dict[str, str] | None) -> float | None: continue return None + + +__all__ = [ + "RateLimitConfig", + "RateLimiter", + "get_default_rate_limit", + "parse_retry_after", + "DEFAULT_RATE_LIMITS", +] diff --git a/deriva/adapters/llm/retry.py b/deriva/adapters/llm/retry.py new file mode 100644 index 0000000..0ad4c29 --- /dev/null +++ b/deriva/adapters/llm/retry.py @@ -0,0 +1,158 @@ +""" +Retry utilities with exponential backoff. + +Uses the backoff library for robust retry handling with jitter. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import TypeVar + +import backoff +from backoff._typing import Details + +logger = logging.getLogger(__name__) + +# Type variable for decorated function return type +T = TypeVar("T") + +# Exceptions that should trigger retry +RETRIABLE_EXCEPTIONS = ( + ConnectionError, + TimeoutError, + OSError, # Includes network errors +) + + +def on_backoff(details: Details) -> None: + """Log backoff events.""" + wait = details.get("wait", 0) + tries = details.get("tries", 0) + target = details.get("target") + target_name = getattr(target, "__name__", "unknown") if target else "unknown" + exception = details.get("exception") + + logger.warning( + "Retry %d for %s, backing off %.2fs. Error: %s", + tries, + target_name, + wait, + exception, + ) + + +def on_giveup(details: Details) -> None: + """Log when retries are exhausted.""" + tries = details.get("tries", 0) + target = details.get("target") + target_name = getattr(target, "__name__", "unknown") if target else "unknown" + exception = details.get("exception") + + logger.error( + "Giving up on %s after %d attempts. Final error: %s", + target_name, + tries, + exception, + ) + + +def create_retry_decorator( + max_retries: int = 3, + base_delay: float = 2.0, + max_delay: float = 60.0, + exceptions: tuple = RETRIABLE_EXCEPTIONS, +) -> Callable[[Callable[..., T]], Callable[..., T]]: + """ + Create a retry decorator with exponential backoff. + + Args: + max_retries: Maximum number of retry attempts (default: 3) + base_delay: Base delay factor for exponential backoff (default: 2.0) + max_delay: Maximum delay between retries (default: 60.0) + exceptions: Tuple of exception types to retry on + + Returns: + Decorator function + + Example: + @create_retry_decorator(max_retries=5) + def flaky_api_call(): + ... + """ + return backoff.on_exception( + backoff.expo, + exception=exceptions, + max_tries=max_retries + 1, # backoff counts total tries, not retries + factor=base_delay, + max_value=max_delay, + jitter=backoff.full_jitter, + on_backoff=on_backoff, + on_giveup=on_giveup, + ) + + +def retry_on_rate_limit( + max_retries: int = 3, + base_delay: float = 2.0, + max_delay: float = 60.0, +) -> Callable[[Callable[..., T]], Callable[..., T]]: + """ + Decorator for retrying on rate limit errors (HTTP 429). + + Uses exponential backoff with jitter to handle rate limits gracefully. + + Args: + max_retries: Maximum retry attempts + base_delay: Base delay in seconds + max_delay: Maximum delay in seconds + + Returns: + Decorator function + + Example: + @retry_on_rate_limit(max_retries=5) + def api_call(): + response = requests.get(url) + if response.status_code == 429: + raise RateLimitError("Rate limited") + return response + """ + # Import here to avoid issues if these aren't installed + try: + from httpx import HTTPStatusError + from pydantic_ai import exceptions as pai_exceptions + + rate_limit_exceptions = ( + ConnectionError, + TimeoutError, + HTTPStatusError, + ) + + # Add PydanticAI rate limit exception if available + if hasattr(pai_exceptions, "RateLimitError"): + rate_limit_exceptions = ( + *rate_limit_exceptions, + pai_exceptions.RateLimitError, + ) + except ImportError: + rate_limit_exceptions = (ConnectionError, TimeoutError) + + return backoff.on_exception( + backoff.expo, + exception=rate_limit_exceptions, + max_tries=max_retries + 1, + factor=base_delay, + max_value=max_delay, + jitter=backoff.full_jitter, + on_backoff=on_backoff, + on_giveup=on_giveup, + ) + + +__all__ = [ + "create_retry_decorator", + "retry_on_rate_limit", + "RETRIABLE_EXCEPTIONS", +] diff --git a/deriva/cli/cli.py b/deriva/cli/cli.py index 42039a2..2262f80 100644 --- a/deriva/cli/cli.py +++ b/deriva/cli/cli.py @@ -2,371 +2,135 @@ CLI entry point for Deriva. Provides headless command-line interface for pipeline operations. -Uses PipelineSession from the services layer. +Uses Typer for modern CLI with auto-completion and help generation. Usage: - python -m deriva.cli.cli repo clone - python -m deriva.cli.cli repo list --detailed - python -m deriva.cli.cli repo delete --force - python -m deriva.cli.cli repo info - python -m deriva.cli.cli config list extraction - python -m deriva.cli.cli config list derivation --phase generate - python -m deriva.cli.cli config enable extraction BusinessConcept - python -m deriva.cli.cli run extraction - python -m deriva.cli.cli run derivation --phase generate - python -m deriva.cli.cli run all --repo flask_invoice_generator - python -m deriva.cli.cli clear graph - python -m deriva.cli.cli clear model - python -m deriva.cli.cli status + deriva repo clone + deriva repo list --detailed + deriva repo delete --force + deriva repo info + deriva config list extraction + deriva config list derivation --phase generate + deriva config enable extraction BusinessConcept + deriva run extraction + deriva run derivation --phase generate + deriva run all --repo flask_invoice_generator + deriva clear graph + deriva clear model + deriva status """ from __future__ import annotations -import argparse import sys -from typing import Any +from typing import Annotated -from deriva.cli.progress import ( - create_benchmark_progress_reporter, - create_progress_reporter, +import typer + +from deriva.cli.commands.benchmark import app as benchmark_app +from deriva.cli.commands.config import app as config_app +from deriva.cli.commands.repo import app as repo_app +from deriva.cli.commands.run import ( + _print_derivation_result, + _print_extraction_result, + _print_pipeline_result, ) -from deriva.services import config +from deriva.cli.progress import create_progress_reporter from deriva.services.session import PipelineSession +# Create main app +app = typer.Typer( + name="deriva", + help="Deriva CLI - Generate ArchiMate models from code repositories", + no_args_is_help=True, +) -# ============================================================================= -# Config Commands -# ============================================================================= - - -def cmd_config_list(args: argparse.Namespace) -> int: - """List configurations for a step type.""" - with PipelineSession() as session: - step_type = args.step_type - enabled_only = args.enabled - - steps = session.list_steps(step_type, enabled_only=enabled_only) - - if not steps: - print(f"No {step_type} configurations found.") - return 0 - - print(f"\n{step_type.upper()} CONFIGURATIONS") - print("-" * 60) - - for step in steps: - status = "enabled" if step["enabled"] else "disabled" - name = step["name"] - seq = step["sequence"] - print(f" [{seq}] {name:<30} ({status})") - - print() - return 0 - - -def cmd_config_show(args: argparse.Namespace) -> int: - """Show detailed configuration for a specific step.""" - with PipelineSession() as session: - step_type = args.step_type - name = args.name - - if step_type == "extraction": - cfg = config.get_extraction_config(session._engine, name) - if not cfg: - print(f"Extraction config '{name}' not found.") - return 1 - print(f"\nEXTRACTION CONFIG: {cfg.node_type}") - print("-" * 60) - print(f" Sequence: {cfg.sequence}") - print(f" Enabled: {cfg.enabled}") - print(f" Sources: {cfg.input_sources or 'None'}") - print(f" Instruction: {(cfg.instruction or '')[:100]}...") - print(f" Example: {(cfg.example or '')[:100]}...") - - elif step_type == "derivation": - cfg = config.get_derivation_config(session._engine, name) - if not cfg: - print(f"Derivation config '{name}' not found.") - return 1 - print(f"\nDERIVATION CONFIG: {cfg.element_type}") - print("-" * 60) - print(f" Sequence: {cfg.sequence}") - print(f" Enabled: {cfg.enabled}") - print(f" Query: {(cfg.input_graph_query or '')[:100]}...") - print(f" Instruction: {(cfg.instruction or '')[:100]}...") - - else: - print(f"Unknown step type: {step_type}") - return 1 - - print() - return 0 - - -def cmd_config_enable(args: argparse.Namespace) -> int: - """Enable a configuration step.""" - with PipelineSession() as session: - if session.enable_step(args.step_type, args.name): - print(f"Enabled {args.step_type} step: {args.name}") - return 0 - else: - print(f"Step not found: {args.step_type}/{args.name}") - return 1 - - -def cmd_config_disable(args: argparse.Namespace) -> int: - """Disable a configuration step.""" - with PipelineSession() as session: - if session.disable_step(args.step_type, args.name): - print(f"Disabled {args.step_type} step: {args.name}") - return 0 - else: - print(f"Step not found: {args.step_type}/{args.name}") - return 1 - - -def cmd_config_update(args: argparse.Namespace) -> int: - """Update a configuration with versioning.""" - import json - - with PipelineSession() as session: - step_type = args.step_type - name = args.name - instruction = args.instruction - example = args.example - params = getattr(args, "params", None) - - # Read instruction from file if provided - if args.instruction_file: - try: - with open(args.instruction_file, encoding="utf-8") as f: - instruction = f.read() - except Exception as e: - print(f"Error reading instruction file: {e}") - return 1 - - # Read example from file if provided - if args.example_file: - try: - with open(args.example_file, encoding="utf-8") as f: - example = f.read() - except Exception as e: - print(f"Error reading example file: {e}") - return 1 - - # Read params from file if provided - params_file = getattr(args, "params_file", None) - if params_file: - try: - with open(params_file, encoding="utf-8") as f: - params = f.read() - except Exception as e: - print(f"Error reading params file: {e}") - return 1 - - # Validate params is valid JSON if provided - if params: - try: - json.loads(params) - except json.JSONDecodeError as e: - print(f"Error: params must be valid JSON: {e}") - return 1 - - if step_type == "derivation": - result = config.create_derivation_config_version( - session._engine, - name, - instruction=instruction, - example=example, - input_graph_query=args.query, - params=params, - ) - elif step_type == "extraction": - result = config.create_extraction_config_version( - session._engine, - name, - instruction=instruction, - example=example, - input_sources=args.sources, - ) - else: - print(f"Versioned updates not yet supported for: {step_type}") - return 1 - - if result.get("success"): - print(f"Updated {step_type} config: {name}") - print(f" Version: {result['old_version']} -> {result['new_version']}") - if params: - print(" Params: updated") - return 0 - else: - print(f"Error: {result.get('error', 'Unknown error')}") - return 1 - - -def cmd_config_versions(args: argparse.Namespace) -> int: - """Show active config versions.""" - with PipelineSession() as session: - versions = config.get_active_config_versions(session._engine) - - print("\nACTIVE CONFIG VERSIONS") - print("=" * 60) - - for step_type in ["extraction", "derivation"]: - if versions.get(step_type): - print(f"\n{step_type.upper()}:") - for name, version in sorted(versions[step_type].items()): - print(f" {name:<30} v{version}") - - print() - return 0 - - -# ============================================================================= -# File Type Commands -# ============================================================================= - - -def cmd_filetype_list(args: argparse.Namespace) -> int: - """List all registered file types.""" - with PipelineSession() as session: - file_types = session.get_file_types() - - if not file_types: - print("No file types registered.") - return 0 - - # Group by file_type - by_type: dict[str, list] = {} - for ft in file_types: - ft_type = ft.get("file_type", "unknown") - if ft_type not in by_type: - by_type[ft_type] = [] - by_type[ft_type].append(ft) - - print(f"\n{'=' * 60}") - print("FILE TYPE REGISTRY") - print(f"{'=' * 60}") - print(f"Total: {len(file_types)} registered\n") - - for ft_type in sorted(by_type.keys()): - entries = by_type[ft_type] - print(f"{ft_type.upper()} ({len(entries)}):") - for ft in sorted(entries, key=lambda x: x.get("extension", "")): - ext = ft.get("extension", "") - subtype = ft.get("subtype", "") - print(f" {ext:<25} -> {subtype}") - print() - - return 0 - - -def cmd_filetype_add(args: argparse.Namespace) -> int: - """Add a new file type.""" - extension = args.extension - file_type = args.file_type - subtype = args.subtype - - with PipelineSession() as session: - success = session.add_file_type(extension, file_type, subtype) - - if success: - print(f"Added file type: {extension} -> {file_type}/{subtype}") - return 0 - else: - print(f"Failed to add file type (may already exist): {extension}") - return 1 - - -def cmd_filetype_delete(args: argparse.Namespace) -> int: - """Delete a file type.""" - extension = args.extension - - with PipelineSession() as session: - success = session.delete_file_type(extension) - - if success: - print(f"Deleted file type: {extension}") - return 0 - else: - print(f"File type not found: {extension}") - return 1 - - -def cmd_filetype_stats(args: argparse.Namespace) -> int: - """Show file type statistics.""" - with PipelineSession() as session: - stats = session.get_file_type_stats() - - print(f"\n{'=' * 60}") - print("FILE TYPE STATISTICS") - print(f"{'=' * 60}\n") - - for ft_type, count in sorted(stats.items(), key=lambda x: -x[1]): - print(f" {ft_type:<20} {count}") - - print(f"\n {'Total':<20} {sum(stats.values())}") - - return 0 +# Add subcommand groups +app.add_typer(config_app, name="config") +app.add_typer(repo_app, name="repo") +app.add_typer(benchmark_app, name="benchmark") # ============================================================================= -# Run Commands +# Run Command (standalone, not a subgroup) # ============================================================================= -def cmd_run(args: argparse.Namespace) -> int: +@app.command("run") +def run_stage( + stage: Annotated[ + str, typer.Argument(help="Pipeline stage to run (extraction, derivation, all)") + ], + repo: Annotated[ + str | None, typer.Option("--repo", help="Specific repository to process") + ] = None, + phase: Annotated[ + str | None, typer.Option("--phase", help="Run specific phase") + ] = None, + verbose: Annotated[ + bool, typer.Option("-v", "--verbose", help="Print detailed progress") + ] = False, + quiet: Annotated[ + bool, typer.Option("-q", "--quiet", help="Disable progress bar") + ] = False, + no_llm: Annotated[ + bool, typer.Option("--no-llm", help="Skip LLM-based steps") + ] = False, +) -> None: """Run pipeline stages.""" - stage = args.stage - repo_name = getattr(args, "repo", None) - verbose = getattr(args, "verbose", False) - no_llm = getattr(args, "no_llm", False) - phase = getattr(args, "phase", None) - quiet = getattr(args, "quiet", False) + if stage not in ("extraction", "derivation", "all"): + typer.echo( + f"Error: stage must be 'extraction', 'derivation', or 'all', got '{stage}'", + err=True, + ) + raise typer.Exit(1) # Validate phase is appropriate for stage extraction_phases = {"classify", "parse"} derivation_phases = {"prep", "generate", "refine"} if phase: if stage == "extraction" and phase not in extraction_phases: - print(f"Error: Phase '{phase}' is not valid for extraction.") - print(f"Valid extraction phases: {', '.join(sorted(extraction_phases))}") - return 1 + typer.echo(f"Error: Phase '{phase}' is not valid for extraction.", err=True) + typer.echo( + f"Valid extraction phases: {', '.join(sorted(extraction_phases))}" + ) + raise typer.Exit(1) if stage == "derivation" and phase not in derivation_phases: - print(f"Error: Phase '{phase}' is not valid for derivation.") - print(f"Valid derivation phases: {', '.join(sorted(derivation_phases))}") - return 1 + typer.echo(f"Error: Phase '{phase}' is not valid for derivation.", err=True) + typer.echo( + f"Valid derivation phases: {', '.join(sorted(derivation_phases))}" + ) + raise typer.Exit(1) - print(f"\n{'=' * 60}") - print(f"DERIVA - Running {stage.upper()} pipeline") - print(f"{'=' * 60}") + typer.echo(f"\n{'=' * 60}") + typer.echo(f"DERIVA - Running {stage.upper()} pipeline") + typer.echo(f"{'=' * 60}") - if repo_name: - print(f"Repository: {repo_name}") + if repo: + typer.echo(f"Repository: {repo}") if phase: - print(f"Phase: {phase}") + typer.echo(f"Phase: {phase}") with PipelineSession() as session: - print("Connected to Neo4j") + typer.echo("Connected to Neo4j") # Show LLM status llm_info = session.llm_info if llm_info and not no_llm: - print(f"LLM configured: {llm_info['provider']}/{llm_info['model']}") + typer.echo(f"LLM configured: {llm_info['provider']}/{llm_info['model']}") elif no_llm: - print("LLM disabled (--no-llm)") + typer.echo("LLM disabled (--no-llm)") else: - print("Warning: LLM not configured. LLM-based steps will be skipped.") + typer.echo("Warning: LLM not configured. LLM-based steps will be skipped.") - # Create progress reporter (Rich-based if available, quiet if --quiet) + # Create progress reporter progress_reporter = create_progress_reporter(quiet=quiet or verbose) if stage == "extraction": - # Convert phase to phases list for extraction phases = [phase] if phase else None with progress_reporter: result = session.run_extraction( - repo_name=repo_name, + repo_name=repo, verbose=verbose, no_llm=no_llm, progress=progress_reporter, @@ -376,8 +140,11 @@ def cmd_run(args: argparse.Namespace) -> int: elif stage == "derivation": if not llm_info: - print("Error: Derivation requires LLM. Configure LLM in .env file.") - return 1 + typer.echo( + "Error: Derivation requires LLM. Configure LLM in .env file.", + err=True, + ) + raise typer.Exit(1) phases = [phase] if phase else None with progress_reporter: result = session.run_derivation( @@ -390,188 +157,116 @@ def cmd_run(args: argparse.Namespace) -> int: elif stage == "all": with progress_reporter: result = session.run_pipeline( - repo_name=repo_name, + repo_name=repo, verbose=verbose, progress=progress_reporter, ) _print_pipeline_result(result) - else: - print(f"Unknown stage: {stage}") - return 1 - - return 0 if result.get("success") else 1 - - -def _print_extraction_result(result: dict) -> None: - """Print extraction results.""" - print(f"\n{'-' * 60}") - print("EXTRACTION RESULTS") - print(f"{'-' * 60}") - stats = result.get("stats", {}) - print(f" Repos processed: {stats.get('repos_processed', 0)}") - print(f" Nodes created: {stats.get('nodes_created', 0)}") - print(f" Edges created: {stats.get('edges_created', 0)}") - print(f" Steps completed: {stats.get('steps_completed', 0)}") - print(f" Steps skipped: {stats.get('steps_skipped', 0)}") - - if result.get("warnings"): - print(f"\nWarnings ({len(result['warnings'])}):") - for warn in result["warnings"][:5]: - print(f" - {warn}") - if len(result["warnings"]) > 5: - print(f" ... and {len(result['warnings']) - 5} more") - - if result.get("errors"): - print(f"\nErrors ({len(result['errors'])}):") - for err in result["errors"][:5]: - print(f" - {err}") - if len(result["errors"]) > 5: - print(f" ... and {len(result['errors']) - 5} more") - - -def _print_derivation_result(result: dict) -> None: - """Print derivation results.""" - print(f"\n{'-' * 60}") - print("DERIVATION RESULTS") - print(f"{'-' * 60}") - stats = result.get("stats", {}) - print(f" Elements created: {stats.get('elements_created', 0)}") - print(f" Relationships created: {stats.get('relationships_created', 0)}") - print(f" Elements validated: {stats.get('elements_validated', 0)}") - print(f" Issues found: {stats.get('issues_found', 0)}") - print(f" Steps completed: {stats.get('steps_completed', 0)}") - - issues = result.get("issues", []) - if issues: - print(f"\nIssues ({len(issues)}):") - for issue in issues[:10]: - severity = issue.get("severity", "warning") - msg = issue.get("message", "") - print(f" [{severity.upper()}] {msg}") - if len(issues) > 10: - print(f" ... and {len(issues) - 10} more") - - if result.get("errors"): - print(f"\nErrors ({len(result['errors'])}):") - for err in result["errors"][:5]: - print(f" - {err}") - - -def _print_pipeline_result(result: dict) -> None: - """Print full pipeline results.""" - print(f"\n{'=' * 60}") - print("PIPELINE COMPLETE") - print(f"{'=' * 60}") - - results = result.get("results", {}) - - if results.get("classification"): - stats = results["classification"].get("stats", {}) - print("\nClassification:") - print(f" Files classified: {stats.get('files_classified', 0)}") - print(f" Files undefined: {stats.get('files_undefined', 0)}") - - if results.get("extraction"): - stats = results["extraction"].get("stats", {}) - print("\nExtraction:") - print(f" Nodes created: {stats.get('nodes_created', 0)}") - - if results.get("derivation"): - stats = results["derivation"].get("stats", {}) - print("\nDerivation:") - print(f" Elements created: {stats.get('elements_created', 0)}") - print(f" Issues found: {stats.get('issues_found', 0)}") - - if result.get("errors"): - print(f"\nTotal errors: {len(result['errors'])}") + if not result.get("success"): + raise typer.Exit(1) # ============================================================================= -# Status Commands +# Status Command # ============================================================================= -def cmd_status(args: argparse.Namespace) -> int: +@app.command("status") +def status() -> None: """Show current pipeline status.""" with PipelineSession() as session: - print("\nDERIVA STATUS") - print("=" * 60) + typer.echo("\nDERIVA STATUS") + typer.echo("=" * 60) # Count enabled steps per type for step_type in ["extraction", "derivation"]: all_steps = session.list_steps(step_type) enabled = [s for s in all_steps if s["enabled"]] - print( + typer.echo( f" {step_type.capitalize()}: {len(enabled)}/{len(all_steps)} steps enabled" ) # File types file_types = session.get_file_types() - print(f" File Types: {len(file_types)} registered") + typer.echo(f" File Types: {len(file_types)} registered") # Graph stats try: graph_stats = session.get_graph_stats() - print(f" Graph Nodes: {graph_stats['total_nodes']}") + typer.echo(f" Graph Nodes: {graph_stats['total_nodes']}") except Exception: - print(" Graph Nodes: (not connected)") + typer.echo(" Graph Nodes: (not connected)") # ArchiMate stats try: archimate_stats = session.get_archimate_stats() - print(f" ArchiMate Elements: {archimate_stats['total_elements']}") + typer.echo(f" ArchiMate Elements: {archimate_stats['total_elements']}") except Exception: - print(" ArchiMate Elements: (not connected)") + typer.echo(" ArchiMate Elements: (not connected)") - print() - return 0 + typer.echo("") # ============================================================================= -# Export Commands +# Export Command # ============================================================================= -def cmd_export(args: argparse.Namespace) -> int: +@app.command("export") +def export( + output: Annotated[ + str, typer.Option("-o", "--output", help="Output file path") + ] = "workspace/output/model.xml", + name: Annotated[str | None, typer.Option("-n", "--name", help="Model name")] = None, + verbose: Annotated[ + bool, typer.Option("-v", "--verbose", help="Print detailed progress") + ] = False, +) -> None: """Export ArchiMate model to file.""" - output_path = args.output - model_name = args.name or "Deriva Model" - verbose = getattr(args, "verbose", False) + model_name = name or "Deriva Model" - print(f"\n{'=' * 60}") - print("DERIVA - Exporting ArchiMate Model") - print(f"{'=' * 60}") + typer.echo(f"\n{'=' * 60}") + typer.echo("DERIVA - Exporting ArchiMate Model") + typer.echo(f"{'=' * 60}") with PipelineSession() as session: if verbose: - print("Connected to Neo4j") + typer.echo("Connected to Neo4j") - result = session.export_model(output_path=output_path, model_name=model_name) + result = session.export_model(output_path=output, model_name=model_name) if result["success"]: - print(f" Elements exported: {result['elements_exported']}") - print(f" Relationships exported: {result['relationships_exported']}") - print(f"\nExported to: {result['output_path']}") - print("Model can be opened with Archi or other ArchiMate-compatible tools.") - return 0 + typer.echo(f" Elements exported: {result['elements_exported']}") + typer.echo(f" Relationships exported: {result['relationships_exported']}") + typer.echo(f"\nExported to: {result['output_path']}") + typer.echo( + "Model can be opened with Archi or other ArchiMate-compatible tools." + ) else: - print(f"Error: {result.get('error', 'Unknown error')}") - return 1 + typer.echo(f"Error: {result.get('error', 'Unknown error')}", err=True) + raise typer.Exit(1) # ============================================================================= -# Clear Commands +# Clear Command # ============================================================================= -def cmd_clear(args: argparse.Namespace) -> int: +@app.command("clear") +def clear( + target: Annotated[str, typer.Argument(help="Data layer to clear (graph, model)")], +) -> None: """Clear graph or model data.""" - target = args.target + if target not in ("graph", "model"): + typer.echo( + f"Error: target must be 'graph' or 'model', got '{target}'", err=True + ) + raise typer.Exit(1) - print(f"\n{'=' * 60}") - print(f"DERIVA - Clearing {target.upper()}") - print(f"{'=' * 60}") + typer.echo(f"\n{'=' * 60}") + typer.echo(f"DERIVA - Clearing {target.upper()}") + typer.echo(f"{'=' * 60}") with PipelineSession() as session: if target == "graph": @@ -579,590 +274,14 @@ def cmd_clear(args: argparse.Namespace) -> int: elif target == "model": result = session.clear_model() else: - print(f"Unknown clear target: {target}") - return 1 + typer.echo(f"Unknown clear target: {target}", err=True) + raise typer.Exit(1) if result.get("success"): - print(result.get("message", "Done")) - return 0 - else: - print(f"Error: {result.get('error', 'Unknown error')}") - return 1 - - -# ============================================================================= -# Repository Commands -# ============================================================================= - - -def cmd_repo_clone(args: argparse.Namespace) -> int: - """Clone a repository.""" - url = args.url - name = getattr(args, "name", None) - branch = getattr(args, "branch", None) - overwrite = getattr(args, "overwrite", False) - - print(f"\n{'=' * 60}") - print("DERIVA - Cloning Repository") - print(f"{'=' * 60}") - print(f"URL: {url}") - if name: - print(f"Name: {name}") - if branch: - print(f"Branch: {branch}") - - with PipelineSession() as session: - result = session.clone_repository( - url=url, name=name, branch=branch, overwrite=overwrite - ) - if result.get("success"): - print("\nRepository cloned successfully!") - print(f" Name: {result.get('name', 'N/A')}") - print(f" Path: {result.get('path', 'N/A')}") - print(f" URL: {result.get('url', url)}") - return 0 - else: - print(f"\nError: {result.get('error', 'Unknown error')}") - return 1 - - -def cmd_repo_list(args: argparse.Namespace) -> int: - """List all repositories.""" - detailed = getattr(args, "detailed", False) - - with PipelineSession() as session: - repos = session.get_repositories(detailed=detailed) - - if not repos: - print("\nNo repositories found.") - print(f"Workspace: {session.workspace_dir}") - print("\nClone a repository with:") - print(" deriva repo clone ") - return 0 - - print(f"\n{'=' * 60}") - print("REPOSITORIES") - print(f"{'=' * 60}") - print(f"Workspace: {session.workspace_dir}\n") - - for repo in repos: - if detailed: - dirty = " (dirty)" if repo.get("is_dirty") else "" - print(f" {repo['name']}{dirty}") - print(f" URL: {repo.get('url', 'N/A')}") - print(f" Branch: {repo.get('branch', 'N/A')}") - print(f" Size: {repo.get('size_mb', 0):.2f} MB") - print(f" Cloned: {repo.get('cloned_at', 'N/A')}") - print() - else: - print(f" {repo['name']}") - - print(f"\nTotal: {len(repos)} repositories") - return 0 - - -def cmd_repo_delete(args: argparse.Namespace) -> int: - """Delete a repository.""" - name = args.name - force = getattr(args, "force", False) - - print(f"\n{'=' * 60}") - print("DERIVA - Deleting Repository") - print(f"{'=' * 60}") - print(f"Repository: {name}") - - with PipelineSession() as session: - try: - result = session.delete_repository(name=name, force=force) - if result.get("success"): - print(f"\nRepository '{name}' deleted successfully.") - return 0 - else: - print(f"\nError: {result.get('error', 'Unknown error')}") - return 1 - except Exception as e: - print(f"\nError: {e}") - if "uncommitted changes" in str(e).lower(): - print("Use --force to delete anyway.") - return 1 - - -def cmd_repo_info(args: argparse.Namespace) -> int: - """Show repository details.""" - name = args.name - - with PipelineSession() as session: - try: - info = session.get_repository_info(name) - - if not info: - print(f"\nRepository '{name}' not found.") - return 1 - - print(f"\n{'=' * 60}") - print(f"REPOSITORY: {info['name']}") - print(f"{'=' * 60}") - print(f" Path: {info.get('path', 'N/A')}") - print(f" URL: {info.get('url', 'N/A')}") - print(f" Branch: {info.get('branch', 'N/A')}") - print(f" Last Commit: {info.get('last_commit', 'N/A')}") - print(f" Dirty: {info.get('is_dirty', False)}") - print(f" Size: {info.get('size_mb', 0):.2f} MB") - print(f" Cloned At: {info.get('cloned_at', 'N/A')}") - print() - return 0 - except Exception as e: - print(f"\nError: {e}") - return 1 - - -# ============================================================================= -# Consistency Commands -# ============================================================================= - - -# ============================================================================= -# Benchmark Commands -# ============================================================================= - - -def cmd_benchmark_run(args: argparse.Namespace) -> int: - """Run benchmark matrix.""" - repos = [r.strip() for r in args.repos.split(",")] - models = [m.strip() for m in args.models.split(",")] - runs = getattr(args, "runs", 3) - stages = [s.strip() for s in args.stages.split(",")] if args.stages else None - description = getattr(args, "description", "") - verbose = getattr(args, "verbose", False) - quiet = getattr(args, "quiet", False) - use_cache = not getattr(args, "no_cache", False) - export_models = not getattr(args, "no_export_models", False) - clear_between_runs = not getattr(args, "no_clear", False) - bench_hash = getattr(args, "bench_hash", False) - defer_relationships = getattr(args, "defer_relationships", False) - per_repo = getattr(args, "per_repo", False) - nocache_configs_str = getattr(args, "nocache_configs", None) - nocache_configs = ( - [c.strip() for c in nocache_configs_str.split(",")] - if nocache_configs_str - else None - ) - - # Calculate total runs based on mode - if per_repo: - total_runs = len(repos) * len(models) * runs + typer.echo(result.get("message", "Done")) else: - total_runs = len(models) * runs - - print(f"\n{'=' * 60}") - print("DERIVA - Multi-Model Benchmark") - print(f"{'=' * 60}") - print(f"Repositories: {repos}") - print(f"Models: {models}") - print(f"Runs per combination: {runs}") - print(f"Mode: {'per-repo' if per_repo else 'combined'}") - print(f"Total runs: {total_runs}") - if stages: - print(f"Stages: {stages}") - print(f"Cache: {'enabled' if use_cache else 'disabled'}") - print(f"Export models: {'enabled' if export_models else 'disabled'}") - print(f"Clear between runs: {'yes' if clear_between_runs else 'no'}") - if bench_hash: - print("Bench hash: enabled (per-run cache isolation)") - if defer_relationships: - print("Defer relationships: enabled (two-phase derivation)") - if nocache_configs: - print(f"No-cache configs: {nocache_configs}") - print(f"{'=' * 60}\n") - - with PipelineSession() as session: - print("Connected to Neo4j") - - # Create benchmark progress reporter (Rich-based if available) - progress_reporter = create_benchmark_progress_reporter(quiet=quiet or verbose) - - with progress_reporter: - result = session.run_benchmark( - repositories=repos, - models=models, - runs=runs, - stages=stages, - description=description, - verbose=verbose, - use_cache=use_cache, - nocache_configs=nocache_configs, - progress=progress_reporter, - export_models=export_models, - clear_between_runs=clear_between_runs, - bench_hash=bench_hash, - defer_relationships=defer_relationships, - per_repo=per_repo, - ) - - print(f"\n{'=' * 60}") - print("BENCHMARK COMPLETE") - print(f"{'=' * 60}") - print(f"Session ID: {result.session_id}") - print(f"Runs completed: {result.runs_completed}") - print(f"Runs failed: {result.runs_failed}") - print(f"Duration: {result.duration_seconds:.1f}s") - print(f"OCEL log: {result.ocel_path}") - if export_models: - print(f"Model files: workspace/benchmarks/{result.session_id}/models/") - - if result.errors: - print(f"\nErrors ({len(result.errors)}):") - for err in result.errors[:5]: - print(f" - {err}") - if len(result.errors) > 5: - print(f" ... and {len(result.errors) - 5} more") - - print("\nTo analyze results:") - print(f" deriva benchmark analyze {result.session_id}") - - return 0 if result.success else 1 - - -def cmd_benchmark_list(args: argparse.Namespace) -> int: - """List benchmark sessions.""" - limit = getattr(args, "limit", 10) - - with PipelineSession() as session: - sessions = session.list_benchmarks(limit=limit) - - if not sessions: - print("No benchmark sessions found.") - return 0 - - print(f"\n{'=' * 60}") - print("BENCHMARK SESSIONS") - print(f"{'=' * 60}") - - for s in sessions: - status_icon = ( - "" - if s["status"] == "completed" - else "" - if s["status"] == "failed" - else "" - ) - print(f"\n{status_icon} {s['session_id']}") - print(f" Status: {s['status']}") - print(f" Started: {s['started_at']}") - if s.get("description"): - print(f" Description: {s['description']}") - - print() - return 0 - - -def _get_run_stats_from_ocel(analyzer: Any) -> dict[str, list[tuple[int, int]]]: - """Extract node/edge counts per run from OCEL CompleteRun events.""" - result: dict[str, list[tuple[int, int]]] = {} - - for event in analyzer.ocel_log.events: - if event.activity != "CompleteRun": - continue - - model = event.objects.get("Model", [None])[0] - if not model: - continue - - stats = event.attributes.get("stats", {}) - extraction = stats.get("extraction", {}) - nodes = extraction.get("nodes_created", 0) - edges = extraction.get("edges_created", 0) - - if model not in result: - result[model] = [] - result[model].append((nodes, edges)) - - return result - - -def cmd_benchmark_analyze(args: argparse.Namespace) -> int: - """Analyze benchmark results.""" - session_id = args.session_id - output = getattr(args, "output", None) - fmt = getattr(args, "format", "json") - - print(f"\n{'=' * 60}") - print(f"ANALYZING BENCHMARK: {session_id}") - print(f"{'=' * 60}\n") - - with PipelineSession() as session: - try: - analyzer = session.analyze_benchmark(session_id) - except ValueError as e: - print(f"Error: {e}") - return 1 - - # Compute and display summary - summary = analyzer.compute_full_analysis() - - # Get raw node/edge counts per run from OCEL - run_stats = _get_run_stats_from_ocel(analyzer) - - # Show extraction stats per model (from run stats) - if run_stats: - print("INTRA-MODEL CONSISTENCY (stability across runs)") - print("-" * 75) - print(f" {'Model':<22} {'Nodes':<25} {'Edges':<25}") - print( - f" {'':<22} {'Min-Max (Stable/Var)':<25} {'Min-Max (Stable/Var)':<25}" - ) - print("-" * 75) - for model, runs in sorted(run_stats.items()): - node_vals = [n for n, e in runs] - edge_vals = [e for n, e in runs] - node_min, node_max = min(node_vals), max(node_vals) - edge_min, edge_max = min(edge_vals), max(edge_vals) - node_var = node_max - node_min - edge_var = edge_max - edge_min - # Stable = minimum (guaranteed in all runs), Unstable = variance - node_str = f"{node_min}-{node_max} ({node_min}/{node_var})" - edge_str = f"{edge_min}-{edge_max} ({edge_min}/{edge_var})" - print(f" {model:<22} {node_str:<25} {edge_str:<25}") - print() - - # Intra-model consistency (structural edges from OCEL) - if summary.intra_model: - print("STRUCTURAL EDGE CONSISTENCY (OCEL tracked)") - print("-" * 70) - print(f" {'Model':<22} {'Stable':<10} {'Unstable':<10} {'%':<8}") - print("-" * 70) - for m in summary.intra_model: - stable_edges = len(m.stable_edges) - unstable_edges = len(m.unstable_edges) - total_edges = stable_edges + unstable_edges - edge_pct = ( - (stable_edges / total_edges * 100) if total_edges > 0 else 100 - ) - print( - f" {m.model:<22} {stable_edges:<10} {unstable_edges:<10} {edge_pct:.0f}%" - ) - print() - - # Inter-model consistency - if summary.inter_model: - print("INTER-MODEL CONSISTENCY (agreement across models)") - print("-" * 70) - for im in summary.inter_model: - edges_sets = [set(e) for e in im.edges_by_model.values()] - total_edges = len(set().union(*edges_sets)) if edges_sets else 0 - overlap_edges = len(im.edge_overlap) - pct = im.edge_jaccard * 100 - print(f" {im.repository}:") - print( - f" Structural edges: {overlap_edges}/{total_edges} stable ({pct:.0f}%)" - ) - print() - - # Hotspots - if summary.localization.hotspots: - print("INCONSISTENCY HOTSPOTS") - print("-" * 50) - for h in summary.localization.hotspots: - print( - f" [{h['severity'].upper()}] {h['type']}: {h['name']} ({h['consistency']:.1f}%)" - ) - print() - - # Export - output_path = analyzer.export_summary(output, format=fmt) - print(f"Analysis exported to: {output_path}") - - # Save to DB - analyzer.save_metrics_to_db() - print("Metrics saved to database.") - - return 0 - - -def cmd_benchmark_models(args: argparse.Namespace) -> int: - """List available model configurations.""" - with PipelineSession() as session: - models = session.list_benchmark_models() - - if not models: - print("\nNo benchmark model configurations found.") - print("\nConfigure models in .env with pattern:") - print(" LLM_BENCH_{NAME}_PROVIDER=azure|openai|anthropic|ollama") - print(" LLM_BENCH_{NAME}_MODEL=model-id") - print(" LLM_BENCH_{NAME}_URL=api-url (optional)") - print(" LLM_BENCH_{NAME}_KEY=api-key (optional)") - return 0 - - print(f"\n{'=' * 60}") - print("AVAILABLE BENCHMARK MODELS") - print(f"{'=' * 60}\n") - - for name, cfg in sorted(models.items()): - print(f" {name}") - print(f" Provider: {cfg.provider}") - print(f" Model: {cfg.model}") - if cfg.api_url: - print(f" URL: {cfg.api_url[:50]}...") - print() - - return 0 - - -def cmd_benchmark_deviations(args: argparse.Namespace) -> int: - """Analyze config deviations for a benchmark session.""" - session_id = args.session_id - output = getattr(args, "output", None) - sort_by = getattr(args, "sort_by", "deviation_count") - - print(f"\n{'=' * 60}") - print(f"CONFIG DEVIATION ANALYSIS: {session_id}") - print(f"{'=' * 60}\n") - - with PipelineSession() as session: - try: - analyzer = session.analyze_config_deviations(session_id) - except ValueError as e: - print(f"Error: {e}") - return 1 - - # Run analysis - report = analyzer.analyze() - - # Display summary - print(f"Total runs analyzed: {report.total_runs}") - print(f"Total deviations: {report.total_deviations}") - print(f"Overall consistency: {report.overall_consistency:.1%}\n") - - if report.config_deviations: - print("CONFIG DEVIATIONS (sorted by deviation count)") - print("-" * 60) - - for cd in report.config_deviations: - status = ( - "LOW" - if cd.consistency_score >= 0.8 - else "MEDIUM" - if cd.consistency_score >= 0.5 - else "HIGH" - ) - print(f" [{status}] {cd.config_type}: {cd.config_id}") - print(f" Consistency: {cd.consistency_score:.1%}") - print(f" Deviations: {cd.deviation_count}/{cd.total_objects}") - if cd.deviating_objects[:3]: - print(f" Sample: {', '.join(cd.deviating_objects[:3])}") - print() - - # Get recommendations - from deriva.modules.analysis import generate_recommendations - - recommendations = generate_recommendations(report.config_deviations) - if recommendations: - print("RECOMMENDATIONS") - print("-" * 60) - for rec in recommendations: - print(f" • {rec}") - print() - - # Export - if sort_by != "deviation_count": - output_path = analyzer.export_sorted_json(output, sort_by=sort_by) - else: - output_path = analyzer.export_json(output) - - print(f"Report exported to: {output_path}") - - return 0 - - -def cmd_benchmark_comprehensive(args: argparse.Namespace) -> int: - """Run comprehensive benchmark analysis.""" - session_ids = args.session_ids - output = getattr(args, "output", "workspace/analysis") - format_type = getattr(args, "format", "both") - _include_semantic = not getattr( - args, "no_semantic", False - ) # Reserved for future use - - print(f"\n{'=' * 60}") - print("BENCHMARK ANALYSIS") - print(f"{'=' * 60}\n") - print(f"Sessions: {', '.join(session_ids)}") - - from deriva.services.analysis import BenchmarkAnalyzer - - with PipelineSession() as session: - try: - analyzer = BenchmarkAnalyzer( - session_ids=list(session_ids), - engine=session._engine, - ) - except ValueError as e: - print(f"Error: {e}") - return 1 - - # Generate report - print("\nRunning analysis...") - report = analyzer.generate_report() - - # Display summary - print(f"\nRepositories: {', '.join(report.repositories)}") - print(f"Models: {', '.join(report.models)}") - print("\nOVERALL METRICS") - print("-" * 40) - print(f" Consistency: {report.overall_consistency:.1%}") - print(f" Precision: {report.overall_precision:.1%}") - print(f" Recall: {report.overall_recall:.1%}") - - # Show per-repo summary - if report.stability_reports: - print("\nPER-REPOSITORY STABILITY") - print("-" * 40) - for repo, phases in report.stability_reports.items(): - if "derivation" in phases: - print( - f" {repo}: {phases['derivation'].overall_consistency:.1%} derivation consistency" - ) - - # Show semantic match summary - if report.semantic_reports: - print("\nSEMANTIC MATCH SUMMARY") - print("-" * 40) - for repo, sr in report.semantic_reports.items(): - print( - f" {repo}: P={sr.element_precision:.1%} R={sr.element_recall:.1%} F1={sr.element_f1:.2f}" - ) - - # Show best/worst types from cross-repo - if report.cross_repo: - if report.cross_repo.best_element_types: - print("\nBEST ELEMENT TYPES (highest consistency)") - print("-" * 40) - for t, score in report.cross_repo.best_element_types[:5]: - print(f" {t}: {score:.1%}") - - if report.cross_repo.worst_element_types: - print("\nWORST ELEMENT TYPES (lowest consistency)") - print("-" * 40) - for t, score in report.cross_repo.worst_element_types[:5]: - print(f" {t}: {score:.1%}") - - # Show recommendations - if report.recommendations: - print("\nRECOMMENDATIONS") - print("-" * 40) - for rec in report.recommendations[:10]: - print(f" • {rec}") - - # Export - print(f"\nExporting to: {output}") - paths = analyzer.export_all(output) - - if format_type in ("json", "both"): - print(f" JSON: {paths.get('json', 'N/A')}") - if format_type in ("markdown", "both"): - print(f" Markdown: {paths.get('markdown', 'N/A')}") - - return 0 + typer.echo(f"Error: {result.get('error', 'Unknown error')}", err=True) + raise typer.Exit(1) # ============================================================================= @@ -1170,563 +289,13 @@ def cmd_benchmark_comprehensive(args: argparse.Namespace) -> int: # ============================================================================= -def create_parser() -> argparse.ArgumentParser: - """Create the argument parser.""" - parser = argparse.ArgumentParser( - prog="deriva", - description="Deriva CLI - Generate ArchiMate models from code repositories", - ) - - subparsers = parser.add_subparsers(dest="command", help="Available commands") - - # ------------------------------------------------------------------------- - # config command - # ------------------------------------------------------------------------- - config_parser = subparsers.add_parser( - "config", help="Manage pipeline configurations" - ) - config_subparsers = config_parser.add_subparsers( - dest="config_action", help="Config actions" - ) - - # config list - config_list = config_subparsers.add_parser("list", help="List configurations") - config_list.add_argument( - "step_type", - choices=["extraction", "derivation"], - help="Type of configuration to list", - ) - config_list.add_argument( - "--enabled", - action="store_true", - help="Only show enabled configurations", - ) - config_list.add_argument( - "--phase", - choices=["prep", "generate", "refine"], - help="Filter derivation by phase", - ) - config_list.set_defaults(func=cmd_config_list) - - # config show - config_show = config_subparsers.add_parser( - "show", help="Show configuration details" - ) - config_show.add_argument( - "step_type", - choices=["extraction", "derivation"], - help="Type of configuration", - ) - config_show.add_argument( - "name", help="Name of the configuration (node_type or step_name)" - ) - config_show.set_defaults(func=cmd_config_show) - - # config enable - config_enable = config_subparsers.add_parser( - "enable", help="Enable a configuration" - ) - config_enable.add_argument( - "step_type", - choices=["extraction", "derivation"], - help="Type of configuration", - ) - config_enable.add_argument("name", help="Name to enable") - config_enable.set_defaults(func=cmd_config_enable) - - # config disable - config_disable = config_subparsers.add_parser( - "disable", help="Disable a configuration" - ) - config_disable.add_argument( - "step_type", - choices=["extraction", "derivation"], - help="Type of configuration", - ) - config_disable.add_argument("name", help="Name to disable") - config_disable.set_defaults(func=cmd_config_disable) - - # config update (versioned) - config_update = config_subparsers.add_parser( - "update", help="Update configuration (creates new version)" - ) - config_update.add_argument( - "step_type", - choices=["extraction", "derivation"], - help="Type of configuration to update", - ) - config_update.add_argument("name", help="Name of the configuration to update") - config_update.add_argument( - "-i", - "--instruction", - type=str, - default=None, - help="New instruction text", - ) - config_update.add_argument( - "-e", - "--example", - type=str, - default=None, - help="New example text", - ) - config_update.add_argument( - "--instruction-file", - type=str, - default=None, - help="Read instruction from file", - ) - config_update.add_argument( - "--example-file", - type=str, - default=None, - help="Read example from file", - ) - config_update.add_argument( - "-q", - "--query", - type=str, - default=None, - help="New input_graph_query (derivation only)", - ) - config_update.add_argument( - "-s", - "--sources", - type=str, - default=None, - help="New input_sources (extraction only)", - ) - config_update.add_argument( - "-p", - "--params", - type=str, - default=None, - help="New params JSON (derivation only)", - ) - config_update.add_argument( - "--params-file", - type=str, - default=None, - help="Read params JSON from file (derivation only)", - ) - config_update.set_defaults(func=cmd_config_update) - - # config versions - config_versions = config_subparsers.add_parser( - "versions", help="Show active config versions" - ) - config_versions.set_defaults(func=cmd_config_versions) - - # config filetype (sub-subparser) - config_filetype = config_subparsers.add_parser( - "filetype", help="Manage file type registry" - ) - filetype_subparsers = config_filetype.add_subparsers( - dest="filetype_action", help="File type actions" - ) - - # filetype list - filetype_list = filetype_subparsers.add_parser( - "list", help="List registered file types" - ) - filetype_list.set_defaults(func=cmd_filetype_list) - - # filetype add - filetype_add = filetype_subparsers.add_parser("add", help="Add a file type") - filetype_add.add_argument( - "extension", help="File extension (e.g., '.py', 'Dockerfile', '*.test.js')" - ) - filetype_add.add_argument( - "file_type", - help="File type category (source, config, docs, data, dependency, test, template, unknown)", - ) - filetype_add.add_argument("subtype", help="Subtype (e.g., 'python', 'javascript')") - filetype_add.set_defaults(func=cmd_filetype_add) - - # filetype delete - filetype_delete = filetype_subparsers.add_parser( - "delete", help="Delete a file type" - ) - filetype_delete.add_argument("extension", help="Extension to delete") - filetype_delete.set_defaults(func=cmd_filetype_delete) - - # filetype stats - filetype_stats = filetype_subparsers.add_parser( - "stats", help="Show file type statistics" - ) - filetype_stats.set_defaults(func=cmd_filetype_stats) - - # ------------------------------------------------------------------------- - # run command - # ------------------------------------------------------------------------- - run_parser = subparsers.add_parser("run", help="Run pipeline stages") - run_parser.add_argument( - "stage", - choices=["extraction", "derivation", "all"], - help="Pipeline stage to run", - ) - run_parser.add_argument( - "--repo", - type=str, - default=None, - help="Specific repository to process (default: all repos)", - ) - run_parser.add_argument( - "--phase", - choices=["classify", "parse", "prep", "generate", "refine"], - help="Run specific phase: extraction (classify, parse) or derivation (prep, generate, refine)", - ) - run_parser.add_argument( - "-v", - "--verbose", - action="store_true", - help="Print detailed progress (disables progress bar)", - ) - run_parser.add_argument( - "-q", - "--quiet", - action="store_true", - help="Disable progress bar display", - ) - run_parser.add_argument( - "--no-llm", - action="store_true", - help="Skip LLM-based steps (structural extraction only)", - ) - run_parser.set_defaults(func=cmd_run) - - # ------------------------------------------------------------------------- - # status command - # ------------------------------------------------------------------------- - status_parser = subparsers.add_parser("status", help="Show pipeline status") - status_parser.set_defaults(func=cmd_status) - - # ------------------------------------------------------------------------- - # export command - # ------------------------------------------------------------------------- - export_parser = subparsers.add_parser( - "export", help="Export ArchiMate model to file" - ) - export_parser.add_argument( - "-o", - "--output", - type=str, - default="workspace/output/model.xml", - help="Output file path (default: workspace/output/model.xml)", - ) - export_parser.add_argument( - "-n", - "--name", - type=str, - default=None, - help="Model name (default: Deriva Model)", - ) - export_parser.add_argument( - "-v", - "--verbose", - action="store_true", - help="Print detailed progress", - ) - export_parser.set_defaults(func=cmd_export) - - # ------------------------------------------------------------------------- - # clear command - # ------------------------------------------------------------------------- - clear_parser = subparsers.add_parser("clear", help="Clear graph or model data") - clear_parser.add_argument( - "target", - choices=["graph", "model"], - help="Data layer to clear", - ) - clear_parser.set_defaults(func=cmd_clear) - - # ------------------------------------------------------------------------- - # repo command - # ------------------------------------------------------------------------- - repo_parser = subparsers.add_parser("repo", help="Manage repositories") - repo_subparsers = repo_parser.add_subparsers( - dest="repo_action", help="Repository actions" - ) - - # repo clone - repo_clone = repo_subparsers.add_parser("clone", help="Clone a repository") - repo_clone.add_argument("url", help="Repository URL to clone") - repo_clone.add_argument( - "-n", - "--name", - type=str, - default=None, - help="Custom name for the repository (default: derived from URL)", - ) - repo_clone.add_argument( - "-b", - "--branch", - type=str, - default=None, - help="Branch to clone (default: default branch)", - ) - repo_clone.add_argument( - "--overwrite", - action="store_true", - help="Overwrite existing repository if it exists", - ) - repo_clone.set_defaults(func=cmd_repo_clone) - - # repo list - repo_list = repo_subparsers.add_parser("list", help="List all repositories") - repo_list.add_argument( - "-d", - "--detailed", - action="store_true", - help="Show detailed information", - ) - repo_list.set_defaults(func=cmd_repo_list) - - # repo delete - repo_delete = repo_subparsers.add_parser("delete", help="Delete a repository") - repo_delete.add_argument("name", help="Repository name to delete") - repo_delete.add_argument( - "-f", - "--force", - action="store_true", - help="Force delete even with uncommitted changes", - ) - repo_delete.set_defaults(func=cmd_repo_delete) - - # repo info - repo_info = repo_subparsers.add_parser("info", help="Show repository details") - repo_info.add_argument("name", help="Repository name") - repo_info.set_defaults(func=cmd_repo_info) - - # ------------------------------------------------------------------------- - # benchmark command - # ------------------------------------------------------------------------- - benchmark_parser = subparsers.add_parser( - "benchmark", help="Multi-model benchmarking" - ) - benchmark_subparsers = benchmark_parser.add_subparsers( - dest="benchmark_action", help="Benchmark actions" - ) - - # benchmark run - benchmark_run = benchmark_subparsers.add_parser("run", help="Run benchmark matrix") - benchmark_run.add_argument( - "--repos", - type=str, - required=True, - help="Comma-separated list of repository names", - ) - benchmark_run.add_argument( - "--models", - type=str, - required=True, - help="Comma-separated list of model config names (from LLM_BENCH_* env vars)", - ) - benchmark_run.add_argument( - "-n", - "--runs", - type=int, - default=3, - help="Number of runs per (repo, model) combination (default: 3)", - ) - benchmark_run.add_argument( - "--stages", - type=str, - default=None, - help="Comma-separated list of stages (default: all)", - ) - benchmark_run.add_argument( - "-d", - "--description", - type=str, - default="", - help="Optional description for the benchmark session", - ) - benchmark_run.add_argument( - "-v", - "--verbose", - action="store_true", - help="Print detailed progress (disables progress bar)", - ) - benchmark_run.add_argument( - "-q", - "--quiet", - action="store_true", - help="Disable progress bar display", - ) - benchmark_run.add_argument( - "--no-cache", - action="store_true", - help="Disable LLM response caching globally", - ) - benchmark_run.add_argument( - "--nocache-configs", - type=str, - default=None, - help="Comma-separated list of config names to skip cache for (e.g., 'ApplicationComponent,DataObject')", - ) - benchmark_run.add_argument( - "--no-export-models", - action="store_true", - help="Disable exporting ArchiMate model files after each run", - ) - benchmark_run.add_argument( - "--no-clear", - action="store_true", - help="Don't clear graph/model between runs (keep existing data)", - ) - benchmark_run.add_argument( - "--bench-hash", - action="store_true", - help="Include repo/model/run in cache key for per-run isolation. Allows resuming failed runs with cache on.", - ) - benchmark_run.add_argument( - "--defer-relationships", - action="store_true", - help="Two-phase derivation: create all elements first, then derive relationships in one pass", - ) - benchmark_run.add_argument( - "--per-repo", - action="store_true", - help="Run each repository as a separate benchmark (default: combine all repos into one model)", - ) - benchmark_run.set_defaults(func=cmd_benchmark_run) - - # benchmark list - benchmark_list = benchmark_subparsers.add_parser( - "list", help="List benchmark sessions" - ) - benchmark_list.add_argument( - "-l", - "--limit", - type=int, - default=10, - help="Number of sessions to show (default: 10)", - ) - benchmark_list.set_defaults(func=cmd_benchmark_list) - - # benchmark analyze - benchmark_analyze = benchmark_subparsers.add_parser( - "analyze", help="Analyze benchmark results" - ) - benchmark_analyze.add_argument( - "session_id", - help="Benchmark session ID to analyze", - ) - benchmark_analyze.add_argument( - "-o", - "--output", - type=str, - default=None, - help="Output file for analysis (default: workspace/benchmarks/{session}/analysis/summary.json)", - ) - benchmark_analyze.add_argument( - "-f", - "--format", - choices=["json", "markdown"], - default="json", - help="Output format (default: json)", - ) - benchmark_analyze.set_defaults(func=cmd_benchmark_analyze) - - # benchmark models - benchmark_models = benchmark_subparsers.add_parser( - "models", help="List available model configs" - ) - benchmark_models.set_defaults(func=cmd_benchmark_models) - - # benchmark deviations - benchmark_deviations = benchmark_subparsers.add_parser( - "deviations", help="Analyze config deviations for a session" - ) - benchmark_deviations.add_argument( - "session_id", - help="Benchmark session ID to analyze", - ) - benchmark_deviations.add_argument( - "-o", - "--output", - type=str, - default=None, - help="Output file for deviation report (default: workspace/benchmarks/{session}/config_deviations.json)", - ) - benchmark_deviations.add_argument( - "-s", - "--sort-by", - choices=["deviation_count", "consistency_score", "total_objects"], - default="deviation_count", - help="Sort configs by this metric (default: deviation_count)", - ) - benchmark_deviations.set_defaults(func=cmd_benchmark_deviations) - - # benchmark comprehensive-analysis - benchmark_comprehensive = benchmark_subparsers.add_parser( - "comprehensive-analysis", - help="Run comprehensive analysis across multiple sessions", - ) - benchmark_comprehensive.add_argument( - "session_ids", - nargs="+", - help="Benchmark session IDs to analyze (one or more)", - ) - benchmark_comprehensive.add_argument( - "-o", - "--output", - type=str, - default="workspace/analysis", - help="Output directory for analysis files (default: workspace/analysis)", - ) - benchmark_comprehensive.add_argument( - "-f", - "--format", - choices=["json", "markdown", "both"], - default="both", - help="Output format (default: both)", - ) - benchmark_comprehensive.add_argument( - "--no-semantic", - action="store_true", - help="Skip semantic matching against reference models", - ) - benchmark_comprehensive.set_defaults(func=cmd_benchmark_comprehensive) - - return parser - - def main() -> int: """Main entry point.""" - parser = create_parser() - args = parser.parse_args() - - if not args.command: - parser.print_help() - return 0 - - if args.command == "config" and not args.config_action: - parser.parse_args(["config", "--help"]) + try: + app() return 0 - - if ( - args.command == "config" - and args.config_action == "filetype" - and not getattr(args, "filetype_action", None) - ): - parser.parse_args(["config", "filetype", "--help"]) - return 0 - - if args.command == "benchmark" and not getattr(args, "benchmark_action", None): - parser.parse_args(["benchmark", "--help"]) - return 0 - - if args.command == "repo" and not getattr(args, "repo_action", None): - parser.parse_args(["repo", "--help"]) - return 0 - - if hasattr(args, "func"): - return args.func(args) - - parser.print_help() - return 0 + except SystemExit as e: + return e.code if isinstance(e.code, int) else 1 if __name__ == "__main__": diff --git a/deriva/cli/commands/__init__.py b/deriva/cli/commands/__init__.py new file mode 100644 index 0000000..eed72ba --- /dev/null +++ b/deriva/cli/commands/__init__.py @@ -0,0 +1,7 @@ +""" +CLI command modules. + +Each module defines a typer subapp for its command group. +""" + +from __future__ import annotations diff --git a/deriva/cli/commands/benchmark.py b/deriva/cli/commands/benchmark.py new file mode 100644 index 0000000..4eada3c --- /dev/null +++ b/deriva/cli/commands/benchmark.py @@ -0,0 +1,510 @@ +""" +Benchmark CLI commands. + +Provides commands for multi-model benchmarking. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Annotated, Any + +import typer + +from deriva.cli.progress import create_benchmark_progress_reporter +from deriva.services.session import PipelineSession + +if TYPE_CHECKING: + pass + +app = typer.Typer(name="benchmark", help="Multi-model benchmarking") + + +def _get_run_stats_from_ocel(analyzer: Any) -> dict[str, list[tuple[int, int]]]: + """Extract node/edge counts per run from OCEL CompleteRun events.""" + result: dict[str, list[tuple[int, int]]] = {} + + for event in analyzer.ocel_log.events: + if event.activity != "CompleteRun": + continue + + model = event.objects.get("Model", [None])[0] + if not model: + continue + + stats = event.attributes.get("stats", {}) + extraction = stats.get("extraction", {}) + nodes = extraction.get("nodes_created", 0) + edges = extraction.get("edges_created", 0) + + if model not in result: + result[model] = [] + result[model].append((nodes, edges)) + + return result + + +@app.command("run") +def benchmark_run( + repos: Annotated[ + str, typer.Option("--repos", help="Comma-separated list of repository names") + ], + models: Annotated[ + str, typer.Option("--models", help="Comma-separated list of model config names") + ], + runs: Annotated[ + int, typer.Option("-n", "--runs", help="Number of runs per combination") + ] = 3, + stages: Annotated[ + str | None, typer.Option("--stages", help="Comma-separated list of stages") + ] = None, + description: Annotated[ + str, typer.Option("-d", "--description", help="Session description") + ] = "", + verbose: Annotated[ + bool, typer.Option("-v", "--verbose", help="Print detailed progress") + ] = False, + quiet: Annotated[ + bool, typer.Option("-q", "--quiet", help="Disable progress bar") + ] = False, + no_cache: Annotated[ + bool, typer.Option("--no-cache", help="Disable LLM response caching") + ] = False, + nocache_configs: Annotated[ + str | None, typer.Option("--nocache-configs", help="Configs to skip cache for") + ] = None, + no_export_models: Annotated[ + bool, typer.Option("--no-export-models", help="Disable model export") + ] = False, + no_clear: Annotated[ + bool, typer.Option("--no-clear", help="Don't clear graph between runs") + ] = False, + bench_hash: Annotated[ + bool, typer.Option("--bench-hash", help="Per-run cache isolation") + ] = False, + defer_relationships: Annotated[ + bool, typer.Option("--defer-relationships", help="Two-phase derivation") + ] = False, + per_repo: Annotated[ + bool, typer.Option("--per-repo", help="Run each repo separately") + ] = False, +) -> None: + """Run benchmark matrix.""" + repos_list = [r.strip() for r in repos.split(",")] + models_list = [m.strip() for m in models.split(",")] + stages_list = [s.strip() for s in stages.split(",")] if stages else None + nocache_configs_list = ( + [c.strip() for c in nocache_configs.split(",")] if nocache_configs else None + ) + + use_cache = not no_cache + export_models = not no_export_models + clear_between_runs = not no_clear + + # Calculate total runs based on mode + if per_repo: + total_runs = len(repos_list) * len(models_list) * runs + else: + total_runs = len(models_list) * runs + + typer.echo(f"\n{'=' * 60}") + typer.echo("DERIVA - Multi-Model Benchmark") + typer.echo(f"{'=' * 60}") + typer.echo(f"Repositories: {repos_list}") + typer.echo(f"Models: {models_list}") + typer.echo(f"Runs per combination: {runs}") + typer.echo(f"Mode: {'per-repo' if per_repo else 'combined'}") + typer.echo(f"Total runs: {total_runs}") + if stages_list: + typer.echo(f"Stages: {stages_list}") + typer.echo(f"Cache: {'enabled' if use_cache else 'disabled'}") + typer.echo(f"Export models: {'enabled' if export_models else 'disabled'}") + typer.echo(f"Clear between runs: {'yes' if clear_between_runs else 'no'}") + if bench_hash: + typer.echo("Bench hash: enabled (per-run cache isolation)") + if defer_relationships: + typer.echo("Defer relationships: enabled (two-phase derivation)") + if nocache_configs_list: + typer.echo(f"No-cache configs: {nocache_configs_list}") + typer.echo(f"{'=' * 60}\n") + + with PipelineSession() as session: + typer.echo("Connected to Neo4j") + + progress_reporter = create_benchmark_progress_reporter(quiet=quiet or verbose) + + with progress_reporter: + result = session.run_benchmark( + repositories=repos_list, + models=models_list, + runs=runs, + stages=stages_list, + description=description, + verbose=verbose, + use_cache=use_cache, + nocache_configs=nocache_configs_list, + progress=progress_reporter, + export_models=export_models, + clear_between_runs=clear_between_runs, + bench_hash=bench_hash, + defer_relationships=defer_relationships, + per_repo=per_repo, + ) + + typer.echo(f"\n{'=' * 60}") + typer.echo("BENCHMARK COMPLETE") + typer.echo(f"{'=' * 60}") + typer.echo(f"Session ID: {result.session_id}") + typer.echo(f"Runs completed: {result.runs_completed}") + typer.echo(f"Runs failed: {result.runs_failed}") + typer.echo(f"Duration: {result.duration_seconds:.1f}s") + typer.echo(f"OCEL log: {result.ocel_path}") + if export_models: + typer.echo(f"Model files: workspace/benchmarks/{result.session_id}/models/") + + if result.errors: + typer.echo(f"\nErrors ({len(result.errors)}):") + for err in result.errors[:5]: + typer.echo(f" - {err}") + if len(result.errors) > 5: + typer.echo(f" ... and {len(result.errors) - 5} more") + + typer.echo("\nTo analyze results:") + typer.echo(f" deriva benchmark analyze {result.session_id}") + + if not result.success: + raise typer.Exit(1) + + +@app.command("list") +def benchmark_list( + limit: Annotated[ + int, typer.Option("-l", "--limit", help="Number of sessions to show") + ] = 10, +) -> None: + """List benchmark sessions.""" + with PipelineSession() as session: + sessions = session.list_benchmarks(limit=limit) + + if not sessions: + typer.echo("No benchmark sessions found.") + return + + typer.echo(f"\n{'=' * 60}") + typer.echo("BENCHMARK SESSIONS") + typer.echo(f"{'=' * 60}") + + for s in sessions: + status_icon = ( + "" + if s["status"] == "completed" + else "" + if s["status"] == "failed" + else "" + ) + typer.echo(f"\n{status_icon} {s['session_id']}") + typer.echo(f" Status: {s['status']}") + typer.echo(f" Started: {s['started_at']}") + if s.get("description"): + typer.echo(f" Description: {s['description']}") + + typer.echo("") + + +@app.command("analyze") +def benchmark_analyze( + session_id: Annotated[str, typer.Argument(help="Benchmark session ID to analyze")], + output: Annotated[ + str | None, typer.Option("-o", "--output", help="Output file for analysis") + ] = None, + format: Annotated[ + str, typer.Option("-f", "--format", help="Output format") + ] = "json", +) -> None: + """Analyze benchmark results.""" + if format not in ("json", "markdown"): + typer.echo("Error: format must be 'json' or 'markdown'", err=True) + raise typer.Exit(1) + + typer.echo(f"\n{'=' * 60}") + typer.echo(f"ANALYZING BENCHMARK: {session_id}") + typer.echo(f"{'=' * 60}\n") + + with PipelineSession() as session: + try: + analyzer = session.analyze_benchmark(session_id) + except ValueError as e: + typer.echo(f"Error: {e}", err=True) + raise typer.Exit(1) + + # Compute and display summary + summary = analyzer.compute_full_analysis() + + # Get raw node/edge counts per run from OCEL + run_stats = _get_run_stats_from_ocel(analyzer) + + # Show extraction stats per model (from run stats) + if run_stats: + typer.echo("INTRA-MODEL CONSISTENCY (stability across runs)") + typer.echo("-" * 75) + typer.echo(f" {'Model':<22} {'Nodes':<25} {'Edges':<25}") + typer.echo( + f" {'':<22} {'Min-Max (Stable/Var)':<25} {'Min-Max (Stable/Var)':<25}" + ) + typer.echo("-" * 75) + for model, run_list in sorted(run_stats.items()): + node_vals = [n for n, e in run_list] + edge_vals = [e for n, e in run_list] + node_min, node_max = min(node_vals), max(node_vals) + edge_min, edge_max = min(edge_vals), max(edge_vals) + node_var = node_max - node_min + edge_var = edge_max - edge_min + node_str = f"{node_min}-{node_max} ({node_min}/{node_var})" + edge_str = f"{edge_min}-{edge_max} ({edge_min}/{edge_var})" + typer.echo(f" {model:<22} {node_str:<25} {edge_str:<25}") + typer.echo("") + + # Intra-model consistency + if summary.intra_model: + typer.echo("STRUCTURAL EDGE CONSISTENCY (OCEL tracked)") + typer.echo("-" * 70) + typer.echo(f" {'Model':<22} {'Stable':<10} {'Unstable':<10} {'%':<8}") + typer.echo("-" * 70) + for m in summary.intra_model: + stable_edges = len(m.stable_edges) + unstable_edges = len(m.unstable_edges) + total_edges = stable_edges + unstable_edges + edge_pct = ( + (stable_edges / total_edges * 100) if total_edges > 0 else 100 + ) + typer.echo( + f" {m.model:<22} {stable_edges:<10} {unstable_edges:<10} {edge_pct:.0f}%" + ) + typer.echo("") + + # Inter-model consistency + if summary.inter_model: + typer.echo("INTER-MODEL CONSISTENCY (agreement across models)") + typer.echo("-" * 70) + for im in summary.inter_model: + edges_sets = [set(e) for e in im.edges_by_model.values()] + total_edges = len(set().union(*edges_sets)) if edges_sets else 0 + overlap_edges = len(im.edge_overlap) + pct = im.edge_jaccard * 100 + typer.echo(f" {im.repository}:") + typer.echo( + f" Structural edges: {overlap_edges}/{total_edges} stable ({pct:.0f}%)" + ) + typer.echo("") + + # Hotspots + if summary.localization.hotspots: + typer.echo("INCONSISTENCY HOTSPOTS") + typer.echo("-" * 50) + for h in summary.localization.hotspots: + typer.echo( + f" [{h['severity'].upper()}] {h['type']}: {h['name']} ({h['consistency']:.1f}%)" + ) + typer.echo("") + + # Export + output_path = analyzer.export_summary(output, format=format) + typer.echo(f"Analysis exported to: {output_path}") + + # Save to DB + analyzer.save_metrics_to_db() + typer.echo("Metrics saved to database.") + + +@app.command("models") +def benchmark_models() -> None: + """List available model configurations.""" + with PipelineSession() as session: + models_dict = session.list_benchmark_models() + + if not models_dict: + typer.echo("\nNo benchmark model configurations found.") + typer.echo("\nConfigure models in .env with pattern:") + typer.echo(" LLM_BENCH_{NAME}_PROVIDER=azure|openai|anthropic|ollama") + typer.echo(" LLM_BENCH_{NAME}_MODEL=model-id") + typer.echo(" LLM_BENCH_{NAME}_URL=api-url (optional)") + typer.echo(" LLM_BENCH_{NAME}_KEY=api-key (optional)") + return + + typer.echo(f"\n{'=' * 60}") + typer.echo("AVAILABLE BENCHMARK MODELS") + typer.echo(f"{'=' * 60}\n") + + for name, cfg in sorted(models_dict.items()): + typer.echo(f" {name}") + typer.echo(f" Provider: {cfg.provider}") + typer.echo(f" Model: {cfg.model}") + if cfg.api_url: + typer.echo(f" URL: {cfg.api_url[:50]}...") + typer.echo("") + + +@app.command("deviations") +def benchmark_deviations( + session_id: Annotated[str, typer.Argument(help="Benchmark session ID to analyze")], + output: Annotated[ + str | None, typer.Option("-o", "--output", help="Output file") + ] = None, + sort_by: Annotated[ + str, typer.Option("-s", "--sort-by", help="Sort metric") + ] = "deviation_count", +) -> None: + """Analyze config deviations for a benchmark session.""" + if sort_by not in ("deviation_count", "consistency_score", "total_objects"): + typer.echo( + "Error: sort-by must be one of: deviation_count, consistency_score, total_objects", + err=True, + ) + raise typer.Exit(1) + + typer.echo(f"\n{'=' * 60}") + typer.echo(f"CONFIG DEVIATION ANALYSIS: {session_id}") + typer.echo(f"{'=' * 60}\n") + + with PipelineSession() as session: + try: + analyzer = session.analyze_config_deviations(session_id) + except ValueError as e: + typer.echo(f"Error: {e}", err=True) + raise typer.Exit(1) + + report = analyzer.analyze() + + typer.echo(f"Total runs analyzed: {report.total_runs}") + typer.echo(f"Total deviations: {report.total_deviations}") + typer.echo(f"Overall consistency: {report.overall_consistency:.1%}\n") + + if report.config_deviations: + typer.echo("CONFIG DEVIATIONS (sorted by deviation count)") + typer.echo("-" * 60) + + for cd in report.config_deviations: + status = ( + "LOW" + if cd.consistency_score >= 0.8 + else "MEDIUM" + if cd.consistency_score >= 0.5 + else "HIGH" + ) + typer.echo(f" [{status}] {cd.config_type}: {cd.config_id}") + typer.echo(f" Consistency: {cd.consistency_score:.1%}") + typer.echo( + f" Deviations: {cd.deviation_count}/{cd.total_objects}" + ) + if cd.deviating_objects[:3]: + typer.echo(f" Sample: {', '.join(cd.deviating_objects[:3])}") + typer.echo("") + + from deriva.modules.analysis import generate_recommendations + + recommendations = generate_recommendations(report.config_deviations) + if recommendations: + typer.echo("RECOMMENDATIONS") + typer.echo("-" * 60) + for rec in recommendations: + typer.echo(f" - {rec}") + typer.echo("") + + if sort_by != "deviation_count": + output_path = analyzer.export_sorted_json(output, sort_by=sort_by) + else: + output_path = analyzer.export_json(output) + + typer.echo(f"Report exported to: {output_path}") + + +@app.command("comprehensive-analysis") +def benchmark_comprehensive( + session_ids: Annotated[ + list[str], typer.Argument(help="Benchmark session IDs to analyze") + ], + output: Annotated[ + str, typer.Option("-o", "--output", help="Output directory") + ] = "workspace/analysis", + format: Annotated[ + str, typer.Option("-f", "--format", help="Output format") + ] = "both", + no_semantic: Annotated[ + bool, typer.Option("--no-semantic", help="Skip semantic matching") + ] = False, +) -> None: + """Run comprehensive benchmark analysis.""" + if format not in ("json", "markdown", "both"): + typer.echo("Error: format must be 'json', 'markdown', or 'both'", err=True) + raise typer.Exit(1) + + typer.echo(f"\n{'=' * 60}") + typer.echo("BENCHMARK ANALYSIS") + typer.echo(f"{'=' * 60}\n") + typer.echo(f"Sessions: {', '.join(session_ids)}") + + from deriva.services.analysis import BenchmarkAnalyzer + + with PipelineSession() as session: + try: + analyzer = BenchmarkAnalyzer( + session_ids=list(session_ids), + engine=session._engine, + ) + except ValueError as e: + typer.echo(f"Error: {e}", err=True) + raise typer.Exit(1) + + typer.echo("\nRunning analysis...") + report = analyzer.generate_report() + + typer.echo(f"\nRepositories: {', '.join(report.repositories)}") + typer.echo(f"Models: {', '.join(report.models)}") + typer.echo("\nOVERALL METRICS") + typer.echo("-" * 40) + typer.echo(f" Consistency: {report.overall_consistency:.1%}") + typer.echo(f" Precision: {report.overall_precision:.1%}") + typer.echo(f" Recall: {report.overall_recall:.1%}") + + if report.stability_reports: + typer.echo("\nPER-REPOSITORY STABILITY") + typer.echo("-" * 40) + for repo, phases in report.stability_reports.items(): + if "derivation" in phases: + typer.echo( + f" {repo}: {phases['derivation'].overall_consistency:.1%} derivation consistency" + ) + + if report.semantic_reports: + typer.echo("\nSEMANTIC MATCH SUMMARY") + typer.echo("-" * 40) + for repo, sr in report.semantic_reports.items(): + typer.echo( + f" {repo}: P={sr.element_precision:.1%} R={sr.element_recall:.1%} F1={sr.element_f1:.2f}" + ) + + if report.cross_repo: + if report.cross_repo.best_element_types: + typer.echo("\nBEST ELEMENT TYPES (highest consistency)") + typer.echo("-" * 40) + for t, score in report.cross_repo.best_element_types[:5]: + typer.echo(f" {t}: {score:.1%}") + + if report.cross_repo.worst_element_types: + typer.echo("\nWORST ELEMENT TYPES (lowest consistency)") + typer.echo("-" * 40) + for t, score in report.cross_repo.worst_element_types[:5]: + typer.echo(f" {t}: {score:.1%}") + + if report.recommendations: + typer.echo("\nRECOMMENDATIONS") + typer.echo("-" * 40) + for rec in report.recommendations[:10]: + typer.echo(f" - {rec}") + + typer.echo(f"\nExporting to: {output}") + paths = analyzer.export_all(output) + + if format in ("json", "both"): + typer.echo(f" JSON: {paths.get('json', 'N/A')}") + if format in ("markdown", "both"): + typer.echo(f" Markdown: {paths.get('markdown', 'N/A')}") diff --git a/deriva/cli/commands/config.py b/deriva/cli/commands/config.py new file mode 100644 index 0000000..be39ce6 --- /dev/null +++ b/deriva/cli/commands/config.py @@ -0,0 +1,356 @@ +""" +Config CLI commands. + +Provides commands for managing pipeline configurations. +""" + +from __future__ import annotations + +import json +from typing import Annotated + +import typer + +from deriva.services import config +from deriva.services.session import PipelineSession + +app = typer.Typer(name="config", help="Manage pipeline configurations") + +# Filetype subapp +filetype_app = typer.Typer(name="filetype", help="Manage file type registry") +app.add_typer(filetype_app) + + +# ============================================================================= +# Config Commands +# ============================================================================= + + +@app.command("list") +def config_list( + step_type: Annotated[str, typer.Argument(help="Type of configuration to list")], + enabled: Annotated[ + bool, typer.Option("--enabled", help="Only show enabled configurations") + ] = False, + phase: Annotated[ + str | None, + typer.Option( + "--phase", help="Filter derivation by phase (prep, generate, refine)" + ), + ] = None, +) -> None: + """List configurations for a step type.""" + if step_type not in ("extraction", "derivation"): + typer.echo( + f"Error: step_type must be 'extraction' or 'derivation', got '{step_type}'", + err=True, + ) + raise typer.Exit(1) + + with PipelineSession() as session: + steps = session.list_steps(step_type, enabled_only=enabled) + + if not steps: + typer.echo(f"No {step_type} configurations found.") + return + + typer.echo(f"\n{step_type.upper()} CONFIGURATIONS") + typer.echo("-" * 60) + + for step in steps: + status = "enabled" if step["enabled"] else "disabled" + name = step["name"] + seq = step["sequence"] + typer.echo(f" [{seq}] {name:<30} ({status})") + + typer.echo("") + + +@app.command("show") +def config_show( + step_type: Annotated[str, typer.Argument(help="Type of configuration")], + name: Annotated[str, typer.Argument(help="Name of the configuration")], +) -> None: + """Show detailed configuration for a specific step.""" + if step_type not in ("extraction", "derivation"): + typer.echo("Error: step_type must be 'extraction' or 'derivation'", err=True) + raise typer.Exit(1) + + with PipelineSession() as session: + if step_type == "extraction": + cfg = config.get_extraction_config(session._engine, name) + if not cfg: + typer.echo(f"Extraction config '{name}' not found.") + raise typer.Exit(1) + typer.echo(f"\nEXTRACTION CONFIG: {cfg.node_type}") + typer.echo("-" * 60) + typer.echo(f" Sequence: {cfg.sequence}") + typer.echo(f" Enabled: {cfg.enabled}") + typer.echo(f" Sources: {cfg.input_sources or 'None'}") + typer.echo(f" Instruction: {(cfg.instruction or '')[:100]}...") + typer.echo(f" Example: {(cfg.example or '')[:100]}...") + + elif step_type == "derivation": + cfg = config.get_derivation_config(session._engine, name) + if not cfg: + typer.echo(f"Derivation config '{name}' not found.") + raise typer.Exit(1) + typer.echo(f"\nDERIVATION CONFIG: {cfg.element_type}") + typer.echo("-" * 60) + typer.echo(f" Sequence: {cfg.sequence}") + typer.echo(f" Enabled: {cfg.enabled}") + typer.echo(f" Query: {(cfg.input_graph_query or '')[:100]}...") + typer.echo(f" Instruction: {(cfg.instruction or '')[:100]}...") + + typer.echo("") + + +@app.command("enable") +def config_enable( + step_type: Annotated[str, typer.Argument(help="Type of configuration")], + name: Annotated[str, typer.Argument(help="Name to enable")], +) -> None: + """Enable a configuration step.""" + if step_type not in ("extraction", "derivation"): + typer.echo("Error: step_type must be 'extraction' or 'derivation'", err=True) + raise typer.Exit(1) + + with PipelineSession() as session: + if session.enable_step(step_type, name): + typer.echo(f"Enabled {step_type} step: {name}") + else: + typer.echo(f"Step not found: {step_type}/{name}") + raise typer.Exit(1) + + +@app.command("disable") +def config_disable( + step_type: Annotated[str, typer.Argument(help="Type of configuration")], + name: Annotated[str, typer.Argument(help="Name to disable")], +) -> None: + """Disable a configuration step.""" + if step_type not in ("extraction", "derivation"): + typer.echo("Error: step_type must be 'extraction' or 'derivation'", err=True) + raise typer.Exit(1) + + with PipelineSession() as session: + if session.disable_step(step_type, name): + typer.echo(f"Disabled {step_type} step: {name}") + else: + typer.echo(f"Step not found: {step_type}/{name}") + raise typer.Exit(1) + + +@app.command("update") +def config_update( + step_type: Annotated[str, typer.Argument(help="Type of configuration to update")], + name: Annotated[str, typer.Argument(help="Name of the configuration to update")], + instruction: Annotated[ + str | None, typer.Option("-i", "--instruction", help="New instruction text") + ] = None, + example: Annotated[ + str | None, typer.Option("-e", "--example", help="New example text") + ] = None, + instruction_file: Annotated[ + str | None, + typer.Option("--instruction-file", help="Read instruction from file"), + ] = None, + example_file: Annotated[ + str | None, typer.Option("--example-file", help="Read example from file") + ] = None, + query: Annotated[ + str | None, + typer.Option("-q", "--query", help="New input_graph_query (derivation only)"), + ] = None, + sources: Annotated[ + str | None, + typer.Option("-s", "--sources", help="New input_sources (extraction only)"), + ] = None, + params: Annotated[ + str | None, + typer.Option("-p", "--params", help="New params JSON (derivation only)"), + ] = None, + params_file: Annotated[ + str | None, typer.Option("--params-file", help="Read params JSON from file") + ] = None, +) -> None: + """Update a configuration with versioning.""" + if step_type not in ("extraction", "derivation"): + typer.echo("Error: step_type must be 'extraction' or 'derivation'", err=True) + raise typer.Exit(1) + + # Read instruction from file if provided + if instruction_file: + try: + with open(instruction_file, encoding="utf-8") as f: + instruction = f.read() + except Exception as e: + typer.echo(f"Error reading instruction file: {e}", err=True) + raise typer.Exit(1) + + # Read example from file if provided + if example_file: + try: + with open(example_file, encoding="utf-8") as f: + example = f.read() + except Exception as e: + typer.echo(f"Error reading example file: {e}", err=True) + raise typer.Exit(1) + + # Read params from file if provided + if params_file: + try: + with open(params_file, encoding="utf-8") as f: + params = f.read() + except Exception as e: + typer.echo(f"Error reading params file: {e}", err=True) + raise typer.Exit(1) + + # Validate params is valid JSON if provided + if params: + try: + json.loads(params) + except json.JSONDecodeError as e: + typer.echo(f"Error: params must be valid JSON: {e}", err=True) + raise typer.Exit(1) + + with PipelineSession() as session: + if step_type == "derivation": + result = config.create_derivation_config_version( + session._engine, + name, + instruction=instruction, + example=example, + input_graph_query=query, + params=params, + ) + elif step_type == "extraction": + result = config.create_extraction_config_version( + session._engine, + name, + instruction=instruction, + example=example, + input_sources=sources, + ) + else: + typer.echo(f"Versioned updates not yet supported for: {step_type}") + raise typer.Exit(1) + + if result.get("success"): + typer.echo(f"Updated {step_type} config: {name}") + typer.echo(f" Version: {result['old_version']} -> {result['new_version']}") + if params: + typer.echo(" Params: updated") + else: + typer.echo(f"Error: {result.get('error', 'Unknown error')}", err=True) + raise typer.Exit(1) + + +@app.command("versions") +def config_versions() -> None: + """Show active config versions.""" + with PipelineSession() as session: + versions = config.get_active_config_versions(session._engine) + + typer.echo("\nACTIVE CONFIG VERSIONS") + typer.echo("=" * 60) + + for step_type in ["extraction", "derivation"]: + if versions.get(step_type): + typer.echo(f"\n{step_type.upper()}:") + for name, version in sorted(versions[step_type].items()): + typer.echo(f" {name:<30} v{version}") + + typer.echo("") + + +# ============================================================================= +# Filetype Commands +# ============================================================================= + + +@filetype_app.command("list") +def filetype_list() -> None: + """List all registered file types.""" + with PipelineSession() as session: + file_types = session.get_file_types() + + if not file_types: + typer.echo("No file types registered.") + return + + # Group by file_type + by_type: dict[str, list] = {} + for ft in file_types: + ft_type = ft.get("file_type", "unknown") + if ft_type not in by_type: + by_type[ft_type] = [] + by_type[ft_type].append(ft) + + typer.echo(f"\n{'=' * 60}") + typer.echo("FILE TYPE REGISTRY") + typer.echo(f"{'=' * 60}") + typer.echo(f"Total: {len(file_types)} registered\n") + + for ft_type in sorted(by_type.keys()): + entries = by_type[ft_type] + typer.echo(f"{ft_type.upper()} ({len(entries)}):") + for ft in sorted(entries, key=lambda x: x.get("extension", "")): + ext = ft.get("extension", "") + subtype = ft.get("subtype", "") + typer.echo(f" {ext:<25} -> {subtype}") + typer.echo("") + + +@filetype_app.command("add") +def filetype_add( + extension: Annotated[ + str, typer.Argument(help="File extension (e.g., '.py', 'Dockerfile')") + ], + file_type: Annotated[str, typer.Argument(help="File type category")], + subtype: Annotated[ + str, typer.Argument(help="Subtype (e.g., 'python', 'javascript')") + ], +) -> None: + """Add a new file type.""" + with PipelineSession() as session: + success = session.add_file_type(extension, file_type, subtype) + + if success: + typer.echo(f"Added file type: {extension} -> {file_type}/{subtype}") + else: + typer.echo( + f"Failed to add file type (may already exist): {extension}", err=True + ) + raise typer.Exit(1) + + +@filetype_app.command("delete") +def filetype_delete( + extension: Annotated[str, typer.Argument(help="Extension to delete")], +) -> None: + """Delete a file type.""" + with PipelineSession() as session: + success = session.delete_file_type(extension) + + if success: + typer.echo(f"Deleted file type: {extension}") + else: + typer.echo(f"File type not found: {extension}", err=True) + raise typer.Exit(1) + + +@filetype_app.command("stats") +def filetype_stats() -> None: + """Show file type statistics.""" + with PipelineSession() as session: + stats = session.get_file_type_stats() + + typer.echo(f"\n{'=' * 60}") + typer.echo("FILE TYPE STATISTICS") + typer.echo(f"{'=' * 60}\n") + + for ft_type, count in sorted(stats.items(), key=lambda x: -x[1]): + typer.echo(f" {ft_type:<20} {count}") + + typer.echo(f"\n {'Total':<20} {sum(stats.values())}") diff --git a/deriva/cli/commands/repo.py b/deriva/cli/commands/repo.py new file mode 100644 index 0000000..0df7c95 --- /dev/null +++ b/deriva/cli/commands/repo.py @@ -0,0 +1,149 @@ +""" +Repository CLI commands. + +Provides commands for managing repositories. +""" + +from __future__ import annotations + +from typing import Annotated + +import typer + +from deriva.services.session import PipelineSession + +app = typer.Typer(name="repo", help="Manage repositories") + + +@app.command("clone") +def repo_clone( + url: Annotated[str, typer.Argument(help="Repository URL to clone")], + name: Annotated[ + str | None, typer.Option("-n", "--name", help="Custom name for the repository") + ] = None, + branch: Annotated[ + str | None, typer.Option("-b", "--branch", help="Branch to clone") + ] = None, + overwrite: Annotated[ + bool, typer.Option("--overwrite", help="Overwrite existing repository") + ] = False, +) -> None: + """Clone a repository.""" + typer.echo(f"\n{'=' * 60}") + typer.echo("DERIVA - Cloning Repository") + typer.echo(f"{'=' * 60}") + typer.echo(f"URL: {url}") + if name: + typer.echo(f"Name: {name}") + if branch: + typer.echo(f"Branch: {branch}") + + with PipelineSession() as session: + result = session.clone_repository( + url=url, name=name, branch=branch, overwrite=overwrite + ) + if result.get("success"): + typer.echo("\nRepository cloned successfully!") + typer.echo(f" Name: {result.get('name', 'N/A')}") + typer.echo(f" Path: {result.get('path', 'N/A')}") + typer.echo(f" URL: {result.get('url', url)}") + else: + typer.echo(f"\nError: {result.get('error', 'Unknown error')}", err=True) + raise typer.Exit(1) + + +@app.command("list") +def repo_list( + detailed: Annotated[ + bool, typer.Option("-d", "--detailed", help="Show detailed information") + ] = False, +) -> None: + """List all repositories.""" + with PipelineSession() as session: + repos = session.get_repositories(detailed=detailed) + + if not repos: + typer.echo("\nNo repositories found.") + typer.echo(f"Workspace: {session.workspace_dir}") + typer.echo("\nClone a repository with:") + typer.echo(" deriva repo clone ") + return + + typer.echo(f"\n{'=' * 60}") + typer.echo("REPOSITORIES") + typer.echo(f"{'=' * 60}") + typer.echo(f"Workspace: {session.workspace_dir}\n") + + for repo in repos: + if detailed: + dirty = " (dirty)" if repo.get("is_dirty") else "" + typer.echo(f" {repo['name']}{dirty}") + typer.echo(f" URL: {repo.get('url', 'N/A')}") + typer.echo(f" Branch: {repo.get('branch', 'N/A')}") + typer.echo(f" Size: {repo.get('size_mb', 0):.2f} MB") + typer.echo(f" Cloned: {repo.get('cloned_at', 'N/A')}") + typer.echo("") + else: + typer.echo(f" {repo['name']}") + + typer.echo(f"\nTotal: {len(repos)} repositories") + + +@app.command("delete") +def repo_delete( + name: Annotated[str, typer.Argument(help="Repository name to delete")], + force: Annotated[ + bool, + typer.Option( + "-f", "--force", help="Force delete even with uncommitted changes" + ), + ] = False, +) -> None: + """Delete a repository.""" + typer.echo(f"\n{'=' * 60}") + typer.echo("DERIVA - Deleting Repository") + typer.echo(f"{'=' * 60}") + typer.echo(f"Repository: {name}") + + with PipelineSession() as session: + try: + result = session.delete_repository(name=name, force=force) + if result.get("success"): + typer.echo(f"\nRepository '{name}' deleted successfully.") + else: + typer.echo(f"\nError: {result.get('error', 'Unknown error')}", err=True) + raise typer.Exit(1) + except Exception as e: + typer.echo(f"\nError: {e}", err=True) + if "uncommitted changes" in str(e).lower(): + typer.echo("Use --force to delete anyway.") + raise typer.Exit(1) + + +@app.command("info") +def repo_info( + name: Annotated[str, typer.Argument(help="Repository name")], +) -> None: + """Show repository details.""" + with PipelineSession() as session: + try: + info = session.get_repository_info(name) + + if not info: + typer.echo(f"\nRepository '{name}' not found.", err=True) + raise typer.Exit(1) + + typer.echo(f"\n{'=' * 60}") + typer.echo(f"REPOSITORY: {info['name']}") + typer.echo(f"{'=' * 60}") + typer.echo(f" Path: {info.get('path', 'N/A')}") + typer.echo(f" URL: {info.get('url', 'N/A')}") + typer.echo(f" Branch: {info.get('branch', 'N/A')}") + typer.echo(f" Last Commit: {info.get('last_commit', 'N/A')}") + typer.echo(f" Dirty: {info.get('is_dirty', False)}") + typer.echo(f" Size: {info.get('size_mb', 0):.2f} MB") + typer.echo(f" Cloned At: {info.get('cloned_at', 'N/A')}") + typer.echo("") + except Exception as e: + typer.echo(f"\nError: {e}", err=True) + raise typer.Exit(1) diff --git a/deriva/cli/commands/run.py b/deriva/cli/commands/run.py new file mode 100644 index 0000000..45de9f2 --- /dev/null +++ b/deriva/cli/commands/run.py @@ -0,0 +1,221 @@ +""" +Run CLI commands. + +Provides commands for running pipeline stages, status, export, and clear. +""" + +from __future__ import annotations + +from typing import Annotated + +import typer + +from deriva.cli.progress import create_progress_reporter +from deriva.services.session import PipelineSession + +app = typer.Typer(name="run", help="Run pipeline stages") + + +def _print_extraction_result(result: dict) -> None: + """Print extraction results.""" + typer.echo(f"\n{'-' * 60}") + typer.echo("EXTRACTION RESULTS") + typer.echo(f"{'-' * 60}") + stats = result.get("stats", {}) + typer.echo(f" Repos processed: {stats.get('repos_processed', 0)}") + typer.echo(f" Nodes created: {stats.get('nodes_created', 0)}") + typer.echo(f" Edges created: {stats.get('edges_created', 0)}") + typer.echo(f" Steps completed: {stats.get('steps_completed', 0)}") + typer.echo(f" Steps skipped: {stats.get('steps_skipped', 0)}") + + if result.get("warnings"): + typer.echo(f"\nWarnings ({len(result['warnings'])}):") + for warn in result["warnings"][:5]: + typer.echo(f" - {warn}") + if len(result["warnings"]) > 5: + typer.echo(f" ... and {len(result['warnings']) - 5} more") + + if result.get("errors"): + typer.echo(f"\nErrors ({len(result['errors'])}):") + for err in result["errors"][:5]: + typer.echo(f" - {err}") + if len(result["errors"]) > 5: + typer.echo(f" ... and {len(result['errors']) - 5} more") + + +def _print_derivation_result(result: dict) -> None: + """Print derivation results.""" + typer.echo(f"\n{'-' * 60}") + typer.echo("DERIVATION RESULTS") + typer.echo(f"{'-' * 60}") + stats = result.get("stats", {}) + typer.echo(f" Elements created: {stats.get('elements_created', 0)}") + typer.echo(f" Relationships created: {stats.get('relationships_created', 0)}") + typer.echo(f" Elements validated: {stats.get('elements_validated', 0)}") + typer.echo(f" Issues found: {stats.get('issues_found', 0)}") + typer.echo(f" Steps completed: {stats.get('steps_completed', 0)}") + + issues = result.get("issues", []) + if issues: + typer.echo(f"\nIssues ({len(issues)}):") + for issue in issues[:10]: + severity = issue.get("severity", "warning") + msg = issue.get("message", "") + typer.echo(f" [{severity.upper()}] {msg}") + if len(issues) > 10: + typer.echo(f" ... and {len(issues) - 10} more") + + if result.get("errors"): + typer.echo(f"\nErrors ({len(result['errors'])}):") + for err in result["errors"][:5]: + typer.echo(f" - {err}") + + +def _print_pipeline_result(result: dict) -> None: + """Print full pipeline results.""" + typer.echo(f"\n{'=' * 60}") + typer.echo("PIPELINE COMPLETE") + typer.echo(f"{'=' * 60}") + + results = result.get("results", {}) + + if results.get("classification"): + stats = results["classification"].get("stats", {}) + typer.echo("\nClassification:") + typer.echo(f" Files classified: {stats.get('files_classified', 0)}") + typer.echo(f" Files undefined: {stats.get('files_undefined', 0)}") + + if results.get("extraction"): + stats = results["extraction"].get("stats", {}) + typer.echo("\nExtraction:") + typer.echo(f" Nodes created: {stats.get('nodes_created', 0)}") + + if results.get("derivation"): + stats = results["derivation"].get("stats", {}) + typer.echo("\nDerivation:") + typer.echo(f" Elements created: {stats.get('elements_created', 0)}") + typer.echo(f" Issues found: {stats.get('issues_found', 0)}") + + if result.get("errors"): + typer.echo(f"\nTotal errors: {len(result['errors'])}") + + +@app.callback(invoke_without_command=True) +def run_stage( + ctx: typer.Context, + stage: Annotated[ + str, typer.Argument(help="Pipeline stage to run (extraction, derivation, all)") + ], + repo: Annotated[ + str | None, typer.Option("--repo", help="Specific repository to process") + ] = None, + phase: Annotated[ + str | None, typer.Option("--phase", help="Run specific phase") + ] = None, + verbose: Annotated[ + bool, typer.Option("-v", "--verbose", help="Print detailed progress") + ] = False, + quiet: Annotated[ + bool, typer.Option("-q", "--quiet", help="Disable progress bar") + ] = False, + no_llm: Annotated[ + bool, typer.Option("--no-llm", help="Skip LLM-based steps") + ] = False, +) -> None: + """Run pipeline stages.""" + if stage not in ("extraction", "derivation", "all"): + typer.echo( + f"Error: stage must be 'extraction', 'derivation', or 'all', got '{stage}'", + err=True, + ) + raise typer.Exit(1) + + # Validate phase is appropriate for stage + extraction_phases = {"classify", "parse"} + derivation_phases = {"prep", "generate", "refine"} + if phase: + if stage == "extraction" and phase not in extraction_phases: + typer.echo(f"Error: Phase '{phase}' is not valid for extraction.", err=True) + typer.echo( + f"Valid extraction phases: {', '.join(sorted(extraction_phases))}" + ) + raise typer.Exit(1) + if stage == "derivation" and phase not in derivation_phases: + typer.echo(f"Error: Phase '{phase}' is not valid for derivation.", err=True) + typer.echo( + f"Valid derivation phases: {', '.join(sorted(derivation_phases))}" + ) + raise typer.Exit(1) + + typer.echo(f"\n{'=' * 60}") + typer.echo(f"DERIVA - Running {stage.upper()} pipeline") + typer.echo(f"{'=' * 60}") + + if repo: + typer.echo(f"Repository: {repo}") + if phase: + typer.echo(f"Phase: {phase}") + + with PipelineSession() as session: + typer.echo("Connected to Neo4j") + + # Show LLM status + llm_info = session.llm_info + if llm_info and not no_llm: + typer.echo(f"LLM configured: {llm_info['provider']}/{llm_info['model']}") + elif no_llm: + typer.echo("LLM disabled (--no-llm)") + else: + typer.echo("Warning: LLM not configured. LLM-based steps will be skipped.") + + # Create progress reporter + progress_reporter = create_progress_reporter(quiet=quiet or verbose) + + if stage == "extraction": + phases = [phase] if phase else None + with progress_reporter: + result = session.run_extraction( + repo_name=repo, + verbose=verbose, + no_llm=no_llm, + progress=progress_reporter, + phases=phases, + ) + _print_extraction_result(result) + + elif stage == "derivation": + if not llm_info: + typer.echo( + "Error: Derivation requires LLM. Configure LLM in .env file.", + err=True, + ) + raise typer.Exit(1) + phases = [phase] if phase else None + with progress_reporter: + result = session.run_derivation( + verbose=verbose, + phases=phases, + progress=progress_reporter, + ) + _print_derivation_result(result) + + elif stage == "all": + with progress_reporter: + result = session.run_pipeline( + repo_name=repo, + verbose=verbose, + progress=progress_reporter, + ) + _print_pipeline_result(result) + + if not result.get("success"): + raise typer.Exit(1) + + +# ============================================================================= +# Additional Commands (status, export, clear) +# ============================================================================= + + +# These are standalone commands, not subcommands of run +# They will be added to the main app directly diff --git a/deriva/common/cache_utils.py b/deriva/common/cache_utils.py index f418f21..fb27734 100644 --- a/deriva/common/cache_utils.py +++ b/deriva/common/cache_utils.py @@ -1,9 +1,9 @@ """ Common caching utilities for Deriva. -Provides a base class for two-tier (memory + disk) caching and utilities -for generating cache keys. Used by LLM cache, graph cache, and other -caching implementations. +Provides a base class for disk caching using diskcache (SQLite-backed) +and utilities for generating cache keys. Used by LLM cache, graph cache, +and other caching implementations. Usage: from deriva.common.cache_utils import BaseDiskCache, hash_inputs @@ -21,6 +21,8 @@ def generate_key(self, *args) -> str: from pathlib import Path from typing import Any +import diskcache + from deriva.common.exceptions import CacheError @@ -122,16 +124,15 @@ def tuple_to_dict(t: tuple) -> dict | list | Any: class BaseDiskCache: """ - Base class for two-tier (memory + disk) caching with JSON persistence. - - Provides a generic caching interface that stores entries in both memory - (for fast access) and on disk (for persistence across runs). + Base class for disk caching using diskcache (SQLite-backed). - Subclasses should implement domain-specific key generation. + Provides a generic caching interface that stores entries in SQLite + for efficient persistence and retrieval. Includes export functionality + for auditing. Attributes: - cache_dir: Path to the directory storing cache files - _memory_cache: In-memory cache dictionary + cache_dir: Path to the directory storing cache data + _cache: diskcache.Cache instance Example: class MyCache(BaseDiskCache): @@ -147,58 +148,64 @@ def get_or_compute(self, key: str, compute_fn) -> Any: return result """ - def __init__(self, cache_dir: str | Path): + # Default size limit: 1GB + DEFAULT_SIZE_LIMIT = 2**30 + + def __init__(self, cache_dir: str | Path, size_limit: int | None = None): """ Initialize cache with specified directory. Args: cache_dir: Directory to store cache files (created if not exists) + size_limit: Maximum cache size in bytes (default: 1GB) """ self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) - self._memory_cache: dict[str, dict[str, Any]] = {} - def get_from_memory(self, cache_key: str) -> dict[str, Any] | None: + # Initialize diskcache + self._cache = diskcache.Cache( + str(self.cache_dir), + size_limit=size_limit or self.DEFAULT_SIZE_LIMIT, + ) + + def get(self, cache_key: str) -> dict[str, Any] | None: """ - Retrieve cached data from in-memory cache. + Retrieve cached data. Args: cache_key: The cache key Returns: Cached data dict or None if not found + + Raises: + CacheError: If cache is corrupted """ - return self._memory_cache.get(cache_key) + try: + result = self._cache.get(cache_key) + return result + except Exception as e: + raise CacheError(f"Error reading from cache: {e}") from e - def get_from_disk(self, cache_key: str) -> dict[str, Any] | None: + def get_from_memory(self, cache_key: str) -> dict[str, Any] | None: """ - Retrieve cached data from disk. + Retrieve cached data (alias for get, kept for backward compatibility). + + diskcache handles its own memory caching internally. Args: cache_key: The cache key Returns: Cached data dict or None if not found - - Raises: - CacheError: If cache file is corrupted """ - cache_file = self.cache_dir / f"{cache_key}.json" - - if not cache_file.exists(): - return None + return self.get(cache_key) - try: - with open(cache_file, encoding="utf-8") as f: - return json.load(f) - except json.JSONDecodeError as e: - raise CacheError(f"Corrupted cache file: {cache_file}") from e - except Exception as e: - raise CacheError(f"Error reading cache file: {e}") from e - - def get(self, cache_key: str) -> dict[str, Any] | None: + def get_from_disk(self, cache_key: str) -> dict[str, Any] | None: """ - Retrieve cached data, checking memory first, then disk. + Retrieve cached data (alias for get, kept for backward compatibility). + + diskcache uses SQLite, not individual files. Args: cache_key: The cache key @@ -206,79 +213,60 @@ def get(self, cache_key: str) -> dict[str, Any] | None: Returns: Cached data dict or None if not found """ - # Check memory cache first - cached = self.get_from_memory(cache_key) - if cached is not None: - return cached - - # Check disk cache - cached = self.get_from_disk(cache_key) - if cached is not None: - # Populate memory cache for faster future access - self._memory_cache[cache_key] = cached - - return cached + return self.get(cache_key) - def set(self, cache_key: str, data: dict[str, Any]) -> None: + def set( + self, cache_key: str, data: dict[str, Any], expire: float | None = None + ) -> None: """ - Store data in both memory and disk cache. + Store data in cache. Args: cache_key: The cache key - data: Dictionary to cache (must be JSON-serializable) + data: Dictionary to cache + expire: Optional TTL in seconds Raises: - CacheError: If unable to write to disk + CacheError: If unable to write to cache """ - # Store in memory - self._memory_cache[cache_key] = data - - # Store on disk - cache_file = self.cache_dir / f"{cache_key}.json" try: - with open(cache_file, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2, default=str) + self._cache.set(cache_key, data, expire=expire) except Exception as e: - raise CacheError(f"Error writing cache file: {e}") from e + raise CacheError(f"Error writing to cache: {e}") from e def invalidate(self, cache_key: str) -> None: """ - Remove entry from both memory and disk cache. + Remove entry from cache. Args: cache_key: The cache key to invalidate """ - # Remove from memory - self._memory_cache.pop(cache_key, None) - - # Remove from disk - cache_file = self.cache_dir / f"{cache_key}.json" - if cache_file.exists(): - try: - cache_file.unlink() - except Exception as e: - raise CacheError(f"Error deleting cache file: {e}") from e + try: + self._cache.delete(cache_key) + except Exception as e: + raise CacheError(f"Error deleting cache entry: {e}") from e def clear_memory(self) -> None: - """Clear the in-memory cache.""" - self._memory_cache.clear() + """Clear the in-memory portion of cache (triggers SQLite cleanup).""" + try: + self._cache.cull() + except Exception: + pass # Cull is optional optimization def clear_disk(self) -> None: """ - Clear all cache files from disk. + Clear all cache entries. Raises: - CacheError: If unable to delete cache files + CacheError: If unable to clear cache """ try: - for cache_file in self.cache_dir.glob("*.json"): - cache_file.unlink() + self._cache.clear() except Exception as e: - raise CacheError(f"Error clearing disk cache: {e}") from e + raise CacheError(f"Error clearing cache: {e}") from e def clear_all(self) -> None: - """Clear both memory and disk caches.""" - self.clear_memory() + """Clear the entire cache.""" self.clear_disk() def get_stats(self) -> dict[str, Any]: @@ -287,31 +275,98 @@ def get_stats(self) -> dict[str, Any]: Returns: Dictionary with: - - memory_entries: Number of entries in memory - - disk_entries: Number of files on disk - - disk_size_bytes: Total size of cache files - - disk_size_mb: Total size in megabytes + - entries: Number of entries in cache + - size_bytes: Total size of cache + - size_mb: Total size in megabytes - cache_dir: Path to cache directory + - volume: diskcache volume stats """ - disk_files = list(self.cache_dir.glob("*.json")) - total_size = sum(f.stat().st_size for f in disk_files) + try: + volume = self._cache.volume() + except Exception: + volume = 0 + + entry_count = len(self._cache) return { - "memory_entries": len(self._memory_cache), - "disk_entries": len(disk_files), - "disk_size_bytes": total_size, - "disk_size_mb": round(total_size / (1024 * 1024), 2), + "memory_entries": entry_count, # Kept for backward compat + "disk_entries": entry_count, # Kept for backward compat + "entries": entry_count, + "disk_size_bytes": volume, + "disk_size_mb": round(volume / (1024 * 1024), 2), + "size_bytes": volume, + "size_mb": round(volume / (1024 * 1024), 2), "cache_dir": str(self.cache_dir), } def keys(self) -> list[str]: """ - Get all cache keys (from disk). + Get all cache keys. Returns: List of cache keys """ - return [f.stem for f in self.cache_dir.glob("*.json")] + return list(self._cache.iterkeys()) + + def export_to_json( + self, output_path: str | Path, include_values: bool = True + ) -> int: + """ + Export cache contents to JSON for auditing. + + Args: + output_path: Path to write JSON file + include_values: If True, include cached values; if False, keys only + + Returns: + Number of entries exported + + Example: + cache = BaseDiskCache("./my_cache") + count = cache.export_to_json("./cache_audit.json") + print(f"Exported {count} entries") + """ + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + entries = [] + for key in self._cache.iterkeys(): + entry = {"key": key} + if include_values: + try: + entry["value"] = self._cache[key] + except KeyError: + entry["value"] = None + entry["error"] = "Key expired or deleted during export" + entries.append(entry) + + with open(output_path, "w", encoding="utf-8") as f: + json.dump( + { + "cache_dir": str(self.cache_dir), + "entry_count": len(entries), + "entries": entries, + }, + f, + indent=2, + default=str, + ) + + return len(entries) + + def close(self) -> None: + """Close the cache connection.""" + try: + self._cache.close() + except Exception: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False __all__ = [ diff --git a/deriva/common/logging.py b/deriva/common/logging.py index e8c0e48..6fc5f33 100644 --- a/deriva/common/logging.py +++ b/deriva/common/logging.py @@ -1,5 +1,5 @@ """ -Logging module - JSON Lines based logging for pipeline runs. +Logging module - Structured logging for pipeline runs using structlog. Logging Levels: - Level 1: High-level phases (classification, extraction, derivation, validation) @@ -13,11 +13,13 @@ import json import logging -from dataclasses import asdict, dataclass +from contextlib import contextmanager from datetime import datetime from enum import Enum from pathlib import Path -from typing import Any +from typing import IO, Any + +import structlog __all__ = [ "LogLevel", @@ -50,36 +52,112 @@ class LogStatus(str, Enum): SKIPPED = "skipped" -@dataclass class LogEntry: - """A single log entry.""" - - level: int - phase: str - status: str - timestamp: str - message: str - step: str | None = None - sequence: int | None = None - duration_ms: int | None = None - items_processed: int | None = None - items_created: int | None = None - items_failed: int | None = None - stats: dict[str, Any] | None = None - error: str | None = None + """A single log entry - compatibility wrapper for structlog output.""" + + def __init__( + self, + level: int, + phase: str, + status: str, + timestamp: str, + message: str, + step: str | None = None, + sequence: int | None = None, + duration_ms: int | None = None, + items_processed: int | None = None, + items_created: int | None = None, + items_failed: int | None = None, + stats: dict[str, Any] | None = None, + error: str | None = None, + ): + self.level = level + self.phase = phase + self.status = status + self.timestamp = timestamp + self.message = message + self.step = step + self.sequence = sequence + self.duration_ms = duration_ms + self.items_processed = items_processed + self.items_created = items_created + self.items_failed = items_failed + self.stats = stats + self.error = error def to_dict(self) -> dict[str, Any]: """Convert to dictionary, excluding None values.""" - return {k: v for k, v in asdict(self).items() if v is not None} + result = { + "level": self.level, + "phase": self.phase, + "status": self.status, + "timestamp": self.timestamp, + "message": self.message, + } + if self.step is not None: + result["step"] = self.step + if self.sequence is not None: + result["sequence"] = self.sequence + if self.duration_ms is not None: + result["duration_ms"] = self.duration_ms + if self.items_processed is not None: + result["items_processed"] = self.items_processed + if self.items_created is not None: + result["items_created"] = self.items_created + if self.items_failed is not None: + result["items_failed"] = self.items_failed + if self.stats is not None: + result["stats"] = self.stats + if self.error is not None: + result["error"] = self.error + return result def to_json(self) -> str: """Convert to JSON string.""" return json.dumps(self.to_dict()) +def _jsonl_renderer( + logger: structlog.types.WrappedLogger, + method_name: str, + event_dict: structlog.types.EventDict, +) -> str: + """Custom renderer that outputs JSONL format compatible with existing logs.""" + return json.dumps(event_dict, default=str) + + +def _create_structlog_logger(file_handle: IO[str]) -> structlog.stdlib.BoundLogger: + """Create a structlog logger configured for JSONL file output.""" + + def file_writer( + logger: structlog.types.WrappedLogger, + method_name: str, + event_dict: structlog.types.EventDict, + ) -> str: + """Write to file and return the JSON string.""" + json_str = json.dumps(event_dict, default=str) + file_handle.write(json_str + "\n") + file_handle.flush() + return json_str + + structlog.configure( + processors=[ + structlog.stdlib.add_log_level, + structlog.processors.TimeStamper(fmt="iso"), + file_writer, + ], + wrapper_class=structlog.stdlib.BoundLogger, + context_class=dict, + logger_factory=structlog.PrintLoggerFactory(), + cache_logger_on_first_use=False, + ) + + return structlog.get_logger() + + class RunLogger: """ - Logger for a single pipeline run. + Logger for a single pipeline run using structlog. Creates and appends to a JSONL file in workspace/logs/run_{id}/. """ @@ -97,7 +175,6 @@ def __init__(self, run_id: int, logs_dir: str = "workspace/logs"): # Resolve logs_dir relative to project root logs_path = Path(logs_dir) if not logs_path.is_absolute(): - # Resolve relative to project root (2 levels up from logging.py: common -> src -> root) project_root = Path(__file__).parent.parent.parent logs_path = project_root / logs_path @@ -112,11 +189,25 @@ def __init__(self, run_id: int, logs_dir: str = "workspace/logs"): datetime_str = self.start_time.strftime("%Y%m%d_%H%M%S") self.log_file = self.run_dir / f"log_{datetime_str}.jsonl" + # Open file handle for structlog + self._file_handle: IO[str] | None = None + # Track current phase for step logging self._current_phase: str | None = None self._phase_start: datetime | None = None self._step_sequence: int = 0 + # Bound logger with run context + self._logger: structlog.stdlib.BoundLogger | None = None + + def _ensure_logger(self) -> structlog.stdlib.BoundLogger: + """Ensure the structlog logger is initialized.""" + if self._logger is None: + self._file_handle = open(self.log_file, "a", encoding="utf-8") + self._logger = _create_structlog_logger(self._file_handle) + self._logger = self._logger.bind(run_id=self.run_id) + return self._logger + def _write_entry(self, entry: LogEntry) -> None: """Write a log entry to the JSONL file.""" with open(self.log_file, "a", encoding="utf-8") as f: @@ -130,6 +221,16 @@ def _elapsed_ms(self, start: datetime) -> int: """Calculate elapsed milliseconds since start.""" return int((datetime.now() - start).total_seconds() * 1000) + def close(self) -> None: + """Close the file handle.""" + if hasattr(self, "_file_handle") and self._file_handle: + self._file_handle.close() + self._file_handle = None + + def __del__(self) -> None: + """Cleanup on deletion.""" + self.close() + # ==================== Level 1: Phase Logging ==================== def phase_start(self, phase: str, message: str = "") -> None: @@ -327,20 +428,42 @@ def step_skipped(self, step: str, message: str = "") -> None: ) self._write_entry(entry) - # ==================== Level 3: Detail Logging ==================== + # ==================== Level 2: Step Context Manager ==================== - def detail_file_classified( - self, file_path: str, file_type: str, subtype: str, extension: str - ) -> None: + @contextmanager + def step(self, step_name: str, message: str = ""): """ - Log a successfully classified file (Level 3). + Context manager for step logging with automatic timing. Args: - file_path: Path to the file - file_type: Classified file type (e.g., 'documentation', 'code') - subtype: File subtype (e.g., 'markdown', 'python') - extension: File extension + step_name: Name of the step + message: Optional start message + + Yields: + StepContext for tracking items processed/created + + Example: + with logger.step("Repository") as ctx: + # do work + ctx.items_processed = 10 + ctx.items_created = 5 """ + ctx = self.step_start(step_name, message) + try: + yield ctx + except Exception as e: + ctx.error(str(e)) + raise + else: + if not ctx._completed: + ctx.complete() + + # ==================== Level 3: Detail Logging ==================== + + def detail_file_classified( + self, file_path: str, file_type: str, subtype: str, extension: str + ) -> None: + """Log a successfully classified file (Level 3).""" entry = LogEntry( level=LogLevel.DETAIL, phase=self._current_phase or "classification", @@ -357,13 +480,7 @@ def detail_file_classified( self._write_entry(entry) def detail_file_unclassified(self, file_path: str, extension: str) -> None: - """ - Log an unclassified file (Level 3). - - Args: - file_path: Path to the file - extension: Unknown file extension - """ + """Log an unclassified file (Level 3).""" entry = LogEntry( level=LogLevel.DETAIL, phase=self._current_phase or "classification", @@ -392,22 +509,7 @@ def detail_extraction( success: bool = True, error: str | None = None, ) -> None: - """ - Log an LLM extraction detail (Level 3). - - Args: - file_path: Path to the source file - node_type: Type of node being extracted (e.g., 'BusinessConcept') - prompt: The prompt sent to LLM - response: The LLM response - tokens_in: Input tokens used - tokens_out: Output tokens generated - cache_used: Whether cache was used - retries: Number of retries needed - concepts_extracted: Number of concepts/nodes extracted - success: Whether extraction succeeded - error: Error message if failed - """ + """Log an LLM extraction detail (Level 3).""" entry = LogEntry( level=LogLevel.DETAIL, phase=self._current_phase or "extraction", @@ -436,15 +538,7 @@ def detail_node_created( source_file: str, properties: dict[str, Any] | None = None, ) -> None: - """ - Log a node creation detail (Level 3). - - Args: - node_id: ID of the created node - node_type: Type of node (Repository, Directory, File, BusinessConcept) - source_file: Source file path - properties: Optional node properties - """ + """Log a node creation detail (Level 3).""" entry = LogEntry( level=LogLevel.DETAIL, phase=self._current_phase or "extraction", @@ -463,15 +557,7 @@ def detail_node_created( def detail_edge_created( self, edge_id: str, relationship_type: str, from_node: str, to_node: str ) -> None: - """ - Log an edge creation detail (Level 3). - - Args: - edge_id: ID of the created edge - relationship_type: Type of relationship - from_node: Source node ID - to_node: Target node ID - """ + """Log an edge creation detail (Level 3).""" entry = LogEntry( level=LogLevel.DETAIL, phase=self._current_phase or "extraction", @@ -495,21 +581,7 @@ def detail_node_deactivated( algorithm: str | None = None, properties: dict[str, Any] | None = None, ) -> None: - """ - Log a node deactivation detail (Level 3). - - Used during graph preparation phase when nodes are marked as inactive - (active=False) rather than deleted. The node remains in Neo4j but won't - be queried for further derivation. - - Args: - node_id: ID of the deactivated node - node_type: Type of node (e.g., 'Directory', 'File', 'BusinessConcept') - reason: Human-readable reason for deactivation - algorithm: Optional algorithm that triggered deactivation - (e.g., 'k-core', 'articulation_points', 'scc') - properties: Optional additional metadata about the deactivation - """ + """Log a node deactivation detail (Level 3).""" stats = { "node_id": node_id, "node_type": node_type, @@ -541,23 +613,7 @@ def detail_edge_deactivated( algorithm: str | None = None, properties: dict[str, Any] | None = None, ) -> None: - """ - Log an edge deactivation detail (Level 3). - - Used during graph preparation phase when edges are marked as inactive - (active=False) rather than deleted. The edge remains in Neo4j but won't - be queried for further derivation. - - Args: - edge_id: ID of the deactivated edge - relationship_type: Type of relationship - from_node: Source node ID - to_node: Target node ID - reason: Human-readable reason for deactivation - algorithm: Optional algorithm that triggered deactivation - (e.g., 'cycle_detection', 'redundant_edges') - properties: Optional additional metadata about the deactivation - """ + """Log an edge deactivation detail (Level 3).""" stats = { "edge_id": edge_id, "relationship_type": relationship_type, @@ -590,17 +646,7 @@ def detail_element_created( confidence: float | None = None, properties: dict[str, Any] | None = None, ) -> None: - """ - Log an ArchiMate element creation detail (Level 3). - - Args: - element_id: ID of the created element - element_type: ArchiMate element type (e.g., 'ApplicationComponent') - name: Human-readable element name - source_node: Optional reference to source graph node - confidence: Optional confidence score from LLM derivation - properties: Optional additional element properties - """ + """Log an ArchiMate element creation detail (Level 3).""" stats: dict[str, Any] = { "element_id": element_id, "element_type": element_type, @@ -632,17 +678,7 @@ def detail_relationship_created( confidence: float | None = None, properties: dict[str, Any] | None = None, ) -> None: - """ - Log an ArchiMate relationship creation detail (Level 3). - - Args: - relationship_id: ID of the created relationship - relationship_type: ArchiMate relationship type (e.g., 'Composition') - source_element: Source element ID - target_element: Target element ID - confidence: Optional confidence score from LLM derivation - properties: Optional additional relationship properties - """ + """Log an ArchiMate relationship creation detail (Level 3).""" stats: dict[str, Any] = { "relationship_id": relationship_id, "relationship_type": relationship_type, @@ -797,7 +833,6 @@ def read_run_logs( Returns: List of log entries, or empty list if no logs found """ - # Resolve logs_dir relative to project root logs_path = Path(logs_dir) if not logs_path.is_absolute(): project_root = Path(__file__).parent.parent.parent @@ -827,7 +862,7 @@ def read_run_logs( # ============================================================================= -# Standard Logging Bridge +# Standard Logging Bridge (integrates Python logging with structlog) # ============================================================================= @@ -837,14 +872,6 @@ class RunLoggerHandler(logging.Handler): This bridges the standard logging module with the structured RunLogger, allowing warnings/errors from adapters to appear in pipeline logs. - - Usage: - logger = RunLogger(run_id=1) - handler = RunLoggerHandler(logger) - logging.getLogger().addHandler(handler) - - # Or use the convenience function: - setup_logging_bridge(logger) """ def __init__(self, run_logger: RunLogger, min_level: int = logging.WARNING): @@ -909,14 +936,6 @@ def setup_logging_bridge( Returns: The handler (for later removal if needed) - - Example: - with PipelineSession() as session: - logger = get_logger_for_active_run(session._engine) - if logger: - handler = setup_logging_bridge(logger) - # ... run pipeline ... - # Warnings from adapters now appear in run logs """ handler = RunLoggerHandler(run_logger, min_level) handler.setFormatter(logging.Formatter("%(name)s: %(message)s")) diff --git a/deriva/modules/derivation/application_component.py b/deriva/modules/derivation/application_component.py index b5a5361..3d18e5e 100644 --- a/deriva/modules/derivation/application_component.py +++ b/deriva/modules/derivation/application_component.py @@ -138,78 +138,3 @@ def filter_candidates( ) return combined - - -# ============================================================================= -# Backward Compatibility - Module-level exports -# ============================================================================= - -# Create singleton instance for module-level function calls -_instance = ApplicationComponentDerivation() - -# Export module-level constants (for services/derivation.py compatibility) -ELEMENT_TYPE = _instance.ELEMENT_TYPE -OUTBOUND_RULES = _instance.OUTBOUND_RULES -INBOUND_RULES = _instance.INBOUND_RULES - - -def filter_candidates( - candidates: list[Candidate], - enrichments: dict[str, dict[str, Any]], - max_candidates: int, -) -> list[Candidate]: - """ - Backward-compatible filter_candidates function. - - Delegates to ApplicationComponentDerivation.filter_candidates(). - """ - return _instance.filter_candidates(candidates, enrichments, max_candidates) - - -def generate( - graph_manager, - archimate_manager, - engine, - llm_query_fn, - query, - instruction, - example, - max_candidates, - batch_size, - existing_elements, - temperature=None, - max_tokens=None, - defer_relationships=False, -): - """ - Backward-compatible generate function. - - Delegates to ApplicationComponentDerivation.generate(). - """ - return _instance.generate( - graph_manager=graph_manager, - archimate_manager=archimate_manager, - engine=engine, - llm_query_fn=llm_query_fn, - query=query, - instruction=instruction, - example=example, - max_candidates=max_candidates, - batch_size=batch_size, - existing_elements=existing_elements, - temperature=temperature, - max_tokens=max_tokens, - defer_relationships=defer_relationships, - ) - - -__all__ = [ - # Backward-compatible exports - "ELEMENT_TYPE", - "OUTBOUND_RULES", - "INBOUND_RULES", - "filter_candidates", - "generate", - # New class export - "ApplicationComponentDerivation", -] diff --git a/deriva/modules/derivation/application_interface.py b/deriva/modules/derivation/application_interface.py index d1b2aee..ff23abd 100644 --- a/deriva/modules/derivation/application_interface.py +++ b/deriva/modules/derivation/application_interface.py @@ -106,12 +106,12 @@ def filter_candidates( likely_interfaces = [ c for c in filtered - if self._is_likely_interface(c.name, include_patterns, exclude_patterns) + if self.matches_patterns(c.name, include_patterns, exclude_patterns) ] others = [ c for c in filtered - if not self._is_likely_interface(c.name, include_patterns, exclude_patterns) + if not self.matches_patterns(c.name, include_patterns, exclude_patterns) ] likely_interfaces = filter_by_pagerank( @@ -131,95 +131,3 @@ def filter_candidates( ) return likely_interfaces[:max_candidates] - - def _is_likely_interface( - self, name: str, include_patterns: set[str], exclude_patterns: set[str] - ) -> bool: - """Check if a method name suggests an application interface.""" - if not name: - return False - - name_lower = name.lower() - - # Check exclusion patterns first - for pattern in exclude_patterns: - if pattern in name_lower: - return False - - # Check for interface patterns - for pattern in include_patterns: - if pattern in name_lower: - return True - - return False - - -# ============================================================================= -# Backward Compatibility - Module-level exports -# ============================================================================= - -_instance = ApplicationInterfaceDerivation() - -ELEMENT_TYPE = _instance.ELEMENT_TYPE -OUTBOUND_RULES = _instance.OUTBOUND_RULES -INBOUND_RULES = _instance.INBOUND_RULES - - -def filter_candidates( - candidates: list[Candidate], - enrichments: dict[str, dict[str, Any]], - include_patterns: set[str], - exclude_patterns: set[str], - max_candidates: int, -) -> list[Candidate]: - """Backward-compatible filter_candidates function.""" - return _instance.filter_candidates( - candidates, - enrichments, - max_candidates, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - -def generate( - graph_manager, - archimate_manager, - engine, - llm_query_fn, - query, - instruction, - example, - max_candidates, - batch_size, - existing_elements, - temperature=None, - max_tokens=None, - defer_relationships=False, -): - """Backward-compatible generate function.""" - return _instance.generate( - graph_manager=graph_manager, - archimate_manager=archimate_manager, - engine=engine, - llm_query_fn=llm_query_fn, - query=query, - instruction=instruction, - example=example, - max_candidates=max_candidates, - batch_size=batch_size, - existing_elements=existing_elements, - temperature=temperature, - max_tokens=max_tokens, - defer_relationships=defer_relationships, - ) - - -__all__ = [ - "ELEMENT_TYPE", - "OUTBOUND_RULES", - "INBOUND_RULES", - "filter_candidates", - "generate", - "ApplicationInterfaceDerivation", -] diff --git a/deriva/modules/derivation/application_service.py b/deriva/modules/derivation/application_service.py index c844fe2..8607829 100644 --- a/deriva/modules/derivation/application_service.py +++ b/deriva/modules/derivation/application_service.py @@ -98,12 +98,12 @@ def filter_candidates( likely_services = [ c for c in filtered - if self._is_likely_service(c.name, include_patterns, exclude_patterns) + if self.matches_patterns(c.name, include_patterns, exclude_patterns) ] others = [ c for c in filtered - if not self._is_likely_service(c.name, include_patterns, exclude_patterns) + if not self.matches_patterns(c.name, include_patterns, exclude_patterns) ] # Sort by PageRank @@ -121,95 +121,3 @@ def filter_candidates( ) return likely_services[:max_candidates] - - def _is_likely_service( - self, name: str, include_patterns: set[str], exclude_patterns: set[str] - ) -> bool: - """Check if a method name suggests an application service.""" - if not name: - return False - - name_lower = name.lower() - - # Check exclusion patterns first - for pattern in exclude_patterns: - if name_lower.startswith(pattern) or pattern in name_lower: - return False - - # Check for service patterns - for pattern in include_patterns: - if pattern in name_lower: - return True - - return False - - -# ============================================================================= -# Backward Compatibility - Module-level exports -# ============================================================================= - -_instance = ApplicationServiceDerivation() - -ELEMENT_TYPE = _instance.ELEMENT_TYPE -OUTBOUND_RULES = _instance.OUTBOUND_RULES -INBOUND_RULES = _instance.INBOUND_RULES - - -def filter_candidates( - candidates: list[Candidate], - enrichments: dict[str, dict[str, Any]], - include_patterns: set[str], - exclude_patterns: set[str], - max_candidates: int, -) -> list[Candidate]: - """Backward-compatible filter_candidates function.""" - return _instance.filter_candidates( - candidates, - enrichments, - max_candidates, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - -def generate( - graph_manager, - archimate_manager, - engine, - llm_query_fn, - query, - instruction, - example, - max_candidates, - batch_size, - existing_elements, - temperature=None, - max_tokens=None, - defer_relationships=False, -): - """Backward-compatible generate function.""" - return _instance.generate( - graph_manager=graph_manager, - archimate_manager=archimate_manager, - engine=engine, - llm_query_fn=llm_query_fn, - query=query, - instruction=instruction, - example=example, - max_candidates=max_candidates, - batch_size=batch_size, - existing_elements=existing_elements, - temperature=temperature, - max_tokens=max_tokens, - defer_relationships=defer_relationships, - ) - - -__all__ = [ - "ELEMENT_TYPE", - "OUTBOUND_RULES", - "INBOUND_RULES", - "filter_candidates", - "generate", - "ApplicationServiceDerivation", -] diff --git a/deriva/modules/derivation/business_actor.py b/deriva/modules/derivation/business_actor.py index 4c91bfd..f27a65c 100644 --- a/deriva/modules/derivation/business_actor.py +++ b/deriva/modules/derivation/business_actor.py @@ -104,12 +104,12 @@ def filter_candidates( likely_actors = [ c for c in filtered - if self._is_likely_actor(c.name, include_patterns, exclude_patterns) + if self.matches_patterns(c.name, include_patterns, exclude_patterns) ] others = [ c for c in filtered - if not self._is_likely_actor(c.name, include_patterns, exclude_patterns) + if not self.matches_patterns(c.name, include_patterns, exclude_patterns) ] likely_actors = filter_by_pagerank( @@ -131,95 +131,3 @@ def filter_candidates( ) return likely_actors[:max_candidates] - - def _is_likely_actor( - self, name: str, include_patterns: set[str], exclude_patterns: set[str] - ) -> bool: - """Check if a type name suggests a business actor.""" - if not name: - return False - - name_lower = name.lower() - - # Check exclusion patterns first - for pattern in exclude_patterns: - if pattern in name_lower: - return False - - # Check for actor patterns - for pattern in include_patterns: - if pattern in name_lower: - return True - - return False - - -# ============================================================================= -# Backward Compatibility - Module-level exports -# ============================================================================= - -_instance = BusinessActorDerivation() - -ELEMENT_TYPE = _instance.ELEMENT_TYPE -OUTBOUND_RULES = _instance.OUTBOUND_RULES -INBOUND_RULES = _instance.INBOUND_RULES - - -def filter_candidates( - candidates: list[Candidate], - enrichments: dict[str, dict[str, Any]], - include_patterns: set[str], - exclude_patterns: set[str], - max_candidates: int, -) -> list[Candidate]: - """Backward-compatible filter_candidates function.""" - return _instance.filter_candidates( - candidates, - enrichments, - max_candidates, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - -def generate( - graph_manager, - archimate_manager, - engine, - llm_query_fn, - query, - instruction, - example, - max_candidates, - batch_size, - existing_elements, - temperature=None, - max_tokens=None, - defer_relationships=False, -): - """Backward-compatible generate function.""" - return _instance.generate( - graph_manager=graph_manager, - archimate_manager=archimate_manager, - engine=engine, - llm_query_fn=llm_query_fn, - query=query, - instruction=instruction, - example=example, - max_candidates=max_candidates, - batch_size=batch_size, - existing_elements=existing_elements, - temperature=temperature, - max_tokens=max_tokens, - defer_relationships=defer_relationships, - ) - - -__all__ = [ - "ELEMENT_TYPE", - "OUTBOUND_RULES", - "INBOUND_RULES", - "filter_candidates", - "generate", - "BusinessActorDerivation", -] diff --git a/deriva/modules/derivation/business_event.py b/deriva/modules/derivation/business_event.py index b7cc581..e84a879 100644 --- a/deriva/modules/derivation/business_event.py +++ b/deriva/modules/derivation/business_event.py @@ -101,12 +101,12 @@ def filter_candidates( likely_events = [ c for c in filtered - if self._is_likely_event(c.name, include_patterns, exclude_patterns) + if self.matches_patterns(c.name, include_patterns, exclude_patterns) ] others = [ c for c in filtered - if not self._is_likely_event(c.name, include_patterns, exclude_patterns) + if not self.matches_patterns(c.name, include_patterns, exclude_patterns) ] likely_events = filter_by_pagerank(likely_events, top_n=max_candidates // 2) @@ -124,95 +124,3 @@ def filter_candidates( ) return likely_events[:max_candidates] - - def _is_likely_event( - self, name: str, include_patterns: set[str], exclude_patterns: set[str] - ) -> bool: - """Check if a name suggests a business event.""" - if not name: - return False - - name_lower = name.lower() - - # Check exclusion patterns first - for pattern in exclude_patterns: - if pattern in name_lower: - return False - - # Check for event patterns - for pattern in include_patterns: - if pattern in name_lower: - return True - - return False - - -# ============================================================================= -# Backward Compatibility - Module-level exports -# ============================================================================= - -_instance = BusinessEventDerivation() - -ELEMENT_TYPE = _instance.ELEMENT_TYPE -OUTBOUND_RULES = _instance.OUTBOUND_RULES -INBOUND_RULES = _instance.INBOUND_RULES - - -def filter_candidates( - candidates: list[Candidate], - enrichments: dict[str, dict[str, Any]], - include_patterns: set[str], - exclude_patterns: set[str], - max_candidates: int, -) -> list[Candidate]: - """Backward-compatible filter_candidates function.""" - return _instance.filter_candidates( - candidates, - enrichments, - max_candidates, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - -def generate( - graph_manager, - archimate_manager, - engine, - llm_query_fn, - query, - instruction, - example, - max_candidates, - batch_size, - existing_elements, - temperature=None, - max_tokens=None, - defer_relationships=False, -): - """Backward-compatible generate function.""" - return _instance.generate( - graph_manager=graph_manager, - archimate_manager=archimate_manager, - engine=engine, - llm_query_fn=llm_query_fn, - query=query, - instruction=instruction, - example=example, - max_candidates=max_candidates, - batch_size=batch_size, - existing_elements=existing_elements, - temperature=temperature, - max_tokens=max_tokens, - defer_relationships=defer_relationships, - ) - - -__all__ = [ - "ELEMENT_TYPE", - "OUTBOUND_RULES", - "INBOUND_RULES", - "filter_candidates", - "generate", - "BusinessEventDerivation", -] diff --git a/deriva/modules/derivation/business_function.py b/deriva/modules/derivation/business_function.py index 2e792b7..d7499e9 100644 --- a/deriva/modules/derivation/business_function.py +++ b/deriva/modules/derivation/business_function.py @@ -107,12 +107,12 @@ def filter_candidates( likely_functions = [ c for c in filtered - if self._is_likely_function(c.name, include_patterns, exclude_patterns) + if self.matches_patterns(c.name, include_patterns, exclude_patterns) ] others = [ c for c in filtered - if not self._is_likely_function(c.name, include_patterns, exclude_patterns) + if not self.matches_patterns(c.name, include_patterns, exclude_patterns) ] likely_functions = filter_by_pagerank( @@ -132,95 +132,3 @@ def filter_candidates( ) return likely_functions[:max_candidates] - - def _is_likely_function( - self, name: str, include_patterns: set[str], exclude_patterns: set[str] - ) -> bool: - """Check if a module name suggests a business function.""" - if not name: - return False - - name_lower = name.lower() - - # Check exclusion patterns first - for pattern in exclude_patterns: - if pattern in name_lower: - return False - - # Check for function patterns - for pattern in include_patterns: - if pattern in name_lower: - return True - - return False - - -# ============================================================================= -# Backward Compatibility - Module-level exports -# ============================================================================= - -_instance = BusinessFunctionDerivation() - -ELEMENT_TYPE = _instance.ELEMENT_TYPE -OUTBOUND_RULES = _instance.OUTBOUND_RULES -INBOUND_RULES = _instance.INBOUND_RULES - - -def filter_candidates( - candidates: list[Candidate], - enrichments: dict[str, dict[str, Any]], - include_patterns: set[str], - exclude_patterns: set[str], - max_candidates: int, -) -> list[Candidate]: - """Backward-compatible filter_candidates function.""" - return _instance.filter_candidates( - candidates, - enrichments, - max_candidates, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - -def generate( - graph_manager, - archimate_manager, - engine, - llm_query_fn, - query, - instruction, - example, - max_candidates, - batch_size, - existing_elements, - temperature=None, - max_tokens=None, - defer_relationships=False, -): - """Backward-compatible generate function.""" - return _instance.generate( - graph_manager=graph_manager, - archimate_manager=archimate_manager, - engine=engine, - llm_query_fn=llm_query_fn, - query=query, - instruction=instruction, - example=example, - max_candidates=max_candidates, - batch_size=batch_size, - existing_elements=existing_elements, - temperature=temperature, - max_tokens=max_tokens, - defer_relationships=defer_relationships, - ) - - -__all__ = [ - "ELEMENT_TYPE", - "OUTBOUND_RULES", - "INBOUND_RULES", - "filter_candidates", - "generate", - "BusinessFunctionDerivation", -] diff --git a/deriva/modules/derivation/business_object.py b/deriva/modules/derivation/business_object.py index ea39cd7..f6464e9 100644 --- a/deriva/modules/derivation/business_object.py +++ b/deriva/modules/derivation/business_object.py @@ -170,74 +170,3 @@ def _is_likely_business_object( # Default: include if it looks like a noun (starts with capital, no underscores) return name[0].isupper() and "_" not in name - - -# ============================================================================= -# Backward Compatibility - Module-level exports -# ============================================================================= - -_instance = BusinessObjectDerivation() - -ELEMENT_TYPE = _instance.ELEMENT_TYPE -OUTBOUND_RULES = _instance.OUTBOUND_RULES -INBOUND_RULES = _instance.INBOUND_RULES - - -def filter_candidates( - candidates: list[Candidate], - enrichments: dict[str, dict[str, Any]], - include_patterns: set[str], - exclude_patterns: set[str], - max_candidates: int, -) -> list[Candidate]: - """Backward-compatible filter_candidates function.""" - return _instance.filter_candidates( - candidates, - enrichments, - max_candidates, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - -def generate( - graph_manager, - archimate_manager, - engine, - llm_query_fn, - query, - instruction, - example, - max_candidates, - batch_size, - existing_elements, - temperature=None, - max_tokens=None, - defer_relationships=False, -): - """Backward-compatible generate function.""" - return _instance.generate( - graph_manager=graph_manager, - archimate_manager=archimate_manager, - engine=engine, - llm_query_fn=llm_query_fn, - query=query, - instruction=instruction, - example=example, - max_candidates=max_candidates, - batch_size=batch_size, - existing_elements=existing_elements, - temperature=temperature, - max_tokens=max_tokens, - defer_relationships=defer_relationships, - ) - - -__all__ = [ - "ELEMENT_TYPE", - "OUTBOUND_RULES", - "INBOUND_RULES", - "filter_candidates", - "generate", - "BusinessObjectDerivation", -] diff --git a/deriva/modules/derivation/business_process.py b/deriva/modules/derivation/business_process.py index 2dade24..aaff5a1 100644 --- a/deriva/modules/derivation/business_process.py +++ b/deriva/modules/derivation/business_process.py @@ -91,12 +91,12 @@ def filter_candidates( likely_processes = [ c for c in filtered - if self._is_likely_process(c.name, include_patterns, exclude_patterns) + if self.matches_patterns(c.name, include_patterns, exclude_patterns) ] others = [ c for c in filtered - if not self._is_likely_process(c.name, include_patterns, exclude_patterns) + if not self.matches_patterns(c.name, include_patterns, exclude_patterns) ] likely_processes = filter_by_pagerank( @@ -115,93 +115,3 @@ def filter_candidates( ) return likely_processes[:max_candidates] - - def _is_likely_process( - self, name: str, include_patterns: set[str], exclude_patterns: set[str] - ) -> bool: - """Check if a method name suggests a business process.""" - if not name: - return False - - name_lower = name.lower() - - for pattern in exclude_patterns: - if pattern in name_lower: - return False - - for pattern in include_patterns: - if pattern in name_lower: - return True - - return False - - -# ============================================================================= -# Backward Compatibility - Module-level exports -# ============================================================================= - -_instance = BusinessProcessDerivation() - -ELEMENT_TYPE = _instance.ELEMENT_TYPE -OUTBOUND_RULES = _instance.OUTBOUND_RULES -INBOUND_RULES = _instance.INBOUND_RULES - - -def filter_candidates( - candidates: list[Candidate], - enrichments: dict[str, dict[str, Any]], - include_patterns: set[str], - exclude_patterns: set[str], - max_candidates: int, -) -> list[Candidate]: - """Backward-compatible filter_candidates function.""" - return _instance.filter_candidates( - candidates, - enrichments, - max_candidates, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - -def generate( - graph_manager, - archimate_manager, - engine, - llm_query_fn, - query, - instruction, - example, - max_candidates, - batch_size, - existing_elements, - temperature=None, - max_tokens=None, - defer_relationships=False, -): - """Backward-compatible generate function.""" - return _instance.generate( - graph_manager=graph_manager, - archimate_manager=archimate_manager, - engine=engine, - llm_query_fn=llm_query_fn, - query=query, - instruction=instruction, - example=example, - max_candidates=max_candidates, - batch_size=batch_size, - existing_elements=existing_elements, - temperature=temperature, - max_tokens=max_tokens, - defer_relationships=defer_relationships, - ) - - -__all__ = [ - "ELEMENT_TYPE", - "OUTBOUND_RULES", - "INBOUND_RULES", - "filter_candidates", - "generate", - "BusinessProcessDerivation", -] diff --git a/deriva/modules/derivation/data_object.py b/deriva/modules/derivation/data_object.py index a2e8559..ab353f6 100644 --- a/deriva/modules/derivation/data_object.py +++ b/deriva/modules/derivation/data_object.py @@ -115,14 +115,12 @@ def filter_candidates( likely_data = [ c for c in filtered - if self._is_likely_data_object(c.name, include_patterns, exclude_patterns) + if self.matches_patterns(c.name, include_patterns, exclude_patterns) ] others = [ c for c in filtered - if not self._is_likely_data_object( - c.name, include_patterns, exclude_patterns - ) + if not self.matches_patterns(c.name, include_patterns, exclude_patterns) ] likely_data = filter_by_pagerank( @@ -144,95 +142,3 @@ def filter_candidates( ) return likely_data[:max_candidates] - - def _is_likely_data_object( - self, name: str, include_patterns: set[str], exclude_patterns: set[str] - ) -> bool: - """Check if a file name suggests a data object.""" - if not name: - return False - - name_lower = name.lower() - - # Check exclusion patterns first - for pattern in exclude_patterns: - if pattern in name_lower: - return False - - # Check for data file patterns - for pattern in include_patterns: - if pattern in name_lower: - return True - - return False - - -# ============================================================================= -# Backward Compatibility - Module-level exports -# ============================================================================= - -_instance = DataObjectDerivation() - -ELEMENT_TYPE = _instance.ELEMENT_TYPE -OUTBOUND_RULES = _instance.OUTBOUND_RULES -INBOUND_RULES = _instance.INBOUND_RULES - - -def filter_candidates( - candidates: list[Candidate], - enrichments: dict[str, dict[str, Any]], - include_patterns: set[str], - exclude_patterns: set[str], - max_candidates: int, -) -> list[Candidate]: - """Backward-compatible filter_candidates function.""" - return _instance.filter_candidates( - candidates, - enrichments, - max_candidates, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - -def generate( - graph_manager, - archimate_manager, - engine, - llm_query_fn, - query, - instruction, - example, - max_candidates, - batch_size, - existing_elements, - temperature=None, - max_tokens=None, - defer_relationships=False, -): - """Backward-compatible generate function.""" - return _instance.generate( - graph_manager=graph_manager, - archimate_manager=archimate_manager, - engine=engine, - llm_query_fn=llm_query_fn, - query=query, - instruction=instruction, - example=example, - max_candidates=max_candidates, - batch_size=batch_size, - existing_elements=existing_elements, - temperature=temperature, - max_tokens=max_tokens, - defer_relationships=defer_relationships, - ) - - -__all__ = [ - "ELEMENT_TYPE", - "OUTBOUND_RULES", - "INBOUND_RULES", - "filter_candidates", - "generate", - "DataObjectDerivation", -] diff --git a/deriva/modules/derivation/device.py b/deriva/modules/derivation/device.py index 0cdbd1b..5a98742 100644 --- a/deriva/modules/derivation/device.py +++ b/deriva/modules/derivation/device.py @@ -144,74 +144,3 @@ def _is_likely_device( return True return False - - -# ============================================================================= -# Backward Compatibility - Module-level exports -# ============================================================================= - -_instance = DeviceDerivation() - -ELEMENT_TYPE = _instance.ELEMENT_TYPE -OUTBOUND_RULES = _instance.OUTBOUND_RULES -INBOUND_RULES = _instance.INBOUND_RULES - - -def filter_candidates( - candidates: list[Candidate], - enrichments: dict[str, dict[str, Any]], - include_patterns: set[str], - exclude_patterns: set[str], - max_candidates: int, -) -> list[Candidate]: - """Backward-compatible filter_candidates function.""" - return _instance.filter_candidates( - candidates, - enrichments, - max_candidates, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - -def generate( - graph_manager, - archimate_manager, - engine, - llm_query_fn, - query, - instruction, - example, - max_candidates, - batch_size, - existing_elements, - temperature=None, - max_tokens=None, - defer_relationships=False, -): - """Backward-compatible generate function.""" - return _instance.generate( - graph_manager=graph_manager, - archimate_manager=archimate_manager, - engine=engine, - llm_query_fn=llm_query_fn, - query=query, - instruction=instruction, - example=example, - max_candidates=max_candidates, - batch_size=batch_size, - existing_elements=existing_elements, - temperature=temperature, - max_tokens=max_tokens, - defer_relationships=defer_relationships, - ) - - -__all__ = [ - "ELEMENT_TYPE", - "OUTBOUND_RULES", - "INBOUND_RULES", - "filter_candidates", - "generate", - "DeviceDerivation", -] diff --git a/deriva/modules/derivation/element_base.py b/deriva/modules/derivation/element_base.py index b8e92cf..cf66060 100644 --- a/deriva/modules/derivation/element_base.py +++ b/deriva/modules/derivation/element_base.py @@ -429,8 +429,49 @@ class PatternBasedDerivation(ElementDerivationBase): Modules using this base class receive include_patterns and exclude_patterns as kwargs to their filter_candidates() method. + + Subclasses can override PATTERN_MATCH_DEFAULT to control the default + return value when no patterns match (default is False). """ + # Override in subclass to change default behavior when no patterns match + PATTERN_MATCH_DEFAULT: bool = False + + def matches_patterns( + self, name: str, include_patterns: set[str], exclude_patterns: set[str] + ) -> bool: + """ + Check if name matches include patterns and not exclude patterns. + + This is a common utility method that consolidates the pattern matching + logic previously duplicated across all element modules. + + Args: + name: The name to check + include_patterns: Patterns that indicate a match + exclude_patterns: Patterns that indicate exclusion + + Returns: + True if name matches include patterns and not exclude patterns, + otherwise returns PATTERN_MATCH_DEFAULT + """ + if not name: + return False + + name_lower = name.lower() + + # Check exclusion patterns first + for pattern in exclude_patterns: + if pattern in name_lower: + return False + + # Check for include patterns + for pattern in include_patterns: + if pattern in name_lower: + return True + + return self.PATTERN_MATCH_DEFAULT + def get_filter_kwargs(self, engine: Any) -> dict[str, Any]: """ Load include/exclude patterns from config database. diff --git a/deriva/modules/derivation/node.py b/deriva/modules/derivation/node.py index 98385dc..59bfb92 100644 --- a/deriva/modules/derivation/node.py +++ b/deriva/modules/derivation/node.py @@ -149,74 +149,3 @@ def _is_likely_node( return True return False - - -# ============================================================================= -# Backward Compatibility - Module-level exports -# ============================================================================= - -_instance = NodeDerivation() - -ELEMENT_TYPE = _instance.ELEMENT_TYPE -OUTBOUND_RULES = _instance.OUTBOUND_RULES -INBOUND_RULES = _instance.INBOUND_RULES - - -def filter_candidates( - candidates: list[Candidate], - enrichments: dict[str, dict[str, Any]], - include_patterns: set[str], - exclude_patterns: set[str], - max_candidates: int, -) -> list[Candidate]: - """Backward-compatible filter_candidates function.""" - return _instance.filter_candidates( - candidates, - enrichments, - max_candidates, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - -def generate( - graph_manager, - archimate_manager, - engine, - llm_query_fn, - query, - instruction, - example, - max_candidates, - batch_size, - existing_elements, - temperature=None, - max_tokens=None, - defer_relationships=False, -): - """Backward-compatible generate function.""" - return _instance.generate( - graph_manager=graph_manager, - archimate_manager=archimate_manager, - engine=engine, - llm_query_fn=llm_query_fn, - query=query, - instruction=instruction, - example=example, - max_candidates=max_candidates, - batch_size=batch_size, - existing_elements=existing_elements, - temperature=temperature, - max_tokens=max_tokens, - defer_relationships=defer_relationships, - ) - - -__all__ = [ - "ELEMENT_TYPE", - "OUTBOUND_RULES", - "INBOUND_RULES", - "filter_candidates", - "generate", - "NodeDerivation", -] diff --git a/deriva/modules/derivation/system_software.py b/deriva/modules/derivation/system_software.py index 894bdf7..e5a3d88 100644 --- a/deriva/modules/derivation/system_software.py +++ b/deriva/modules/derivation/system_software.py @@ -111,16 +111,12 @@ def filter_candidates( likely_system = [ c for c in filtered - if self._is_likely_system_software( - c.name, include_patterns, exclude_patterns - ) + if self.matches_patterns(c.name, include_patterns, exclude_patterns) ] others = [ c for c in filtered - if not self._is_likely_system_software( - c.name, include_patterns, exclude_patterns - ) + if not self.matches_patterns(c.name, include_patterns, exclude_patterns) ] likely_system = filter_by_pagerank(likely_system, top_n=max_candidates // 2) @@ -138,95 +134,3 @@ def filter_candidates( ) return likely_system[:max_candidates] - - def _is_likely_system_software( - self, name: str, include_patterns: set[str], exclude_patterns: set[str] - ) -> bool: - """Check if a name suggests system software.""" - if not name: - return False - - name_lower = name.lower() - - # Check exclusion patterns first - for pattern in exclude_patterns: - if pattern in name_lower: - return False - - # Check for system software patterns - for pattern in include_patterns: - if pattern in name_lower: - return True - - return False - - -# ============================================================================= -# Backward Compatibility - Module-level exports -# ============================================================================= - -_instance = SystemSoftwareDerivation() - -ELEMENT_TYPE = _instance.ELEMENT_TYPE -OUTBOUND_RULES = _instance.OUTBOUND_RULES -INBOUND_RULES = _instance.INBOUND_RULES - - -def filter_candidates( - candidates: list[Candidate], - enrichments: dict[str, dict[str, Any]], - include_patterns: set[str], - exclude_patterns: set[str], - max_candidates: int, -) -> list[Candidate]: - """Backward-compatible filter_candidates function.""" - return _instance.filter_candidates( - candidates, - enrichments, - max_candidates, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - -def generate( - graph_manager, - archimate_manager, - engine, - llm_query_fn, - query, - instruction, - example, - max_candidates, - batch_size, - existing_elements, - temperature=None, - max_tokens=None, - defer_relationships=False, -): - """Backward-compatible generate function.""" - return _instance.generate( - graph_manager=graph_manager, - archimate_manager=archimate_manager, - engine=engine, - llm_query_fn=llm_query_fn, - query=query, - instruction=instruction, - example=example, - max_candidates=max_candidates, - batch_size=batch_size, - existing_elements=existing_elements, - temperature=temperature, - max_tokens=max_tokens, - defer_relationships=defer_relationships, - ) - - -__all__ = [ - "ELEMENT_TYPE", - "OUTBOUND_RULES", - "INBOUND_RULES", - "filter_candidates", - "generate", - "SystemSoftwareDerivation", -] diff --git a/deriva/modules/derivation/technology_service.py b/deriva/modules/derivation/technology_service.py index c93b1c0..de368f8 100644 --- a/deriva/modules/derivation/technology_service.py +++ b/deriva/modules/derivation/technology_service.py @@ -98,17 +98,6 @@ def filter_candidates( 2. Identify likely tech services using patterns 3. Prioritize by PageRank 4. Fill remaining slots from non-matches - - Args: - candidates: Raw candidates from graph query - enrichments: Graph enrichment data - max_candidates: Maximum to return - include_patterns: Patterns indicating likely tech services - exclude_patterns: Patterns to exclude - **kwargs: Additional unused kwargs - - Returns: - Filtered list of candidates """ include_patterns = include_patterns or set() exclude_patterns = exclude_patterns or set() @@ -124,14 +113,12 @@ def filter_candidates( likely_tech = [ c for c in filtered - if self._is_likely_tech_service(c.name, include_patterns, exclude_patterns) + if self.matches_patterns(c.name, include_patterns, exclude_patterns) ] others = [ c for c in filtered - if not self._is_likely_tech_service( - c.name, include_patterns, exclude_patterns - ) + if not self.matches_patterns(c.name, include_patterns, exclude_patterns) ] # Prioritize likely matches by PageRank @@ -154,117 +141,3 @@ def filter_candidates( ) return likely_tech[:max_candidates] - - def _is_likely_tech_service( - self, name: str, include_patterns: set[str], exclude_patterns: set[str] - ) -> bool: - """ - Check if a dependency suggests a technology service. - - Args: - name: Dependency name - include_patterns: Patterns that indicate a tech service - exclude_patterns: Patterns to exclude - - Returns: - True if name matches include patterns and not exclude patterns - """ - if not name: - return False - - name_lower = name.lower() - - # Check exclusions first - for pattern in exclude_patterns: - if pattern in name_lower: - return False - - # Check inclusions - for pattern in include_patterns: - if pattern in name_lower: - return True - - return False - - -# ============================================================================= -# Backward Compatibility - Module-level exports -# ============================================================================= - -# Create singleton instance for module-level function calls -_instance = TechnologyServiceDerivation() - -# Export module-level constants (for services/derivation.py compatibility) -ELEMENT_TYPE = _instance.ELEMENT_TYPE -OUTBOUND_RULES = _instance.OUTBOUND_RULES -INBOUND_RULES = _instance.INBOUND_RULES - - -def filter_candidates( - candidates: list[Candidate], - enrichments: dict[str, dict[str, Any]], - include_patterns: set[str], - exclude_patterns: set[str], - max_candidates: int, -) -> list[Candidate]: - """ - Backward-compatible filter_candidates function. - - Delegates to TechnologyServiceDerivation.filter_candidates(). - """ - return _instance.filter_candidates( - candidates, - enrichments, - max_candidates, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - -def generate( - graph_manager, - archimate_manager, - engine, - llm_query_fn, - query, - instruction, - example, - max_candidates, - batch_size, - existing_elements, - temperature=None, - max_tokens=None, - defer_relationships=False, -): - """ - Backward-compatible generate function. - - Delegates to TechnologyServiceDerivation.generate(). - """ - return _instance.generate( - graph_manager=graph_manager, - archimate_manager=archimate_manager, - engine=engine, - llm_query_fn=llm_query_fn, - query=query, - instruction=instruction, - example=example, - max_candidates=max_candidates, - batch_size=batch_size, - existing_elements=existing_elements, - temperature=temperature, - max_tokens=max_tokens, - defer_relationships=defer_relationships, - ) - - -__all__ = [ - # Backward-compatible exports - "ELEMENT_TYPE", - "OUTBOUND_RULES", - "INBOUND_RULES", - "filter_candidates", - "generate", - # New class export - "TechnologyServiceDerivation", -] diff --git a/deriva/services/__init__.py b/deriva/services/__init__.py index 1f8cea7..61fff07 100644 --- a/deriva/services/__init__.py +++ b/deriva/services/__init__.py @@ -37,13 +37,16 @@ from __future__ import annotations -from . import config, derivation, extraction, pipeline +from . import config, config_models, derivation, extraction, pipeline +from .config import get_settings from .session import PipelineSession __all__ = [ "PipelineSession", "config", + "config_models", "extraction", "derivation", "pipeline", + "get_settings", ] diff --git a/deriva/services/config.py b/deriva/services/config.py index bb444a1..37e93be 100644 --- a/deriva/services/config.py +++ b/deriva/services/config.py @@ -24,11 +24,39 @@ # Get derivation configs by phase generate_configs = config.get_derivation_configs(engine, phase="generate") + + # Load environment settings (pydantic-settings) + from deriva.services.config_models import DerivaSettings + settings = DerivaSettings() + print(settings.llm.temperature) """ from __future__ import annotations -from typing import Any +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from deriva.services.config_models import DerivaSettings + + +@lru_cache +def get_settings() -> DerivaSettings: + """ + Get cached application settings from environment. + + Returns: + DerivaSettings instance with all environment configuration. + + Usage: + settings = get_settings() + print(settings.llm.temperature) + print(settings.neo4j.uri) + """ + from deriva.services.config_models import DerivaSettings + + return DerivaSettings() + # ============================================================================= # Type Definitions diff --git a/deriva/services/config_models.py b/deriva/services/config_models.py new file mode 100644 index 0000000..388106c --- /dev/null +++ b/deriva/services/config_models.py @@ -0,0 +1,265 @@ +""" +Pydantic models for Deriva configuration. + +Uses pydantic-settings for environment variable validation and type coercion. +""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, Field, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + +# ============================================================================= +# Environment Settings (from .env file) +# ============================================================================= + + +class Neo4jSettings(BaseSettings): + """Neo4j connection settings.""" + + model_config = SettingsConfigDict(env_prefix="NEO4J_", env_file=".env", extra="ignore") + + uri: str = "bolt://localhost:7687" + username: str = "" + password: str = "" + database: str = "neo4j" + encrypted: bool = False + + # Connection pool settings + max_connection_lifetime: int = 3600 + max_connection_pool_size: int = 50 + connection_acquisition_timeout: int = 60 + + # Logging + log_level: str = "INFO" + log_queries: bool = False + suppress_notifications: bool = True + + # Namespaces + namespace_graph: str = "Graph" + namespace_archimate: str = "Model" + + +class LLMSettings(BaseSettings): + """LLM configuration settings.""" + + model_config = SettingsConfigDict(env_prefix="LLM_", env_file=".env", extra="ignore") + + default_model: str | None = None + temperature: float = Field(default=0.6, ge=0.0, le=2.0) + max_retries: int = Field(default=3, ge=0) + timeout: int = Field(default=60, ge=1) + max_tokens: int | None = None + + # Cache settings + cache_dir: str = "workspace/cache/llm" + cache_ttl: int = 0 + nocache: bool = False + + # Rate limiting + rate_limit_rpm: int = 0 + rate_limit_delay: float = 0.0 + rate_limit_retries: int = 3 + + # Token limits + token_limit_default: int = 32000 + + +class GraphSettings(BaseSettings): + """Graph manager settings.""" + + model_config = SettingsConfigDict(env_prefix="GRAPH_", env_file=".env", extra="ignore") + + namespace: str = "Graph" + cache_dir: str = "workspace/cache/graph" + log_level: str = "INFO" + + +class ArchimateSettings(BaseSettings): + """ArchiMate manager settings.""" + + model_config = SettingsConfigDict(env_prefix="ARCHIMATE_", env_file=".env", extra="ignore") + + namespace: str = "Model" + version: str = "3.1" + identifier_prefix: str = "id-" + + # Validation + validation_strict_mode: bool = False + validation_allow_custom_properties: bool = True + + # Export + export_pretty_print: bool = True + export_encoding: str = "UTF-8" + export_xml_declaration: bool = True + export_validate_on_export: bool = True + export_include_metadata: bool = True + + +class AppSettings(BaseSettings): + """Application-level settings.""" + + model_config = SettingsConfigDict(env_prefix="APP_", env_file=".env", extra="ignore") + + env: str = "development" + log_level: str = "INFO" + log_dir: str = "logs" + + +class DerivaSettings(BaseSettings): + """ + Master settings class that aggregates all settings. + + Usage: + settings = DerivaSettings() + print(settings.neo4j.uri) + print(settings.llm.temperature) + """ + + model_config = SettingsConfigDict(env_file=".env", extra="ignore") + + # Repository settings + repository_workspace_dir: str = Field(default="workspace/repositories", alias="REPOSITORY_WORKSPACE_DIR") + output_dir: str = Field(default="workspace/output/model.xml", alias="OUTPUT_DIR") + + # Nested settings are loaded separately + @property + def neo4j(self) -> Neo4jSettings: + return Neo4jSettings() + + @property + def llm(self) -> LLMSettings: + return LLMSettings() + + @property + def graph(self) -> GraphSettings: + return GraphSettings() + + @property + def archimate(self) -> ArchimateSettings: + return ArchimateSettings() + + @property + def app(self) -> AppSettings: + return AppSettings() + + +# ============================================================================= +# Pipeline Configuration Models (from DuckDB) +# ============================================================================= + + +class ExtractionConfigModel(BaseModel): + """Extraction step configuration with validation.""" + + node_type: str = Field(..., min_length=1, description="The node type to extract (e.g., 'BusinessConcept')") + sequence: int = Field(default=0, ge=0, description="Execution order") + enabled: bool = Field(default=True, description="Whether this extraction step is enabled") + input_sources: str | None = Field(default=None, description="JSON string of input sources") + instruction: str | None = Field(default=None, description="LLM instruction prompt") + example: str | None = Field(default=None, description="Example output for LLM") + extraction_method: Literal["llm", "ast", "structural"] = Field(default="llm", description="Extraction method to use") + temperature: float | None = Field(default=None, ge=0.0, le=2.0, description="LLM temperature override") + max_tokens: int | None = Field(default=None, ge=1, description="LLM max_tokens override") + + +class DerivationConfigModel(BaseModel): + """Unified derivation step configuration with validation.""" + + step_name: str = Field(..., min_length=1, description="The derivation step name") + phase: Literal["prep", "generate", "refine", "relationship"] = Field(..., description="Derivation phase") + sequence: int = Field(default=0, ge=0, description="Execution order within phase") + enabled: bool = Field(default=True, description="Whether this derivation step is enabled") + llm: bool = Field(default=False, description="True = uses LLM, False = pure graph algorithm") + input_graph_query: str | None = Field(default=None, description="Cypher query for graph input") + input_model_query: str | None = Field(default=None, description="Cypher query for model input") + instruction: str | None = Field(default=None, description="LLM instruction prompt") + example: str | None = Field(default=None, description="Example output for LLM") + params: str | None = Field(default=None, description="JSON parameters for graph algorithms") + temperature: float | None = Field(default=None, ge=0.0, le=2.0, description="LLM temperature override") + max_tokens: int | None = Field(default=None, ge=1, description="LLM max_tokens override") + max_candidates: int | None = Field(default=None, ge=1, description="Max candidates to send to LLM") + batch_size: int | None = Field(default=None, ge=1, description="Batch size for LLM processing") + + @property + def element_type(self) -> str: + """Backward compatibility: element_type maps to step_name.""" + return self.step_name + + +class FileTypeModel(BaseModel): + """File type registry entry with validation.""" + + extension: str = Field(..., min_length=1, description="File extension (e.g., '.py')") + file_type: str = Field(..., min_length=1, description="File type category (e.g., 'code')") + subtype: str = Field(..., min_length=1, description="File subtype (e.g., 'python')") + chunk_delimiter: str | None = Field(default=None, description="Delimiter for chunking") + chunk_max_tokens: int | None = Field(default=None, ge=1, description="Max tokens per chunk") + chunk_overlap: int = Field(default=0, ge=0, description="Token overlap between chunks") + + +# ============================================================================= +# Threshold and Limit Models +# ============================================================================= + + +class ConfidenceThresholds(BaseModel): + """Confidence threshold configuration.""" + + min_relationship: float = Field(default=0.6, ge=0.0, le=1.0, description="Minimum confidence for relationships") + community_rel: float = Field(default=0.95, ge=0.0, le=1.0, description="Confidence for community-based relationships") + name_match: float = Field(default=0.95, ge=0.0, le=1.0, description="Confidence for name-based matches") + file_match: float = Field(default=0.85, ge=0.0, le=1.0, description="Confidence for file-based matches") + fuzzy_match: float = Field(default=0.85, ge=0.0, le=1.0, description="Threshold for fuzzy string matching") + semantic: float = Field(default=0.95, ge=0.0, le=1.0, description="Confidence for semantic similarity matches") + pagerank_min: float = Field(default=0.001, ge=0.0, description="Minimum PageRank to consider") + + +class DerivationLimits(BaseModel): + """Derivation processing limits.""" + + max_relationships_per_derivation: int = Field(default=500, ge=1, description="Max relationships per derivation step") + default_batch_size: int = Field(default=10, ge=1, description="Default batch size for LLM processing") + default_max_candidates: int = Field(default=30, ge=1, description="Default max candidates for LLM derivation") + high_pagerank_non_roots: int = Field(default=10, ge=1, description="For ApplicationComponent filtering") + + +class PageRankConfig(BaseModel): + """PageRank algorithm configuration.""" + + damping: float = Field(default=0.85, ge=0.0, le=1.0, description="Damping factor") + max_iter: int = Field(default=100, ge=1, description="Maximum iterations") + tol: float = Field(default=1e-6, gt=0.0, description="Convergence tolerance") + + +class LouvainConfig(BaseModel): + """Louvain algorithm configuration.""" + + resolution: float = Field(default=1.0, gt=0.0, description="Resolution parameter") + + +# ============================================================================= +# Benchmark Model Configuration +# ============================================================================= + + +class BenchmarkModelConfigModel(BaseModel): + """Benchmark model configuration with validation.""" + + name: str = Field(..., min_length=1, description="Model identifier") + provider: Literal["azure", "openai", "anthropic", "ollama", "mistral", "lmstudio"] = Field(..., description="LLM provider") + model: str = Field(..., min_length=1, description="Model name") + api_url: str | None = Field(default=None, description="API endpoint URL") + api_key: str | None = Field(default=None, description="API key (direct)") + api_key_env: str | None = Field(default=None, description="Env var name for API key") + structured_output: bool = Field(default=False, description="Enable structured output at API level") + + @field_validator("provider", mode="before") + @classmethod + def normalize_provider(cls, v: str) -> str: + """Normalize provider to lowercase before Literal validation.""" + if isinstance(v, str): + return v.lower() + return v diff --git a/deriva/services/derivation.py b/deriva/services/derivation.py index 59bf3a8..4148788 100644 --- a/deriva/services/derivation.py +++ b/deriva/services/derivation.py @@ -47,67 +47,62 @@ from deriva.common.types import ProgressReporter, RunLoggerProtocol from deriva.adapters.graph import GraphManager from deriva.modules.derivation import prep +from deriva.modules.derivation.application_component import ApplicationComponentDerivation +from deriva.modules.derivation.application_interface import ApplicationInterfaceDerivation +from deriva.modules.derivation.application_service import ApplicationServiceDerivation from deriva.modules.derivation.base import derive_consolidated_relationships +from deriva.modules.derivation.business_actor import BusinessActorDerivation +from deriva.modules.derivation.business_event import BusinessEventDerivation +from deriva.modules.derivation.business_function import BusinessFunctionDerivation + +# Derivation class imports +from deriva.modules.derivation.business_object import BusinessObjectDerivation +from deriva.modules.derivation.business_process import BusinessProcessDerivation +from deriva.modules.derivation.data_object import DataObjectDerivation +from deriva.modules.derivation.device import DeviceDerivation +from deriva.modules.derivation.element_base import ElementDerivationBase +from deriva.modules.derivation.node import NodeDerivation from deriva.modules.derivation.refine import run_refine_step +from deriva.modules.derivation.system_software import SystemSoftwareDerivation +from deriva.modules.derivation.technology_service import TechnologyServiceDerivation from deriva.services import config -# Element generation module registry -# Maps element_type to module with generate() function -_ELEMENT_MODULES: dict[str, Any] = {} - - -def _load_element_module(element_type: str) -> Any: - """Lazily load element generation module.""" - if element_type in _ELEMENT_MODULES: - return _ELEMENT_MODULES[element_type] - - module = None - # Business Layer - if element_type == "BusinessObject": - from deriva.modules.derivation import business_object as module - elif element_type == "BusinessProcess": - from deriva.modules.derivation import business_process as module - elif element_type == "BusinessActor": - from deriva.modules.derivation import business_actor as module - elif element_type == "BusinessEvent": - from deriva.modules.derivation import business_event as module - elif element_type == "BusinessFunction": - from deriva.modules.derivation import business_function as module - # Application Layer - elif element_type == "ApplicationComponent": - from deriva.modules.derivation import application_component as module - elif element_type == "ApplicationService": - from deriva.modules.derivation import application_service as module - elif element_type == "ApplicationInterface": - from deriva.modules.derivation import application_interface as module - elif element_type == "DataObject": - from deriva.modules.derivation import data_object as module - # Technology Layer - elif element_type == "TechnologyService": - from deriva.modules.derivation import technology_service as module - elif element_type == "Node": - from deriva.modules.derivation import node as module - elif element_type == "Device": - from deriva.modules.derivation import device as module - elif element_type == "SystemSoftware": - from deriva.modules.derivation import system_software as module - - _ELEMENT_MODULES[element_type] = module - return module +# Registry: element_type -> derivation class +DERIVATION_REGISTRY: dict[str, type[ElementDerivationBase]] = { + "BusinessObject": BusinessObjectDerivation, + "BusinessProcess": BusinessProcessDerivation, + "BusinessActor": BusinessActorDerivation, + "BusinessEvent": BusinessEventDerivation, + "BusinessFunction": BusinessFunctionDerivation, + "ApplicationComponent": ApplicationComponentDerivation, + "ApplicationService": ApplicationServiceDerivation, + "ApplicationInterface": ApplicationInterfaceDerivation, + "DataObject": DataObjectDerivation, + "TechnologyService": TechnologyServiceDerivation, + "Node": NodeDerivation, + "Device": DeviceDerivation, + "SystemSoftware": SystemSoftwareDerivation, +} +# Instance cache for reuse +_DERIVATION_INSTANCES: dict[str, ElementDerivationBase] = {} -def _collect_relationship_rules() -> dict[str, tuple[list[Any], list[Any]]]: - """Collect relationship rules from all loaded element modules. - Returns: - Dict mapping element_type to (outbound_rules, inbound_rules) tuple - """ +def _get_derivation(element_type: str) -> ElementDerivationBase | None: + """Get or create a derivation instance for an element type.""" + if element_type not in DERIVATION_REGISTRY: + return None + if element_type not in _DERIVATION_INSTANCES: + _DERIVATION_INSTANCES[element_type] = DERIVATION_REGISTRY[element_type]() + return _DERIVATION_INSTANCES[element_type] + + +def _collect_relationship_rules() -> dict[str, tuple[list[Any], list[Any]]]: + """Collect relationship rules from all derivation classes.""" rules: dict[str, tuple[list[Any], list[Any]]] = {} - for element_type, module in _ELEMENT_MODULES.items(): - if module is None: - continue - outbound = getattr(module, "OUTBOUND_RULES", []) - inbound = getattr(module, "INBOUND_RULES", []) + for element_type, cls in DERIVATION_REGISTRY.items(): + outbound = cls.OUTBOUND_RULES + inbound = cls.INBOUND_RULES if outbound or inbound: rules[element_type] = (outbound, inbound) return rules @@ -157,18 +152,18 @@ def generate_element( Returns: Dict with success, elements_created, relationships_created, created_elements, errors """ - module = _load_element_module(element_type) + derivation = _get_derivation(element_type) - if module is None: + if derivation is None: return { "success": False, "elements_created": 0, "relationships_created": 0, - "errors": [f"No generation module for element type: {element_type}"], + "errors": [f"No derivation class for element type: {element_type}"], } try: - result = module.generate( + result = derivation.generate( graph_manager=graph_manager, archimate_manager=archimate_manager, engine=engine, diff --git a/pyproject.toml b/pyproject.toml index 13a5975..293fedd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,17 +35,17 @@ classifiers = [ dependencies = [ # Core "pydantic>=2.12.5", - "pydantic-ai>=0.1.0", + "pydantic-settings>=2.0", "python-dotenv>=1.2.1", "pyyaml>=6.0.3", "rich>=14.2.0", - # Database & Graph + "lxml>=6.0.2", + # Databases "duckdb>=1.4.3", "neo4j>=6.1.0", "sqlglot>=28.6.0", - # UI & XML - "marimo>=0.19.2", - "lxml>=6.0.2", + # UI + "marimo>=0.19.3", # Code Analysis (tree-sitter) "tree-sitter>=0.24.6", "tree-sitter-python>=0.23.0", @@ -57,6 +57,13 @@ dependencies = [ "python-docx>=1.2.0", # Graph Algorithms "solvor>=0.5.3", + # LLM + "backoff>=2.2.1", + "pydantic-ai>=1.42.0", + # Utilities + "diskcache>=5.6.3", + "structlog>=24.0", + "typer>=0.21.1", ] [project.optional-dependencies] @@ -64,7 +71,7 @@ dev = [ "pre-commit>=4.5.1", "pytest>=9.0.2", "pytest-cov>=7.0.0", - "ruff>=0.14.11", + "ruff>=0.14.13", "ty>=0.0.12", "types-lxml>=2025.3.30", ] @@ -178,7 +185,7 @@ omit = [ ] [tool.coverage.report] -fail_under = 79 +fail_under = 80 exclude_lines = [ "pragma: no cover", "if TYPE_CHECKING:", diff --git a/tests/conftest.py b/tests/conftest.py index c9450cd..ad1ddfa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -123,6 +123,120 @@ def _query(prompt, schema): return _query +# ============================================================================= +# TreeSitter Fixtures +# ============================================================================= + + +@pytest.fixture +def treesitter_manager(): + """Provide a TreeSitterManager instance for tests.""" + from deriva.adapters.treesitter import TreeSitterManager + + return TreeSitterManager() + + +# ============================================================================= +# Repository Mock Fixtures +# ============================================================================= + + +@pytest.fixture +def mock_repository(tmp_path): + """Factory fixture for creating mock repository objects. + + Usage: + def test_something(mock_repository): + repo = mock_repository(name="my_repo", files={"src/main.py": "def main(): pass"}) + """ + from unittest.mock import MagicMock + + def _make(name="test_repo", files=None, branch="main"): + repo = MagicMock() + repo.name = name + repo.path = str(tmp_path / name) + repo.url = f"https://example.com/{name}.git" + repo.branch = branch + + # Create the repo directory + repo_dir = tmp_path / name + repo_dir.mkdir(exist_ok=True) + + # Create any specified files + if files: + for file_path, content in files.items(): + full_path = repo_dir / file_path + full_path.parent.mkdir(parents=True, exist_ok=True) + full_path.write_text(content) + + return repo + + return _make + + +# ============================================================================= +# Session/Pipeline Mock Fixtures +# ============================================================================= + + +@pytest.fixture +def mock_session_dependencies(): + """Mock all PipelineSession external dependencies. + + Usage: + def test_something(mock_session_dependencies): + mocks = mock_session_dependencies + mocks["db"].return_value = ... + """ + from unittest.mock import patch + + with ( + patch("deriva.services.session.get_connection") as mock_db, + patch("deriva.services.session.GraphManager") as mock_graph, + patch("deriva.services.session.ArchimateManager") as mock_archimate, + patch("deriva.services.session.RepoManager") as mock_repo, + patch("deriva.services.session.Neo4jConnection") as mock_neo4j, + ): + yield { + "db": mock_db, + "graph": mock_graph, + "archimate": mock_archimate, + "repo": mock_repo, + "neo4j": mock_neo4j, + } + + +# ============================================================================= +# Assertion Helpers (registered as pytest helpers) +# ============================================================================= + + +def assert_success(result, expected=True): + """Assert operation result success/failure. + + Usage: + from tests.conftest import assert_success + assert_success(result) # Expects success + assert_success(result, expected=False) # Expects failure + """ + assert result["success"] is expected, f"Expected success={expected}, got {result.get('success')}" + if expected: + errors = result.get("errors", []) + assert not errors, f"Unexpected errors: {errors}" + + +def assert_error_contains(result, message): + """Assert result failed with specific error message. + + Usage: + from tests.conftest import assert_error_contains + assert_error_contains(result, "Missing required") + """ + assert result["success"] is False, "Expected failure, got success" + errors = result.get("errors", []) + assert any(message in err for err in errors), f"Expected '{message}' in errors: {errors}" + + # ============================================================================= # Markers # ============================================================================= diff --git a/tests/test_adapters/llm/test_cache.py b/tests/test_adapters/llm/test_cache.py index f05e433..55ed85c 100644 --- a/tests/test_adapters/llm/test_cache.py +++ b/tests/test_adapters/llm/test_cache.py @@ -7,7 +7,6 @@ import pytest from deriva.adapters.llm.cache import CacheManager -from deriva.adapters.llm.models import CacheError class TestCacheManager: @@ -50,84 +49,74 @@ def test_generate_cache_key_with_schema(self): key2 = CacheManager.generate_cache_key("test", "gpt-4", None) assert key1 != key2 - def test_set_and_get_from_memory(self, cache_manager): - """Should store and retrieve from memory cache.""" + def test_generate_cache_key_with_bench_hash(self): + """Should include bench_hash in cache key generation.""" + key1 = CacheManager.generate_cache_key("test", "gpt-4", bench_hash="repo:model:1") + key2 = CacheManager.generate_cache_key("test", "gpt-4", bench_hash="repo:model:2") + key3 = CacheManager.generate_cache_key("test", "gpt-4", bench_hash=None) + assert key1 != key2 + assert key1 != key3 + + def test_set_and_get(self, cache_manager): + """Should store and retrieve from cache.""" cache_key = CacheManager.generate_cache_key("test", "gpt-4") cache_manager.set_response(cache_key, "response content", "test", "gpt-4") - cached = cache_manager.get_from_memory(cache_key) + cached = cache_manager.get(cache_key) assert cached is not None assert cached["content"] == "response content" assert cached["model"] == "gpt-4" - def test_set_and_get_from_disk(self, cache_manager): - """Should store and retrieve from disk cache.""" + def test_get_from_memory_alias(self, cache_manager): + """get_from_memory should work as alias for get (backward compat).""" cache_key = CacheManager.generate_cache_key("test", "gpt-4") - cache_manager.set_response(cache_key, "disk content", "test", "gpt-4") - - # Clear memory to force disk read - cache_manager.clear_memory() + cache_manager.set_response(cache_key, "response content", "test", "gpt-4") - cached = cache_manager.get_from_disk(cache_key) + cached = cache_manager.get_from_memory(cache_key) assert cached is not None - assert cached["content"] == "disk content" + assert cached["content"] == "response content" - def test_get_checks_memory_first_then_disk(self, cache_manager): - """Should check memory cache first, then disk.""" + def test_get_from_disk_alias(self, cache_manager): + """get_from_disk should work as alias for get (backward compat).""" cache_key = CacheManager.generate_cache_key("test", "gpt-4") - cache_manager.set_response(cache_key, "original content", "test", "gpt-4") - - # Clear memory - cache_manager.clear_memory() - assert cache_manager.get_from_memory(cache_key) is None + cache_manager.set_response(cache_key, "disk content", "test", "gpt-4") - # get() should load from disk and populate memory - cached = cache_manager.get(cache_key) + cached = cache_manager.get_from_disk(cache_key) assert cached is not None - assert cached["content"] == "original content" - - # Now it should be in memory - assert cache_manager.get_from_memory(cache_key) is not None + assert cached["content"] == "disk content" def test_get_returns_none_for_missing_key(self, cache_manager): """Should return None for non-existent cache key.""" cached = cache_manager.get("nonexistent_key") assert cached is None - def test_clear_memory(self, cache_manager): - """Should clear memory cache.""" + def test_clear_all(self, cache_manager): + """Should clear all cache entries.""" cache_key = CacheManager.generate_cache_key("test", "gpt-4") cache_manager.set_response(cache_key, "content", "test", "gpt-4") - assert cache_manager.get_from_memory(cache_key) is not None + # Verify entry exists + assert cache_manager.get(cache_key) is not None - cache_manager.clear_memory() - assert cache_manager.get_from_memory(cache_key) is None + cache_manager.clear_all() - def test_clear_disk(self, cache_manager, temp_cache_dir): - """Should clear disk cache.""" + # Entry should be gone + assert cache_manager.get(cache_key) is None + + def test_clear_disk(self, cache_manager): + """Should clear all disk cache entries.""" cache_key = CacheManager.generate_cache_key("test", "gpt-4") cache_manager.set_response(cache_key, "content", "test", "gpt-4") - # Verify file exists - cache_files = list(Path(temp_cache_dir).glob("*.json")) - assert len(cache_files) == 1 + # Verify entry exists + stats_before = cache_manager.get_cache_stats() + assert stats_before["entries"] >= 1 cache_manager.clear_disk() - # Verify file is deleted - cache_files = list(Path(temp_cache_dir).glob("*.json")) - assert len(cache_files) == 0 - - def test_clear_all(self, cache_manager, temp_cache_dir): - """Should clear both memory and disk cache.""" - cache_key = CacheManager.generate_cache_key("test", "gpt-4") - cache_manager.set_response(cache_key, "content", "test", "gpt-4") - - cache_manager.clear_all() - - assert cache_manager.get_from_memory(cache_key) is None - assert len(list(Path(temp_cache_dir).glob("*.json"))) == 0 + # All entries should be gone + stats_after = cache_manager.get_cache_stats() + assert stats_after["entries"] == 0 def test_get_cache_stats(self, cache_manager, temp_cache_dir): """Should return accurate cache statistics.""" @@ -138,9 +127,10 @@ def test_get_cache_stats(self, cache_manager, temp_cache_dir): stats = cache_manager.get_cache_stats() - assert stats["memory_entries"] == 3 - assert stats["disk_entries"] == 3 - assert stats["disk_size_bytes"] > 0 + assert stats["entries"] == 3 + assert stats["memory_entries"] == 3 # Backward compat + assert stats["disk_entries"] == 3 # Backward compat + assert stats["size_bytes"] > 0 assert stats["cache_dir"] == temp_cache_dir def test_cache_with_usage_data(self, cache_manager): @@ -162,34 +152,36 @@ def test_cache_includes_timestamp(self, cache_manager): assert "cached_at" in cached assert cached["cached_at"] is not None + def test_invalidate(self, cache_manager): + """Should remove specific cache entry.""" + cache_key = CacheManager.generate_cache_key("test", "gpt-4") + cache_manager.set_response(cache_key, "content", "test", "gpt-4") -class TestCacheManagerCorruptedCache: - """Tests for handling corrupted cache files.""" - - @pytest.fixture - def temp_cache_dir(self): - """Create a temporary cache directory.""" - temp_dir = tempfile.mkdtemp() - yield temp_dir - shutil.rmtree(temp_dir, ignore_errors=True) + # Verify it exists + assert cache_manager.get(cache_key) is not None - def test_corrupted_cache_file_raises_error(self, temp_cache_dir): - """Should raise CacheError for corrupted cache file.""" - cache_manager = CacheManager(temp_cache_dir) + # Invalidate + cache_manager.invalidate(cache_key) - # Create corrupted cache file - cache_key = "corrupted_key" - cache_file = Path(temp_cache_dir) / f"{cache_key}.json" - cache_file.write_text("not valid json {{{") + # Should be gone + assert cache_manager.get(cache_key) is None - with pytest.raises(CacheError) as exc_info: - cache_manager.get_from_disk(cache_key) + def test_keys(self, cache_manager): + """Should return all cache keys.""" + keys_to_add = [] + for i in range(3): + key = CacheManager.generate_cache_key(f"test{i}", "gpt-4") + keys_to_add.append(key) + cache_manager.set_response(key, f"content {i}", f"test{i}", "gpt-4") - assert "Corrupted cache file" in str(exc_info.value) + stored_keys = cache_manager.keys() + assert len(stored_keys) == 3 + for key in keys_to_add: + assert key in stored_keys -class TestCacheManagerErrors: - """Tests for error handling in CacheManager.""" +class TestCacheManagerExport: + """Tests for cache export functionality.""" @pytest.fixture def temp_cache_dir(self): @@ -198,52 +190,48 @@ def temp_cache_dir(self): yield temp_dir shutil.rmtree(temp_dir, ignore_errors=True) - def test_get_from_disk_generic_error(self, temp_cache_dir): - """Should raise CacheError for generic read errors.""" - from unittest.mock import patch + def test_export_to_json(self, temp_cache_dir): + """Should export cache contents to JSON.""" + import json cache_manager = CacheManager(temp_cache_dir) - # Create a valid cache file first - cache_key = "test_key" - cache_file = Path(temp_cache_dir) / f"{cache_key}.json" - cache_file.write_text('{"content": "test"}') - - # Mock open to raise a generic exception - with patch("builtins.open", side_effect=PermissionError("Access denied")): - with pytest.raises(CacheError) as exc_info: - cache_manager.get_from_disk(cache_key) - - assert "Error reading cache file" in str(exc_info.value) + # Add some entries + for i in range(3): + key = CacheManager.generate_cache_key(f"test{i}", "gpt-4") + cache_manager.set_response(key, f"content {i}", f"test{i}", "gpt-4") - def test_set_write_error(self, temp_cache_dir): - """Should raise CacheError when write fails.""" - from unittest.mock import patch + # Export + export_path = Path(temp_cache_dir) / "export.json" + count = cache_manager.export_to_json(export_path) - cache_manager = CacheManager(temp_cache_dir) + assert count == 3 + assert export_path.exists() - # Mock open to raise exception during write - with patch("builtins.open", side_effect=PermissionError("Access denied")): - with pytest.raises(CacheError) as exc_info: - cache_manager.set_response("key", "content", "prompt", "model") + # Verify JSON contents + with open(export_path) as f: + data = json.load(f) - assert "Error writing cache file" in str(exc_info.value) + assert data["entry_count"] == 3 + assert len(data["entries"]) == 3 + assert all("key" in entry for entry in data["entries"]) + assert all("value" in entry for entry in data["entries"]) - def test_clear_disk_error(self, temp_cache_dir): - """Should raise CacheError when delete fails.""" - from unittest.mock import patch + def test_export_keys_only(self, temp_cache_dir): + """Should export only keys when include_values=False.""" + import json cache_manager = CacheManager(temp_cache_dir) + key = CacheManager.generate_cache_key("test", "gpt-4") + cache_manager.set_response(key, "content", "test", "gpt-4") - # Add a cache entry - cache_manager.set_response("key", "content", "prompt", "model") + export_path = Path(temp_cache_dir) / "keys_only.json" + cache_manager.export_to_json(export_path, include_values=False) - # Mock unlink to fail - with patch.object(Path, "unlink", side_effect=PermissionError("Access denied")): - with pytest.raises(CacheError) as exc_info: - cache_manager.clear_disk() + with open(export_path) as f: + data = json.load(f) - assert "Error clearing disk cache" in str(exc_info.value) + assert "value" not in data["entries"][0] class TestCachedLLMCallDecorator: diff --git a/tests/test_adapters/llm/test_manager.py b/tests/test_adapters/llm/test_manager.py index 1a6837b..eb8db44 100644 --- a/tests/test_adapters/llm/test_manager.py +++ b/tests/test_adapters/llm/test_manager.py @@ -19,7 +19,6 @@ LiveResponse, ) - # ============================================================================= # load_benchmark_models() Tests # ============================================================================= @@ -528,7 +527,7 @@ def test_query_returns_cached_response(self, tmp_path, monkeypatch): with patch("deriva.adapters.llm.manager.load_dotenv"): with patch.dict("os.environ", env_vars, clear=True): manager = LLMManager() - manager.cache.get = MagicMock(return_value=cached_data) + manager.cache.get = MagicMock(return_value=cached_data) # type: ignore[method-assign] response = manager.query("Hello") @@ -595,8 +594,8 @@ def test_query_caches_successful_response(self, tmp_path, monkeypatch): with patch("deriva.adapters.llm.manager.Agent") as mock_agent_class: mock_agent_class.return_value.run_sync.return_value = mock_result manager = LLMManager() - manager.cache.get = MagicMock(return_value=None) - manager.cache.set_response = MagicMock() + manager.cache.get = MagicMock(return_value=None) # type: ignore[method-assign] + manager.cache.set_response = MagicMock() # type: ignore[method-assign] manager.query("Hello") @@ -636,7 +635,7 @@ def test_clear_cache(self, tmp_path, monkeypatch): with patch("deriva.adapters.llm.manager.load_dotenv"): with patch.dict("os.environ", env_vars, clear=True): manager = LLMManager() - manager.cache.clear_all = MagicMock() + manager.cache.clear_all = MagicMock() # type: ignore[method-assign] manager.clear_cache() @@ -655,7 +654,7 @@ def test_get_cache_stats(self, tmp_path, monkeypatch): with patch("deriva.adapters.llm.manager.load_dotenv"): with patch.dict("os.environ", env_vars, clear=True): manager = LLMManager() - manager.cache.get_cache_stats = MagicMock(return_value=mock_stats) + manager.cache.get_cache_stats = MagicMock(return_value=mock_stats) # type: ignore[method-assign] stats = manager.get_cache_stats() diff --git a/tests/test_adapters/llm/test_models.py b/tests/test_adapters/llm/test_models.py index aaaf3f5..2803dc8 100644 --- a/tests/test_adapters/llm/test_models.py +++ b/tests/test_adapters/llm/test_models.py @@ -1,7 +1,6 @@ """Tests for managers.llm.models module.""" import pytest -from pydantic import BaseModel, Field from deriva.adapters.llm.models import ( BaseResponse, diff --git a/tests/test_adapters/llm/test_rate_limiter.py b/tests/test_adapters/llm/test_rate_limiter.py index 02344f5..4e43a6c 100644 --- a/tests/test_adapters/llm/test_rate_limiter.py +++ b/tests/test_adapters/llm/test_rate_limiter.py @@ -22,24 +22,15 @@ def test_default_values(self): config = RateLimitConfig() assert config.requests_per_minute == 60 assert config.min_request_delay == 0.0 - assert config.backoff_base == 2.0 - assert config.backoff_max == 60.0 - assert config.backoff_jitter == 0.1 def test_custom_values(self): """Should accept custom values.""" config = RateLimitConfig( requests_per_minute=100, min_request_delay=0.5, - backoff_base=3.0, - backoff_max=120.0, - backoff_jitter=0.2, ) assert config.requests_per_minute == 100 assert config.min_request_delay == 0.5 - assert config.backoff_base == 3.0 - assert config.backoff_max == 120.0 - assert config.backoff_jitter == 0.2 class TestGetDefaultRateLimit: @@ -52,7 +43,7 @@ def test_known_providers(self): assert get_default_rate_limit("anthropic") == 60 assert get_default_rate_limit("ollama") == 0 assert get_default_rate_limit("lmstudio") == 0 - assert get_default_rate_limit("claudecode") == 30 + assert get_default_rate_limit("mistral") == 24 def test_case_insensitive(self): """Should be case insensitive.""" @@ -67,7 +58,7 @@ def test_unknown_provider_returns_default(self): def test_all_providers_in_dict(self): """All expected providers should be in DEFAULT_RATE_LIMITS.""" - expected = {"azure", "openai", "anthropic", "mistral", "ollama", "lmstudio", "claudecode"} + expected = {"azure", "openai", "anthropic", "mistral", "ollama", "lmstudio"} assert set(DEFAULT_RATE_LIMITS.keys()) == expected @@ -78,9 +69,9 @@ def test_default_initialization(self): """Should initialize with default config.""" limiter = RateLimiter() assert limiter.config.requests_per_minute == 60 - assert limiter._consecutive_rate_limits == 0 assert limiter._last_request_time == 0.0 assert len(limiter._request_times) == 0 + assert limiter._successful_requests == 0 def test_custom_config(self): """Should accept custom config.""" @@ -175,19 +166,19 @@ def test_old_timestamps_cleaned_up(self): class TestRateLimiterRecordSuccess: """Tests for RateLimiter.record_success method.""" - def test_resets_consecutive_rate_limits(self): - """Should reset consecutive rate limit counter on success.""" + def test_increments_successful_requests(self): + """Should increment successful requests counter.""" limiter = RateLimiter() - limiter._consecutive_rate_limits = 5 limiter.record_success() + assert limiter._successful_requests == 1 - assert limiter._consecutive_rate_limits == 0 + limiter.record_success() + assert limiter._successful_requests == 2 def test_thread_safe(self): """Should be thread-safe.""" limiter = RateLimiter() - limiter._consecutive_rate_limits = 10 # Call from multiple threads threads = [threading.Thread(target=limiter.record_success) for _ in range(5)] @@ -196,64 +187,7 @@ def test_thread_safe(self): for t in threads: t.join() - assert limiter._consecutive_rate_limits == 0 - - -class TestRateLimiterRecordRateLimit: - """Tests for RateLimiter.record_rate_limit method.""" - - def test_increments_consecutive_count(self): - """Should increment consecutive rate limit counter.""" - limiter = RateLimiter() - - limiter.record_rate_limit() - assert limiter._consecutive_rate_limits == 1 - - limiter.record_rate_limit() - assert limiter._consecutive_rate_limits == 2 - - def test_returns_exponential_backoff(self): - """Should return exponentially increasing backoff times.""" - config = RateLimitConfig(backoff_base=2.0, backoff_jitter=0.0) - limiter = RateLimiter(config=config) - - delay1 = limiter.record_rate_limit() # 2^1 = 2 - delay2 = limiter.record_rate_limit() # 2^2 = 4 - delay3 = limiter.record_rate_limit() # 2^3 = 8 - - assert delay1 == 2.0 - assert delay2 == 4.0 - assert delay3 == 8.0 - - def test_respects_max_backoff(self): - """Should cap backoff at configured maximum.""" - config = RateLimitConfig(backoff_base=2.0, backoff_max=10.0, backoff_jitter=0.0) - limiter = RateLimiter(config=config) - - # Force high consecutive count - limiter._consecutive_rate_limits = 10 # Would be 2^11 = 2048 - - delay = limiter.record_rate_limit() - - assert delay == 10.0 # Capped at max - - def test_adds_jitter(self): - """Should add jitter when configured.""" - config = RateLimitConfig(backoff_base=2.0, backoff_jitter=0.5) - - # Collect delays with same consecutive count to test jitter variation - delays = [] - for _ in range(10): - limiter = RateLimiter(config=config) # Fresh limiter each time - delay = limiter.record_rate_limit() # First rate limit = 2^1 = 2 - delays.append(delay) - - # With 50% jitter, delay should be 2 + random(0, 1.0) = 2.0 to 3.0 - for delay in delays: - assert 2.0 <= delay <= 3.0 - - # Should have some variation (not all exactly the same) - assert len(set(delays)) > 1, "Jitter should cause variation in delays" + assert limiter._successful_requests == 5 class TestRateLimiterGetStats: @@ -267,13 +201,13 @@ def test_returns_current_state(self): # Make some requests limiter.wait_if_needed() limiter.wait_if_needed() - limiter._consecutive_rate_limits = 3 + limiter.record_success() stats = limiter.get_stats() assert stats["requests_last_minute"] == 2 assert stats["rpm_limit"] == 100 - assert stats["consecutive_rate_limits"] == 3 + assert stats["successful_requests"] == 1 assert stats["min_request_delay"] == 0.5 def test_excludes_old_requests(self): @@ -315,23 +249,15 @@ def make_request(): assert len(limiter._request_times) == 10 assert len(results) == 10 - def test_concurrent_rate_limit_recording(self): - """Should handle concurrent rate limit recordings safely.""" - config = RateLimitConfig(backoff_jitter=0.0) - limiter = RateLimiter(config=config) - - delays = [] - - def record_limit(): - delay = limiter.record_rate_limit() - delays.append(delay) + def test_concurrent_success_recording(self): + """Should handle concurrent success recordings safely.""" + limiter = RateLimiter() - threads = [threading.Thread(target=record_limit) for _ in range(5)] + threads = [threading.Thread(target=limiter.record_success) for _ in range(5)] for t in threads: t.start() for t in threads: t.join() # Should have incremented 5 times - assert limiter._consecutive_rate_limits == 5 - assert len(delays) == 5 + assert limiter._successful_requests == 5 diff --git a/tests/test_adapters/llm/test_retry.py b/tests/test_adapters/llm/test_retry.py new file mode 100644 index 0000000..71a9318 --- /dev/null +++ b/tests/test_adapters/llm/test_retry.py @@ -0,0 +1,276 @@ +"""Tests for adapters.llm.retry module.""" + +from __future__ import annotations + +from typing import Any, cast + +import pytest +from backoff._typing import Details + +from deriva.adapters.llm.retry import ( + RETRIABLE_EXCEPTIONS, + create_retry_decorator, + on_backoff, + on_giveup, + retry_on_rate_limit, +) + + +class TestRetriableExceptions: + """Tests for RETRIABLE_EXCEPTIONS constant.""" + + def test_contains_connection_error(self): + """Should include ConnectionError.""" + assert ConnectionError in RETRIABLE_EXCEPTIONS + + def test_contains_timeout_error(self): + """Should include TimeoutError.""" + assert TimeoutError in RETRIABLE_EXCEPTIONS + + def test_contains_os_error(self): + """Should include OSError for network errors.""" + assert OSError in RETRIABLE_EXCEPTIONS + + +class TestOnBackoff: + """Tests for on_backoff callback function.""" + + def test_logs_backoff_event(self, caplog): + """Should log warning with backoff details.""" + import logging + + caplog.set_level(logging.WARNING) + + def dummy_func(): + pass + + # Cast to Any because backoff adds 'exception' at runtime but it's not in the TypedDict + details = cast( + Details, + { + "target": dummy_func, + "args": (), + "kwargs": {}, + "tries": 2, + "elapsed": 0.0, + "wait": 2.5, + "exception": ConnectionError("Connection refused"), + }, + ) + + on_backoff(details) + + assert "Retry 2" in caplog.text + assert "dummy_func" in caplog.text + assert "2.50s" in caplog.text + + def test_handles_missing_details(self, caplog): + """Should handle missing details gracefully.""" + import logging + + caplog.set_level(logging.WARNING) + + on_backoff(cast(Any, {})) + + # Should still log without crashing + assert "Retry 0" in caplog.text + + +class TestOnGiveup: + """Tests for on_giveup callback function.""" + + def test_logs_giveup_event(self, caplog): + """Should log error when retries exhausted.""" + import logging + + caplog.set_level(logging.ERROR) + + def dummy_func(): + pass + + # Cast to Details because backoff adds 'exception' at runtime but it's not in the TypedDict + details = cast( + Details, + { + "target": dummy_func, + "args": (), + "kwargs": {}, + "tries": 5, + "elapsed": 0.0, + "exception": TimeoutError("Request timed out"), + }, + ) + + on_giveup(details) + + assert "Giving up" in caplog.text + assert "dummy_func" in caplog.text + assert "5 attempts" in caplog.text + + def test_handles_missing_details(self, caplog): + """Should handle missing details gracefully.""" + import logging + + caplog.set_level(logging.ERROR) + + on_giveup(cast(Any, {})) + + # Should still log without crashing + assert "Giving up" in caplog.text + + +class TestCreateRetryDecorator: + """Tests for create_retry_decorator function.""" + + def test_returns_decorator(self): + """Should return a callable decorator.""" + decorator = create_retry_decorator() + assert callable(decorator) + + def test_decorated_function_works_on_success(self): + """Should allow successful function calls.""" + decorator = create_retry_decorator(max_retries=2) + + @decorator + def successful_func(): + return "success" + + result = successful_func() + assert result == "success" + + def test_retries_on_retriable_exception(self): + """Should retry on retriable exceptions.""" + call_count = 0 + + decorator = create_retry_decorator(max_retries=3, base_delay=0.01) + + @decorator + def flaky_func(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise ConnectionError("Connection failed") + return "success" + + result = flaky_func() + assert result == "success" + assert call_count == 3 + + def test_gives_up_after_max_retries(self): + """Should give up after max retries exceeded.""" + decorator = create_retry_decorator(max_retries=2, base_delay=0.01) + + @decorator + def always_fails(): + raise TimeoutError("Always times out") + + with pytest.raises(TimeoutError): + always_fails() + + def test_does_not_retry_non_retriable_exceptions(self): + """Should not retry on non-retriable exceptions.""" + call_count = 0 + decorator = create_retry_decorator(max_retries=3) + + @decorator + def raises_value_error(): + nonlocal call_count + call_count += 1 + raise ValueError("Not retriable") + + with pytest.raises(ValueError): + raises_value_error() + + assert call_count == 1 # No retries + + def test_custom_exceptions(self): + """Should retry on custom exception types.""" + + class CustomError(Exception): + pass + + call_count = 0 + decorator = create_retry_decorator(max_retries=2, base_delay=0.01, exceptions=(CustomError,)) + + @decorator + def raises_custom(): + nonlocal call_count + call_count += 1 + if call_count < 2: + raise CustomError("Custom error") + return "success" + + result = raises_custom() + assert result == "success" + assert call_count == 2 + + +class TestRetryOnRateLimit: + """Tests for retry_on_rate_limit decorator.""" + + def test_returns_decorator(self): + """Should return a callable decorator.""" + decorator = retry_on_rate_limit() + assert callable(decorator) + + def test_decorated_function_works_on_success(self): + """Should allow successful function calls.""" + decorator = retry_on_rate_limit(max_retries=2) + + @decorator + def successful_func(): + return "success" + + result = successful_func() + assert result == "success" + + def test_retries_on_connection_error(self): + """Should retry on ConnectionError.""" + call_count = 0 + decorator = retry_on_rate_limit(max_retries=3, base_delay=0.01) + + @decorator + def flaky_func(): + nonlocal call_count + call_count += 1 + if call_count < 2: + raise ConnectionError("Rate limited") + return "success" + + result = flaky_func() + assert result == "success" + assert call_count == 2 + + def test_retries_on_timeout_error(self): + """Should retry on TimeoutError.""" + call_count = 0 + decorator = retry_on_rate_limit(max_retries=3, base_delay=0.01) + + @decorator + def flaky_func(): + nonlocal call_count + call_count += 1 + if call_count < 2: + raise TimeoutError("Timed out") + return "success" + + result = flaky_func() + assert result == "success" + assert call_count == 2 + + def test_handles_httpx_import_error(self): + """Should work even if httpx not installed.""" + # Just verify it creates a decorator without crashing + decorator = retry_on_rate_limit(max_retries=2) + assert callable(decorator) + + def test_gives_up_after_max_retries(self): + """Should give up after max retries.""" + decorator = retry_on_rate_limit(max_retries=2, base_delay=0.01) + + @decorator + def always_fails(): + raise ConnectionError("Always fails") + + with pytest.raises(ConnectionError): + always_fails() diff --git a/tests/test_adapters/treesitter/test_csharp.py b/tests/test_adapters/treesitter/test_csharp.py deleted file mode 100644 index 2419576..0000000 --- a/tests/test_adapters/treesitter/test_csharp.py +++ /dev/null @@ -1,388 +0,0 @@ -"""Tests for C# language extractor.""" - -from __future__ import annotations - -from deriva.adapters.treesitter import TreeSitterManager - - -class TestCSharpTypes: - """Tests for C# type extraction.""" - - def setup_method(self): - """Set up test fixtures.""" - self.manager = TreeSitterManager() - - def test_extracts_class(self): - """Should extract C# class.""" - source = """ -public class UserService -{ - private string _name; -} -""" - types = self.manager.extract_types(source, language="csharp") - - assert len(types) >= 1 - service = next((t for t in types if t.name == "UserService"), None) - assert service is not None - assert service.kind == "class" - - def test_extracts_class_with_inheritance(self): - """Should extract class with base class.""" - source = """ -public class Admin : User -{ - public int Level { get; set; } -} -""" - types = self.manager.extract_types(source, language="csharp") - - admin = next((t for t in types if t.name == "Admin"), None) - assert admin is not None - - def test_extracts_class_with_interface(self): - """Should extract class implementing interface.""" - source = """ -public class UserRepository : IRepository, IDisposable -{ - public void Dispose() { } -} -""" - types = self.manager.extract_types(source, language="csharp") - - repo = next((t for t in types if t.name == "UserRepository"), None) - assert repo is not None - - def test_extracts_interface(self): - """Should extract C# interface.""" - source = """ -public interface IRepository -{ - T GetById(int id); - void Save(T entity); -} -""" - types = self.manager.extract_types(source, language="csharp") - - repo = next((t for t in types if t.name == "IRepository"), None) - assert repo is not None - assert repo.kind == "interface" - - def test_extracts_struct(self): - """Should extract C# struct.""" - source = """ -public struct Point -{ - public int X; - public int Y; -} -""" - types = self.manager.extract_types(source, language="csharp") - - point = next((t for t in types if t.name == "Point"), None) - assert point is not None - - def test_extracts_enum(self): - """Should extract C# enum.""" - source = """ -public enum Status -{ - Pending, - Active, - Completed -} -""" - types = self.manager.extract_types(source, language="csharp") - - status = next((t for t in types if t.name == "Status"), None) - assert status is not None - - def test_extracts_abstract_class(self): - """Should extract abstract class.""" - source = """ -public abstract class BaseEntity -{ - protected int Id { get; set; } - public abstract void Validate(); -} -""" - types = self.manager.extract_types(source, language="csharp") - - entity = next((t for t in types if t.name == "BaseEntity"), None) - assert entity is not None - - def test_extracts_record(self): - """Should extract C# record.""" - source = """ -public record User(string Name, int Age); -""" - types = self.manager.extract_types(source, language="csharp") - user = next((t for t in types if t.name == "User"), None) - - assert user is not None - - -class TestCSharpMethods: - """Tests for C# method extraction.""" - - def setup_method(self): - """Set up test fixtures.""" - self.manager = TreeSitterManager() - - def test_extracts_public_method(self): - """Should extract public method.""" - source = """ -public class Calculator -{ - public int Add(int a, int b) - { - return a + b; - } -} -""" - methods = self.manager.extract_methods(source, language="csharp") - - add = next((m for m in methods if m.name == "Add"), None) - assert add is not None - - def test_extracts_private_method(self): - """Should extract private method.""" - source = """ -public class Service -{ - private void Helper() - { - // internal logic - } -} -""" - methods = self.manager.extract_methods(source, language="csharp") - - helper = next((m for m in methods if m.name == "Helper"), None) - assert helper is not None - - def test_extracts_static_method(self): - """Should extract static method.""" - source = """ -public class Utils -{ - public static string Format(string input) - { - return input.Trim(); - } -} -""" - methods = self.manager.extract_methods(source, language="csharp") - - format_method = next((m for m in methods if m.name == "Format"), None) - assert format_method is not None - - def test_extracts_async_method(self): - """Should extract async method.""" - source = """ -public class ApiClient -{ - public async Task GetUserAsync(int id) - { - return await _repository.FindAsync(id); - } -} -""" - methods = self.manager.extract_methods(source, language="csharp") - - get_user = next((m for m in methods if m.name == "GetUserAsync"), None) - assert get_user is not None - - def test_extracts_generic_method(self): - """Should handle methods with generics.""" - source = """ -public class Repository -{ - public T FindById(int id) where T : class - { - return null; - } -} -""" - methods = self.manager.extract_methods(source, language="csharp") - - find_by_id = next((m for m in methods if m.name == "FindById"), None) - assert find_by_id is not None - - def test_extracts_constructor(self): - """Should extract constructor.""" - source = """ -public class Person -{ - public Person(string name) - { - Name = name; - } - - public string Name { get; } -} -""" - methods = self.manager.extract_methods(source, language="csharp") - constructor = next((m for m in methods if m.name == "Person"), None) - - assert constructor is not None - assert constructor.class_name == "Person" - - def test_extracts_class_with_property_accessors(self): - """Should handle class with property accessors without crashing.""" - source = """ -public class User -{ - private string _name; - - public string Name - { - get { return _name; } - set { _name = value; } - } - - public void SetName(string name) { _name = name; } -} -""" - methods = self.manager.extract_methods(source, language="csharp") - set_name = next((m for m in methods if m.name == "SetName"), None) - - # Regular method should be extracted even with properties in class - assert set_name is not None - - -class TestCSharpImports: - """Tests for C# using statement extraction.""" - - def setup_method(self): - """Set up test fixtures.""" - self.manager = TreeSitterManager() - - def test_extracts_using(self): - """Should extract using statement.""" - source = """ -using System; -""" - imports = self.manager.extract_imports(source, language="csharp") - - assert len(imports) >= 1 - - def test_extracts_multiple_usings(self): - """Should extract multiple using statements.""" - source = """ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; -""" - imports = self.manager.extract_imports(source, language="csharp") - - assert len(imports) >= 1 - - def test_extracts_using_with_alias(self): - """Should extract using with alias.""" - source = """ -using Console = System.Console; -using Dict = System.Collections.Generic.Dictionary; -""" - imports = self.manager.extract_imports(source, language="csharp") - - assert len(imports) >= 1 - - def test_extracts_global_using(self): - """Should handle global using.""" - source = """ -global using System; -global using System.Collections.Generic; -""" - imports = self.manager.extract_imports(source, language="csharp") - - assert len(imports) >= 2 - - -class TestCSharpEdgeCases: - """Tests for edge cases in C# extraction.""" - - def setup_method(self): - """Set up test fixtures.""" - self.manager = TreeSitterManager() - - def test_handles_empty_file(self): - """Should handle empty file.""" - types = self.manager.extract_types("", language="csharp") - assert types == [] - - def test_handles_namespace(self): - """Should handle namespace declaration.""" - source = """ -namespace MyApp.Services -{ - public class UserService - { - } -} -""" - types = self.manager.extract_types(source, language="csharp") - - service = next((t for t in types if t.name == "UserService"), None) - assert service is not None - - def test_handles_file_scoped_namespace(self): - """Should handle file-scoped namespace.""" - source = """ -namespace MyApp.Services; - -public class UserService -{ -} -""" - types = self.manager.extract_types(source, language="csharp") - service = next((t for t in types if t.name == "UserService"), None) - - assert service is not None - - def test_handles_attributes(self): - """Should handle classes with attributes.""" - source = """ -[Serializable] -[Table("users")] -public class User -{ - [Key] - public int Id { get; set; } -} -""" - types = self.manager.extract_types(source, language="csharp") - - user = next((t for t in types if t.name == "User"), None) - assert user is not None - - def test_handles_nullable_types(self): - """Should handle nullable reference types.""" - source = """ -public class Service -{ - public string? GetValue(int? id) - { - return null; - } -} -""" - methods = self.manager.extract_methods(source, language="csharp") - get_value = next((m for m in methods if m.name == "GetValue"), None) - - assert get_value is not None - assert get_value.class_name == "Service" - - def test_handles_partial_class(self): - """Should handle partial class.""" - source = """ -public partial class User -{ - public string Name { get; set; } -} -""" - types = self.manager.extract_types(source, language="csharp") - - user = next((t for t in types if t.name == "User"), None) - assert user is not None diff --git a/tests/test_adapters/treesitter/test_java.py b/tests/test_adapters/treesitter/test_java.py deleted file mode 100644 index 87908ff..0000000 --- a/tests/test_adapters/treesitter/test_java.py +++ /dev/null @@ -1,301 +0,0 @@ -"""Tests for Java language extractor.""" - -from __future__ import annotations - -from deriva.adapters.treesitter import TreeSitterManager - - -class TestJavaTypes: - """Tests for Java type extraction.""" - - def setup_method(self): - """Set up test fixtures.""" - self.manager = TreeSitterManager() - - def test_extracts_class(self): - """Should extract Java class.""" - source = """ -public class UserService { - private String name; -} -""" - types = self.manager.extract_types(source, language="java") - - assert len(types) >= 1 - service = next((t for t in types if t.name == "UserService"), None) - assert service is not None - assert service.kind == "class" - - def test_extracts_class_with_extends(self): - """Should extract class inheritance.""" - source = """ -public class Admin extends User { - private int level; -} -""" - types = self.manager.extract_types(source, language="java") - - admin = next((t for t in types if t.name == "Admin"), None) - assert admin is not None - - def test_extracts_class_with_implements(self): - """Should extract interface implementation.""" - source = """ -public class UserServiceImpl implements UserService, Serializable { - public void save(User user) {} -} -""" - types = self.manager.extract_types(source, language="java") - - impl = next((t for t in types if t.name == "UserServiceImpl"), None) - assert impl is not None - - def test_extracts_interface(self): - """Should extract Java interface.""" - source = """ -public interface Repository { - T findById(Long id); - void save(T entity); -} -""" - types = self.manager.extract_types(source, language="java") - - repo = next((t for t in types if t.name == "Repository"), None) - assert repo is not None - assert repo.kind == "interface" - - def test_extracts_enum(self): - """Should extract Java enum.""" - source = """ -public enum Status { - PENDING, - ACTIVE, - COMPLETED -} -""" - types = self.manager.extract_types(source, language="java") - - status = next((t for t in types if t.name == "Status"), None) - assert status is not None - - def test_extracts_abstract_class(self): - """Should extract abstract class.""" - source = """ -public abstract class BaseEntity { - protected Long id; - - public abstract void validate(); -} -""" - types = self.manager.extract_types(source, language="java") - - entity = next((t for t in types if t.name == "BaseEntity"), None) - assert entity is not None - - def test_extracts_inner_class(self): - """Should handle inner classes.""" - source = """ -public class Outer { - public class Inner { - private int value; - } -} -""" - types = self.manager.extract_types(source, language="java") - - # Should extract at least outer class - assert len(types) >= 1 - - -class TestJavaMethods: - """Tests for Java method extraction.""" - - def setup_method(self): - """Set up test fixtures.""" - self.manager = TreeSitterManager() - - def test_extracts_public_method(self): - """Should extract public method.""" - source = """ -public class Calculator { - public int add(int a, int b) { - return a + b; - } -} -""" - methods = self.manager.extract_methods(source, language="java") - - add = next((m for m in methods if m.name == "add"), None) - assert add is not None - - def test_extracts_private_method(self): - """Should extract private method.""" - source = """ -public class Service { - private void helper() { - // internal logic - } -} -""" - methods = self.manager.extract_methods(source, language="java") - - helper = next((m for m in methods if m.name == "helper"), None) - assert helper is not None - - def test_extracts_static_method(self): - """Should extract static method.""" - source = """ -public class Utils { - public static String format(String input) { - return input.trim(); - } -} -""" - methods = self.manager.extract_methods(source, language="java") - - format_method = next((m for m in methods if m.name == "format"), None) - assert format_method is not None - - def test_extracts_method_with_generics(self): - """Should handle methods with generics.""" - source = """ -public class Repository { - public List findAll(Class type) { - return new ArrayList<>(); - } -} -""" - methods = self.manager.extract_methods(source, language="java") - - find_all = next((m for m in methods if m.name == "findAll"), None) - assert find_all is not None - - def test_extracts_constructor(self): - """Should extract constructor.""" - source = """ -public class Person { - private String name; - - public Person(String name) { - this.name = name; - } -} -""" - methods = self.manager.extract_methods(source, language="java") - constructor = next((m for m in methods if m.name == "Person"), None) - - assert constructor is not None - assert constructor.class_name == "Person" - - def test_extracts_overloaded_methods(self): - """Should extract overloaded methods.""" - source = """ -public class Printer { - public void print(String s) {} - public void print(int i) {} - public void print(String s, int count) {} -} -""" - methods = self.manager.extract_methods(source, language="java") - - print_methods = [m for m in methods if m.name == "print"] - assert len(print_methods) == 3 - - -class TestJavaImports: - """Tests for Java import extraction.""" - - def setup_method(self): - """Set up test fixtures.""" - self.manager = TreeSitterManager() - - def test_extracts_single_import(self): - """Should extract single class import.""" - source = """ -import java.util.List; -""" - imports = self.manager.extract_imports(source, language="java") - - assert len(imports) == 1 - - def test_extracts_wildcard_import(self): - """Should extract wildcard import.""" - source = """ -import java.util.*; -""" - imports = self.manager.extract_imports(source, language="java") - - assert len(imports) == 1 - - def test_extracts_static_import(self): - """Should extract static import.""" - source = """ -import static java.lang.Math.PI; -import static java.lang.Math.*; -""" - imports = self.manager.extract_imports(source, language="java") - - assert len(imports) >= 1 - - def test_extracts_multiple_imports(self): - """Should extract multiple imports.""" - source = """ -import java.util.List; -import java.util.ArrayList; -import java.util.Map; -import java.io.IOException; -""" - imports = self.manager.extract_imports(source, language="java") - - assert len(imports) == 4 - - -class TestJavaEdgeCases: - """Tests for edge cases in Java extraction.""" - - def setup_method(self): - """Set up test fixtures.""" - self.manager = TreeSitterManager() - - def test_handles_empty_file(self): - """Should handle empty file.""" - types = self.manager.extract_types("", language="java") - assert types == [] - - def test_handles_annotations(self): - """Should handle annotated classes.""" - source = """ -@Entity -@Table(name = "users") -public class User { - @Id - private Long id; -} -""" - types = self.manager.extract_types(source, language="java") - - user = next((t for t in types if t.name == "User"), None) - assert user is not None - - def test_handles_package_declaration(self): - """Should handle package declaration.""" - source = """ -package com.example.service; - -public class MyService { -} -""" - types = self.manager.extract_types(source, language="java") - - assert len(types) >= 1 - - def test_handles_record(self): - """Should handle Java records.""" - source = """ -public record User(String name, int age) { -} -""" - types = self.manager.extract_types(source, language="java") - user = next((t for t in types if t.name == "User"), None) - - assert user is not None diff --git a/tests/test_adapters/treesitter/test_javascript.py b/tests/test_adapters/treesitter/test_javascript.py deleted file mode 100644 index 3269185..0000000 --- a/tests/test_adapters/treesitter/test_javascript.py +++ /dev/null @@ -1,381 +0,0 @@ -"""Tests for JavaScript language extractor.""" - -from __future__ import annotations - -from deriva.adapters.treesitter import TreeSitterManager - - -class TestJavaScriptTypes: - """Tests for JavaScript type extraction.""" - - def setup_method(self): - """Set up test fixtures.""" - self.manager = TreeSitterManager() - - def test_extracts_class(self): - """Should extract ES6 class definition.""" - source = """ -class UserService { - constructor() { - this.users = []; - } -} -""" - types = self.manager.extract_types(source, language="javascript") - - assert len(types) >= 1 - service = next((t for t in types if t.name == "UserService"), None) - assert service is not None - assert service.kind == "class" - - def test_extracts_class_with_extends(self): - """Should extract class inheritance.""" - source = """ -class Admin extends User { - constructor() { - super(); - } -} -""" - types = self.manager.extract_types(source, language="javascript") - - admin = next((t for t in types if t.name == "Admin"), None) - assert admin is not None - assert admin.kind == "class" - - def test_extracts_function_declaration(self): - """Should extract function declaration.""" - source = """ -function processData(data) { - return data.map(x => x * 2); -} -""" - types = self.manager.extract_types(source, language="javascript") - - func = next((t for t in types if t.name == "processData"), None) - assert func is not None - assert func.kind == "function" - - def test_extracts_async_function(self): - """Should detect async function declarations.""" - source = """ -async function fetchUser(id) { - const response = await fetch(`/api/users/${id}`); - return response.json(); -} -""" - types = self.manager.extract_types(source, language="javascript") - - func = next((t for t in types if t.name == "fetchUser"), None) - assert func is not None - assert func.is_async is True - - def test_extracts_exported_function(self): - """Should extract exported functions.""" - source = """ -export function helper() { - return 42; -} - -export async function asyncHelper() { - return await Promise.resolve(42); -} -""" - types = self.manager.extract_types(source, language="javascript") - - names = {t.name for t in types} - assert "helper" in names - assert "asyncHelper" in names - - def test_extracts_exported_class(self): - """Should extract exported class.""" - source = """ -export class Service { - run() { - return 'running'; - } -} -""" - types = self.manager.extract_types(source, language="javascript") - - service = next((t for t in types if t.name == "Service"), None) - assert service is not None - - def test_extracts_arrow_function_variable(self): - """Should extract const arrow functions as types.""" - source = """ -const add = (a, b) => a + b; - -const multiply = (x, y) => { - return x * y; -}; -""" - types = self.manager.extract_types(source, language="javascript") - - # Arrow functions assigned to const should be extracted - multiply = next((t for t in types if t.name == "multiply"), None) - assert multiply is not None - assert multiply.kind == "function" - - def test_extracts_async_arrow_function(self): - """Should extract async arrow functions.""" - source = """ -const fetchData = async (url) => { - const res = await fetch(url); - return res.json(); -}; -""" - types = self.manager.extract_types(source, language="javascript") - fetch_func = next((t for t in types if t.name == "fetchData"), None) - - assert fetch_func is not None - assert fetch_func.is_async is True - - -class TestJavaScriptMethods: - """Tests for JavaScript method extraction.""" - - def setup_method(self): - """Set up test fixtures.""" - self.manager = TreeSitterManager() - - def test_extracts_class_methods(self): - """Should extract methods from class.""" - source = """ -class Calculator { - add(a, b) { - return a + b; - } - - subtract(a, b) { - return a - b; - } -} -""" - methods = self.manager.extract_methods(source, language="javascript") - - names = {m.name for m in methods} - assert "add" in names - assert "subtract" in names - - def test_extracts_constructor(self): - """Should extract constructor method.""" - source = """ -class Service { - constructor(config) { - this.config = config; - } -} -""" - methods = self.manager.extract_methods(source, language="javascript") - - constructor = next((m for m in methods if m.name == "constructor"), None) - assert constructor is not None - - def test_extracts_async_method(self): - """Should detect async class methods.""" - source = """ -class ApiClient { - async get(url) { - return await fetch(url); - } - - async post(url, data) { - return await fetch(url, { method: 'POST', body: data }); - } -} -""" - methods = self.manager.extract_methods(source, language="javascript") - - async_methods = [m for m in methods if m.is_async] - assert len(async_methods) >= 1 - - def test_extracts_static_method(self): - """Should extract static methods.""" - source = """ -class MathUtils { - static square(x) { - return x * x; - } - - static cube(x) { - return x * x * x; - } -} -""" - methods = self.manager.extract_methods(source, language="javascript") - - names = {m.name for m in methods} - assert "square" in names or "cube" in names - - def test_extracts_getter_setter(self): - """Should extract getters and setters.""" - source = """ -class Person { - get name() { - return this._name; - } - - set name(value) { - this._name = value; - } -} -""" - methods = self.manager.extract_methods(source, language="javascript") - - names = {m.name for m in methods} - assert "name" in names - - def test_extracts_top_level_function(self): - """Should extract top-level functions.""" - source = """ -function helper(x) { - return x * 2; -} -""" - methods = self.manager.extract_methods(source, language="javascript") - - helper = next((m for m in methods if m.name == "helper"), None) - assert helper is not None - assert helper.class_name is None - - def test_extracts_arrow_function_method(self): - """Should extract arrow function class fields.""" - source = """ -class Handler { - handleClick = () => { - console.log('clicked'); - }; - - handleSubmit = async (data) => { - await this.submit(data); - }; -} -""" - methods = self.manager.extract_methods(source, language="javascript") - names = {m.name for m in methods} - - # At minimum, one of the handlers should be extracted - assert "handleClick" in names or "handleSubmit" in names - - -class TestJavaScriptImports: - """Tests for JavaScript import extraction.""" - - def setup_method(self): - """Set up test fixtures.""" - self.manager = TreeSitterManager() - - def test_extracts_default_import(self): - """Should extract default import.""" - source = """ -import React from 'react'; -""" - imports = self.manager.extract_imports(source, language="javascript") - - assert len(imports) >= 1 - react_import = next((i for i in imports if "react" in i.module), None) - assert react_import is not None - - def test_extracts_named_imports(self): - """Should extract named imports.""" - source = """ -import { useState, useEffect, useCallback } from 'react'; -""" - imports = self.manager.extract_imports(source, language="javascript") - - assert len(imports) >= 1 - - def test_extracts_namespace_import(self): - """Should extract namespace imports.""" - source = """ -import * as utils from './utils'; -""" - imports = self.manager.extract_imports(source, language="javascript") - - assert len(imports) >= 1 - - def test_extracts_side_effect_import(self): - """Should extract side-effect imports.""" - source = """ -import './styles.css'; -import 'polyfill'; -""" - imports = self.manager.extract_imports(source, language="javascript") - - assert len(imports) >= 1 - - def test_extracts_mixed_imports(self): - """Should handle mixed import styles.""" - source = """ -import React, { Component, useState } from 'react'; -""" - imports = self.manager.extract_imports(source, language="javascript") - - assert len(imports) >= 1 - - def test_extracts_require_statement(self): - """Should extract CommonJS require statements.""" - source = """ -const express = require('express'); -const { Router } = require('express'); -""" - imports = self.manager.extract_imports(source, language="javascript") - - # CommonJS require should be captured as imports - modules = [i.module for i in imports] - assert any("express" in m for m in modules) - - -class TestJavaScriptEdgeCases: - """Tests for edge cases in JavaScript extraction.""" - - def setup_method(self): - """Set up test fixtures.""" - self.manager = TreeSitterManager() - - def test_handles_empty_file(self): - """Should handle empty file.""" - types = self.manager.extract_types("", language="javascript") - assert types == [] - - def test_handles_jsx_syntax(self): - """Should handle JSX syntax.""" - source = """ -function Component() { - return
Hello
; -} -""" - types = self.manager.extract_types(source, language="javascript") - component = next((t for t in types if t.name == "Component"), None) - - assert component is not None - assert component.kind == "function" - - def test_handles_template_literals(self): - """Should handle template literals in code.""" - source = """ -function greet(name) { - return `Hello, ${name}!`; -} -""" - types = self.manager.extract_types(source, language="javascript") - - assert len(types) >= 1 - - def test_handles_private_fields(self): - """Should handle private class fields.""" - source = """ -class Counter { - #count = 0; - - increment() { - this.#count++; - } -} -""" - methods = self.manager.extract_methods(source, language="javascript") - increment = next((m for m in methods if m.name == "increment"), None) - - assert increment is not None - assert increment.class_name == "Counter" diff --git a/tests/test_adapters/treesitter/test_languages.py b/tests/test_adapters/treesitter/test_languages.py new file mode 100644 index 0000000..c22b53f --- /dev/null +++ b/tests/test_adapters/treesitter/test_languages.py @@ -0,0 +1,543 @@ +"""Consolidated tests for all language extractors (Java, JavaScript, C#, Python). + +This file replaces the individual test_java.py, test_javascript.py, test_csharp.py, +and test_python.py files with parameterized tests that reduce duplication. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from deriva.adapters.treesitter import TreeSitterManager + +# ============================================================================= +# FIXTURES +# ============================================================================= + + +@pytest.fixture +def manager(): + """Provide a TreeSitterManager instance for all tests.""" + return TreeSitterManager() + + +# ============================================================================= +# LANGUAGE-SPECIFIC CODE SAMPLES +# ============================================================================= + +SAMPLES: dict[str, dict[str, Any]] = { + "java": { + "class": ("public class UserService { private String name; }", "UserService", "class"), + "class_inheritance": ("public class Admin extends User { private int level; }", "Admin"), + "interface": ("public interface Repository { T findById(Long id); }", "Repository", "interface"), + "enum": ("public enum Status { PENDING, ACTIVE, COMPLETED }", "Status"), + "abstract_class": ("public abstract class BaseEntity { public abstract void validate(); }", "BaseEntity"), + "inner_class": ("public class Outer { public class Inner { private int value; } }", "Outer"), + "method_public": ("public class Calc { public int add(int a, int b) { return a + b; } }", "add"), + "method_private": ("public class Svc { private void helper() { } }", "helper"), + "method_static": ("public class Utils { public static String format(String s) { return s; } }", "format"), + "constructor": ("public class Person { public Person(String name) { } }", "Person", "Person"), + "overloaded": ("public class P { public void print(String s) {} public void print(int i) {} }", "print", 2), + "import_single": ("import java.util.List;", 1), + "import_wildcard": ("import java.util.*;", 1), + "import_multiple": ("import java.util.List;\nimport java.util.Map;\nimport java.io.IOException;", 3), + "annotated_class": ('@Entity\n@Table(name = "users")\npublic class User { }', "User"), + "record": ("public record User(String name, int age) { }", "User"), + }, + "javascript": { + "class": ("class UserService { constructor() { this.users = []; } }", "UserService", "class"), + "class_inheritance": ("class Admin extends User { constructor() { super(); } }", "Admin"), + "function": ("function processData(data) { return data; }", "processData", "function"), + "async_function": ("async function fetchUser(id) { return await fetch(id); }", "fetchUser", True), + "arrow_function": ("const multiply = (x, y) => { return x * y; };", "multiply", "function"), + "async_arrow": ("const fetchData = async (url) => { return await fetch(url); };", "fetchData", True), + "exported_class": ("export class Service { run() { return true; } }", "Service"), + "method_class": ("class Calc { add(a, b) { return a + b; } subtract(a, b) { return a - b; } }", ["add", "subtract"]), + "constructor": ("class Service { constructor(config) { this.config = config; } }", "constructor"), + "async_method": ("class Api { async get(url) { return await fetch(url); } }", "get", True), + "static_method": ("class MathUtils { static square(x) { return x * x; } }", "square"), + "import_default": ("import React from 'react';", "react"), + "import_named": ("import { useState, useEffect } from 'react';", 1), + "import_namespace": ("import * as utils from './utils';", 1), + "import_require": ("const express = require('express');", "express"), + "jsx_function": ("function Component() { return
Hello
; }", "Component", "function"), + "private_fields": ("class Counter { #count = 0; increment() { this.#count++; } }", "increment", "Counter"), + }, + "csharp": { + "class": ("public class UserService { private string _name; }", "UserService", "class"), + "class_inheritance": ("public class Admin : User { public int Level { get; set; } }", "Admin"), + "interface": ("public interface IRepository { T GetById(int id); }", "IRepository", "interface"), + "struct": ("public struct Point { public int X; public int Y; }", "Point"), + "enum": ("public enum Status { Pending, Active, Completed }", "Status"), + "abstract_class": ("public abstract class BaseEntity { public abstract void Validate(); }", "BaseEntity"), + "record": ("public record User(string Name, int Age);", "User"), + "method_public": ("public class Calc { public int Add(int a, int b) { return a + b; } }", "Add"), + "method_private": ("public class Svc { private void Helper() { } }", "Helper"), + "method_static": ("public class Utils { public static string Format(string s) { return s; } }", "Format"), + "method_async": ("public class Api { public async Task GetUserAsync(int id) { return null; } }", "GetUserAsync"), + "constructor": ("public class Person { public Person(string name) { } public string Name { get; } }", "Person", "Person"), + "import_using": ("using System;", 1), + "import_multiple": ("using System;\nusing System.Collections.Generic;\nusing System.Linq;", 3), + "import_alias": ("using Console = System.Console;", 1), + "namespace_class": ("namespace MyApp.Services { public class UserService { } }", "UserService"), + "file_scoped_ns": ("namespace MyApp.Services;\npublic class UserService { }", "UserService"), + "attributes": ('[Serializable]\n[Table("users")]\npublic class User { }', "User"), + "partial_class": ("public partial class User { public string Name { get; set; } }", "User"), + }, + "python": { + "class": ("class User: pass", "User", "class"), + "class_inheritance": ("class Admin(User, PermissionMixin): pass", "Admin", {"User", "PermissionMixin"}), + "class_docstring": ('class Service:\n """A service class."""\n pass', "Service", "service"), + "decorated_class": ("@dataclass\nclass Config:\n name: str", "Config"), + "function": ("def process_data(items): return items", "process_data", "function"), + "async_function": ("async def fetch_user(user_id): pass", "fetch_user", True), + "type_alias": ("type UserId = int\ntype Callback = Callable[[int], str]", 2), + "multiple_classes": ("class First: pass\nclass Second: pass\nclass Third: pass", {"First", "Second", "Third"}), + "method_class": ("class Calc:\n def add(self, a, b): return a + b\n def sub(self, a, b): return a - b", ["add", "sub"], "Calc"), + "standalone_fn": ("def helper_function(x): return x * 2", "helper_function", None), + "async_method": ("class Svc:\n async def fetch(self, url): return url", "fetch", True), + "staticmethod": ("class Utils:\n @staticmethod\n def format_date(date): return date", "format_date", "static"), + "classmethod": ("class Factory:\n @classmethod\n def create(cls): return cls()", "create", "classmethod"), + "property": ("class Person:\n @property\n def full_name(self): return self.name", "full_name", "property"), + "import_simple": ("import os", "os", False), + "import_alias": ("import numpy as np", "numpy", "np"), + "import_from": ("from pathlib import Path", "pathlib", ["Path"]), + "import_multiple": ("import os\nimport sys\nimport json\nfrom pathlib import Path\nfrom typing import Any", 5), + "nested_class": ("class Outer:\n class Inner:\n def method(self): pass", "Outer"), + "unicode": ("class Données:\n def traiter(self): pass", "Données"), + }, +} + + +# ============================================================================= +# TYPE EXTRACTION TESTS +# ============================================================================= + + +class TestTypeExtraction: + """Tests for type extraction across all languages.""" + + @pytest.mark.parametrize("language", ["java", "javascript", "csharp", "python"]) + def test_extracts_class(self, manager, language): + """Should extract basic class definition.""" + source, name, kind = SAMPLES[language]["class"] + types = manager.extract_types(source, language=language) + + assert len(types) >= 1 + target = next((t for t in types if t.name == name), None) + assert target is not None, f"Expected to find {name} in {[t.name for t in types]}" + assert target.kind == kind + + @pytest.mark.parametrize("language", ["java", "javascript", "csharp", "python"]) + def test_extracts_class_with_inheritance(self, manager, language): + """Should extract class with inheritance/extends.""" + sample = SAMPLES[language].get("class_inheritance") + if not sample: + pytest.skip(f"No inheritance sample for {language}") + assert sample is not None # For type checker after pytest.skip + + source, name = sample[0], sample[1] + types = manager.extract_types(source, language=language) + + target = next((t for t in types if t.name == name), None) + assert target is not None, f"Expected to find {name}" + # Python has bases attribute to check + if language == "python" and len(sample) > 2: + assert set(target.bases) == sample[2] + + @pytest.mark.parametrize( + "language,sample_key", + [ + ("java", "interface"), + ("csharp", "interface"), + ], + ) + def test_extracts_interface(self, manager, language, sample_key): + """Should extract interface definition.""" + source, name, kind = SAMPLES[language][sample_key] + types = manager.extract_types(source, language=language) + + target = next((t for t in types if t.name == name), None) + assert target is not None + assert target.kind == kind + + @pytest.mark.parametrize("language", ["java", "csharp"]) + def test_extracts_enum(self, manager, language): + """Should extract enum definition.""" + source, name = SAMPLES[language]["enum"][:2] + types = manager.extract_types(source, language=language) + + target = next((t for t in types if t.name == name), None) + assert target is not None + + @pytest.mark.parametrize("language", ["java", "csharp"]) + def test_extracts_abstract_class(self, manager, language): + """Should extract abstract class.""" + source, name = SAMPLES[language]["abstract_class"][:2] + types = manager.extract_types(source, language=language) + + target = next((t for t in types if t.name == name), None) + assert target is not None + + @pytest.mark.parametrize("language", ["java", "csharp"]) + def test_extracts_record(self, manager, language): + """Should extract record type.""" + source, name = SAMPLES[language]["record"][:2] + types = manager.extract_types(source, language=language) + + target = next((t for t in types if t.name == name), None) + assert target is not None + + def test_csharp_extracts_struct(self, manager): + """Should extract C# struct.""" + source, name = SAMPLES["csharp"]["struct"][:2] + types = manager.extract_types(source, language="csharp") + + target = next((t for t in types if t.name == name), None) + assert target is not None + + @pytest.mark.parametrize( + "language,sample_key", + [ + ("javascript", "function"), + ("javascript", "async_function"), + ("javascript", "arrow_function"), + ("python", "function"), + ("python", "async_function"), + ], + ) + def test_extracts_function_as_type(self, manager, language, sample_key): + """Should extract top-level functions as types.""" + sample = SAMPLES[language][sample_key] + source, name = sample[0], sample[1] + types = manager.extract_types(source, language=language) + + target = next((t for t in types if t.name == name), None) + assert target is not None + # Check async flag if provided + if len(sample) > 2 and sample[2] is True: + assert target.is_async is True + + def test_python_extracts_type_alias(self, manager): + """Should extract Python type aliases.""" + source, count = SAMPLES["python"]["type_alias"][:2] + types = manager.extract_types(source, language="python") + aliases = [t for t in types if t.kind == "type_alias"] + + assert len(aliases) == count + + def test_python_extracts_multiple_classes(self, manager): + """Should extract multiple class definitions.""" + source, expected_names = SAMPLES["python"]["multiple_classes"][:2] + types = manager.extract_types(source, language="python") + + assert {t.name for t in types} == expected_names + + +# ============================================================================= +# METHOD EXTRACTION TESTS +# ============================================================================= + + +class TestMethodExtraction: + """Tests for method extraction across all languages.""" + + @pytest.mark.parametrize( + "language,sample_key", + [ + ("java", "method_public"), + ("javascript", "method_class"), + ("csharp", "method_public"), + ("python", "method_class"), + ], + ) + def test_extracts_class_methods(self, manager, language, sample_key): + """Should extract methods from a class.""" + sample = SAMPLES[language][sample_key] + source = sample[0] + expected = sample[1] # Either a name or list of names + methods = manager.extract_methods(source, language=language) + + if isinstance(expected, list): + names = {m.name for m in methods} + assert all(e in names for e in expected), f"Expected {expected} in {names}" + else: + target = next((m for m in methods if m.name == expected), None) + assert target is not None + + @pytest.mark.parametrize("language", ["java", "csharp"]) + def test_extracts_private_method(self, manager, language): + """Should extract private methods.""" + source, name = SAMPLES[language]["method_private"][:2] + methods = manager.extract_methods(source, language=language) + + target = next((m for m in methods if m.name == name), None) + assert target is not None + + @pytest.mark.parametrize("language", ["java", "javascript", "csharp"]) + def test_extracts_static_method(self, manager, language): + """Should extract static methods.""" + sample_key = "method_static" if language != "javascript" else "static_method" + source, name = SAMPLES[language][sample_key][:2] + methods = manager.extract_methods(source, language=language) + + target = next((m for m in methods if m.name == name), None) + assert target is not None + + @pytest.mark.parametrize( + "language,sample_key", + [ + ("java", "constructor"), + ("javascript", "constructor"), + ("csharp", "constructor"), + ], + ) + def test_extracts_constructor(self, manager, language, sample_key): + """Should extract constructor.""" + sample = SAMPLES[language][sample_key] + source, name = sample[0], sample[1] + methods = manager.extract_methods(source, language=language) + + target = next((m for m in methods if m.name == name), None) + assert target is not None + # Check class_name if provided + if len(sample) > 2: + assert target.class_name == sample[2] + + @pytest.mark.parametrize( + "language,sample_key", + [ + ("javascript", "async_method"), + ("csharp", "method_async"), + ("python", "async_method"), + ], + ) + def test_extracts_async_method(self, manager, language, sample_key): + """Should detect async methods.""" + sample = SAMPLES[language][sample_key] + source, name = sample[0], sample[1] + methods = manager.extract_methods(source, language=language) + + target = next((m for m in methods if m.name == name), None) + assert target is not None + if language == "python" or language == "javascript": + assert target.is_async is True + + def test_java_extracts_overloaded_methods(self, manager): + """Should extract overloaded methods in Java.""" + source, name, count = SAMPLES["java"]["overloaded"] + methods = manager.extract_methods(source, language="java") + + matching = [m for m in methods if m.name == name] + assert len(matching) == count + + def test_python_extracts_standalone_function(self, manager): + """Should extract standalone function with no class.""" + source, name, class_name = SAMPLES["python"]["standalone_fn"] + methods = manager.extract_methods(source, language="python") + + target = next((m for m in methods if m.name == name), None) + assert target is not None + assert target.class_name is class_name + + @pytest.mark.parametrize( + "sample_key,expected_attr", + [ + ("staticmethod", "is_static"), + ("classmethod", "is_classmethod"), + ("property", "is_property"), + ], + ) + def test_python_extracts_decorated_methods(self, manager, sample_key, expected_attr): + """Should detect Python method decorators.""" + source, name = SAMPLES["python"][sample_key][:2] + methods = manager.extract_methods(source, language="python") + + target = next((m for m in methods if m.name == name), None) + assert target is not None + assert getattr(target, expected_attr) is True + + +# ============================================================================= +# IMPORT EXTRACTION TESTS +# ============================================================================= + + +class TestImportExtraction: + """Tests for import extraction across all languages.""" + + @pytest.mark.parametrize( + "language,sample_key,expected_count", + [ + ("java", "import_single", 1), + ("java", "import_wildcard", 1), + ("java", "import_multiple", 3), + ("javascript", "import_named", 1), + ("javascript", "import_namespace", 1), + ("csharp", "import_using", 1), + ("csharp", "import_multiple", 3), + ("csharp", "import_alias", 1), + ("python", "import_multiple", 5), + ], + ) + def test_extracts_imports_count(self, manager, language, sample_key, expected_count): + """Should extract correct number of imports.""" + source = SAMPLES[language][sample_key][0] + imports = manager.extract_imports(source, language=language) + + assert len(imports) >= expected_count + + def test_python_extracts_simple_import(self, manager): + """Should extract simple Python import.""" + source, module, is_from = SAMPLES["python"]["import_simple"] + imports = manager.extract_imports(source, language="python") + + assert len(imports) == 1 + assert imports[0].module == module + assert imports[0].is_from_import is is_from + + def test_python_extracts_import_with_alias(self, manager): + """Should extract Python import with alias.""" + source, module, alias = SAMPLES["python"]["import_alias"] + imports = manager.extract_imports(source, language="python") + + assert imports[0].module == module + assert imports[0].alias == alias + + def test_python_extracts_from_import(self, manager): + """Should extract Python from-import.""" + source, module, names = SAMPLES["python"]["import_from"] + imports = manager.extract_imports(source, language="python") + + assert imports[0].module == module + assert all(n in imports[0].names for n in names) + assert imports[0].is_from_import is True + + def test_javascript_extracts_default_import(self, manager): + """Should extract JavaScript default import.""" + source, module = SAMPLES["javascript"]["import_default"][:2] + imports = manager.extract_imports(source, language="javascript") + + assert len(imports) >= 1 + assert any(module in i.module for i in imports) + + def test_javascript_extracts_require(self, manager): + """Should extract CommonJS require statements.""" + source, module = SAMPLES["javascript"]["import_require"][:2] + imports = manager.extract_imports(source, language="javascript") + + modules = [i.module for i in imports] + assert any(module in m for m in modules) + + +# ============================================================================= +# EDGE CASES +# ============================================================================= + + +class TestEdgeCases: + """Tests for edge cases across all languages.""" + + @pytest.mark.parametrize("language", ["java", "javascript", "csharp", "python"]) + def test_handles_empty_file(self, manager, language): + """Should handle empty file without error.""" + types = manager.extract_types("", language=language) + assert types == [] + + methods = manager.extract_methods("", language=language) + assert methods == [] + + imports = manager.extract_imports("", language=language) + assert imports == [] + + @pytest.mark.parametrize( + "language,sample_key", + [ + ("java", "annotated_class"), + ("csharp", "attributes"), + ("python", "decorated_class"), + ], + ) + def test_handles_decorated_annotated_classes(self, manager, language, sample_key): + """Should handle classes with decorators/annotations/attributes.""" + source, name = SAMPLES[language][sample_key][:2] + types = manager.extract_types(source, language=language) + + target = next((t for t in types if t.name == name), None) + assert target is not None + + @pytest.mark.parametrize( + "language,sample_key", + [ + ("csharp", "namespace_class"), + ("csharp", "file_scoped_ns"), + ], + ) + def test_handles_namespaces(self, manager, language, sample_key): + """Should handle namespace declarations.""" + source, name = SAMPLES[language][sample_key][:2] + types = manager.extract_types(source, language=language) + + target = next((t for t in types if t.name == name), None) + assert target is not None + + def test_java_handles_inner_class(self, manager): + """Should handle Java inner classes.""" + source, name = SAMPLES["java"]["inner_class"][:2] + types = manager.extract_types(source, language="java") + + # Should extract at least the outer class + assert len(types) >= 1 + + def test_javascript_handles_jsx(self, manager): + """Should handle JSX syntax.""" + source, name, kind = SAMPLES["javascript"]["jsx_function"] + types = manager.extract_types(source, language="javascript") + + target = next((t for t in types if t.name == name), None) + assert target is not None + assert target.kind == kind + + def test_javascript_handles_private_fields(self, manager): + """Should handle private class fields.""" + source, name, class_name = SAMPLES["javascript"]["private_fields"] + methods = manager.extract_methods(source, language="javascript") + + target = next((m for m in methods if m.name == name), None) + assert target is not None + assert target.class_name == class_name + + def test_csharp_handles_partial_class(self, manager): + """Should handle partial classes.""" + source, name = SAMPLES["csharp"]["partial_class"][:2] + types = manager.extract_types(source, language="csharp") + + target = next((t for t in types if t.name == name), None) + assert target is not None + + def test_python_handles_nested_class(self, manager): + """Should handle nested class definitions.""" + source, name = SAMPLES["python"]["nested_class"][:2] + types = manager.extract_types(source, language="python") + + assert name in {t.name for t in types} + + def test_python_handles_unicode_identifiers(self, manager): + """Should handle unicode in identifiers.""" + source, name = SAMPLES["python"]["unicode"][:2] + types = manager.extract_types(source, language="python") + + assert len(types) >= 1 + + def test_python_handles_syntax_errors(self, manager): + """Should handle syntax errors gracefully.""" + source = "class Incomplete {\n def broken" + types = manager.extract_types(source, language="python") + assert isinstance(types, list) + + def test_python_handles_comments_only(self, manager): + """Should handle file with only comments.""" + types = manager.extract_types("# Comment\n# Another", language="python") + assert types == [] diff --git a/tests/test_adapters/treesitter/test_manager.py b/tests/test_adapters/treesitter/test_manager.py index 756684d..2c9f358 100644 --- a/tests/test_adapters/treesitter/test_manager.py +++ b/tests/test_adapters/treesitter/test_manager.py @@ -112,71 +112,65 @@ def test_includes_known_languages(self): class TestExtractTypes: """Tests for extract_types method.""" - def test_extracts_python_class(self): + def test_extracts_python_class(self, treesitter_manager): """Should extract Python class definitions.""" - manager = TreeSitterManager() source = ''' class MyClass: """A sample class.""" pass ''' - types = manager.extract_types(source, language="python") + types = treesitter_manager.extract_types(source, language="python") assert len(types) == 1 assert types[0].name == "MyClass" assert types[0].kind == "class" - def test_extracts_python_function(self): + def test_extracts_python_function(self, treesitter_manager): """Should extract Python function definitions.""" - manager = TreeSitterManager() source = ''' def my_function(): """A sample function.""" pass ''' - types = manager.extract_types(source, language="python") + types = treesitter_manager.extract_types(source, language="python") assert len(types) == 1 assert types[0].name == "my_function" assert types[0].kind == "function" - def test_extracts_class_with_bases(self): + def test_extracts_class_with_bases(self, treesitter_manager): """Should extract class inheritance.""" - manager = TreeSitterManager() source = """ class Child(Parent): pass """ - types = manager.extract_types(source, language="python") + types = treesitter_manager.extract_types(source, language="python") assert len(types) == 1 assert types[0].name == "Child" assert "Parent" in types[0].bases - def test_returns_empty_for_unknown_language(self): + def test_returns_empty_for_unknown_language(self, treesitter_manager): """Should return empty list for unknown language.""" - manager = TreeSitterManager() source = "class MyClass: pass" - types = manager.extract_types(source, language="unknown") + types = treesitter_manager.extract_types(source, language="unknown") assert types == [] - def test_returns_empty_for_no_language(self): + def test_returns_empty_for_no_language(self, treesitter_manager): """Should return empty list when language cannot be determined.""" - manager = TreeSitterManager() source = "class MyClass: pass" - types = manager.extract_types(source) + types = treesitter_manager.extract_types(source) assert types == [] - def test_detects_language_from_path(self): + def test_detects_language_from_path(self, treesitter_manager): """Should detect language from file path.""" - manager = TreeSitterManager() source = """ class MyClass: pass """ - types = manager.extract_types(source, file_path="test.py") + types = treesitter_manager.extract_types(source, file_path="test.py") assert len(types) == 1 assert types[0].name == "MyClass" @@ -185,78 +179,72 @@ class MyClass: class TestExtractMethods: """Tests for extract_methods method.""" - def test_extracts_class_methods(self): + def test_extracts_class_methods(self, treesitter_manager): """Should extract methods from classes.""" - manager = TreeSitterManager() source = """ class MyClass: def my_method(self): pass """ - methods = manager.extract_methods(source, language="python") + methods = treesitter_manager.extract_methods(source, language="python") assert len(methods) >= 1 method_names = [m.name for m in methods] assert "my_method" in method_names - def test_extracts_standalone_functions(self): + def test_extracts_standalone_functions(self, treesitter_manager): """Should extract standalone functions.""" - manager = TreeSitterManager() source = """ def standalone(): pass """ - methods = manager.extract_methods(source, language="python") + methods = treesitter_manager.extract_methods(source, language="python") assert len(methods) == 1 assert methods[0].name == "standalone" assert methods[0].class_name is None - def test_returns_empty_for_unknown_language(self): + def test_returns_empty_for_unknown_language(self, treesitter_manager): """Should return empty list for unknown language.""" - manager = TreeSitterManager() source = "def test(): pass" - methods = manager.extract_methods(source, language="unknown") + methods = treesitter_manager.extract_methods(source, language="unknown") assert methods == [] class TestExtractImports: """Tests for extract_imports method.""" - def test_extracts_simple_import(self): + def test_extracts_simple_import(self, treesitter_manager): """Should extract simple import statements.""" - manager = TreeSitterManager() source = """ import os """ - imports = manager.extract_imports(source, language="python") + imports = treesitter_manager.extract_imports(source, language="python") assert len(imports) == 1 assert imports[0].module == "os" - def test_extracts_from_import(self): + def test_extracts_from_import(self, treesitter_manager): """Should extract from import statements.""" - manager = TreeSitterManager() source = """ from pathlib import Path """ - imports = manager.extract_imports(source, language="python") + imports = treesitter_manager.extract_imports(source, language="python") assert len(imports) == 1 assert imports[0].module == "pathlib" assert "Path" in imports[0].names assert imports[0].is_from_import is True - def test_extracts_multiple_imports(self): + def test_extracts_multiple_imports(self, treesitter_manager): """Should extract multiple import statements.""" - manager = TreeSitterManager() source = """ import os import sys from typing import List, Dict """ - imports = manager.extract_imports(source, language="python") + imports = treesitter_manager.extract_imports(source, language="python") assert len(imports) == 3 modules = [i.module for i in imports] @@ -264,21 +252,19 @@ def test_extracts_multiple_imports(self): assert "sys" in modules assert "typing" in modules - def test_returns_empty_for_unknown_language(self): + def test_returns_empty_for_unknown_language(self, treesitter_manager): """Should return empty list for unknown language.""" - manager = TreeSitterManager() source = "import os" - imports = manager.extract_imports(source, language="unknown") + imports = treesitter_manager.extract_imports(source, language="unknown") assert imports == [] class TestExtractAll: """Tests for extract_all method.""" - def test_extracts_all_elements(self): + def test_extracts_all_elements(self, treesitter_manager): """Should extract types, methods, and imports.""" - manager = TreeSitterManager() source = """ import os @@ -286,7 +272,7 @@ class MyClass: def my_method(self): pass """ - result = manager.extract_all(source, language="python") + result = treesitter_manager.extract_all(source, language="python") assert "types" in result assert "methods" in result @@ -295,12 +281,11 @@ def my_method(self): assert len(result["methods"]) >= 1 assert len(result["imports"]) == 1 - def test_returns_empty_dicts_for_unknown_language(self): + def test_returns_empty_dicts_for_unknown_language(self, treesitter_manager): """Should return empty lists for unknown language.""" - manager = TreeSitterManager() source = "code" - result = manager.extract_all(source, language="unknown") + result = treesitter_manager.extract_all(source, language="unknown") assert result["types"] == [] assert result["methods"] == [] @@ -310,36 +295,28 @@ def test_returns_empty_dicts_for_unknown_language(self): class TestResolveLanguage: """Tests for _resolve_language private method.""" - def test_explicit_language_takes_precedence(self): + def test_explicit_language_takes_precedence(self, treesitter_manager): """Should use explicit language over file path.""" - manager = TreeSitterManager() - # Even though file is .js, explicit language should be used - lang = manager._resolve_language("test.js", "python") + lang = treesitter_manager._resolve_language("test.js", "python") assert lang == "python" - def test_uses_file_path_when_no_explicit(self): + def test_uses_file_path_when_no_explicit(self, treesitter_manager): """Should use file path when no explicit language.""" - manager = TreeSitterManager() - - lang = manager._resolve_language("test.py", None) + lang = treesitter_manager._resolve_language("test.py", None) assert lang == "python" - def test_returns_none_when_no_info(self): + def test_returns_none_when_no_info(self, treesitter_manager): """Should return None when no language info available.""" - manager = TreeSitterManager() - - lang = manager._resolve_language(None, None) + lang = treesitter_manager._resolve_language(None, None) assert lang is None - def test_typescript_maps_to_javascript(self): + def test_typescript_maps_to_javascript(self, treesitter_manager): """Should map TypeScript to JavaScript.""" - manager = TreeSitterManager() - - lang = manager._resolve_language(None, "typescript") + lang = treesitter_manager._resolve_language(None, "typescript") assert lang == "javascript" - lang = manager._resolve_language("test.ts", None) + lang = treesitter_manager._resolve_language("test.ts", None) assert lang == "javascript" @@ -348,6 +325,7 @@ class TestParserCaching: def test_parser_is_cached(self): """Should cache parsers for reuse.""" + # Need fresh instance to test caching manager = TreeSitterManager() source = "class Test: pass" diff --git a/tests/test_adapters/treesitter/test_python.py b/tests/test_adapters/treesitter/test_python.py deleted file mode 100644 index 812ef6e..0000000 --- a/tests/test_adapters/treesitter/test_python.py +++ /dev/null @@ -1,262 +0,0 @@ -"""Tests for Python language extractor.""" - -from __future__ import annotations - -import pytest - -from deriva.adapters.treesitter import TreeSitterManager - - -@pytest.fixture -def manager(): - """Provide a TreeSitterManager for tests.""" - return TreeSitterManager() - - -class TestPythonTypes: - """Tests for Python type extraction.""" - - def test_extracts_simple_class(self, manager): - """Should extract basic class definition.""" - types = manager.extract_types("class User: pass", language="python") - - assert len(types) == 1 - assert types[0].name == "User" - assert types[0].kind == "class" - - def test_extracts_class_with_inheritance(self, manager): - """Should extract class with base classes.""" - types = manager.extract_types("class Admin(User, PermissionMixin): pass", language="python") - - assert len(types) == 1 - assert types[0].name == "Admin" - assert set(types[0].bases) == {"User", "PermissionMixin"} - - def test_extracts_class_docstring(self, manager): - """Should extract class docstring.""" - source = 'class Service:\n """A service class."""\n pass' - types = manager.extract_types(source, language="python") - - assert types[0].docstring is not None - assert "service" in types[0].docstring.lower() - - def test_extracts_decorated_class(self, manager): - """Should extract decorated class.""" - source = "@dataclass\nclass Config:\n name: str" - types = manager.extract_types(source, language="python") - - assert len(types) == 1 - assert types[0].name == "Config" - - def test_extracts_top_level_function(self, manager): - """Should extract top-level function as type.""" - types = manager.extract_types("def process_data(items): return items", language="python") - - assert len(types) == 1 - assert types[0].name == "process_data" - assert types[0].kind == "function" - - def test_extracts_async_function(self, manager): - """Should detect async functions.""" - types = manager.extract_types("async def fetch_user(user_id): pass", language="python") - - assert len(types) == 1 - assert types[0].name == "fetch_user" - assert types[0].is_async is True - - def test_extracts_decorated_function(self, manager): - """Should extract decorated top-level function.""" - types = manager.extract_types("@cache\ndef expensive_operation(): return 1", language="python") - - assert len(types) == 1 - assert types[0].name == "expensive_operation" - - def test_extracts_type_alias(self, manager): - """Should extract Python 3.12+ type aliases.""" - source = "type UserId = int\ntype Callback = Callable[[int], str]" - types = manager.extract_types(source, language="python") - type_aliases = [t for t in types if t.kind == "type_alias"] - - assert len(type_aliases) == 2 - - def test_extracts_multiple_classes(self, manager): - """Should extract multiple class definitions.""" - source = "class First: pass\nclass Second: pass\nclass Third: pass" - types = manager.extract_types(source, language="python") - - assert {t.name for t in types} == {"First", "Second", "Third"} - - -class TestPythonMethods: - """Tests for Python method extraction.""" - - def test_extracts_class_methods(self, manager): - """Should extract methods from class with correct class_name.""" - source = "class Calculator:\n def add(self, a, b): return a + b\n def subtract(self, a, b): return a - b" - methods = manager.extract_methods(source, language="python") - - assert {m.name for m in methods} == {"add", "subtract"} - assert all(m.class_name == "Calculator" for m in methods) - - def test_extracts_standalone_function(self, manager): - """Should extract standalone function with no class.""" - methods = manager.extract_methods("def helper_function(x): return x * 2", language="python") - - assert len(methods) == 1 - assert methods[0].name == "helper_function" - assert methods[0].class_name is None - - def test_extracts_method_parameters(self, manager): - """Should extract method parameters.""" - source = "class Service:\n def process(self, data: list, config: dict = None) -> bool: pass" - methods = manager.extract_methods(source, language="python") - - assert len(methods) == 1 - assert len(methods[0].parameters) >= 1 - - def test_extracts_return_annotation(self, manager): - """Should extract return type annotation.""" - source = "class Repository:\n def find_by_id(self, id: int) -> User: pass" - methods = manager.extract_methods(source, language="python") - - assert methods[0].return_annotation is not None - - def test_extracts_async_method(self, manager): - """Should detect async methods.""" - source = "class AsyncService:\n async def fetch(self, url): return url" - methods = manager.extract_methods(source, language="python") - - assert methods[0].is_async is True - - def test_extracts_staticmethod(self, manager): - """Should detect @staticmethod decorator.""" - source = "class Utils:\n @staticmethod\n def format_date(date): return date" - methods = manager.extract_methods(source, language="python") - - assert methods[0].is_static is True - - def test_extracts_classmethod(self, manager): - """Should detect @classmethod decorator.""" - source = "class Factory:\n @classmethod\n def create(cls): return cls()" - methods = manager.extract_methods(source, language="python") - - assert methods[0].is_classmethod is True - - def test_extracts_property(self, manager): - """Should detect @property decorator.""" - source = "class Person:\n @property\n def full_name(self): return self.name" - methods = manager.extract_methods(source, language="python") - - assert methods[0].is_property is True - - def test_extracts_method_docstring(self, manager): - """Should extract method docstring.""" - source = 'class Service:\n def process(self, data):\n """Process the data."""\n return data' - methods = manager.extract_methods(source, language="python") - - assert "Process" in methods[0].docstring - - def test_extracts_decorated_standalone_function(self, manager): - """Should extract decorated top-level functions as methods.""" - source = "@cache\ndef expensive_computation(data): return data" - methods = manager.extract_methods(source, language="python") - - assert methods[0].name == "expensive_computation" - assert methods[0].class_name is None - - -class TestPythonImports: - """Tests for Python import extraction.""" - - def test_extracts_simple_import(self, manager): - """Should extract simple import statement.""" - imports = manager.extract_imports("import os", language="python") - - assert len(imports) == 1 - assert imports[0].module == "os" - assert imports[0].is_from_import is False - - def test_extracts_import_with_alias(self, manager): - """Should extract import with alias.""" - imports = manager.extract_imports("import numpy as np", language="python") - - assert imports[0].module == "numpy" - assert imports[0].alias == "np" - - def test_extracts_from_import(self, manager): - """Should extract from-import statement.""" - imports = manager.extract_imports("from pathlib import Path", language="python") - - assert imports[0].module == "pathlib" - assert "Path" in imports[0].names - assert imports[0].is_from_import is True - - def test_extracts_multiple_names_from_import(self, manager): - """Should extract multiple names from single import.""" - imports = manager.extract_imports("from typing import List, Dict, Optional", language="python") - - assert imports[0].module == "typing" - assert len(imports[0].names) >= 1 - - def test_extracts_star_import(self, manager): - """Should extract wildcard import.""" - imports = manager.extract_imports("from module import *", language="python") - - assert imports[0].module == "module" - assert imports[0].is_from_import is True - - def test_extracts_dotted_import(self, manager): - """Should extract dotted module path.""" - imports = manager.extract_imports("import os.path\nfrom collections.abc import Mapping", language="python") - - assert len(imports) == 2 - modules = {i.module for i in imports} - assert "os.path" in modules or "collections.abc" in modules - - def test_extracts_relative_import(self, manager): - """Should extract relative imports.""" - source = "from . import utils\nfrom ..core import base" - imports = manager.extract_imports(source, language="python") - - assert len(imports) >= 1 - - def test_extracts_multiple_import_statements(self, manager): - """Should extract all import statements.""" - source = "import os\nimport sys\nimport json\nfrom pathlib import Path\nfrom typing import Any" - imports = manager.extract_imports(source, language="python") - - assert len(imports) == 5 - - -class TestPythonEdgeCases: - """Tests for edge cases in Python extraction.""" - - def test_handles_empty_file(self, manager): - """Should handle empty file.""" - assert manager.extract_types("", language="python") == [] - assert manager.extract_methods("", language="python") == [] - assert manager.extract_imports("", language="python") == [] - - def test_handles_comments_only(self, manager): - """Should handle file with only comments.""" - assert manager.extract_types("# This is a comment\n# Another comment", language="python") == [] - - def test_handles_nested_classes(self, manager): - """Should handle nested class definitions.""" - source = "class Outer:\n class Inner:\n def method(self): pass" - types = manager.extract_types(source, language="python") - - assert "Outer" in {t.name for t in types} - - def test_handles_unicode_identifiers(self, manager): - """Should handle unicode in identifiers.""" - source = "class Données:\n def traiter(self): pass" - types = manager.extract_types(source, language="python") - - assert len(types) >= 1 - - def test_handles_syntax_errors_gracefully(self, manager): - """Should handle syntax errors without crashing.""" - source = "class Incomplete {\n def broken" - types = manager.extract_types(source, language="python") - assert isinstance(types, list) diff --git a/tests/test_cli/test_cli.py b/tests/test_cli/test_cli.py index 833fa1a..dbc0cbd 100644 --- a/tests/test_cli/test_cli.py +++ b/tests/test_cli/test_cli.py @@ -1,94 +1,57 @@ -"""Tests for cli.cli module.""" +"""Tests for cli.cli module (typer-based).""" from __future__ import annotations -import argparse from unittest.mock import MagicMock, patch -from deriva.cli.cli import ( - _get_run_stats_from_ocel, +from typer.testing import CliRunner + +from deriva.cli.cli import app, main +from deriva.cli.commands.benchmark import _get_run_stats_from_ocel +from deriva.cli.commands.run import ( _print_derivation_result, _print_extraction_result, _print_pipeline_result, - cmd_clear, - cmd_config_disable, - cmd_config_enable, - cmd_config_list, - cmd_config_show, - cmd_config_versions, - cmd_export, - cmd_repo_clone, - cmd_repo_delete, - cmd_repo_info, - cmd_repo_list, - cmd_status, - create_parser, - main, ) +runner = CliRunner() + + +class TestCLIHelp: + """Tests for CLI help commands.""" + + def test_main_help(self): + """Should show main help.""" + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + assert "Deriva CLI" in result.stdout -class TestCreateParser: - """Tests for argument parser creation.""" - - def test_creates_parser(self): - """Should create argument parser.""" - parser = create_parser() - assert isinstance(parser, argparse.ArgumentParser) - - def test_run_command_exists(self): - """Should parse run command.""" - parser = create_parser() - args = parser.parse_args(["run", "extraction"]) - assert args.command == "run" - assert args.stage == "extraction" - - def test_run_command_with_options(self): - """Should parse run command with options.""" - parser = create_parser() - args = parser.parse_args(["run", "derivation", "--phase", "prep", "-v"]) - assert args.stage == "derivation" - assert args.phase == "prep" - assert args.verbose is True - - def test_config_list_command(self): - """Should parse config list command.""" - parser = create_parser() - args = parser.parse_args(["config", "list", "extraction"]) - assert args.command == "config" - assert args.config_action == "list" - assert args.step_type == "extraction" - - def test_repo_clone_command(self): - """Should parse repo clone command.""" - parser = create_parser() - args = parser.parse_args(["repo", "clone", "https://github.com/user/repo"]) - assert args.command == "repo" - assert args.repo_action == "clone" - assert args.url == "https://github.com/user/repo" - - def test_clear_command(self): - """Should parse clear command.""" - parser = create_parser() - args = parser.parse_args(["clear", "graph"]) - assert args.command == "clear" - assert args.target == "graph" - - def test_export_command_with_options(self): - """Should parse export command with options.""" - parser = create_parser() - args = parser.parse_args(["export", "-o", "out.xml", "-n", "MyModel"]) - assert args.command == "export" - assert args.output == "out.xml" - assert args.name == "MyModel" - - def test_benchmark_run_command(self): - """Should parse benchmark run command.""" - parser = create_parser() - args = parser.parse_args(["benchmark", "run", "--repos", "repo1,repo2", "--models", "gpt4,claude"]) - assert args.command == "benchmark" - assert args.benchmark_action == "run" - assert args.repos == "repo1,repo2" - assert args.models == "gpt4,claude" + def test_run_help(self): + """Should show run command help.""" + result = runner.invoke(app, ["run", "--help"]) + assert result.exit_code == 0 + assert "Pipeline stage" in result.stdout + + def test_config_help(self): + """Should show config command help.""" + result = runner.invoke(app, ["config", "--help"]) + assert result.exit_code == 0 + assert "list" in result.stdout + assert "enable" in result.stdout + + def test_repo_help(self): + """Should show repo command help.""" + result = runner.invoke(app, ["repo", "--help"]) + assert result.exit_code == 0 + assert "clone" in result.stdout + assert "list" in result.stdout + + def test_benchmark_help(self): + """Should show benchmark command help.""" + result = runner.invoke(app, ["benchmark", "--help"]) + assert result.exit_code == 0 + assert "run" in result.stdout + assert "analyze" in result.stdout class TestPrintExtractionResult: @@ -124,6 +87,40 @@ def test_prints_errors(self, capsys): assert "Errors (2)" in output assert "Error 1" in output + def test_prints_warnings(self, capsys): + """Should print warnings when present.""" + result = { + "stats": {}, + "warnings": ["Warning 1", "Warning 2", "Warning 3"], + } + _print_extraction_result(result) + output = capsys.readouterr().out + + assert "Warnings (3)" in output + assert "Warning 1" in output + + def test_truncates_many_warnings(self, capsys): + """Should truncate when more than 5 warnings.""" + result = { + "stats": {}, + "warnings": [f"Warning {i}" for i in range(10)], + } + _print_extraction_result(result) + output = capsys.readouterr().out + + assert "... and 5 more" in output + + def test_truncates_many_errors(self, capsys): + """Should truncate when more than 5 errors.""" + result = { + "stats": {}, + "errors": [f"Error {i}" for i in range(10)], + } + _print_extraction_result(result) + output = capsys.readouterr().out + + assert "... and 5 more" in output + class TestPrintDerivationResult: """Tests for _print_derivation_result helper.""" @@ -160,6 +157,28 @@ def test_prints_issues(self, capsys): assert "Issues (1)" in output assert "[WARNING]" in output + def test_prints_errors(self, capsys): + """Should print errors when present.""" + result = { + "stats": {}, + "errors": ["Error 1", "Error 2"], + } + _print_derivation_result(result) + output = capsys.readouterr().out + + assert "Errors (2)" in output + + def test_truncates_many_issues(self, capsys): + """Should truncate when more than 10 issues.""" + result = { + "stats": {}, + "issues": [{"severity": "warning", "message": f"Issue {i}"} for i in range(15)], + } + _print_derivation_result(result) + output = capsys.readouterr().out + + assert "... and 5 more" in output + class TestPrintPipelineResult: """Tests for _print_pipeline_result helper.""" @@ -206,85 +225,29 @@ def test_prints_total_errors(self, capsys): assert "Total errors: 3" in output -class TestPrintExtractionResultAdditional: - """Additional tests for _print_extraction_result helper.""" - - def test_prints_warnings(self, capsys): - """Should print warnings when present.""" - result = { - "stats": {}, - "warnings": ["Warning 1", "Warning 2", "Warning 3"], - } - _print_extraction_result(result) - output = capsys.readouterr().out - - assert "Warnings (3)" in output - assert "Warning 1" in output - - def test_truncates_many_warnings(self, capsys): - """Should truncate when more than 5 warnings.""" - result = { - "stats": {}, - "warnings": [f"Warning {i}" for i in range(10)], - } - _print_extraction_result(result) - output = capsys.readouterr().out - - assert "... and 5 more" in output - - def test_truncates_many_errors(self, capsys): - """Should truncate when more than 5 errors.""" - result = { - "stats": {}, - "errors": [f"Error {i}" for i in range(10)], - } - _print_extraction_result(result) - output = capsys.readouterr().out - - assert "... and 5 more" in output - - -class TestPrintDerivationResultAdditional: - """Additional tests for _print_derivation_result helper.""" - - def test_prints_errors(self, capsys): - """Should print errors when present.""" - result = { - "stats": {}, - "errors": ["Error 1", "Error 2"], - } - _print_derivation_result(result) - output = capsys.readouterr().out - - assert "Errors (2)" in output - - def test_truncates_many_issues(self, capsys): - """Should truncate when more than 10 issues.""" - result = { - "stats": {}, - "issues": [{"severity": "warning", "message": f"Issue {i}"} for i in range(15)], - } - _print_derivation_result(result) - output = capsys.readouterr().out - - assert "... and 5 more" in output - - class TestMain: """Tests for main entry point.""" - def test_no_command_shows_help(self, capsys): - """Should show help when no command provided.""" - with patch("sys.argv", ["deriva"]): + def test_main_returns_zero_on_success(self): + """Should return 0 on success.""" + with patch("deriva.cli.cli.app") as mock_app: + mock_app.return_value = None result = main() assert result == 0 + def test_main_returns_exit_code(self): + """Should return exit code on SystemExit.""" + with patch("deriva.cli.cli.app") as mock_app: + mock_app.side_effect = SystemExit(1) + result = main() + assert result == 1 + -class TestCmdConfigList: - """Tests for cmd_config_list command.""" +class TestConfigListCommand: + """Tests for config list command.""" - @patch("deriva.cli.cli.PipelineSession") - def test_lists_extraction_configs(self, mock_session_class, capsys): + @patch("deriva.cli.commands.config.PipelineSession") + def test_lists_extraction_configs(self, mock_session_class): """Should list extraction configurations.""" mock_session = MagicMock() mock_session.list_steps.return_value = [ @@ -293,234 +256,109 @@ def test_lists_extraction_configs(self, mock_session_class, capsys): ] mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(step_type="extraction", enabled=False) - result = cmd_config_list(args) + result = runner.invoke(app, ["config", "list", "extraction"]) - assert result == 0 - output = capsys.readouterr().out - assert "EXTRACTION CONFIGURATIONS" in output - assert "BusinessConcept" in output + assert result.exit_code == 0 + assert "EXTRACTION CONFIGURATIONS" in result.stdout + assert "BusinessConcept" in result.stdout - @patch("deriva.cli.cli.PipelineSession") - def test_shows_message_when_no_configs(self, mock_session_class, capsys): + @patch("deriva.cli.commands.config.PipelineSession") + def test_shows_message_when_no_configs(self, mock_session_class): """Should show message when no configurations found.""" mock_session = MagicMock() mock_session.list_steps.return_value = [] mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(step_type="derivation", enabled=False) - result = cmd_config_list(args) - - assert result == 0 - output = capsys.readouterr().out - assert "No derivation configurations found" in output - - -class TestCmdConfigShow: - """Tests for cmd_config_show command.""" - - @patch("deriva.cli.cli.config") - @patch("deriva.cli.cli.PipelineSession") - def test_shows_extraction_config(self, mock_session_class, mock_config, capsys): - """Should show extraction config details.""" - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_config.get_extraction_config.return_value = MagicMock( - node_type="BusinessConcept", - sequence=1, - enabled=True, - input_sources="*.py", - instruction="Extract business concepts", - example="{}", - ) - - args = argparse.Namespace(step_type="extraction", name="BusinessConcept") - result = cmd_config_show(args) - - assert result == 0 - output = capsys.readouterr().out - assert "EXTRACTION CONFIG: BusinessConcept" in output - - @patch("deriva.cli.cli.config") - @patch("deriva.cli.cli.PipelineSession") - def test_shows_derivation_config(self, mock_session_class, mock_config, capsys): - """Should show derivation config details.""" - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_config.get_derivation_config.return_value = MagicMock( - element_type="ApplicationComponent", - sequence=1, - enabled=True, - input_graph_query="MATCH (n) RETURN n", - instruction="Derive components", - ) - - args = argparse.Namespace(step_type="derivation", name="ApplicationComponent") - result = cmd_config_show(args) - - assert result == 0 - output = capsys.readouterr().out - assert "DERIVATION CONFIG: ApplicationComponent" in output - - @patch("deriva.cli.cli.config") - @patch("deriva.cli.cli.PipelineSession") - def test_returns_error_when_config_not_found(self, mock_session_class, mock_config, capsys): - """Should return error when config not found.""" - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_config.get_extraction_config.return_value = None - - args = argparse.Namespace(step_type="extraction", name="NonExistent") - result = cmd_config_show(args) - - assert result == 1 - output = capsys.readouterr().out - assert "not found" in output - - @patch("deriva.cli.cli.PipelineSession") - def test_returns_error_for_unknown_step_type(self, mock_session_class, capsys): - """Should return error for unknown step type.""" - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + result = runner.invoke(app, ["config", "list", "derivation"]) - args = argparse.Namespace(step_type="unknown", name="test") - result = cmd_config_show(args) + assert result.exit_code == 0 + assert "No derivation configurations found" in result.stdout - assert result == 1 - output = capsys.readouterr().out - assert "Unknown step type" in output + def test_invalid_step_type(self): + """Should reject invalid step type.""" + result = runner.invoke(app, ["config", "list", "invalid"]) + assert result.exit_code == 1 + assert "Error" in result.output -class TestCmdConfigEnableDisable: - """Tests for cmd_config_enable and cmd_config_disable commands.""" +class TestConfigEnableDisableCommand: + """Tests for config enable and disable commands.""" - @patch("deriva.cli.cli.PipelineSession") - def test_enable_step_success(self, mock_session_class, capsys): + @patch("deriva.cli.commands.config.PipelineSession") + def test_enable_step_success(self, mock_session_class): """Should enable step successfully.""" mock_session = MagicMock() mock_session.enable_step.return_value = True mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(step_type="extraction", name="BusinessConcept") - result = cmd_config_enable(args) + result = runner.invoke(app, ["config", "enable", "extraction", "BusinessConcept"]) - assert result == 0 - output = capsys.readouterr().out - assert "Enabled" in output + assert result.exit_code == 0 + assert "Enabled" in result.stdout - @patch("deriva.cli.cli.PipelineSession") - def test_enable_step_not_found(self, mock_session_class, capsys): + @patch("deriva.cli.commands.config.PipelineSession") + def test_enable_step_not_found(self, mock_session_class): """Should return error when step not found.""" mock_session = MagicMock() mock_session.enable_step.return_value = False mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(step_type="extraction", name="NonExistent") - result = cmd_config_enable(args) + result = runner.invoke(app, ["config", "enable", "extraction", "NonExistent"]) - assert result == 1 - output = capsys.readouterr().out - assert "not found" in output + assert result.exit_code == 1 + assert "not found" in result.stdout - @patch("deriva.cli.cli.PipelineSession") - def test_disable_step_success(self, mock_session_class, capsys): + @patch("deriva.cli.commands.config.PipelineSession") + def test_disable_step_success(self, mock_session_class): """Should disable step successfully.""" mock_session = MagicMock() mock_session.disable_step.return_value = True mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(step_type="derivation", name="ApplicationComponent") - result = cmd_config_disable(args) - - assert result == 0 - output = capsys.readouterr().out - assert "Disabled" in output - - -class TestCmdConfigVersions: - """Tests for cmd_config_versions command.""" - - @patch("deriva.cli.cli.config") - @patch("deriva.cli.cli.PipelineSession") - def test_shows_active_versions(self, mock_session_class, mock_config, capsys): - """Should show active config versions.""" - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_config.get_active_config_versions.return_value = { - "extraction": {"BusinessConcept": 2, "TypeDefinition": 1}, - "derivation": {"ApplicationComponent": 3}, - } - - args = argparse.Namespace() - result = cmd_config_versions(args) + result = runner.invoke(app, ["config", "disable", "derivation", "ApplicationComponent"]) - assert result == 0 - output = capsys.readouterr().out - assert "ACTIVE CONFIG VERSIONS" in output + assert result.exit_code == 0 + assert "Disabled" in result.stdout -class TestCmdClear: - """Tests for cmd_clear command.""" +class TestClearCommand: + """Tests for clear command.""" @patch("deriva.cli.cli.PipelineSession") - def test_clear_graph_success(self, mock_session_class, capsys): + def test_clear_graph_success(self, mock_session_class): """Should clear graph successfully.""" mock_session = MagicMock() mock_session.clear_graph.return_value = {"success": True, "message": "Graph cleared"} mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(target="graph") - result = cmd_clear(args) + result = runner.invoke(app, ["clear", "graph"]) - assert result == 0 - output = capsys.readouterr().out - assert "Graph cleared" in output + assert result.exit_code == 0 + assert "Graph cleared" in result.stdout @patch("deriva.cli.cli.PipelineSession") - def test_clear_model_success(self, mock_session_class, capsys): + def test_clear_model_success(self, mock_session_class): """Should clear model successfully.""" mock_session = MagicMock() mock_session.clear_model.return_value = {"success": True, "message": "Model cleared"} mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(target="model") - result = cmd_clear(args) - - assert result == 0 - - @patch("deriva.cli.cli.PipelineSession") - def test_clear_unknown_target(self, mock_session_class, capsys): - """Should return error for unknown target.""" - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - - args = argparse.Namespace(target="unknown") - result = cmd_clear(args) - - assert result == 1 - - @patch("deriva.cli.cli.PipelineSession") - def test_clear_failure(self, mock_session_class, capsys): - """Should return error on failure.""" - mock_session = MagicMock() - mock_session.clear_graph.return_value = {"success": False, "error": "Connection failed"} - mock_session_class.return_value.__enter__.return_value = mock_session + result = runner.invoke(app, ["clear", "model"]) - args = argparse.Namespace(target="graph") - result = cmd_clear(args) + assert result.exit_code == 0 - assert result == 1 + def test_clear_invalid_target(self): + """Should reject invalid target.""" + result = runner.invoke(app, ["clear", "invalid"]) + assert result.exit_code == 1 + assert "Error" in result.output -class TestCmdStatus: - """Tests for cmd_status command.""" +class TestStatusCommand: + """Tests for status command.""" @patch("deriva.cli.cli.PipelineSession") - def test_shows_status(self, mock_session_class, capsys): + def test_shows_status(self, mock_session_class): """Should show pipeline status.""" mock_session = MagicMock() mock_session.list_steps.return_value = [{"enabled": True}, {"enabled": False}] @@ -529,16 +367,14 @@ def test_shows_status(self, mock_session_class, capsys): mock_session.get_archimate_stats.return_value = {"total_elements": 50} mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace() - result = cmd_status(args) + result = runner.invoke(app, ["status"]) - assert result == 0 - output = capsys.readouterr().out - assert "DERIVA STATUS" in output - assert "1/2 steps enabled" in output + assert result.exit_code == 0 + assert "DERIVA STATUS" in result.stdout + assert "1/2 steps enabled" in result.stdout @patch("deriva.cli.cli.PipelineSession") - def test_handles_graph_connection_error(self, mock_session_class, capsys): + def test_handles_graph_connection_error(self, mock_session_class): """Should handle graph connection error gracefully.""" mock_session = MagicMock() mock_session.list_steps.return_value = [] @@ -547,19 +383,17 @@ def test_handles_graph_connection_error(self, mock_session_class, capsys): mock_session.get_archimate_stats.side_effect = Exception("Not connected") mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace() - result = cmd_status(args) + result = runner.invoke(app, ["status"]) - assert result == 0 - output = capsys.readouterr().out - assert "(not connected)" in output + assert result.exit_code == 0 + assert "(not connected)" in result.stdout -class TestCmdExport: - """Tests for cmd_export command.""" +class TestExportCommand: + """Tests for export command.""" @patch("deriva.cli.cli.PipelineSession") - def test_export_success(self, mock_session_class, capsys): + def test_export_success(self, mock_session_class): """Should export model successfully.""" mock_session = MagicMock() mock_session.export_model.return_value = { @@ -570,31 +404,28 @@ def test_export_success(self, mock_session_class, capsys): } mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(output="out.xml", name="MyModel", verbose=False) - result = cmd_export(args) + result = runner.invoke(app, ["export", "-o", "out.xml", "-n", "MyModel"]) - assert result == 0 - output = capsys.readouterr().out - assert "Elements exported: 50" in output + assert result.exit_code == 0 + assert "Elements exported: 50" in result.stdout @patch("deriva.cli.cli.PipelineSession") - def test_export_failure(self, mock_session_class, capsys): + def test_export_failure(self, mock_session_class): """Should return error on export failure.""" mock_session = MagicMock() - mock_session.export_model.return_value = {"success": False, "error": "No elements to export"} + mock_session.export_model.return_value = {"success": False, "error": "No elements"} mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(output="out.xml", name="MyModel", verbose=False) - result = cmd_export(args) + result = runner.invoke(app, ["export"]) - assert result == 1 + assert result.exit_code == 1 -class TestCmdRepoClone: - """Tests for cmd_repo_clone command.""" +class TestRepoCommands: + """Tests for repo commands.""" - @patch("deriva.cli.cli.PipelineSession") - def test_clone_success(self, mock_session_class, capsys): + @patch("deriva.cli.commands.repo.PipelineSession") + def test_repo_clone_success(self, mock_session_class): """Should clone repository successfully.""" mock_session = MagicMock() mock_session.clone_repository.return_value = { @@ -605,41 +436,13 @@ def test_clone_success(self, mock_session_class, capsys): } mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace( - url="https://github.com/user/repo", - name=None, - branch=None, - overwrite=False, - ) - result = cmd_repo_clone(args) - - assert result == 0 - output = capsys.readouterr().out - assert "cloned successfully" in output - - @patch("deriva.cli.cli.PipelineSession") - def test_clone_failure(self, mock_session_class, capsys): - """Should return error on clone failure.""" - mock_session = MagicMock() - mock_session.clone_repository.return_value = {"success": False, "error": "Repository not found"} - mock_session_class.return_value.__enter__.return_value = mock_session - - args = argparse.Namespace( - url="https://github.com/user/nonexistent", - name=None, - branch=None, - overwrite=False, - ) - result = cmd_repo_clone(args) - - assert result == 1 - + result = runner.invoke(app, ["repo", "clone", "https://github.com/user/repo"]) -class TestCmdRepoList: - """Tests for cmd_repo_list command.""" + assert result.exit_code == 0 + assert "cloned successfully" in result.stdout - @patch("deriva.cli.cli.PipelineSession") - def test_lists_repositories(self, mock_session_class, capsys): + @patch("deriva.cli.commands.repo.PipelineSession") + def test_repo_list(self, mock_session_class): """Should list repositories.""" mock_session = MagicMock() mock_session.workspace_dir = "/workspace" @@ -649,140 +452,162 @@ def test_lists_repositories(self, mock_session_class, capsys): ] mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(detailed=False) - result = cmd_repo_list(args) + result = runner.invoke(app, ["repo", "list"]) - assert result == 0 - output = capsys.readouterr().out - assert "REPOSITORIES" in output - assert "repo1" in output - assert "Total: 2 repositories" in output + assert result.exit_code == 0 + assert "REPOSITORIES" in result.stdout + assert "repo1" in result.stdout + assert "Total: 2 repositories" in result.stdout - @patch("deriva.cli.cli.PipelineSession") - def test_lists_detailed_repositories(self, mock_session_class, capsys): - """Should list repositories with details.""" + @patch("deriva.cli.commands.repo.PipelineSession") + def test_repo_delete_success(self, mock_session_class): + """Should delete repository successfully.""" mock_session = MagicMock() - mock_session.workspace_dir = "/workspace" - mock_session.get_repositories.return_value = [ - { - "name": "repo1", - "url": "https://github.com/user/repo1", - "branch": "main", - "size_mb": 10.5, - "cloned_at": "2024-01-01", - "is_dirty": False, - }, - ] + mock_session.delete_repository.return_value = {"success": True} mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(detailed=True) - result = cmd_repo_list(args) + result = runner.invoke(app, ["repo", "delete", "my_repo"]) - assert result == 0 - output = capsys.readouterr().out - assert "URL:" in output - assert "Branch:" in output + assert result.exit_code == 0 + assert "deleted successfully" in result.stdout - @patch("deriva.cli.cli.PipelineSession") - def test_shows_message_when_no_repos(self, mock_session_class, capsys): - """Should show message when no repositories.""" + @patch("deriva.cli.commands.repo.PipelineSession") + def test_repo_info(self, mock_session_class): + """Should show repository info.""" mock_session = MagicMock() - mock_session.workspace_dir = "/workspace" - mock_session.get_repositories.return_value = [] + mock_session.get_repository_info.return_value = { + "name": "my_repo", + "path": "/workspace/repos/my_repo", + "url": "https://github.com/user/repo", + "branch": "main", + "last_commit": "abc123", + "is_dirty": False, + "size_mb": 15.5, + "cloned_at": "2024-01-01", + } mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(detailed=False) - result = cmd_repo_list(args) + result = runner.invoke(app, ["repo", "info", "my_repo"]) - assert result == 0 - output = capsys.readouterr().out - assert "No repositories found" in output + assert result.exit_code == 0 + assert "REPOSITORY: my_repo" in result.stdout -class TestCmdRepoDelete: - """Tests for cmd_repo_delete command.""" +class TestRunCommand: + """Tests for run command.""" + @patch("deriva.cli.cli.create_progress_reporter") @patch("deriva.cli.cli.PipelineSession") - def test_delete_success(self, mock_session_class, capsys): - """Should delete repository successfully.""" + def test_run_extraction(self, mock_session_class, mock_progress): + """Should run extraction stage.""" mock_session = MagicMock() - mock_session.delete_repository.return_value = {"success": True} + mock_session.llm_info = {"provider": "openai", "model": "gpt-4"} + mock_session.run_extraction.return_value = { + "success": True, + "stats": {"nodes_created": 100, "edges_created": 50}, + } mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(name="my_repo", force=False) - result = cmd_repo_delete(args) + mock_reporter = MagicMock() + mock_progress.return_value = mock_reporter + mock_reporter.__enter__ = MagicMock(return_value=mock_reporter) + mock_reporter.__exit__ = MagicMock(return_value=False) - assert result == 0 - output = capsys.readouterr().out - assert "deleted successfully" in output + result = runner.invoke(app, ["run", "extraction"]) + + assert result.exit_code == 0 + assert "EXTRACTION" in result.stdout @patch("deriva.cli.cli.PipelineSession") - def test_delete_failure(self, mock_session_class, capsys): - """Should return error on delete failure.""" + def test_run_derivation_without_llm(self, mock_session_class): + """Should return error when running derivation without LLM.""" mock_session = MagicMock() - mock_session.delete_repository.return_value = {"success": False, "error": "Not found"} + mock_session.llm_info = None mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(name="nonexistent", force=False) - result = cmd_repo_delete(args) + result = runner.invoke(app, ["run", "derivation"]) - assert result == 1 + assert result.exit_code == 1 + assert "Error" in result.output - @patch("deriva.cli.cli.PipelineSession") - def test_delete_exception(self, mock_session_class, capsys): - """Should handle exception during delete.""" + def test_run_invalid_stage(self): + """Should reject invalid stage.""" + result = runner.invoke(app, ["run", "invalid"]) + assert result.exit_code == 1 + assert "Error" in result.output + + def test_run_invalid_phase_for_extraction(self): + """Should reject invalid phase for extraction.""" + result = runner.invoke(app, ["run", "extraction", "--phase", "generate"]) + assert result.exit_code == 1 + assert "not valid for extraction" in result.output + + def test_run_invalid_phase_for_derivation(self): + """Should reject invalid phase for derivation.""" + result = runner.invoke(app, ["run", "derivation", "--phase", "classify"]) + assert result.exit_code == 1 + assert "not valid for derivation" in result.output + + +class TestBenchmarkCommands: + """Tests for benchmark commands.""" + + @patch("deriva.cli.commands.benchmark.create_benchmark_progress_reporter") + @patch("deriva.cli.commands.benchmark.PipelineSession") + def test_benchmark_run_success(self, mock_session_class, mock_progress): + """Should run benchmark successfully.""" mock_session = MagicMock() - mock_session.delete_repository.side_effect = Exception("uncommitted changes detected") + mock_result = MagicMock() + mock_result.session_id = "bench_123" + mock_result.runs_completed = 3 + mock_result.runs_failed = 0 + mock_result.duration_seconds = 120.5 + mock_result.ocel_path = "workspace/benchmarks/bench_123/ocel.json" + mock_result.success = True + mock_result.errors = [] + mock_session.run_benchmark.return_value = mock_result mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(name="dirty_repo", force=False) - result = cmd_repo_delete(args) - - assert result == 1 - output = capsys.readouterr().out - assert "--force" in output + mock_reporter = MagicMock() + mock_progress.return_value = mock_reporter + mock_reporter.__enter__ = MagicMock(return_value=mock_reporter) + mock_reporter.__exit__ = MagicMock(return_value=False) + result = runner.invoke(app, ["benchmark", "run", "--repos", "repo1", "--models", "gpt4"]) -class TestCmdRepoInfo: - """Tests for cmd_repo_info command.""" + assert result.exit_code == 0 + assert "BENCHMARK COMPLETE" in result.stdout - @patch("deriva.cli.cli.PipelineSession") - def test_shows_repo_info(self, mock_session_class, capsys): - """Should show repository info.""" + @patch("deriva.cli.commands.benchmark.PipelineSession") + def test_benchmark_list(self, mock_session_class): + """Should list benchmark sessions.""" mock_session = MagicMock() - mock_session.get_repository_info.return_value = { - "name": "my_repo", - "path": "/workspace/repos/my_repo", - "url": "https://github.com/user/repo", - "branch": "main", - "last_commit": "abc123", - "is_dirty": False, - "size_mb": 15.5, - "cloned_at": "2024-01-01", - } + mock_session.list_benchmarks.return_value = [ + { + "session_id": "bench_001", + "status": "completed", + "started_at": "2024-01-01T10:00:00", + "description": "Test run", + }, + ] mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(name="my_repo") - result = cmd_repo_info(args) + result = runner.invoke(app, ["benchmark", "list"]) - assert result == 0 - output = capsys.readouterr().out - assert "REPOSITORY: my_repo" in output - assert "Path:" in output + assert result.exit_code == 0 + assert "bench_001" in result.stdout - @patch("deriva.cli.cli.PipelineSession") - def test_repo_not_found(self, mock_session_class, capsys): - """Should return error when repo not found.""" + @patch("deriva.cli.commands.benchmark.PipelineSession") + def test_benchmark_models_empty(self, mock_session_class): + """Should show message when no models configured.""" mock_session = MagicMock() - mock_session.get_repository_info.return_value = None + mock_session.list_benchmark_models.return_value = {} mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(name="nonexistent") - result = cmd_repo_info(args) + result = runner.invoke(app, ["benchmark", "models"]) - assert result == 1 - output = capsys.readouterr().out - assert "not found" in output + assert result.exit_code == 0 + assert "No benchmark model" in result.stdout class TestGetRunStatsFromOcel: @@ -792,7 +617,6 @@ def test_extracts_stats_from_complete_run_events(self): """Should extract node/edge counts from CompleteRun events.""" mock_analyzer = MagicMock() - # Create mock events event1 = MagicMock() event1.activity = "CompleteRun" event1.objects = {"Model": ["gpt4"]} @@ -845,273 +669,11 @@ def test_skips_events_without_model(self): assert result == {} -class TestCreateParserAdditional: - """Additional tests for argument parser creation.""" - - def test_status_command(self): - """Should parse status command.""" - parser = create_parser() - args = parser.parse_args(["status"]) - assert args.command == "status" - - def test_config_enable_command(self): - """Should parse config enable command.""" - parser = create_parser() - args = parser.parse_args(["config", "enable", "extraction", "BusinessConcept"]) - assert args.config_action == "enable" - assert args.step_type == "extraction" - assert args.name == "BusinessConcept" - - def test_config_disable_command(self): - """Should parse config disable command.""" - parser = create_parser() - args = parser.parse_args(["config", "disable", "derivation", "ApplicationComponent"]) - assert args.config_action == "disable" - assert args.name == "ApplicationComponent" - - def test_config_show_command(self): - """Should parse config show command.""" - parser = create_parser() - args = parser.parse_args(["config", "show", "extraction", "TypeDefinition"]) - assert args.config_action == "show" - assert args.name == "TypeDefinition" - - def test_config_update_command(self): - """Should parse config update command.""" - parser = create_parser() - args = parser.parse_args( - [ - "config", - "update", - "derivation", - "ApplicationComponent", - "-i", - "New instruction", - ] - ) - assert args.config_action == "update" - assert args.instruction == "New instruction" - - def test_config_versions_command(self): - """Should parse config versions command.""" - parser = create_parser() - args = parser.parse_args(["config", "versions"]) - assert args.config_action == "versions" - - def test_repo_list_detailed(self): - """Should parse repo list with detailed flag.""" - parser = create_parser() - args = parser.parse_args(["repo", "list", "-d"]) - assert args.detailed is True - - def test_repo_delete_force(self): - """Should parse repo delete with force flag.""" - parser = create_parser() - args = parser.parse_args(["repo", "delete", "my_repo", "-f"]) - assert args.force is True - - def test_benchmark_list_command(self): - """Should parse benchmark list command.""" - parser = create_parser() - args = parser.parse_args(["benchmark", "list", "-l", "20"]) - assert args.benchmark_action == "list" - assert args.limit == 20 - - def test_benchmark_analyze_command(self): - """Should parse benchmark analyze command.""" - parser = create_parser() - args = parser.parse_args(["benchmark", "analyze", "session_123", "-f", "markdown"]) - assert args.benchmark_action == "analyze" - assert args.session_id == "session_123" - assert args.format == "markdown" - - def test_benchmark_models_command(self): - """Should parse benchmark models command.""" - parser = create_parser() - args = parser.parse_args(["benchmark", "models"]) - assert args.benchmark_action == "models" - - def test_benchmark_deviations_command(self): - """Should parse benchmark deviations command.""" - parser = create_parser() - args = parser.parse_args( - [ - "benchmark", - "deviations", - "session_123", - "-s", - "consistency_score", - ] - ) - assert args.benchmark_action == "deviations" - assert args.sort_by == "consistency_score" - - def test_run_with_quiet_flag(self): - """Should parse run with quiet flag.""" - parser = create_parser() - args = parser.parse_args(["run", "extraction", "-q"]) - assert args.quiet is True - - def test_run_with_no_llm_flag(self): - """Should parse run with no-llm flag.""" - parser = create_parser() - args = parser.parse_args(["run", "extraction", "--no-llm"]) - assert args.no_llm is True - - def test_benchmark_run_with_all_options(self): - """Should parse benchmark run with all options.""" - parser = create_parser() - args = parser.parse_args( - [ - "benchmark", - "run", - "--repos", - "repo1,repo2", - "--models", - "gpt4,claude", - "-n", - "5", - "--stages", - "extraction,derivation", - "-d", - "Test benchmark", - "--no-cache", - "--nocache-configs", - "ApplicationComponent,DataObject", - ] - ) - assert args.runs == 5 - assert args.stages == "extraction,derivation" - assert args.description == "Test benchmark" - assert args.no_cache is True - assert args.nocache_configs == "ApplicationComponent,DataObject" - - -# Additional CLI imports for new tests -from deriva.cli.cli import ( - cmd_config_update, - cmd_filetype_add, - cmd_filetype_delete, - cmd_filetype_list, - cmd_filetype_stats, - cmd_run, -) - - -class TestCmdConfigUpdate: - """Tests for cmd_config_update command.""" - - @patch("deriva.cli.cli.config") - @patch("deriva.cli.cli.PipelineSession") - def test_update_derivation_config_success(self, mock_session_class, mock_config, capsys): - """Should update derivation config successfully.""" - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_config.create_derivation_config_version.return_value = { - "success": True, - "old_version": 1, - "new_version": 2, - } - - args = argparse.Namespace( - step_type="derivation", - name="ApplicationComponent", - instruction="New instruction", - example=None, - instruction_file=None, - example_file=None, - query=None, - sources=None, - ) - result = cmd_config_update(args) - - assert result == 0 - output = capsys.readouterr().out - assert "Updated" in output - assert "Version: 1 -> 2" in output - - @patch("deriva.cli.cli.config") - @patch("deriva.cli.cli.PipelineSession") - def test_update_extraction_config_success(self, mock_session_class, mock_config, capsys): - """Should update extraction config successfully.""" - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_config.create_extraction_config_version.return_value = { - "success": True, - "old_version": 1, - "new_version": 2, - } - - args = argparse.Namespace( - step_type="extraction", - name="BusinessConcept", - instruction="New instruction", - example=None, - instruction_file=None, - example_file=None, - query=None, - sources="*.py", - ) - result = cmd_config_update(args) - - assert result == 0 - - @patch("deriva.cli.cli.config") - @patch("deriva.cli.cli.PipelineSession") - def test_update_config_failure(self, mock_session_class, mock_config, capsys): - """Should return error on update failure.""" - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_config.create_derivation_config_version.return_value = { - "success": False, - "error": "Config not found", - } - - args = argparse.Namespace( - step_type="derivation", - name="Unknown", - instruction="New instruction", - example=None, - instruction_file=None, - example_file=None, - query=None, - sources=None, - ) - result = cmd_config_update(args) - - assert result == 1 - output = capsys.readouterr().out - assert "Error" in output - - @patch("deriva.cli.cli.PipelineSession") - def test_update_unknown_step_type(self, mock_session_class, capsys): - """Should return error for unknown step type.""" - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - - args = argparse.Namespace( - step_type="unknown", - name="test", - instruction="New instruction", - example=None, - instruction_file=None, - example_file=None, - query=None, - sources=None, - ) - result = cmd_config_update(args) - - assert result == 1 - +class TestFiletypeCommands: + """Tests for filetype commands.""" -class TestCmdFiletype: - """Tests for file type commands.""" - - @patch("deriva.cli.cli.PipelineSession") - def test_filetype_list_success(self, mock_session_class, capsys): + @patch("deriva.cli.commands.config.PipelineSession") + def test_filetype_list_success(self, mock_session_class): """Should list file types.""" mock_session = MagicMock() mock_session.get_file_types.return_value = [ @@ -1120,654 +682,47 @@ def test_filetype_list_success(self, mock_session_class, capsys): ] mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace() - result = cmd_filetype_list(args) - - assert result == 0 - output = capsys.readouterr().out - assert "FILE TYPE REGISTRY" in output - assert ".py" in output - assert "python" in output - - @patch("deriva.cli.cli.PipelineSession") - def test_filetype_list_empty(self, mock_session_class, capsys): - """Should show message when no file types.""" - mock_session = MagicMock() - mock_session.get_file_types.return_value = [] - mock_session_class.return_value.__enter__.return_value = mock_session - - args = argparse.Namespace() - result = cmd_filetype_list(args) + result = runner.invoke(app, ["config", "filetype", "list"]) - assert result == 0 - output = capsys.readouterr().out - assert "No file types registered" in output + assert result.exit_code == 0 + assert "FILE TYPE REGISTRY" in result.stdout + assert ".py" in result.stdout - @patch("deriva.cli.cli.PipelineSession") - def test_filetype_add_success(self, mock_session_class, capsys): + @patch("deriva.cli.commands.config.PipelineSession") + def test_filetype_add_success(self, mock_session_class): """Should add file type successfully.""" mock_session = MagicMock() mock_session.add_file_type.return_value = True mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace( - extension=".rs", - file_type="code", - subtype="rust", - ) - result = cmd_filetype_add(args) - - assert result == 0 - output = capsys.readouterr().out - assert "Added file type" in output - - @patch("deriva.cli.cli.PipelineSession") - def test_filetype_add_failure(self, mock_session_class, capsys): - """Should return error when add fails.""" - mock_session = MagicMock() - mock_session.add_file_type.return_value = False - mock_session_class.return_value.__enter__.return_value = mock_session - - args = argparse.Namespace( - extension=".py", - file_type="code", - subtype="python", - ) - result = cmd_filetype_add(args) + result = runner.invoke(app, ["config", "filetype", "add", ".rs", "code", "rust"]) - assert result == 1 - output = capsys.readouterr().out - assert "Failed to add" in output + assert result.exit_code == 0 + assert "Added file type" in result.stdout - @patch("deriva.cli.cli.PipelineSession") - def test_filetype_delete_success(self, mock_session_class, capsys): + @patch("deriva.cli.commands.config.PipelineSession") + def test_filetype_delete_success(self, mock_session_class): """Should delete file type successfully.""" mock_session = MagicMock() mock_session.delete_file_type.return_value = True mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace(extension=".rs") - result = cmd_filetype_delete(args) + result = runner.invoke(app, ["config", "filetype", "delete", ".rs"]) - assert result == 0 - output = capsys.readouterr().out - assert "Deleted file type" in output - - @patch("deriva.cli.cli.PipelineSession") - def test_filetype_delete_not_found(self, mock_session_class, capsys): - """Should return error when file type not found.""" - mock_session = MagicMock() - mock_session.delete_file_type.return_value = False - mock_session_class.return_value.__enter__.return_value = mock_session - - args = argparse.Namespace(extension=".xyz") - result = cmd_filetype_delete(args) - - assert result == 1 - output = capsys.readouterr().out - assert "not found" in output + assert result.exit_code == 0 + assert "Deleted file type" in result.stdout - @patch("deriva.cli.cli.PipelineSession") - def test_filetype_stats(self, mock_session_class, capsys): + @patch("deriva.cli.commands.config.PipelineSession") + def test_filetype_stats(self, mock_session_class): """Should show file type statistics.""" mock_session = MagicMock() mock_session.get_file_type_stats.return_value = { "code": 50, "config": 10, - "docs": 5, - } - mock_session_class.return_value.__enter__.return_value = mock_session - - args = argparse.Namespace() - result = cmd_filetype_stats(args) - - assert result == 0 - output = capsys.readouterr().out - assert "FILE TYPE STATISTICS" in output - assert "code" in output - assert "Total" in output - - -class TestCmdRun: - """Tests for cmd_run command.""" - - @patch("deriva.cli.cli.create_progress_reporter") - @patch("deriva.cli.cli.PipelineSession") - def test_run_extraction(self, mock_session_class, mock_progress, capsys): - """Should run extraction stage.""" - mock_session = MagicMock() - mock_session.llm_info = {"provider": "openai", "model": "gpt-4"} - mock_session.run_extraction.return_value = { - "success": True, - "stats": {"nodes_created": 100, "edges_created": 50}, - } - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_reporter = MagicMock() - mock_progress.return_value = mock_reporter - mock_reporter.__enter__ = MagicMock(return_value=mock_reporter) - mock_reporter.__exit__ = MagicMock(return_value=False) - - args = argparse.Namespace( - stage="extraction", - repo=None, - verbose=False, - no_llm=False, - phase=None, - quiet=False, - ) - result = cmd_run(args) - - assert result == 0 - output = capsys.readouterr().out - assert "EXTRACTION" in output - - @patch("deriva.cli.cli.create_progress_reporter") - @patch("deriva.cli.cli.PipelineSession") - def test_run_derivation(self, mock_session_class, mock_progress, capsys): - """Should run derivation stage.""" - mock_session = MagicMock() - mock_session.llm_info = {"provider": "openai", "model": "gpt-4"} - mock_session.run_derivation.return_value = { - "success": True, - "stats": {"elements_created": 10, "relationships_created": 5}, - } - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_reporter = MagicMock() - mock_progress.return_value = mock_reporter - mock_reporter.__enter__ = MagicMock(return_value=mock_reporter) - mock_reporter.__exit__ = MagicMock(return_value=False) - - args = argparse.Namespace( - stage="derivation", - repo=None, - verbose=False, - no_llm=False, - phase="prep", - quiet=False, - ) - result = cmd_run(args) - - assert result == 0 - - @patch("deriva.cli.cli.create_progress_reporter") - @patch("deriva.cli.cli.PipelineSession") - def test_run_derivation_without_llm(self, mock_session_class, mock_progress, capsys): - """Should return error when running derivation without LLM.""" - mock_session = MagicMock() - mock_session.llm_info = None # No LLM configured - mock_session_class.return_value.__enter__.return_value = mock_session - - args = argparse.Namespace( - stage="derivation", - repo=None, - verbose=False, - no_llm=False, - phase=None, - quiet=False, - ) - result = cmd_run(args) - - assert result == 1 - output = capsys.readouterr().out - assert "Error" in output - - @patch("deriva.cli.cli.create_progress_reporter") - @patch("deriva.cli.cli.PipelineSession") - def test_run_all_stages(self, mock_session_class, mock_progress, capsys): - """Should run all pipeline stages.""" - mock_session = MagicMock() - mock_session.llm_info = {"provider": "openai", "model": "gpt-4"} - mock_session.run_pipeline.return_value = { - "success": True, - "results": { - "extraction": {"stats": {"nodes_created": 100}}, - "derivation": {"stats": {"elements_created": 10}}, - }, - } - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_reporter = MagicMock() - mock_progress.return_value = mock_reporter - mock_reporter.__enter__ = MagicMock(return_value=mock_reporter) - mock_reporter.__exit__ = MagicMock(return_value=False) - - args = argparse.Namespace( - stage="all", - repo=None, - verbose=False, - no_llm=False, - phase=None, - quiet=False, - ) - result = cmd_run(args) - - assert result == 0 - - @patch("deriva.cli.cli.create_progress_reporter") - @patch("deriva.cli.cli.PipelineSession") - def test_run_unknown_stage(self, mock_session_class, mock_progress, capsys): - """Should return error for unknown stage.""" - mock_session = MagicMock() - mock_session.llm_info = {"provider": "openai", "model": "gpt-4"} - mock_session_class.return_value.__enter__.return_value = mock_session - - args = argparse.Namespace( - stage="unknown", - repo=None, - verbose=False, - no_llm=False, - phase=None, - quiet=False, - ) - result = cmd_run(args) - - assert result == 1 - output = capsys.readouterr().out - assert "Unknown stage" in output - - @patch("deriva.cli.cli.create_progress_reporter") - @patch("deriva.cli.cli.PipelineSession") - def test_run_with_repo_name(self, mock_session_class, mock_progress, capsys): - """Should run extraction with specific repository.""" - mock_session = MagicMock() - mock_session.llm_info = {"provider": "openai", "model": "gpt-4"} - mock_session.run_extraction.return_value = { - "success": True, - "stats": {}, - } - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_reporter = MagicMock() - mock_progress.return_value = mock_reporter - mock_reporter.__enter__ = MagicMock(return_value=mock_reporter) - mock_reporter.__exit__ = MagicMock(return_value=False) - - args = argparse.Namespace( - stage="extraction", - repo="my_repo", - verbose=False, - no_llm=False, - phase=None, - quiet=False, - ) - result = cmd_run(args) - - assert result == 0 - output = capsys.readouterr().out - assert "Repository: my_repo" in output - - @patch("deriva.cli.cli.create_progress_reporter") - @patch("deriva.cli.cli.PipelineSession") - def test_run_with_no_llm_flag(self, mock_session_class, mock_progress, capsys): - """Should show LLM disabled message with --no-llm flag.""" - mock_session = MagicMock() - mock_session.llm_info = {"provider": "openai", "model": "gpt-4"} - mock_session.run_extraction.return_value = { - "success": True, - "stats": {}, - } - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_reporter = MagicMock() - mock_progress.return_value = mock_reporter - mock_reporter.__enter__ = MagicMock(return_value=mock_reporter) - mock_reporter.__exit__ = MagicMock(return_value=False) - - args = argparse.Namespace( - stage="extraction", - repo=None, - verbose=False, - no_llm=True, - phase=None, - quiet=False, - ) - result = cmd_run(args) - - assert result == 0 - output = capsys.readouterr().out - assert "LLM disabled" in output - - -# ============================================================================= -# Benchmark Commands -# ============================================================================= - -from deriva.cli.cli import ( - cmd_benchmark_analyze, - cmd_benchmark_deviations, - cmd_benchmark_list, - cmd_benchmark_models, - cmd_benchmark_run, -) - - -class TestCmdBenchmarkRun: - """Tests for cmd_benchmark_run command.""" - - @patch("deriva.cli.cli.create_benchmark_progress_reporter") - @patch("deriva.cli.cli.PipelineSession") - def test_benchmark_run_success(self, mock_session_class, mock_progress, capsys): - """Should run benchmark successfully.""" - mock_session = MagicMock() - mock_result = MagicMock() - mock_result.session_id = "bench_123" - mock_result.runs_completed = 3 - mock_result.runs_failed = 0 - mock_result.duration_seconds = 120.5 - mock_result.ocel_path = "workspace/benchmarks/bench_123/ocel.json" - mock_result.success = True - mock_result.errors = [] - mock_session.run_benchmark.return_value = mock_result - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_reporter = MagicMock() - mock_progress.return_value = mock_reporter - mock_reporter.__enter__ = MagicMock(return_value=mock_reporter) - mock_reporter.__exit__ = MagicMock(return_value=False) - - args = argparse.Namespace( - repos="repo1,repo2", - models="gpt4,claude", - runs=3, - stages=None, - description="Test benchmark", - verbose=False, - quiet=False, - no_cache=False, - no_export_models=False, - no_clear=False, - bench_hash=False, - defer_relationships=False, - per_repo=False, - nocache_configs=None, - ) - result = cmd_benchmark_run(args) - - assert result == 0 - output = capsys.readouterr().out - assert "BENCHMARK COMPLETE" in output - assert "bench_123" in output - mock_session.run_benchmark.assert_called_once() - - @patch("deriva.cli.cli.create_benchmark_progress_reporter") - @patch("deriva.cli.cli.PipelineSession") - def test_benchmark_run_with_errors(self, mock_session_class, mock_progress, capsys): - """Should return error code when benchmark has failures.""" - mock_session = MagicMock() - mock_result = MagicMock() - mock_result.session_id = "bench_fail" - mock_result.runs_completed = 2 - mock_result.runs_failed = 1 - mock_result.duration_seconds = 60.0 - mock_result.ocel_path = "workspace/benchmarks/bench_fail/ocel.json" - mock_result.success = False - mock_result.errors = ["Run 3 failed: timeout"] - mock_session.run_benchmark.return_value = mock_result - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_reporter = MagicMock() - mock_progress.return_value = mock_reporter - mock_reporter.__enter__ = MagicMock(return_value=mock_reporter) - mock_reporter.__exit__ = MagicMock(return_value=False) - - args = argparse.Namespace( - repos="repo1", - models="gpt4", - runs=3, - stages=None, - description="", - verbose=False, - quiet=False, - no_cache=False, - no_export_models=False, - no_clear=False, - bench_hash=False, - defer_relationships=False, - per_repo=False, - nocache_configs=None, - ) - result = cmd_benchmark_run(args) - - assert result == 1 - output = capsys.readouterr().out - assert "Errors" in output - assert "timeout" in output - - @patch("deriva.cli.cli.create_benchmark_progress_reporter") - @patch("deriva.cli.cli.PipelineSession") - def test_benchmark_run_per_repo_mode(self, mock_session_class, mock_progress, capsys): - """Should show per-repo mode in output.""" - mock_session = MagicMock() - mock_result = MagicMock() - mock_result.session_id = "bench_per_repo" - mock_result.runs_completed = 6 - mock_result.runs_failed = 0 - mock_result.duration_seconds = 200.0 - mock_result.ocel_path = "path/ocel.json" - mock_result.success = True - mock_result.errors = [] - mock_session.run_benchmark.return_value = mock_result - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_reporter = MagicMock() - mock_progress.return_value = mock_reporter - mock_reporter.__enter__ = MagicMock(return_value=mock_reporter) - mock_reporter.__exit__ = MagicMock(return_value=False) - - args = argparse.Namespace( - repos="repo1,repo2", - models="gpt4", - runs=3, - stages=None, - description="", - verbose=False, - quiet=False, - no_cache=False, - no_export_models=False, - no_clear=False, - bench_hash=False, - defer_relationships=False, - per_repo=True, - nocache_configs=None, - ) - result = cmd_benchmark_run(args) - - assert result == 0 - output = capsys.readouterr().out - assert "per-repo" in output - assert "Total runs: 6" in output - - -class TestCmdBenchmarkList: - """Tests for cmd_benchmark_list command.""" - - @patch("deriva.cli.cli.PipelineSession") - def test_benchmark_list_shows_sessions(self, mock_session_class, capsys): - """Should list benchmark sessions.""" - mock_session = MagicMock() - mock_session.list_benchmarks.return_value = [ - { - "session_id": "bench_001", - "status": "completed", - "started_at": "2024-01-01T10:00:00", - "description": "Test run", - }, - { - "session_id": "bench_002", - "status": "failed", - "started_at": "2024-01-02T10:00:00", - "description": "", - }, - ] - mock_session_class.return_value.__enter__.return_value = mock_session - - args = argparse.Namespace(limit=10) - result = cmd_benchmark_list(args) - - assert result == 0 - output = capsys.readouterr().out - assert "BENCHMARK SESSIONS" in output - assert "bench_001" in output - assert "bench_002" in output - assert "completed" in output - assert "failed" in output - - @patch("deriva.cli.cli.PipelineSession") - def test_benchmark_list_empty(self, mock_session_class, capsys): - """Should show message when no sessions found.""" - mock_session = MagicMock() - mock_session.list_benchmarks.return_value = [] - mock_session_class.return_value.__enter__.return_value = mock_session - - args = argparse.Namespace(limit=10) - result = cmd_benchmark_list(args) - - assert result == 0 - output = capsys.readouterr().out - assert "No benchmark sessions found" in output - - -class TestCmdBenchmarkModels: - """Tests for cmd_benchmark_models command.""" - - @patch("deriva.cli.cli.PipelineSession") - def test_benchmark_models_lists_configs(self, mock_session_class, capsys): - """Should list available model configurations.""" - from deriva.adapters.llm.models import BenchmarkModelConfig - - mock_session = MagicMock() - mock_session.list_benchmark_models.return_value = { - "openai-gpt4": BenchmarkModelConfig( - name="openai-gpt4", - provider="openai", - model="gpt-4", - api_key_env="OPENAI_API_KEY", - ), - "anthropic-claude": BenchmarkModelConfig( - name="anthropic-claude", - provider="anthropic", - model="claude-3-sonnet", - api_key_env="ANTHROPIC_API_KEY", - ), } mock_session_class.return_value.__enter__.return_value = mock_session - args = argparse.Namespace() - result = cmd_benchmark_models(args) + result = runner.invoke(app, ["config", "filetype", "stats"]) - assert result == 0 - output = capsys.readouterr().out - assert "openai-gpt4" in output - assert "anthropic-claude" in output - assert "gpt-4" in output or "openai" in output - - @patch("deriva.cli.cli.PipelineSession") - def test_benchmark_models_empty(self, mock_session_class, capsys): - """Should show message when no models configured.""" - mock_session = MagicMock() - mock_session.list_benchmark_models.return_value = {} - mock_session_class.return_value.__enter__.return_value = mock_session - - args = argparse.Namespace() - result = cmd_benchmark_models(args) - - assert result == 0 - output = capsys.readouterr().out - assert "No benchmark model" in output - - -class TestCmdBenchmarkAnalyze: - """Tests for cmd_benchmark_analyze command.""" - - @patch("deriva.cli.cli.PipelineSession") - def test_benchmark_analyze_success(self, mock_session_class, capsys): - """Should analyze benchmark session.""" - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_analyzer = MagicMock() - mock_analyzer.session_id = "bench_123" - mock_analyzer.compute_intra_model_consistency.return_value = {"gpt4": {"score": 0.95}} - mock_analyzer.compute_inter_model_consistency.return_value = {"overall_agreement": 0.85} - mock_analyzer.localize_inconsistencies.return_value = [] - mock_session.analyze_benchmark.return_value = mock_analyzer - - args = argparse.Namespace( - session_id="bench_123", - output=None, - format="json", - ) - result = cmd_benchmark_analyze(args) - - assert result == 0 - output = capsys.readouterr().out - assert "ANALYZING BENCHMARK" in output or "bench_123" in output - - @patch("deriva.cli.cli.PipelineSession") - def test_benchmark_analyze_not_found(self, mock_session_class, capsys): - """Should handle session not found.""" - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - mock_session.analyze_benchmark.side_effect = ValueError("Session not found") - - args = argparse.Namespace( - session_id="nonexistent", - output=None, - format="json", - ) - result = cmd_benchmark_analyze(args) - - assert result == 1 - output = capsys.readouterr().out - assert "Error" in output - - -class TestCmdBenchmarkDeviations: - """Tests for cmd_benchmark_deviations command.""" - - @patch("deriva.cli.cli.PipelineSession") - def test_benchmark_deviations_success(self, mock_session_class, capsys): - """Should analyze config deviations.""" - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - - mock_report = MagicMock() - mock_report.total_runs = 3 - mock_report.total_deviations = 5 - mock_report.overall_consistency = 0.85 - mock_report.deviations = [] - - mock_analyzer = MagicMock() - mock_analyzer.analyze.return_value = mock_report - mock_session.analyze_config_deviations.return_value = mock_analyzer - - args = argparse.Namespace( - session_id="bench_123", - sort_by="deviation_count", - output=None, - ) - result = cmd_benchmark_deviations(args) - - assert result == 0 - output = capsys.readouterr().out - assert "bench_123" in output - - @patch("deriva.cli.cli.PipelineSession") - def test_benchmark_deviations_not_found(self, mock_session_class, capsys): - """Should handle session not found.""" - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - mock_session.analyze_config_deviations.side_effect = ValueError("Session not found") - - args = argparse.Namespace( - session_id="nonexistent", - sort_by="deviation_count", - output=None, - ) - result = cmd_benchmark_deviations(args) - - assert result == 1 - output = capsys.readouterr().out - assert "Error" in output + assert result.exit_code == 0 + assert "FILE TYPE STATISTICS" in result.stdout diff --git a/tests/test_common/test_cache_utils.py b/tests/test_common/test_cache_utils.py new file mode 100644 index 0000000..effb06a --- /dev/null +++ b/tests/test_common/test_cache_utils.py @@ -0,0 +1,326 @@ +"""Tests for common.cache_utils module.""" + +from __future__ import annotations + +import json + +import pytest + +from deriva.common.cache_utils import BaseDiskCache, dict_to_hashable, hash_inputs +from deriva.common.exceptions import CacheError + + +class TestHashInputs: + """Tests for hash_inputs function.""" + + def test_returns_hex_string(self): + """Should return a 64-character hex string.""" + result = hash_inputs("test") + assert isinstance(result, str) + assert len(result) == 64 + assert all(c in "0123456789abcdef" for c in result) + + def test_same_input_same_hash(self): + """Should produce consistent hashes for same input.""" + hash1 = hash_inputs("test", "value") + hash2 = hash_inputs("test", "value") + assert hash1 == hash2 + + def test_different_input_different_hash(self): + """Should produce different hashes for different input.""" + hash1 = hash_inputs("test", "value1") + hash2 = hash_inputs("test", "value2") + assert hash1 != hash2 + + def test_handles_dict_input(self): + """Should handle dict inputs with consistent ordering.""" + hash1 = hash_inputs({"b": 2, "a": 1}) + hash2 = hash_inputs({"a": 1, "b": 2}) + assert hash1 == hash2 + + def test_handles_list_input(self): + """Should handle list inputs.""" + hash1 = hash_inputs([1, 2, 3]) + hash2 = hash_inputs([1, 2, 3]) + assert hash1 == hash2 + + def test_handles_none_input(self): + """Should skip None values.""" + hash1 = hash_inputs("test", None, "value") + hash2 = hash_inputs("test", "value") + assert hash1 == hash2 + + def test_custom_separator(self): + """Should use custom separator.""" + hash1 = hash_inputs("a", "b", separator="|") + hash2 = hash_inputs("a", "b", separator="::") + assert hash1 != hash2 + + def test_handles_nested_dict(self): + """Should handle nested dictionaries.""" + result = hash_inputs({"outer": {"inner": "value"}}) + assert isinstance(result, str) + assert len(result) == 64 + + +class TestDictToHashable: + """Tests for dict_to_hashable function.""" + + def test_returns_tuple(self): + """Should return a tuple.""" + result = dict_to_hashable({"a": 1}) + assert isinstance(result, tuple) + + def test_simple_dict(self): + """Should convert simple dict to tuple.""" + result = dict_to_hashable({"a": 1, "b": 2}) + assert result == (("a", 1), ("b", 2)) + + def test_nested_dict(self): + """Should handle nested dicts.""" + result = dict_to_hashable({"a": {"b": 1}}) + assert result == (("a", (("b", 1),)),) + + def test_dict_with_list(self): + """Should convert lists to tuples.""" + result = dict_to_hashable({"a": [1, 2, 3]}) + assert result == (("a", (1, 2, 3)),) + + def test_dict_with_list_of_dicts(self): + """Should handle lists containing dicts.""" + result = dict_to_hashable({"a": [{"b": 1}]}) + expected = (("a", ((("b", 1),),)),) + assert result == expected + + def test_sorted_keys(self): + """Should sort dict keys for consistency.""" + result1 = dict_to_hashable({"b": 2, "a": 1}) + result2 = dict_to_hashable({"a": 1, "b": 2}) + assert result1 == result2 + + +class TestBaseDiskCache: + """Tests for BaseDiskCache class.""" + + @pytest.fixture + def cache(self, tmp_path): + """Create a temporary cache for testing.""" + cache = BaseDiskCache(tmp_path / "test_cache") + yield cache + cache.close() + + def test_creates_cache_dir(self, tmp_path): + """Should create cache directory if not exists.""" + cache_dir = tmp_path / "new_cache" + cache = BaseDiskCache(cache_dir) + assert cache_dir.exists() + cache.close() + + def test_set_and_get(self, cache): + """Should store and retrieve data.""" + cache.set("key1", {"value": "test"}) + result = cache.get("key1") + assert result == {"value": "test"} + + def test_get_nonexistent_returns_none(self, cache): + """Should return None for nonexistent key.""" + result = cache.get("nonexistent") + assert result is None + + def test_get_from_memory(self, cache): + """Should retrieve from memory (alias for get).""" + cache.set("key1", {"value": "test"}) + result = cache.get_from_memory("key1") + assert result == {"value": "test"} + + def test_get_from_disk(self, cache): + """Should retrieve from disk (alias for get).""" + cache.set("key1", {"value": "test"}) + result = cache.get_from_disk("key1") + assert result == {"value": "test"} + + def test_invalidate(self, cache): + """Should remove entry from cache.""" + cache.set("key1", {"value": "test"}) + cache.invalidate("key1") + assert cache.get("key1") is None + + def test_clear_memory(self, cache): + """Should clear memory cache.""" + cache.set("key1", {"value": "test"}) + cache.clear_memory() + # Entry should still be retrievable (cull is optimization) + result = cache.get("key1") + assert result == {"value": "test"} + + def test_clear_disk(self, cache): + """Should clear all cache entries.""" + cache.set("key1", {"value": "test1"}) + cache.set("key2", {"value": "test2"}) + cache.clear_disk() + assert cache.get("key1") is None + assert cache.get("key2") is None + + def test_clear_all(self, cache): + """Should clear entire cache.""" + cache.set("key1", {"value": "test"}) + cache.clear_all() + assert cache.get("key1") is None + + def test_get_stats(self, cache): + """Should return cache statistics.""" + cache.set("key1", {"value": "test"}) + stats = cache.get_stats() + + assert "entries" in stats + assert "size_bytes" in stats + assert "size_mb" in stats + assert "cache_dir" in stats + assert stats["entries"] >= 1 + + def test_keys(self, cache): + """Should return list of cache keys.""" + cache.set("key1", {"value": "test1"}) + cache.set("key2", {"value": "test2"}) + keys = cache.keys() + + assert "key1" in keys + assert "key2" in keys + + def test_export_to_json(self, cache, tmp_path): + """Should export cache to JSON file.""" + cache.set("key1", {"value": "test1"}) + cache.set("key2", {"value": "test2"}) + + output_path = tmp_path / "export.json" + count = cache.export_to_json(output_path) + + assert count == 2 + assert output_path.exists() + + with open(output_path) as f: + data = json.load(f) + + assert data["entry_count"] == 2 + assert len(data["entries"]) == 2 + + def test_export_to_json_without_values(self, cache, tmp_path): + """Should export only keys when include_values=False.""" + cache.set("key1", {"value": "test1"}) + + output_path = tmp_path / "export.json" + cache.export_to_json(output_path, include_values=False) + + with open(output_path) as f: + data = json.load(f) + + assert "value" not in data["entries"][0] + + def test_export_creates_parent_dirs(self, cache, tmp_path): + """Should create parent directories for export.""" + output_path = tmp_path / "nested" / "dir" / "export.json" + cache.set("key1", {"value": "test"}) + cache.export_to_json(output_path) + assert output_path.exists() + + def test_context_manager(self, tmp_path): + """Should work as context manager.""" + cache_dir = tmp_path / "ctx_cache" + + with BaseDiskCache(cache_dir) as cache: + cache.set("key1", {"value": "test"}) + result = cache.get("key1") + assert result == {"value": "test"} + + def test_set_with_expire(self, cache): + """Should accept expire parameter.""" + # Just verify it doesn't crash - TTL testing requires time manipulation + cache.set("key1", {"value": "test"}, expire=3600) + result = cache.get("key1") + assert result == {"value": "test"} + + def test_custom_size_limit(self, tmp_path): + """Should accept custom size limit.""" + cache = BaseDiskCache(tmp_path / "sized_cache", size_limit=1024 * 1024) + cache.set("key1", {"value": "test"}) + cache.close() + + def test_close_is_safe(self, cache): + """Should allow multiple close calls.""" + cache.close() + cache.close() # Should not raise + + +class TestBaseDiskCacheErrors: + """Tests for error handling in BaseDiskCache.""" + + def test_get_raises_cache_error_on_corruption(self, tmp_path): + """Should raise CacheError when cache is corrupted.""" + cache = BaseDiskCache(tmp_path / "test_cache") + + # Mock internal cache to simulate error + from unittest.mock import patch + + with patch.object(cache, "_cache") as mock_cache: + mock_cache.get.side_effect = Exception("Corrupted") + with pytest.raises(CacheError, match="Error reading from cache"): + cache.get("key1") + + cache.close() + + def test_set_raises_cache_error_on_write_failure(self, tmp_path): + """Should raise CacheError when write fails.""" + cache = BaseDiskCache(tmp_path / "test_cache") + + from unittest.mock import patch + + with patch.object(cache, "_cache") as mock_cache: + mock_cache.set.side_effect = Exception("Write failed") + with pytest.raises(CacheError, match="Error writing to cache"): + cache.set("key1", {"value": "test"}) + + cache.close() + + def test_invalidate_raises_cache_error_on_delete_failure(self, tmp_path): + """Should raise CacheError when delete fails.""" + cache = BaseDiskCache(tmp_path / "test_cache") + + from unittest.mock import patch + + with patch.object(cache, "_cache") as mock_cache: + mock_cache.delete.side_effect = Exception("Delete failed") + with pytest.raises(CacheError, match="Error deleting cache entry"): + cache.invalidate("key1") + + cache.close() + + def test_clear_disk_raises_cache_error_on_failure(self, tmp_path): + """Should raise CacheError when clear fails.""" + cache = BaseDiskCache(tmp_path / "test_cache") + + from unittest.mock import patch + + with patch.object(cache, "_cache") as mock_cache: + mock_cache.clear.side_effect = Exception("Clear failed") + with pytest.raises(CacheError, match="Error clearing cache"): + cache.clear_disk() + + cache.close() + + def test_get_stats_handles_volume_error(self, tmp_path): + """Should handle volume() errors gracefully.""" + cache = BaseDiskCache(tmp_path / "test_cache") + + from unittest.mock import MagicMock + + original_cache = cache._cache + mock_cache = MagicMock(wraps=original_cache) + mock_cache.volume.side_effect = Exception("Volume error") + mock_cache.__len__ = lambda self: 0 + cache._cache = mock_cache + + stats = cache.get_stats() + assert stats["size_bytes"] == 0 + + cache._cache = original_cache + cache.close() diff --git a/tests/test_common/test_document_reader.py b/tests/test_common/test_document_reader.py index 83c3cce..c7ce3d9 100644 --- a/tests/test_common/test_document_reader.py +++ b/tests/test_common/test_document_reader.py @@ -3,29 +3,19 @@ from __future__ import annotations import tempfile - -# Check if optional dependencies are available -from importlib.util import find_spec from pathlib import Path from unittest.mock import MagicMock, patch -import pytest - -HAS_DOCX = find_spec("docx") is not None -HAS_PYPDF = find_spec("pypdf") is not None - class TestLibraryAvailability: """Tests for library availability checks.""" - @pytest.mark.skipif(not HAS_DOCX, reason="python-docx not installed") def test_is_docx_available_returns_true(self): """Should return True when python-docx is installed.""" from deriva.common.document_reader import is_docx_available assert is_docx_available() is True - @pytest.mark.skipif(not HAS_PYPDF, reason="pypdf not installed") def test_is_pdf_available_returns_true(self): """Should return True when pypdf is installed.""" from deriva.common.document_reader import is_pdf_available @@ -94,7 +84,6 @@ def test_returns_none_on_read_error(self): result = read_docx(Path("nonexistent.docx")) assert result is None - @pytest.mark.skipif(not HAS_DOCX, reason="python-docx not installed") def test_extracts_paragraphs(self): """Should extract paragraph text.""" import docx diff --git a/tests/test_modules/analysis/test_fit_analysis.py b/tests/test_modules/analysis/test_fit_analysis.py index 270dd96..eead5d4 100644 --- a/tests/test_modules/analysis/test_fit_analysis.py +++ b/tests/test_modules/analysis/test_fit_analysis.py @@ -236,11 +236,7 @@ def test_finds_similar_names(self): pairs = _find_similar_names(names, threshold=0.8) assert len(pairs) >= 1 - assert any( - ("UserService" in p and "user_service" in p) - or ("user_service" in p and "UserService" in p) - for p in pairs - ) + assert any(("UserService" in p and "user_service" in p) or ("user_service" in p and "UserService" in p) for p in pairs) def test_ignores_exact_duplicates(self): """Should skip exact case-insensitive duplicates.""" @@ -296,9 +292,7 @@ def test_good_fit_produces_good_fit_message(self): """Should produce 'GOOD FIT' message when well calibrated.""" # Create a scenario with good fit (similar element counts, good precision/recall) derived = [{"type": "ApplicationComponent", "name": f"Svc{i}"} for i in range(5)] - reference = [ - make_reference(f"ref{i}", f"Ref{i}", "ApplicationComponent") for i in range(5) - ] + reference = [make_reference(f"ref{i}", f"Ref{i}", "ApplicationComponent") for i in range(5)] semantic = make_semantic_report(precision=0.9, recall=0.9) result = create_fit_analysis( diff --git a/tests/test_modules/analysis/test_stability_analysis.py b/tests/test_modules/analysis/test_stability_analysis.py index bf830d8..39c9795 100644 --- a/tests/test_modules/analysis/test_stability_analysis.py +++ b/tests/test_modules/analysis/test_stability_analysis.py @@ -9,8 +9,6 @@ from __future__ import annotations -import pytest - from deriva.modules.analysis.stability_analysis import ( aggregate_stability_metrics, compute_phase_stability, diff --git a/tests/test_modules/derivation/test_base.py b/tests/test_modules/derivation/test_base.py index 1069cea..de2bdc6 100644 --- a/tests/test_modules/derivation/test_base.py +++ b/tests/test_modules/derivation/test_base.py @@ -1804,8 +1804,8 @@ def test_handles_response_without_content_attribute(self): class TestSharedGenerateBehavior: """Tests for shared generate() behavior across all element modules. - These tests verify common behavior using application_component as the - reference implementation. All element modules should behave identically + These tests verify common behavior using ApplicationComponentDerivation as the + reference implementation. All element derivation classes should behave identically for these scenarios. """ @@ -1813,11 +1813,12 @@ def test_returns_empty_when_no_candidates(self): """generate() should return 0 elements when no candidates found.""" from unittest.mock import MagicMock, Mock, patch - from deriva.modules.derivation.application_component import generate + from deriva.modules.derivation.application_component import ApplicationComponentDerivation + derivation = ApplicationComponentDerivation() with patch("deriva.modules.derivation.element_base.get_enrichments_from_neo4j", return_value={}): with patch("deriva.modules.derivation.element_base.query_candidates", return_value=[]): - result = generate( + result = derivation.generate( graph_manager=MagicMock(), archimate_manager=MagicMock(), engine=MagicMock(), @@ -1837,11 +1838,12 @@ def test_handles_query_exception(self): """generate() should handle query exceptions gracefully.""" from unittest.mock import MagicMock, Mock, patch - from deriva.modules.derivation.application_component import generate + from deriva.modules.derivation.application_component import ApplicationComponentDerivation + derivation = ApplicationComponentDerivation() with patch("deriva.modules.derivation.element_base.get_enrichments_from_neo4j", return_value={}): with patch("deriva.modules.derivation.element_base.query_candidates", side_effect=Exception("DB error")): - result = generate( + result = derivation.generate( graph_manager=MagicMock(), archimate_manager=MagicMock(), engine=MagicMock(), @@ -1861,12 +1863,13 @@ def test_returns_generation_result_type(self): """generate() should return GenerationResult type.""" from unittest.mock import MagicMock, Mock, patch - from deriva.modules.derivation.application_component import generate + from deriva.modules.derivation.application_component import ApplicationComponentDerivation from deriva.modules.derivation.base import GenerationResult + derivation = ApplicationComponentDerivation() with patch("deriva.modules.derivation.element_base.get_enrichments_from_neo4j", return_value={}): with patch("deriva.modules.derivation.element_base.query_candidates", return_value=[]): - result = generate( + result = derivation.generate( graph_manager=MagicMock(), archimate_manager=MagicMock(), engine=MagicMock(), diff --git a/tests/test_modules/derivation/test_element_base.py b/tests/test_modules/derivation/test_element_base.py new file mode 100644 index 0000000..db3d3a2 --- /dev/null +++ b/tests/test_modules/derivation/test_element_base.py @@ -0,0 +1,444 @@ +"""Tests for modules.derivation.element_base module.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from deriva.modules.derivation.base import Candidate, GenerationResult, RelationshipRule +from deriva.modules.derivation.element_base import ( + ElementDerivationBase, + PatternBasedDerivation, +) + + +class ConcreteDerivation(ElementDerivationBase): + """Concrete implementation for testing abstract base class.""" + + ELEMENT_TYPE = "TestElement" + OUTBOUND_RULES: list[RelationshipRule] = [] + INBOUND_RULES: list[RelationshipRule] = [] + + def filter_candidates(self, candidates, enrichments, max_candidates, **kwargs): + """Simple filter that returns first N candidates.""" + return candidates[:max_candidates] + + +class ConcretePatternDerivation(PatternBasedDerivation): + """Concrete implementation for testing PatternBasedDerivation.""" + + ELEMENT_TYPE = "TestPatternElement" + OUTBOUND_RULES: list[RelationshipRule] = [] + INBOUND_RULES: list[RelationshipRule] = [] + + def filter_candidates( + self, + candidates, + enrichments, + max_candidates, + include_patterns=None, + exclude_patterns=None, + **kwargs, + ): + """Filter using pattern matching.""" + include_patterns = include_patterns or set() + exclude_patterns = exclude_patterns or set() + + filtered = [] + for c in candidates: + if self.matches_patterns(c.name, include_patterns, exclude_patterns): + filtered.append(c) + return filtered[:max_candidates] + + +class TestElementDerivationBase: + """Tests for ElementDerivationBase abstract class.""" + + def test_init_creates_logger(self): + """Should create a logger on initialization.""" + derivation = ConcreteDerivation() + assert derivation.logger is not None + + def test_get_filter_kwargs_returns_empty_dict(self): + """Default get_filter_kwargs should return empty dict.""" + derivation = ConcreteDerivation() + result = derivation.get_filter_kwargs(MagicMock()) + assert result == {} + + def test_generate_returns_result_for_empty_candidates(self): + """Should return success result when no candidates found.""" + derivation = ConcreteDerivation() + + mock_graph = MagicMock() + mock_graph.query.return_value = [] + + result = derivation.generate( + graph_manager=mock_graph, + archimate_manager=MagicMock(), + engine=MagicMock(), + llm_query_fn=MagicMock(), + query="MATCH (n) RETURN n", + instruction="Test", + example="{}", + max_candidates=10, + batch_size=5, + existing_elements=[], + ) + + assert isinstance(result, GenerationResult) + assert result.success is True + assert result.elements_created == 0 + + def test_generate_handles_query_exception(self): + """Should return error result when query fails.""" + derivation = ConcreteDerivation() + + mock_graph = MagicMock() + # First call returns empty enrichments, second raises exception + mock_graph.query.side_effect = [[], Exception("Query failed")] + + result = derivation.generate( + graph_manager=mock_graph, + archimate_manager=MagicMock(), + engine=MagicMock(), + llm_query_fn=MagicMock(), + query="MATCH (n) RETURN n", + instruction="Test", + example="{}", + max_candidates=10, + batch_size=5, + existing_elements=[], + ) + + assert result.success is False + assert len(result.errors) > 0 + assert "Query failed" in result.errors[0] + + def test_generate_returns_empty_when_no_candidates_pass_filter(self): + """Should return success when all candidates are filtered out.""" + + class FilterAllDerivation(ConcreteDerivation): + def filter_candidates(self, candidates, enrichments, max_candidates, **kwargs): + return [] # Filter out everything + + derivation = FilterAllDerivation() + + mock_graph = MagicMock() + mock_graph.query.return_value = [{"id": "1", "name": "test", "labels": ["Node"], "properties": {}}] + + # Patch the helper functions to control behavior + with ( + patch("deriva.modules.derivation.element_base.get_enrichments_from_neo4j") as mock_enrichments, + patch("deriva.modules.derivation.element_base.query_candidates") as mock_candidates, + ): + mock_enrichments.return_value = {} + mock_candidates.return_value = [ + Candidate( + node_id="1", + name="test", + labels=["Node"], + properties={}, + ) + ] + + result = derivation.generate( + graph_manager=mock_graph, + archimate_manager=MagicMock(), + engine=MagicMock(), + llm_query_fn=MagicMock(), + query="MATCH (n) RETURN n", + instruction="Test", + example="{}", + max_candidates=10, + batch_size=5, + existing_elements=[], + ) + + assert result.success is True + assert result.elements_created == 0 + + +class TestPatternBasedDerivation: + """Tests for PatternBasedDerivation mixin class.""" + + def test_matches_patterns_returns_true_for_include_match(self): + """Should return True when name matches include pattern.""" + derivation = ConcretePatternDerivation() + + result = derivation.matches_patterns( + name="UserService", + include_patterns={"service", "manager"}, + exclude_patterns=set(), + ) + + assert result is True + + def test_matches_patterns_returns_false_for_exclude_match(self): + """Should return False when name matches exclude pattern.""" + derivation = ConcretePatternDerivation() + + result = derivation.matches_patterns( + name="TestService", + include_patterns={"service"}, + exclude_patterns={"test"}, + ) + + assert result is False + + def test_matches_patterns_is_case_insensitive(self): + """Should match patterns case-insensitively.""" + derivation = ConcretePatternDerivation() + + result = derivation.matches_patterns( + name="USERSERVICE", + include_patterns={"service"}, + exclude_patterns=set(), + ) + + assert result is True + + def test_matches_patterns_returns_default_when_no_match(self): + """Should return PATTERN_MATCH_DEFAULT when no patterns match.""" + derivation = ConcretePatternDerivation() + + result = derivation.matches_patterns( + name="RandomName", + include_patterns={"service"}, + exclude_patterns=set(), + ) + + assert result is False # PATTERN_MATCH_DEFAULT is False + + def test_matches_patterns_returns_false_for_empty_name(self): + """Should return False for empty name.""" + derivation = ConcretePatternDerivation() + + result = derivation.matches_patterns( + name="", + include_patterns={"service"}, + exclude_patterns=set(), + ) + + assert result is False + + def test_matches_patterns_returns_false_for_none_name(self): + """Should return False for None name.""" + derivation = ConcretePatternDerivation() + + result = derivation.matches_patterns( + name=None, # type: ignore[arg-type] # Testing None handling + include_patterns={"service"}, + exclude_patterns=set(), + ) + + assert result is False + + def test_get_filter_kwargs_loads_patterns(self): + """Should load patterns from config.""" + derivation = ConcretePatternDerivation() + + with patch("deriva.services.config.get_derivation_patterns") as mock_get: + mock_get.return_value = { + "include": {"service", "manager"}, + "exclude": {"test", "mock"}, + } + + result = derivation.get_filter_kwargs(MagicMock()) + + assert "include_patterns" in result + assert "exclude_patterns" in result + assert "service" in result["include_patterns"] + assert "test" in result["exclude_patterns"] + + def test_get_filter_kwargs_handles_missing_patterns(self): + """Should return empty sets when no patterns configured.""" + derivation = ConcretePatternDerivation() + + with patch("deriva.services.config.get_derivation_patterns") as mock_get: + mock_get.side_effect = ValueError("Not found") + + result = derivation.get_filter_kwargs(MagicMock()) + + assert result["include_patterns"] == set() + assert result["exclude_patterns"] == set() + + +class TestPatternMatchDefault: + """Tests for PATTERN_MATCH_DEFAULT behavior.""" + + def test_custom_pattern_match_default(self): + """Should allow customizing PATTERN_MATCH_DEFAULT.""" + + class InclusivePatternDerivation(PatternBasedDerivation): + ELEMENT_TYPE = "Inclusive" + PATTERN_MATCH_DEFAULT = True # Include by default + OUTBOUND_RULES: list[RelationshipRule] = [] + INBOUND_RULES: list[RelationshipRule] = [] + + def filter_candidates(self, candidates, enrichments, max_candidates, **kwargs): + return candidates + + derivation = InclusivePatternDerivation() + + result = derivation.matches_patterns( + name="RandomName", + include_patterns=set(), + exclude_patterns=set(), + ) + + assert result is True # Custom default + + +class TestProcessBatch: + """Tests for _process_batch method.""" + + def test_process_batch_handles_llm_error(self): + """Should add error when LLM call fails.""" + derivation = ConcreteDerivation() + + result = GenerationResult(success=True) + batch = [ + Candidate( + node_id="1", + name="Test", + labels=["Node"], + properties={}, + pagerank=0.5, + louvain_community="1", + ) + ] + + mock_llm = MagicMock(side_effect=Exception("LLM error")) + + derivation._process_batch( + batch_num=1, + batch=batch, + instruction="Test", + example="{}", + llm_query_fn=mock_llm, + llm_kwargs={}, + archimate_manager=MagicMock(), + graph_manager=MagicMock(), + existing_elements=[], + temperature=None, + max_tokens=None, + defer_relationships=False, + result=result, + ) + + assert len(result.errors) > 0 + assert "LLM error" in result.errors[0] + + def test_process_batch_handles_parse_error(self): + """Should add error when response parsing fails.""" + derivation = ConcreteDerivation() + + result = GenerationResult(success=True) + batch = [ + Candidate( + node_id="1", + name="Test", + labels=["Node"], + properties={}, + pagerank=0.5, + louvain_community="1", + ) + ] + + # Mock LLM to return invalid response + mock_response = MagicMock() + mock_response.output = "not valid json" + mock_llm = MagicMock(return_value=mock_response) + + with patch("deriva.modules.derivation.element_base.extract_response_content") as mock_extract: + mock_extract.return_value = ("invalid json", None) + + derivation._process_batch( + batch_num=1, + batch=batch, + instruction="Test", + example="{}", + llm_query_fn=mock_llm, + llm_kwargs={}, + archimate_manager=MagicMock(), + graph_manager=MagicMock(), + existing_elements=[], + temperature=None, + max_tokens=None, + defer_relationships=False, + result=result, + ) + + # Should have parse errors + assert len(result.errors) >= 0 # May or may not have errors depending on parse + + +class TestDeriveRelationships: + """Tests for _derive_relationships method.""" + + def test_derive_relationships_creates_relationships(self): + """Should create relationships from derive_batch_relationships result.""" + derivation = ConcreteDerivation() + + result = GenerationResult(success=True) + batch_elements = [{"identifier": "elem-1", "name": "Element1"}] + existing_elements = [{"identifier": "elem-0", "name": "Element0"}] + + mock_archimate = MagicMock() + + with patch("deriva.modules.derivation.element_base.derive_batch_relationships") as mock_derive: + mock_derive.return_value = [ + { + "source": "elem-1", + "target": "elem-0", + "relationship_type": "Association", + "confidence": 0.8, + } + ] + + derivation._derive_relationships( + batch_elements=batch_elements, + existing_elements=existing_elements, + llm_query_fn=MagicMock(), + temperature=None, + max_tokens=None, + graph_manager=MagicMock(), + archimate_manager=mock_archimate, + result=result, + ) + + assert result.relationships_created == 1 + assert mock_archimate.add_relationship.called + + def test_derive_relationships_handles_creation_error(self): + """Should add error when relationship creation fails.""" + derivation = ConcreteDerivation() + + result = GenerationResult(success=True) + batch_elements = [{"identifier": "elem-1", "name": "Element1"}] + existing_elements = [{"identifier": "elem-0", "name": "Element0"}] + + mock_archimate = MagicMock() + mock_archimate.add_relationship.side_effect = Exception("Creation failed") + + with patch("deriva.modules.derivation.element_base.derive_batch_relationships") as mock_derive: + mock_derive.return_value = [ + { + "source": "elem-1", + "target": "elem-0", + "relationship_type": "Association", + } + ] + + derivation._derive_relationships( + batch_elements=batch_elements, + existing_elements=existing_elements, + llm_query_fn=MagicMock(), + temperature=None, + max_tokens=None, + graph_manager=MagicMock(), + archimate_manager=mock_archimate, + result=result, + ) + + assert len(result.errors) > 0 + assert "Failed to create" in result.errors[0] diff --git a/tests/test_modules/derivation/test_elements.py b/tests/test_modules/derivation/test_elements.py index c1034e3..a9f2b0c 100644 --- a/tests/test_modules/derivation/test_elements.py +++ b/tests/test_modules/derivation/test_elements.py @@ -6,13 +6,13 @@ from __future__ import annotations -import importlib from unittest.mock import MagicMock import pytest import deriva.services.config as config_module from deriva.modules.derivation.base import clear_enrichment_cache +from deriva.services.derivation import DERIVATION_REGISTRY @pytest.fixture(autouse=True) @@ -23,70 +23,57 @@ def reset_enrichment_cache(): clear_enrichment_cache() -# All derivation element module names -DERIVATION_MODULES = [ - "application_component", - "application_interface", - "application_service", - "business_actor", - "business_event", - "business_function", - "business_object", - "business_process", - "data_object", - "device", - "node", - "system_software", - "technology_service", -] - - -def get_module(module_name: str): - """Dynamically import a derivation module.""" - return importlib.import_module(f"deriva.modules.derivation.{module_name}") - - -class TestModuleExports: - """Tests that all derivation modules export required interface.""" - - @pytest.mark.parametrize("module_name", DERIVATION_MODULES) - def test_exports_element_type(self, module_name): - """All derivation modules should export ELEMENT_TYPE constant.""" - module = get_module(module_name) - assert hasattr(module, "ELEMENT_TYPE") - assert isinstance(module.ELEMENT_TYPE, str) - assert len(module.ELEMENT_TYPE) > 0 - - @pytest.mark.parametrize("module_name", DERIVATION_MODULES) - def test_exports_generate_function(self, module_name): - """All derivation modules should export generate() function.""" - module = get_module(module_name) - assert hasattr(module, "generate") - assert callable(module.generate) - - @pytest.mark.parametrize("module_name", DERIVATION_MODULES) - def test_exports_filter_candidates(self, module_name): - """All derivation modules should export filter_candidates() function.""" - module = get_module(module_name) - assert hasattr(module, "filter_candidates") - assert callable(module.filter_candidates) - - -class TestGenerateFunction: - """Tests for generate() function across all modules.""" - - @pytest.mark.parametrize("module_name", DERIVATION_MODULES) - def test_returns_generation_result(self, module_name): - """All derivation modules should return GenerationResult from generate().""" +# All derivation element types +DERIVATION_ELEMENT_TYPES = list(DERIVATION_REGISTRY.keys()) + + +def get_derivation(element_type: str): + """Get a derivation class instance for an element type.""" + cls = DERIVATION_REGISTRY.get(element_type) + return cls() if cls else None + + +class TestDerivationClasses: + """Tests that all derivation classes have required interface.""" + + @pytest.mark.parametrize("element_type", DERIVATION_ELEMENT_TYPES) + def test_has_element_type(self, element_type): + """All derivation classes should have ELEMENT_TYPE attribute.""" + derivation = get_derivation(element_type) + assert hasattr(derivation, "ELEMENT_TYPE") + assert isinstance(derivation.ELEMENT_TYPE, str) + assert derivation.ELEMENT_TYPE == element_type + + @pytest.mark.parametrize("element_type", DERIVATION_ELEMENT_TYPES) + def test_has_generate_method(self, element_type): + """All derivation classes should have generate() method.""" + derivation = get_derivation(element_type) + assert hasattr(derivation, "generate") + assert callable(derivation.generate) + + @pytest.mark.parametrize("element_type", DERIVATION_ELEMENT_TYPES) + def test_has_filter_candidates_method(self, element_type): + """All derivation classes should have filter_candidates() method.""" + derivation = get_derivation(element_type) + assert hasattr(derivation, "filter_candidates") + assert callable(derivation.filter_candidates) + + +class TestGenerateMethod: + """Tests for generate() method across all derivation classes.""" + + @pytest.mark.parametrize("element_type", DERIVATION_ELEMENT_TYPES) + def test_returns_generation_result(self, element_type): + """All derivation classes should return GenerationResult from generate().""" from deriva.modules.derivation.base import GenerationResult - module = get_module(module_name) + derivation = get_derivation(element_type) # Mock the graph manager to return empty results (simplest path) mock_manager = MagicMock() mock_manager.query.return_value = [] - result = module.generate( + result = derivation.generate( graph_manager=mock_manager, archimate_manager=MagicMock(), engine=MagicMock(), @@ -103,8 +90,8 @@ def test_returns_generation_result(self, module_name): assert result.success is True # Empty results should succeed def test_handles_query_exception_application_component(self): - """ApplicationComponent handles query exceptions (only module with try/except around query).""" - module = get_module("application_component") + """ApplicationComponent handles query exceptions.""" + derivation = get_derivation("ApplicationComponent") # First call (enrichments) succeeds, second call (candidates) fails failing_manager = MagicMock() @@ -113,7 +100,7 @@ def test_handles_query_exception_application_component(self): Exception("DB connection error"), # Candidate query fails ] - result = module.generate( + result = derivation.generate( graph_manager=failing_manager, archimate_manager=MagicMock(), engine=MagicMock(), @@ -130,11 +117,10 @@ def test_handles_query_exception_application_component(self): assert len(result.errors) > 0 assert any("error" in e.lower() or "failed" in e.lower() for e in result.errors) - @pytest.mark.parametrize("module_name", DERIVATION_MODULES) - def test_creates_elements_with_valid_llm_response(self, module_name, monkeypatch): - """All derivation modules should create elements when LLM returns valid response.""" - module = get_module(module_name) - element_type = module.ELEMENT_TYPE + @pytest.mark.parametrize("element_type", DERIVATION_ELEMENT_TYPES) + def test_creates_elements_with_valid_llm_response(self, element_type, monkeypatch): + """All derivation classes should create elements when LLM returns valid response.""" + derivation = get_derivation(element_type) # Mock config.get_derivation_patterns to return patterns matching "TestElement" # This ensures PatternBasedDerivation modules don't filter out all candidates @@ -167,9 +153,7 @@ def mock_patterns(_engine, _element_type): }, ] # Order: stats (cache lookup) -> enrichments -> stats (cache store) -> candidates - mock_manager.query.side_effect = [ - stats_results, enrichment_results, stats_results, candidate_results - ] + mock_manager.query.side_effect = [stats_results, enrichment_results, stats_results, candidate_results] # Setup LLM response with valid element mock_llm = MagicMock() @@ -187,7 +171,7 @@ def mock_patterns(_engine, _element_type): mock_archimate = MagicMock() - result = module.generate( + result = derivation.generate( graph_manager=mock_manager, archimate_manager=mock_archimate, engine=MagicMock(), @@ -205,10 +189,10 @@ def mock_patterns(_engine, _element_type): # Should call add_element on archimate manager assert mock_archimate.add_element.called - @pytest.mark.parametrize("module_name", DERIVATION_MODULES) - def test_handles_llm_exception(self, module_name, monkeypatch): - """All derivation modules should handle LLM exceptions gracefully.""" - module = get_module(module_name) + @pytest.mark.parametrize("element_type", DERIVATION_ELEMENT_TYPES) + def test_handles_llm_exception(self, element_type, monkeypatch): + """All derivation classes should handle LLM exceptions gracefully.""" + derivation = get_derivation(element_type) # Mock config.get_derivation_patterns to return patterns matching "Test" # This ensures PatternBasedDerivation modules don't filter out all candidates @@ -234,15 +218,13 @@ def mock_patterns(_engine, _element_type): ] candidate_results = [{"id": "n1", "name": "Test", "labels": [], "properties": {}}] # Order: stats (cache lookup) -> enrichments -> stats (cache store) -> candidates - mock_manager.query.side_effect = [ - stats_results, enrichment_results, stats_results, candidate_results - ] + mock_manager.query.side_effect = [stats_results, enrichment_results, stats_results, candidate_results] # LLM throws exception failing_llm = MagicMock() failing_llm.side_effect = Exception("LLM API error") - result = module.generate( + result = derivation.generate( graph_manager=mock_manager, archimate_manager=MagicMock(), engine=MagicMock(), @@ -264,10 +246,10 @@ def mock_patterns(_engine, _element_type): assert result.success is True assert result.elements_created == 0 - @pytest.mark.parametrize("module_name", DERIVATION_MODULES) - def test_handles_invalid_llm_json(self, module_name, monkeypatch): - """All derivation modules should handle invalid JSON from LLM.""" - module = get_module(module_name) + @pytest.mark.parametrize("element_type", DERIVATION_ELEMENT_TYPES) + def test_handles_invalid_llm_json(self, element_type, monkeypatch): + """All derivation classes should handle invalid JSON from LLM.""" + derivation = get_derivation(element_type) # Mock config.get_derivation_patterns to return patterns matching "Test" # This ensures PatternBasedDerivation modules don't filter out all candidates @@ -293,9 +275,7 @@ def mock_patterns(_engine, _element_type): ] candidate_results = [{"id": "n1", "name": "Test", "labels": [], "properties": {}}] # Order: stats (cache lookup) -> enrichments -> stats (cache store) -> candidates - mock_manager.query.side_effect = [ - stats_results, enrichment_results, stats_results, candidate_results - ] + mock_manager.query.side_effect = [stats_results, enrichment_results, stats_results, candidate_results] # LLM returns invalid JSON invalid_llm = MagicMock() @@ -303,7 +283,7 @@ def mock_patterns(_engine, _element_type): invalid_response.content = "this is not valid json" invalid_llm.return_value = invalid_response - result = module.generate( + result = derivation.generate( graph_manager=mock_manager, archimate_manager=MagicMock(), engine=MagicMock(), diff --git a/tests/test_modules/extraction/test_method.py b/tests/test_modules/extraction/test_method.py new file mode 100644 index 0000000..ea8ee72 --- /dev/null +++ b/tests/test_modules/extraction/test_method.py @@ -0,0 +1,530 @@ +"""Tests for modules.extraction.method module.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from deriva.modules.extraction.method import ( + METHOD_SCHEMA, + build_extraction_prompt, + build_method_node, + extract_methods, + extract_methods_batch, + extract_methods_from_source, + parse_llm_response, +) + + +class TestMethodSchema: + """Tests for METHOD_SCHEMA constant.""" + + def test_schema_has_name(self): + """Should have a name field.""" + assert "name" in METHOD_SCHEMA + assert METHOD_SCHEMA["name"] == "methods_extraction" + + def test_schema_is_strict(self): + """Should be strict mode.""" + assert METHOD_SCHEMA["strict"] is True + + def test_schema_has_methods_array(self): + """Should define methods array in schema.""" + assert "properties" in METHOD_SCHEMA["schema"] # type: ignore[operator] + assert "methods" in METHOD_SCHEMA["schema"]["properties"] # type: ignore[operator] + + def test_methods_item_has_required_fields(self): + """Method items should have required fields.""" + method_schema = METHOD_SCHEMA["schema"]["properties"]["methods"]["items"] # type: ignore[index] + required = method_schema["required"] # type: ignore[index] + + assert "methodName" in required # type: ignore[operator] + assert "returnType" in required # type: ignore[operator] + assert "visibility" in required # type: ignore[operator] + assert "parameters" in required # type: ignore[operator] + + +class TestBuildExtractionPrompt: + """Tests for build_extraction_prompt function.""" + + def test_includes_type_info(self): + """Should include type information in prompt.""" + prompt = build_extraction_prompt( + code_snippet="class MyClass:\n pass", + type_name="MyClass", + type_category="class", + file_path="src/my_file.py", + instruction="Extract all methods", + example='{"methods": []}', + ) + + assert "MyClass" in prompt + assert "class" in prompt + assert "src/my_file.py" in prompt + + def test_includes_line_numbers(self): + """Should add line numbers to code snippet.""" + prompt = build_extraction_prompt( + code_snippet="line1\nline2\nline3", + type_name="Test", + type_category="class", + file_path="test.py", + instruction="", + example="", + ) + + assert "1 |" in prompt + assert "2 |" in prompt + assert "3 |" in prompt + + def test_includes_instruction_and_example(self): + """Should include instruction and example.""" + prompt = build_extraction_prompt( + code_snippet="code", + type_name="Test", + type_category="class", + file_path="test.py", + instruction="Custom instruction here", + example='{"methods": [{"name": "test"}]}', + ) + + assert "Custom instruction here" in prompt + assert '{"methods":' in prompt + + +class TestBuildMethodNode: + """Tests for build_method_node function.""" + + def test_builds_valid_node(self): + """Should build a valid method node.""" + result = build_method_node( + method_data={ + "methodName": "get_user", + "returnType": "User", + "visibility": "public", + "parameters": "self, user_id: int", + "description": "Gets a user", + "isStatic": False, + "isAsync": False, + "startLine": 10, + "endLine": 15, + "confidence": 0.9, + }, + type_name="UserService", + file_path="src/user_service.py", + repo_name="my_repo", + type_start_line=5, + ) + + assert result["success"] is True + assert result["data"]["label"] == "Method" + assert result["data"]["properties"]["methodName"] == "get_user" + assert result["data"]["properties"]["returnType"] == "User" + assert result["data"]["properties"]["visibility"] == "public" + + def test_generates_unique_node_id(self): + """Should generate unique node ID.""" + result = build_method_node( + method_data={ + "methodName": "my_method", + "returnType": "str", + "visibility": "private", + }, + type_name="MyClass", + file_path="src/module.py", + repo_name="test_repo", + ) + + node_id = result["data"]["node_id"] + assert "method_test_repo" in node_id + assert "MyClass" in node_id + assert "my_method" in node_id + + def test_returns_error_for_missing_fields(self): + """Should return error when required fields missing.""" + result = build_method_node( + method_data={"methodName": "test"}, # Missing returnType, visibility + type_name="Test", + file_path="test.py", + repo_name="repo", + ) + + assert result["success"] is False + assert len(result["errors"]) > 0 + assert "returnType" in result["errors"][0] or "visibility" in result["errors"][0] + + def test_normalizes_invalid_visibility(self): + """Should normalize invalid visibility to public.""" + result = build_method_node( + method_data={ + "methodName": "test", + "returnType": "void", + "visibility": "invalid", + }, + type_name="Test", + file_path="test.py", + repo_name="repo", + ) + + assert result["success"] is True + assert result["data"]["properties"]["visibility"] == "public" + + def test_handles_default_values(self): + """Should use default values for optional fields.""" + result = build_method_node( + method_data={ + "methodName": "test", + "returnType": "void", + "visibility": "public", + }, + type_name="Test", + file_path="test.py", + repo_name="repo", + ) + + props = result["data"]["properties"] + assert props["parameters"] == "" + assert props["description"] == "" + assert props["isStatic"] is False + assert props["isAsync"] is False + assert props["confidence"] == 0.8 + + +class TestParseLlmResponse: + """Tests for parse_llm_response function.""" + + def test_parses_valid_json(self): + """Should parse valid JSON response.""" + result = parse_llm_response('{"methods": [{"methodName": "test", "returnType": "void"}]}') + + assert result["success"] is True + assert len(result["data"]) == 1 + assert result["data"][0]["methodName"] == "test" + + def test_returns_error_for_invalid_json(self): + """Should return error for invalid JSON.""" + result = parse_llm_response("not valid json") + + assert result["success"] is False + assert len(result["errors"]) > 0 + + def test_returns_empty_for_empty_methods(self): + """Should return empty list for empty methods array.""" + result = parse_llm_response('{"methods": []}') + + assert result["success"] is True + assert result["data"] == [] + + +class TestExtractMethods: + """Tests for extract_methods function.""" + + def test_skips_empty_code_snippet(self): + """Should skip nodes with empty code snippet.""" + result = extract_methods( + type_node={"properties": {"codeSnippet": ""}}, + repo_name="test", + llm_query_fn=MagicMock(), + config={}, + ) + + assert result["success"] is True + assert result["stats"]["skipped"] == "no_code_snippet" + + def test_calls_llm_with_prompt(self): + """Should call LLM with built prompt.""" + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = '{"methods": []}' + mock_response.usage = {"prompt_tokens": 100, "completion_tokens": 50} + mock_response.response_type = "ResponseType.NEW" + mock_llm.return_value = mock_response + + extract_methods( + type_node={ + "node_id": "type_1", + "properties": { + "typeName": "MyClass", + "category": "class", + "filePath": "test.py", + "codeSnippet": "class MyClass:\n pass", + }, + }, + repo_name="test", + llm_query_fn=mock_llm, + config={"instruction": "Extract methods", "example": "{}"}, + ) + + mock_llm.assert_called_once() + prompt = mock_llm.call_args[0][0] + assert "MyClass" in prompt + + def test_handles_llm_error(self): + """Should handle LLM error response.""" + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.error = "API rate limit exceeded" + mock_llm.return_value = mock_response + + result = extract_methods( + type_node={ + "node_id": "type_1", + "properties": { + "codeSnippet": "class Test: pass", + }, + }, + repo_name="test", + llm_query_fn=mock_llm, + config={}, + ) + + assert result["success"] is False + assert "LLM error" in result["errors"][0] + + def test_handles_exception(self): + """Should handle exception during extraction.""" + mock_llm = MagicMock(side_effect=Exception("Connection error")) + + result = extract_methods( + type_node={ + "node_id": "type_1", + "properties": {"codeSnippet": "class Test: pass"}, + }, + repo_name="test", + llm_query_fn=mock_llm, + config={}, + ) + + assert result["success"] is False + assert "Fatal error" in result["errors"][0] + + def test_creates_contains_edges(self): + """Should create CONTAINS edges from type to methods.""" + mock_llm = MagicMock() + mock_response = MagicMock(spec=["content", "usage", "response_type"]) + mock_response.content = """{"methods": [ + {"methodName": "test", "returnType": "void", "visibility": "public", + "parameters": "", "description": "", "isStatic": false, "isAsync": false, + "startLine": 1, "endLine": 2, "confidence": 0.9} + ]}""" + mock_response.usage = None + mock_response.response_type = "ResponseType.NEW" + mock_llm.return_value = mock_response + + result = extract_methods( + type_node={ + "node_id": "type_test", + "properties": { + "typeName": "Test", + "codeSnippet": "def test(): pass", + }, + }, + repo_name="repo", + llm_query_fn=mock_llm, + config={}, + ) + + assert len(result["data"]["edges"]) == 1 + edge = result["data"]["edges"][0] + assert edge["relationship_type"] == "CONTAINS" + assert edge["from_node_id"] == "type_test" + + +class TestExtractMethodsBatch: + """Tests for extract_methods_batch function.""" + + def test_processes_multiple_types(self): + """Should process multiple type nodes.""" + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = '{"methods": []}' + mock_response.usage = None + mock_llm.return_value = mock_response + + result = extract_methods_batch( + type_nodes=[ + {"node_id": "t1", "properties": {"typeName": "A", "codeSnippet": "code"}}, + {"node_id": "t2", "properties": {"typeName": "B", "codeSnippet": "code"}}, + ], + repo_name="test", + llm_query_fn=mock_llm, + config={}, + ) + + assert result["stats"]["types_processed"] == 2 + assert mock_llm.call_count == 2 + + def test_calls_progress_callback(self): + """Should call progress callback during processing.""" + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = '{"methods": []}' + mock_response.usage = None + mock_llm.return_value = mock_response + + progress_calls = [] + + def progress_cb(current, total, name): + progress_calls.append((current, total, name)) + + extract_methods_batch( + type_nodes=[ + {"node_id": "t1", "properties": {"typeName": "TypeA", "codeSnippet": "c"}}, + ], + repo_name="test", + llm_query_fn=mock_llm, + config={}, + progress_callback=progress_cb, + ) + + assert len(progress_calls) == 1 + assert progress_calls[0] == (1, 1, "TypeA") + + def test_aggregates_results(self): + """Should aggregate results from all types.""" + mock_llm = MagicMock() + mock_response = MagicMock(spec=["content", "usage", "response_type"]) + mock_response.content = """{"methods": [ + {"methodName": "m1", "returnType": "void", "visibility": "public", + "parameters": "", "description": "", "isStatic": false, "isAsync": false, + "startLine": 1, "endLine": 2, "confidence": 0.9} + ]}""" + mock_response.usage = None + mock_response.response_type = "ResponseType.NEW" + mock_llm.return_value = mock_response + + result = extract_methods_batch( + type_nodes=[ + {"node_id": "t1", "properties": {"typeName": "A", "codeSnippet": "code"}}, + {"node_id": "t2", "properties": {"typeName": "B", "codeSnippet": "code"}}, + ], + repo_name="test", + llm_query_fn=mock_llm, + config={}, + ) + + # Each type produces 1 method + assert result["stats"]["total_nodes"] == 2 + assert result["stats"]["types_with_methods"] == 2 + + def test_includes_type_results(self): + """Should include per-type results for logging.""" + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = '{"methods": []}' + mock_response.usage = None + mock_llm.return_value = mock_response + + result = extract_methods_batch( + type_nodes=[ + {"node_id": "t1", "properties": {"typeName": "A", "codeSnippet": "c"}}, + ], + repo_name="test", + llm_query_fn=mock_llm, + config={}, + ) + + assert "type_results" in result + assert len(result["type_results"]) == 1 + assert result["type_results"][0]["type_name"] == "A" + + +class TestExtractMethodsFromSource: + """Tests for extract_methods_from_source function.""" + + def test_extracts_python_methods(self): + """Should extract methods from Python source.""" + source = ''' +class UserService: + def get_user(self, user_id: int) -> User: + """Get a user by ID.""" + pass + + def _private_method(self): + pass +''' + result = extract_methods_from_source( + file_path="src/user_service.py", + file_content=source, + repo_name="test_repo", + ) + + assert result["success"] is True + assert result["stats"]["total_nodes"] >= 2 + method_names = [n["properties"]["methodName"] for n in result["data"]["nodes"]] + assert "get_user" in method_names + assert "_private_method" in method_names + + def test_handles_top_level_functions(self): + """Should extract top-level functions.""" + source = """ +def standalone_function(): + pass +""" + result = extract_methods_from_source( + file_path="utils.py", + file_content=source, + repo_name="test", + ) + + assert result["success"] is True + assert len(result["data"]["nodes"]) >= 1 + + def test_determines_visibility(self): + """Should determine visibility from naming conventions.""" + source = """ +class Test: + def public_method(self): pass + def _private_method(self): pass + def __protected_method(self): pass +""" + result = extract_methods_from_source( + file_path="test.py", + file_content=source, + repo_name="test", + ) + + nodes = result["data"]["nodes"] + visibilities = {n["properties"]["methodName"]: n["properties"]["visibility"] for n in nodes} + + assert visibilities.get("public_method") == "public" + assert visibilities.get("_private_method") == "private" + assert visibilities.get("__protected_method") == "protected" + + def test_handles_syntax_error(self): + """Should handle syntax errors gracefully.""" + result = extract_methods_from_source( + file_path="bad.py", + file_content="def broken(", + repo_name="test", + ) + + # May or may not succeed depending on tree-sitter tolerance + # Just ensure it doesn't crash + assert "success" in result + + def test_creates_contains_edges(self): + """Should create CONTAINS edges.""" + source = """ +class MyClass: + def my_method(self): + pass +""" + result = extract_methods_from_source( + file_path="test.py", + file_content=source, + repo_name="repo", + ) + + assert len(result["data"]["edges"]) >= 1 + edge = result["data"]["edges"][0] + assert edge["relationship_type"] == "CONTAINS" + + def test_extraction_method_is_treesitter(self): + """Should set extraction_method to treesitter.""" + result = extract_methods_from_source( + file_path="test.py", + file_content="def f(): pass", + repo_name="test", + ) + + assert result["stats"]["extraction_method"] == "treesitter" diff --git a/tests/test_services/test_config.py b/tests/test_services/test_config.py index 5c275cc..95ad1e8 100644 --- a/tests/test_services/test_config.py +++ b/tests/test_services/test_config.py @@ -868,3 +868,783 @@ def test_passes_phase_for_derivation(self): list_steps(engine, "derivation", phase="prep") mock_fn.assert_called_once() assert mock_fn.call_args[1].get("phase") == "prep" + + +class TestGetSettings: + """Tests for get_settings function.""" + + def test_returns_cached_settings(self): + """Should return cached DerivaSettings instance.""" + from deriva.services.config import get_settings + + # Clear cache for testing + get_settings.cache_clear() + + settings = get_settings() + assert settings is not None + # Should return same instance on subsequent calls + settings2 = get_settings() + assert settings is settings2 + + +class TestUpdateExtractionConfigTemperatureAndTokens: + """Tests for update_extraction_config with temperature and max_tokens.""" + + def test_updates_temperature_field(self): + """Should update temperature field.""" + engine = MagicMock() + engine.execute.return_value.rowcount = 1 + + result = update_extraction_config(engine, "BusinessConcept", temperature=0.7) + + assert result is True + call_args = engine.execute.call_args[0][0] + assert "temperature = ?" in call_args + + def test_updates_max_tokens_field(self): + """Should update max_tokens field.""" + engine = MagicMock() + engine.execute.return_value.rowcount = 1 + + result = update_extraction_config(engine, "BusinessConcept", max_tokens=4096) + + assert result is True + call_args = engine.execute.call_args[0][0] + assert "max_tokens = ?" in call_args + + +class TestUpdateDerivationConfigTemperatureAndTokens: + """Tests for update_derivation_config with temperature and max_tokens.""" + + def test_updates_input_model_query(self): + """Should update input_model_query field.""" + engine = MagicMock() + engine.execute.return_value.rowcount = 1 + + result = update_derivation_config( + engine, + "ApplicationComponent", + input_model_query="MATCH (n:Model) RETURN n", + ) + + assert result is True + call_args = engine.execute.call_args[0][0] + assert "input_model_query = ?" in call_args + + def test_updates_temperature_field(self): + """Should update temperature field.""" + engine = MagicMock() + engine.execute.return_value.rowcount = 1 + + result = update_derivation_config( + engine, + "ApplicationComponent", + temperature=0.5, + ) + + assert result is True + call_args = engine.execute.call_args[0][0] + assert "temperature = ?" in call_args + + def test_updates_max_tokens_field(self): + """Should update max_tokens field.""" + engine = MagicMock() + engine.execute.return_value.rowcount = 1 + + result = update_derivation_config( + engine, + "ApplicationComponent", + max_tokens=8192, + ) + + assert result is True + call_args = engine.execute.call_args[0][0] + assert "max_tokens = ?" in call_args + + +class TestCreateDerivationConfigVersion: + """Tests for create_derivation_config_version function.""" + + def test_creates_new_version(self): + """Should create new version with incremented version number.""" + from deriva.services.config import create_derivation_config_version + + engine = MagicMock() + # Current config: (id, version, phase, sequence, enabled, llm, graph_query, model_query, instruction, example, params, temperature, max_tokens) + engine.execute.return_value.fetchone.side_effect = [ + (1, 1, "generate", 1, True, True, "MATCH (n)", None, "instruction", "example", None, 0.7, 4096), # Current config + (2,), # Next ID + ] + + result = create_derivation_config_version( + engine, + "ApplicationComponent", + instruction="New instruction", + ) + + assert result["success"] is True + assert result["step_name"] == "ApplicationComponent" + assert result["old_version"] == 1 + assert result["new_version"] == 2 + + def test_returns_error_when_not_found(self): + """Should return error when config not found.""" + from deriva.services.config import create_derivation_config_version + + engine = MagicMock() + engine.execute.return_value.fetchone.return_value = None + + result = create_derivation_config_version(engine, "UnknownStep") + + assert result["success"] is False + assert "Config not found" in result["error"] + + def test_preserves_existing_values(self): + """Should preserve existing values when not specified.""" + from deriva.services.config import create_derivation_config_version + + engine = MagicMock() + engine.execute.return_value.fetchone.side_effect = [ + (1, 1, "generate", 5, True, True, "OLD_QUERY", "OLD_MODEL", "old_instruction", "old_example", '{"key": "value"}', 0.5, 2000), + (2,), + ] + + result = create_derivation_config_version( + engine, + "ApplicationComponent", + enabled=False, # Only change enabled + ) + + assert result["success"] is True + # Verify INSERT preserves old values + insert_call = [c for c in engine.execute.call_args_list if "INSERT INTO derivation_config" in str(c)] + assert len(insert_call) > 0 + + +class TestCreateExtractionConfigVersion: + """Tests for create_extraction_config_version function.""" + + def test_creates_new_version(self): + """Should create new version with incremented version number.""" + from deriva.services.config import create_extraction_config_version + + engine = MagicMock() + # Current config: (id, version, sequence, enabled, input_sources, instruction, example, temperature, max_tokens) + engine.execute.return_value.fetchone.side_effect = [ + (1, 1, 1, True, '{"files": []}', "instruction", "example", 0.7, 4096), + (2,), # Next ID + ] + + result = create_extraction_config_version( + engine, + "BusinessConcept", + instruction="New instruction", + ) + + assert result["success"] is True + assert result["node_type"] == "BusinessConcept" + assert result["old_version"] == 1 + assert result["new_version"] == 2 + + def test_returns_error_when_not_found(self): + """Should return error when config not found.""" + from deriva.services.config import create_extraction_config_version + + engine = MagicMock() + engine.execute.return_value.fetchone.return_value = None + + result = create_extraction_config_version(engine, "UnknownType") + + assert result["success"] is False + assert "Config not found" in result["error"] + + +class TestGetActiveConfigVersions: + """Tests for get_active_config_versions function.""" + + def test_returns_all_active_versions(self): + """Should return all active config versions.""" + from deriva.services.config import get_active_config_versions + + engine = MagicMock() + # Mock two separate fetchall calls + engine.execute.return_value.fetchall.side_effect = [ + [("BusinessConcept", 3), ("TypeDefinition", 1)], # Extraction + [("ApplicationComponent", 2), ("PageRank", 1)], # Derivation + ] + + result = get_active_config_versions(engine) + + assert result["extraction"]["BusinessConcept"] == 3 + assert result["extraction"]["TypeDefinition"] == 1 + assert result["derivation"]["ApplicationComponent"] == 2 + assert result["derivation"]["PageRank"] == 1 + + +class TestLogConsistencyRun: + """Tests for log_consistency_run function.""" + + def test_logs_run_and_returns_id(self): + """Should log run and return new ID.""" + from deriva.services.config import log_consistency_run + + engine = MagicMock() + engine.execute.return_value.fetchone.return_value = (1,) # Next ID + + result = log_consistency_run( + engine, + repo_name="test_repo", + num_runs=5, + results={ + "name_consistency": 0.95, + "identifier_consistency": 0.92, + "count_variance": 0.05, + "stable_count": 10, + "total_unique": 12, + }, + config_versions={"extraction": {"BusinessConcept": 1}}, + ) + + assert result == 1 + # Verify table creation + create_calls = [c for c in engine.execute.call_args_list if "CREATE TABLE" in str(c)] + assert len(create_calls) > 0 + + +class TestGetConsistencyHistory: + """Tests for get_consistency_history function.""" + + def test_returns_history_records(self): + """Should return history records.""" + from deriva.services.config import get_consistency_history + + engine = MagicMock() + # First call checks table existence, second gets data + engine.execute.return_value.fetchall.side_effect = [ + [("consistency_runs",)], # Table exists + [ + (1, "test_repo", 5, 0.95, 0.92, 0.05, 10, 12, '{"extraction": {}}', "2024-01-01 00:00:00"), + ], + ] + + result = get_consistency_history(engine, repo_name="test_repo") + + assert len(result) == 1 + assert result[0]["repo_name"] == "test_repo" + assert result[0]["name_consistency"] == 0.95 + + def test_returns_empty_when_table_not_exists(self): + """Should return empty list when table doesn't exist.""" + from deriva.services.config import get_consistency_history + + engine = MagicMock() + engine.execute.return_value.fetchall.return_value = [] + + result = get_consistency_history(engine) + + assert result == [] + + +class TestGetDerivationPatterns: + """Tests for get_derivation_patterns function.""" + + def test_returns_patterns(self): + """Should return patterns grouped by type.""" + from deriva.services.config import get_derivation_patterns + + engine = MagicMock() + engine.execute.return_value.fetchall.return_value = [ + ("include", '["get", "post", "put"]'), + ("exclude", '["_", "private"]'), + ] + + result = get_derivation_patterns(engine, "ApplicationService") + + assert "get" in result["include"] + assert "post" in result["include"] + assert "_" in result["exclude"] + + def test_raises_when_not_found(self): + """Should raise ValueError when no patterns found.""" + import pytest + + from deriva.services.config import get_derivation_patterns + + engine = MagicMock() + engine.execute.return_value.fetchall.return_value = [] + + with pytest.raises(ValueError, match="No patterns found"): + get_derivation_patterns(engine, "UnknownStep") + + +class TestGetIncludePatterns: + """Tests for get_include_patterns function.""" + + def test_returns_include_patterns_only(self): + """Should return only include patterns.""" + from unittest.mock import patch + + from deriva.services.config import get_include_patterns + + engine = MagicMock() + + with patch( + "deriva.services.config.get_derivation_patterns", + return_value={"include": {"get", "post"}, "exclude": set()}, + ): + result = get_include_patterns(engine, "ApplicationService") + + assert "get" in result + assert "post" in result + + +class TestGetExcludePatterns: + """Tests for get_exclude_patterns function.""" + + def test_returns_exclude_patterns_only(self): + """Should return only exclude patterns.""" + from unittest.mock import patch + + from deriva.services.config import get_exclude_patterns + + engine = MagicMock() + + with patch( + "deriva.services.config.get_derivation_patterns", + return_value={"include": set(), "exclude": {"_", "private"}}, + ): + result = get_exclude_patterns(engine, "ApplicationService") + + assert "_" in result + assert "private" in result + + +class TestUpdateDerivationPatterns: + """Tests for update_derivation_patterns function.""" + + def test_updates_existing_patterns(self): + """Should update existing patterns.""" + from deriva.services.config import update_derivation_patterns + + engine = MagicMock() + engine.execute.return_value.rowcount = 1 + + result = update_derivation_patterns( + engine, + "ApplicationService", + "include", + "http_methods", + ["get", "post", "put", "delete"], + ) + + assert result is True + + def test_inserts_when_not_exists(self): + """Should insert new patterns when not exists.""" + from deriva.services.config import update_derivation_patterns + + engine = MagicMock() + engine.execute.return_value.rowcount = 0 + engine.execute.return_value.fetchone.return_value = (1,) + + result = update_derivation_patterns( + engine, + "NewStep", + "include", + "new_category", + ["pattern1", "pattern2"], + ) + + assert result is True + + +class TestThresholdHelpers: + """Tests for threshold helper functions.""" + + def test_get_confidence_threshold_from_settings(self): + """Should get threshold from system_settings.""" + from unittest.mock import patch + + from deriva.services.config import get_confidence_threshold + + engine = MagicMock() + + with patch( + "deriva.services.config.get_setting", + return_value="0.75", + ): + result = get_confidence_threshold(engine, "min_relationship") + + assert result == 0.75 + + def test_get_confidence_threshold_uses_default(self): + """Should use default when setting not found.""" + from unittest.mock import patch + + from deriva.services.config import get_confidence_threshold + + engine = MagicMock() + + with patch( + "deriva.services.config.get_setting", + return_value=None, + ): + result = get_confidence_threshold(engine, "min_relationship") + + assert result == 0.6 # Default value + + def test_get_confidence_threshold_handles_invalid_value(self): + """Should handle invalid value and use default.""" + from unittest.mock import patch + + from deriva.services.config import get_confidence_threshold + + engine = MagicMock() + + with patch( + "deriva.services.config.get_setting", + return_value="not_a_number", + ): + result = get_confidence_threshold(engine, "min_relationship") + + assert result == 0.6 # Default value + + def test_get_derivation_limit_from_settings(self): + """Should get limit from system_settings.""" + from unittest.mock import patch + + from deriva.services.config import get_derivation_limit + + engine = MagicMock() + + with patch( + "deriva.services.config.get_setting", + return_value="50", + ): + result = get_derivation_limit(engine, "default_batch_size") + + assert result == 50 + + def test_get_derivation_limit_uses_default(self): + """Should use default when setting not found.""" + from unittest.mock import patch + + from deriva.services.config import get_derivation_limit + + engine = MagicMock() + + with patch( + "deriva.services.config.get_setting", + return_value=None, + ): + result = get_derivation_limit(engine, "default_batch_size") + + assert result == 10 # Default value + + +class TestSpecificThresholdHelpers: + """Tests for specific threshold helper functions.""" + + def test_get_min_relationship_confidence(self): + """Should get min relationship confidence.""" + from unittest.mock import patch + + from deriva.services.config import get_min_relationship_confidence + + engine = MagicMock() + + with patch( + "deriva.services.config.get_confidence_threshold", + return_value=0.65, + ): + result = get_min_relationship_confidence(engine) + + assert result == 0.65 + + def test_get_community_rel_confidence(self): + """Should get community relationship confidence.""" + from unittest.mock import patch + + from deriva.services.config import get_community_rel_confidence + + engine = MagicMock() + + with patch( + "deriva.services.config.get_confidence_threshold", + return_value=0.95, + ): + result = get_community_rel_confidence(engine) + + assert result == 0.95 + + def test_get_name_match_confidence(self): + """Should get name match confidence.""" + from unittest.mock import patch + + from deriva.services.config import get_name_match_confidence + + engine = MagicMock() + + with patch( + "deriva.services.config.get_confidence_threshold", + return_value=0.95, + ): + result = get_name_match_confidence(engine) + + assert result == 0.95 + + def test_get_file_match_confidence(self): + """Should get file match confidence.""" + from unittest.mock import patch + + from deriva.services.config import get_file_match_confidence + + engine = MagicMock() + + with patch( + "deriva.services.config.get_confidence_threshold", + return_value=0.85, + ): + result = get_file_match_confidence(engine) + + assert result == 0.85 + + def test_get_fuzzy_match_threshold(self): + """Should get fuzzy match threshold.""" + from unittest.mock import patch + + from deriva.services.config import get_fuzzy_match_threshold + + engine = MagicMock() + + with patch( + "deriva.services.config.get_confidence_threshold", + return_value=0.85, + ): + result = get_fuzzy_match_threshold(engine) + + assert result == 0.85 + + def test_get_semantic_confidence(self): + """Should get semantic confidence.""" + from unittest.mock import patch + + from deriva.services.config import get_semantic_confidence + + engine = MagicMock() + + with patch( + "deriva.services.config.get_confidence_threshold", + return_value=0.95, + ): + result = get_semantic_confidence(engine) + + assert result == 0.95 + + def test_get_pagerank_min(self): + """Should get pagerank minimum threshold.""" + from unittest.mock import patch + + from deriva.services.config import get_pagerank_min + + engine = MagicMock() + + with patch( + "deriva.services.config.get_confidence_threshold", + return_value=0.001, + ): + result = get_pagerank_min(engine) + + assert result == 0.001 + + +class TestSpecificLimitHelpers: + """Tests for specific limit helper functions.""" + + def test_get_max_batch_size(self): + """Should get max batch size.""" + from unittest.mock import patch + + from deriva.services.config import get_max_batch_size + + engine = MagicMock() + + with patch( + "deriva.services.config.get_derivation_limit", + return_value=10, + ): + result = get_max_batch_size(engine) + + assert result == 10 + + def test_get_max_candidates(self): + """Should get max candidates.""" + from unittest.mock import patch + + from deriva.services.config import get_max_candidates + + engine = MagicMock() + + with patch( + "deriva.services.config.get_derivation_limit", + return_value=30, + ): + result = get_max_candidates(engine) + + assert result == 30 + + def test_get_max_relationships_per_derivation(self): + """Should get max relationships per derivation.""" + from unittest.mock import patch + + from deriva.services.config import get_max_relationships_per_derivation + + engine = MagicMock() + + with patch( + "deriva.services.config.get_derivation_limit", + return_value=500, + ): + result = get_max_relationships_per_derivation(engine) + + assert result == 500 + + +class TestAlgorithmSettingsHelpers: + """Tests for algorithm settings helper functions.""" + + def test_get_algorithm_setting_from_settings(self): + """Should get algorithm setting from system_settings.""" + from unittest.mock import patch + + from deriva.services.config import get_algorithm_setting + + engine = MagicMock() + + with patch( + "deriva.services.config.get_setting", + return_value="0.9", + ): + result = get_algorithm_setting(engine, "algorithm_pagerank_damping") + + assert result == "0.9" + + def test_get_algorithm_setting_uses_default(self): + """Should use default when setting not found.""" + from unittest.mock import patch + + from deriva.services.config import get_algorithm_setting + + engine = MagicMock() + + with patch( + "deriva.services.config.get_setting", + return_value=None, + ): + result = get_algorithm_setting(engine, "algorithm_pagerank_damping") + + assert result == "0.85" # Default value + + def test_get_algorithm_setting_float(self): + """Should get algorithm setting as float.""" + from unittest.mock import patch + + from deriva.services.config import get_algorithm_setting_float + + engine = MagicMock() + + with patch( + "deriva.services.config.get_algorithm_setting", + return_value="0.85", + ): + result = get_algorithm_setting_float(engine, "algorithm_pagerank_damping") + + assert result == 0.85 + + def test_get_algorithm_setting_float_handles_invalid(self): + """Should handle invalid float value.""" + from unittest.mock import patch + + from deriva.services.config import get_algorithm_setting_float + + engine = MagicMock() + + with patch( + "deriva.services.config.get_algorithm_setting", + return_value="not_a_number", + ): + result = get_algorithm_setting_float(engine, "key", default=0.5) + + assert result == 0.5 + + def test_get_algorithm_setting_int(self): + """Should get algorithm setting as int.""" + from unittest.mock import patch + + from deriva.services.config import get_algorithm_setting_int + + engine = MagicMock() + + with patch( + "deriva.services.config.get_algorithm_setting", + return_value="100", + ): + result = get_algorithm_setting_int(engine, "algorithm_pagerank_max_iter") + + assert result == 100 + + def test_get_algorithm_setting_int_handles_invalid(self): + """Should handle invalid int value.""" + from unittest.mock import patch + + from deriva.services.config import get_algorithm_setting_int + + engine = MagicMock() + + with patch( + "deriva.services.config.get_algorithm_setting", + return_value="not_a_number", + ): + result = get_algorithm_setting_int(engine, "key", default=50) + + assert result == 50 + + def test_get_pagerank_config(self): + """Should get pagerank configuration.""" + from unittest.mock import patch + + from deriva.services.config import get_pagerank_config + + engine = MagicMock() + + with ( + patch( + "deriva.services.config.get_algorithm_setting_float", + side_effect=[0.85, 1e-6], + ), + patch( + "deriva.services.config.get_algorithm_setting_int", + return_value=100, + ), + ): + result = get_pagerank_config(engine) + + assert result["damping"] == 0.85 + assert result["max_iter"] == 100 + assert result["tol"] == 1e-6 + + def test_get_louvain_config(self): + """Should get louvain configuration.""" + from unittest.mock import patch + + from deriva.services.config import get_louvain_config + + engine = MagicMock() + + with patch( + "deriva.services.config.get_algorithm_setting_float", + return_value=1.0, + ): + result = get_louvain_config(engine) + + assert result["resolution"] == 1.0 diff --git a/tests/test_services/test_config_models.py b/tests/test_services/test_config_models.py new file mode 100644 index 0000000..cdddb45 --- /dev/null +++ b/tests/test_services/test_config_models.py @@ -0,0 +1,238 @@ +"""Tests for config_models module (pydantic-settings integration).""" + +from typing import Any + +import pytest +from pydantic import ValidationError + +from deriva.services.config_models import ( + BenchmarkModelConfigModel, + ConfidenceThresholds, + DerivaSettings, + DerivationConfigModel, + DerivationLimits, + ExtractionConfigModel, + FileTypeModel, + LLMSettings, + LouvainConfig, + Neo4jSettings, + PageRankConfig, +) + +# Type alias to help with BaseSettings._env_file parameter which isn't in the type signature +_Neo4jSettings: Any = Neo4jSettings +_LLMSettings: Any = LLMSettings +_DerivaSettings: Any = DerivaSettings + + +class TestNeo4jSettings: + """Tests for Neo4jSettings.""" + + def test_default_values(self): + """Should have sensible defaults.""" + settings = _Neo4jSettings(_env_file=None) + assert settings.uri == "bolt://localhost:7687" + assert settings.database == "neo4j" + assert settings.encrypted is False + assert settings.max_connection_pool_size == 50 + + def test_loads_from_env(self, monkeypatch): + """Should load values from environment.""" + monkeypatch.setenv("NEO4J_URI", "bolt://custom:7687") + monkeypatch.setenv("NEO4J_DATABASE", "test_db") + settings = _Neo4jSettings(_env_file=None) + assert settings.uri == "bolt://custom:7687" + assert settings.database == "test_db" + + +class TestLLMSettings: + """Tests for LLMSettings.""" + + def test_default_values(self): + """Should have sensible defaults.""" + settings = _LLMSettings(_env_file=None) + assert settings.temperature == 0.6 + assert settings.max_retries == 3 + assert settings.timeout == 60 + assert settings.nocache is False + + def test_temperature_validation(self, monkeypatch): + """Should validate temperature range.""" + monkeypatch.setenv("LLM_TEMPERATURE", "1.5") + settings = _LLMSettings(_env_file=None) + assert settings.temperature == 1.5 + + def test_loads_from_env(self, monkeypatch): + """Should load values from environment.""" + monkeypatch.setenv("LLM_DEFAULT_MODEL", "test-model") + monkeypatch.setenv("LLM_TEMPERATURE", "0.3") + settings = _LLMSettings(_env_file=None) + assert settings.default_model == "test-model" + assert settings.temperature == 0.3 + + +class TestDerivaSettings: + """Tests for DerivaSettings master class.""" + + def test_default_values(self): + """Should have sensible defaults.""" + settings = _DerivaSettings(_env_file=None) + assert settings.repository_workspace_dir == "workspace/repositories" + + def test_nested_settings(self): + """Should provide access to nested settings.""" + settings = _DerivaSettings(_env_file=None) + assert settings.neo4j.uri == "bolt://localhost:7687" + assert settings.llm.temperature == 0.6 + assert settings.graph.namespace == "Graph" + + +class TestExtractionConfigModel: + """Tests for ExtractionConfigModel.""" + + def test_valid_config(self): + """Should create valid config.""" + config = ExtractionConfigModel(node_type="BusinessConcept") + assert config.node_type == "BusinessConcept" + assert config.enabled is True + assert config.extraction_method == "llm" + + def test_requires_node_type(self): + """Should require node_type.""" + with pytest.raises(ValidationError): + ExtractionConfigModel() + + def test_validates_temperature_range(self): + """Should validate temperature range.""" + with pytest.raises(ValidationError): + ExtractionConfigModel(node_type="Test", temperature=3.0) + + def test_validates_extraction_method(self): + """Should validate extraction_method.""" + config = ExtractionConfigModel(node_type="Test", extraction_method="ast") + assert config.extraction_method == "ast" + + with pytest.raises(ValidationError): + ExtractionConfigModel(node_type="Test", extraction_method="invalid") # type: ignore[arg-type] + + +class TestDerivationConfigModel: + """Tests for DerivationConfigModel.""" + + def test_valid_config(self): + """Should create valid config.""" + config = DerivationConfigModel(step_name="ApplicationService", phase="generate") + assert config.step_name == "ApplicationService" + assert config.phase == "generate" + assert config.enabled is True + + def test_element_type_backward_compat(self): + """Should provide element_type alias.""" + config = DerivationConfigModel(step_name="AppComp", phase="generate") + assert config.element_type == "AppComp" + + def test_validates_phase(self): + """Should validate phase values.""" + with pytest.raises(ValidationError): + DerivationConfigModel(step_name="Test", phase="invalid_phase") # type: ignore[arg-type] + + +class TestFileTypeModel: + """Tests for FileTypeModel.""" + + def test_valid_config(self): + """Should create valid config.""" + ft = FileTypeModel(extension=".py", file_type="code", subtype="python") + assert ft.extension == ".py" + assert ft.chunk_overlap == 0 + + def test_with_chunking(self): + """Should accept chunking config.""" + ft = FileTypeModel( + extension=".md", + file_type="doc", + subtype="markdown", + chunk_delimiter="\n\n", + chunk_max_tokens=1000, + chunk_overlap=50, + ) + assert ft.chunk_max_tokens == 1000 + assert ft.chunk_overlap == 50 + + +class TestConfidenceThresholds: + """Tests for ConfidenceThresholds.""" + + def test_default_values(self): + """Should have sensible defaults.""" + thresholds = ConfidenceThresholds() + assert thresholds.min_relationship == 0.6 + assert thresholds.community_rel == 0.95 + assert thresholds.name_match == 0.95 + + def test_validates_range(self): + """Should validate threshold range.""" + with pytest.raises(ValidationError): + ConfidenceThresholds(min_relationship=1.5) + + +class TestDerivationLimits: + """Tests for DerivationLimits.""" + + def test_default_values(self): + """Should have sensible defaults.""" + limits = DerivationLimits() + assert limits.max_relationships_per_derivation == 500 + assert limits.default_batch_size == 10 + + +class TestPageRankConfig: + """Tests for PageRankConfig.""" + + def test_default_values(self): + """Should have sensible defaults.""" + config = PageRankConfig() + assert config.damping == 0.85 + assert config.max_iter == 100 + assert config.tol == 1e-6 + + +class TestLouvainConfig: + """Tests for LouvainConfig.""" + + def test_default_values(self): + """Should have sensible defaults.""" + config = LouvainConfig() + assert config.resolution == 1.0 + + +class TestBenchmarkModelConfigModel: + """Tests for BenchmarkModelConfigModel.""" + + def test_valid_config(self): + """Should create valid config.""" + config = BenchmarkModelConfigModel( + name="test-model", + provider="openai", + model="gpt-4", + ) + assert config.name == "test-model" + assert config.provider == "openai" + + def test_validates_provider(self): + """Should validate provider.""" + with pytest.raises(ValidationError): + BenchmarkModelConfigModel( + name="test", + provider="invalid_provider", # type: ignore[arg-type] + model="model", + ) + + def test_normalizes_provider_case(self): + """Should normalize provider to lowercase.""" + config = BenchmarkModelConfigModel( + name="test", + provider="OpenAI", # type: ignore[arg-type] # Tests case normalization + model="gpt-4", + ) + assert config.provider == "openai" diff --git a/tests/test_services/test_derivation.py b/tests/test_services/test_derivation.py index 0dc688d..c664354 100644 --- a/tests/test_services/test_derivation.py +++ b/tests/test_services/test_derivation.py @@ -4,38 +4,70 @@ from unittest.mock import MagicMock, patch +import pytest + from deriva.modules.derivation.prep import EnrichmentResult from deriva.services import derivation - -class TestLoadElementModule: - """Tests for _load_element_module function.""" - - def test_loads_application_component_module(self): - """Should load ApplicationComponent module.""" - module = derivation._load_element_module("ApplicationComponent") - assert module is not None - assert hasattr(module, "generate") - - def test_loads_business_object_module(self): - """Should load BusinessObject module.""" - module = derivation._load_element_module("BusinessObject") - assert module is not None - - def test_caches_loaded_modules(self): - """Should cache loaded modules.""" - # Clear cache first - derivation._ELEMENT_MODULES.clear() - - module1 = derivation._load_element_module("Node") - module2 = derivation._load_element_module("Node") - - assert module1 is module2 +# ============================================================================= +# FIXTURES +# ============================================================================= + + +@pytest.fixture +def clean_instance_cache(): + """Fixture that clears and restores the instance cache for isolated testing.""" + original_cache = derivation._DERIVATION_INSTANCES.copy() + derivation._DERIVATION_INSTANCES.clear() + yield + derivation._DERIVATION_INSTANCES.clear() + derivation._DERIVATION_INSTANCES.update(original_cache) + + +# ============================================================================= +# ELEMENT MODULE LOADING TESTS +# ============================================================================= + +# All known element types that should have generation modules +ELEMENT_TYPES_WITH_MODULES = [ + "ApplicationComponent", + "ApplicationInterface", + "ApplicationService", + "BusinessActor", + "BusinessEvent", + "BusinessFunction", + "BusinessObject", + "BusinessProcess", + "DataObject", + "Device", + "Node", + "SystemSoftware", + "TechnologyService", +] + + +class TestDerivationRegistry: + """Tests for derivation registry and _get_derivation function.""" + + @pytest.mark.parametrize("element_type", ELEMENT_TYPES_WITH_MODULES) + def test_gets_derivation_instance(self, element_type): + """Should get derivation instance with generate method.""" + instance = derivation._get_derivation(element_type) + assert instance is not None, f"Derivation for {element_type} should exist" + assert hasattr(instance, "generate"), f"Derivation for {element_type} should have generate method" + + def test_caches_derivation_instances(self, clean_instance_cache): + """Should cache derivation instances.""" + instance1 = derivation._get_derivation("Node") + instance2 = derivation._get_derivation("Node") + + assert instance1 is instance2 + assert "Node" in derivation._DERIVATION_INSTANCES def test_returns_none_for_unknown_type(self): """Should return None for unknown element type.""" - module = derivation._load_element_module("UnknownType") - assert module is None + instance = derivation._get_derivation("UnknownType") + assert instance is None class TestGetGraphEdges: @@ -191,9 +223,9 @@ def test_returns_generation_result_as_dict(self): archimate_manager = MagicMock() engine = MagicMock() - with patch.object(derivation, "_load_element_module") as mock_load: - mock_module = MagicMock() - mock_module.generate.return_value = GenerationResult( + with patch.object(derivation, "_get_derivation") as mock_get: + mock_derivation = MagicMock() + mock_derivation.generate.return_value = GenerationResult( success=True, elements_created=3, relationships_created=5, @@ -201,7 +233,7 @@ def test_returns_generation_result_as_dict(self): created_relationships=[{"id": "r1"}], errors=[], ) - mock_load.return_value = mock_module + mock_get.return_value = mock_derivation result = derivation.generate_element( graph_manager=graph_manager, @@ -223,7 +255,7 @@ def test_returns_generation_result_as_dict(self): assert result["relationships_created"] == 5 def test_returns_error_for_unknown_element_type(self): - """Should return error when module not found.""" + """Should return error when derivation not found.""" result = derivation.generate_element( graph_manager=MagicMock(), archimate_manager=MagicMock(), @@ -238,14 +270,14 @@ def test_returns_error_for_unknown_element_type(self): ) assert result["success"] is False - assert "No generation module" in result["errors"][0] + assert "No derivation class" in result["errors"][0] def test_handles_generation_exception(self): """Should handle exception during generation.""" - with patch.object(derivation, "_load_element_module") as mock_load: - mock_module = MagicMock() - mock_module.generate.side_effect = Exception("LLM failed") - mock_load.return_value = mock_module + with patch.object(derivation, "_get_derivation") as mock_get: + mock_derivation = MagicMock() + mock_derivation.generate.side_effect = Exception("LLM failed") + mock_get.return_value = mock_derivation result = derivation.generate_element( graph_manager=MagicMock(), @@ -274,61 +306,6 @@ def test_contains_expected_algorithms(self): assert "k_core_filter" in derivation.ENRICHMENT_ALGORITHMS -class TestLoadAllElementModules: - """Tests for loading all element type modules.""" - - def test_loads_business_process(self): - """Should load BusinessProcess module.""" - module = derivation._load_element_module("BusinessProcess") - assert module is not None - assert hasattr(module, "generate") - - def test_loads_business_actor(self): - """Should load BusinessActor module.""" - module = derivation._load_element_module("BusinessActor") - assert module is not None - - def test_loads_business_event(self): - """Should load BusinessEvent module.""" - module = derivation._load_element_module("BusinessEvent") - assert module is not None - - def test_loads_business_function(self): - """Should load BusinessFunction module.""" - module = derivation._load_element_module("BusinessFunction") - assert module is not None - - def test_loads_application_service(self): - """Should load ApplicationService module.""" - module = derivation._load_element_module("ApplicationService") - assert module is not None - - def test_loads_application_interface(self): - """Should load ApplicationInterface module.""" - module = derivation._load_element_module("ApplicationInterface") - assert module is not None - - def test_loads_data_object(self): - """Should load DataObject module.""" - module = derivation._load_element_module("DataObject") - assert module is not None - - def test_loads_technology_service(self): - """Should load TechnologyService module.""" - module = derivation._load_element_module("TechnologyService") - assert module is not None - - def test_loads_device(self): - """Should load Device module.""" - module = derivation._load_element_module("Device") - assert module is not None - - def test_loads_system_software(self): - """Should load SystemSoftware module.""" - module = derivation._load_element_module("SystemSoftware") - assert module is not None - - class TestRunDerivationWithConfigs: """Tests for run_derivation with actual mock configs.""" diff --git a/tests/test_services/test_session.py b/tests/test_services/test_session.py index 6401f0d..6c47df3 100644 --- a/tests/test_services/test_session.py +++ b/tests/test_services/test_session.py @@ -594,3 +594,502 @@ def test_handles_query_error(self): logger = session._get_run_logger() assert logger is None + + +class TestPipelineSessionNeo4jControl: + """Tests for Neo4j container control methods.""" + + def test_start_neo4j(self): + """Should start Neo4j container.""" + with ( + patch("deriva.services.session.get_connection"), + patch("deriva.services.session.GraphManager"), + patch("deriva.services.session.ArchimateManager"), + patch("deriva.services.session.RepoManager"), + patch("deriva.services.session.Neo4jConnection") as mock_neo4j, + ): + session = PipelineSession(auto_connect=True) + mock_neo4j.return_value.start_container.return_value = {"success": True} + + result = session.start_neo4j() + + assert result["success"] is True + + def test_stop_neo4j(self): + """Should stop Neo4j container.""" + with ( + patch("deriva.services.session.get_connection"), + patch("deriva.services.session.GraphManager"), + patch("deriva.services.session.ArchimateManager"), + patch("deriva.services.session.RepoManager"), + patch("deriva.services.session.Neo4jConnection") as mock_neo4j, + ): + session = PipelineSession(auto_connect=True) + mock_neo4j.return_value.stop_container.return_value = {"success": True} + + result = session.stop_neo4j() + + assert result["success"] is True + + +class TestPipelineSessionDerivationNoLLM: + """Tests for derivation with no LLM configured.""" + + @pytest.fixture + def connected_session(self): + """Create connected session with mocked services.""" + with ( + patch("deriva.services.session.get_connection"), + patch("deriva.services.session.GraphManager"), + patch("deriva.services.session.ArchimateManager"), + patch("deriva.services.session.RepoManager"), + patch("deriva.services.session.Neo4jConnection"), + ): + session = PipelineSession(auto_connect=True) + yield session + + def test_run_derivation_returns_error_when_no_llm(self, connected_session): + """Should return error when LLM not configured.""" + with patch.object(connected_session, "_get_llm_query_fn", return_value=None): + result = connected_session.run_derivation() + + assert result["success"] is False + assert "LLM not configured" in result["errors"][0] + + def test_run_derivation_iter_yields_error_when_no_llm(self, connected_session): + """Should yield error update when LLM not configured.""" + with patch.object(connected_session, "_get_llm_query_fn", return_value=None): + updates = list(connected_session.run_derivation_iter()) + + assert len(updates) == 1 + assert updates[0].status == "error" + assert "LLM not configured" in updates[0].message + + +class TestPipelineSessionExport: + """Tests for export methods.""" + + @pytest.fixture + def connected_session(self): + """Create connected session with mocked managers.""" + with ( + patch("deriva.services.session.get_connection"), + patch("deriva.services.session.GraphManager"), + patch("deriva.services.session.ArchimateManager") as mock_archimate, + patch("deriva.services.session.RepoManager"), + patch("deriva.services.session.Neo4jConnection"), + ): + session = PipelineSession(auto_connect=True) + session._mock_archimate = mock_archimate.return_value + yield session + + def test_export_model_success(self, connected_session, tmp_path): + """Should export model successfully.""" + mock_element = MagicMock() + mock_element.identifier = "elem1" + mock_rel = MagicMock() + mock_rel.source = "elem1" + mock_rel.target = "elem1" + + connected_session._mock_archimate.get_elements.return_value = [mock_element] + connected_session._mock_archimate.get_relationships.return_value = [mock_rel] + + output_path = str(tmp_path / "model.xml") + with patch("deriva.services.session.ArchiMateXMLExporter"): + result = connected_session.export_model(output_path=output_path) + + assert result["success"] is True + assert result["elements_exported"] == 1 + + def test_export_model_no_elements(self, connected_session, tmp_path): + """Should return error when no elements found.""" + connected_session._mock_archimate.get_elements.return_value = [] + + result = connected_session.export_model(output_path=str(tmp_path / "model.xml")) + + assert result["success"] is False + assert "No ArchiMate elements" in result["error"] + + +class TestPipelineSessionConfigMethods: + """Tests for config passthrough methods.""" + + @pytest.fixture + def connected_session(self): + """Create connected session with mocked engine.""" + with ( + patch("deriva.services.session.get_connection") as mock_db, + patch("deriva.services.session.GraphManager"), + patch("deriva.services.session.ArchimateManager"), + patch("deriva.services.session.RepoManager"), + patch("deriva.services.session.Neo4jConnection"), + patch("deriva.services.session.config") as mock_config, + ): + session = PipelineSession(auto_connect=True) + session._mock_engine = mock_db.return_value # type: ignore[attr-defined] + session._mock_config = mock_config # type: ignore[attr-defined] + yield session + + def test_list_steps_delegates_to_config(self, connected_session): + """Should delegate list_steps to config service.""" + connected_session._mock_config.list_steps.return_value = [{"name": "step1"}] + + result = connected_session.list_steps("extraction") + + assert len(result) == 1 + assert result[0]["name"] == "step1" + + def test_enable_step_delegates_to_config(self, connected_session): + """Should delegate enable_step to config service.""" + connected_session._mock_config.enable_step.return_value = True + + result = connected_session.enable_step("extraction", "BusinessConcept") + + assert result is True + + def test_disable_step_delegates_to_config(self, connected_session): + """Should delegate disable_step to config service.""" + connected_session._mock_config.disable_step.return_value = True + + result = connected_session.disable_step("extraction", "BusinessConcept") + + assert result is True + + def test_get_file_types_returns_list(self, connected_session): + """Should return file types as dicts.""" + + # Use a simple object with __dict__ instead of MagicMock + class MockFileType: + def __init__(self): + self.extension = ".py" + self.file_type = "source" + + connected_session._mock_config.get_file_types.return_value = [MockFileType()] + + result = connected_session.get_file_types() + + assert len(result) == 1 + + def test_get_extraction_configs_returns_list(self, connected_session): + """Should return extraction configs as dicts.""" + mock_config = MagicMock() + mock_config.node_type = "BusinessConcept" + mock_config.sequence = 1 + mock_config.enabled = True + mock_config.input_sources = None + mock_config.instruction = "Test" + mock_config.example = "{}" + connected_session._mock_config.get_extraction_configs.return_value = [mock_config] + + result = connected_session.get_extraction_configs() + + assert len(result) == 1 + assert result[0]["node_type"] == "BusinessConcept" + + def test_update_extraction_config(self, connected_session): + """Should delegate to config service.""" + connected_session._mock_config.update_extraction_config.return_value = True + + result = connected_session.update_extraction_config("BusinessConcept", enabled=True) + + assert result is True + + def test_save_extraction_config(self, connected_session): + """Should delegate to config service for versioned update.""" + connected_session._mock_config.create_extraction_config_version.return_value = { + "success": True, + "new_version": 2, + } + + result = connected_session.save_extraction_config("BusinessConcept", enabled=True) + + assert result["success"] is True + + def test_get_derivation_configs(self, connected_session): + """Should return derivation configs as dicts.""" + mock_config = MagicMock() + mock_config.element_type = "ApplicationComponent" + mock_config.sequence = 1 + mock_config.enabled = True + mock_config.input_graph_query = "MATCH (n)" + mock_config.instruction = "Test" + mock_config.example = "{}" + connected_session._mock_config.get_derivation_configs.return_value = [mock_config] + + result = connected_session.get_derivation_configs() + + assert len(result) == 1 + assert result[0]["element_type"] == "ApplicationComponent" + + def test_update_derivation_config(self, connected_session): + """Should delegate to config service.""" + connected_session._mock_config.update_derivation_config.return_value = True + + result = connected_session.update_derivation_config("ApplicationComponent", enabled=True) + + assert result is True + + def test_save_derivation_config(self, connected_session): + """Should delegate to config service for versioned update.""" + connected_session._mock_config.create_derivation_config_version.return_value = { + "success": True, + "new_version": 2, + } + + result = connected_session.save_derivation_config("ApplicationComponent", enabled=True) + + assert result["success"] is True + + def test_get_config_versions(self, connected_session): + """Should return config versions.""" + connected_session._mock_config.get_active_config_versions.return_value = { + "extraction": {"BusinessConcept": 1}, + "derivation": {"ApplicationComponent": 2}, + } + + result = connected_session.get_config_versions() + + assert result["extraction"]["BusinessConcept"] == 1 + assert result["derivation"]["ApplicationComponent"] == 2 + + def test_add_file_type(self, connected_session): + """Should delegate to config service.""" + connected_session._mock_config.add_file_type.return_value = True + + result = connected_session.add_file_type(".rs", "source", "rust") + + assert result is True + + def test_update_file_type(self, connected_session): + """Should delegate to config service.""" + connected_session._mock_config.update_file_type.return_value = True + + result = connected_session.update_file_type(".py", "source", "python3") + + assert result is True + + def test_delete_file_type(self, connected_session): + """Should delegate to config service.""" + connected_session._mock_config.delete_file_type.return_value = True + + result = connected_session.delete_file_type(".xyz") + + assert result is True + + +class TestPipelineSessionFileTypeStats: + """Tests for file type stats method.""" + + def test_get_file_type_stats(self): + """Should return file type statistics.""" + with ( + patch("deriva.services.session.get_connection") as mock_db, + patch("deriva.services.session.GraphManager"), + patch("deriva.services.session.ArchimateManager"), + patch("deriva.services.session.RepoManager"), + patch("deriva.services.session.Neo4jConnection"), + ): + session = PipelineSession(auto_connect=True) + mock_db.return_value.execute.return_value.fetchone.side_effect = [ + (5,), # types + (10,), # subtypes + (20,), # total + ] + + result = session.get_file_type_stats() + + assert result["types"] == 5 + assert result["subtypes"] == 10 + assert result["total"] == 20 + + +class TestPipelineSessionLLMManagement: + """Tests for LLM management methods.""" + + def test_toggle_llm_cache_enabled(self): + """Should enable LLM cache.""" + with ( + patch("deriva.services.session.get_connection"), + patch("deriva.services.session.GraphManager"), + patch("deriva.services.session.ArchimateManager"), + patch("deriva.services.session.RepoManager"), + patch("deriva.services.session.Neo4jConnection"), + ): + session = PipelineSession(auto_connect=True) + session._llm_manager = MagicMock() + + result = session.toggle_llm_cache(True) + + assert result["success"] is True + assert result["cache_enabled"] is True + + def test_toggle_llm_cache_no_manager(self): + """Should return error when no LLM manager.""" + with ( + patch("deriva.services.session.get_connection"), + patch("deriva.services.session.GraphManager"), + patch("deriva.services.session.ArchimateManager"), + patch("deriva.services.session.RepoManager"), + patch("deriva.services.session.Neo4jConnection"), + ): + session = PipelineSession(auto_connect=True) + # Force _llm_manager to None + with patch.object(session, "_get_llm_query_fn", return_value=None): + session._llm_manager = None + + result = session.toggle_llm_cache(True) + + assert result["success"] is False + + def test_list_benchmark_models(self): + """Should return benchmark models dict.""" + with ( + patch("deriva.services.session.get_connection"), + patch("deriva.services.session.GraphManager"), + patch("deriva.services.session.ArchimateManager"), + patch("deriva.services.session.RepoManager"), + patch("deriva.services.session.Neo4jConnection"), + patch("deriva.services.session.benchmarking"), + ): + session = PipelineSession(auto_connect=True) + + with patch("deriva.adapters.llm.manager.load_benchmark_models", return_value={"model1": {}}): + result = session.list_benchmark_models() + + assert "model1" in result + + +class TestPipelineSessionMiscMethods: + """Tests for miscellaneous session methods.""" + + @pytest.fixture + def connected_session(self): + """Create connected session.""" + with ( + patch("deriva.services.session.get_connection") as mock_db, + patch("deriva.services.session.GraphManager"), + patch("deriva.services.session.ArchimateManager"), + patch("deriva.services.session.RepoManager"), + patch("deriva.services.session.Neo4jConnection"), + patch("deriva.services.session.pipeline") as mock_pipeline, + ): + session = PipelineSession(auto_connect=True) + session._mock_engine = mock_db.return_value + session._mock_pipeline = mock_pipeline + yield session + + def test_run_classification(self, connected_session): + """Should delegate to pipeline service.""" + connected_session._mock_pipeline.run_classification.return_value = {"success": True} + + result = connected_session.run_classification(repo_name="test-repo") + + assert result["success"] is True + + def test_get_database_path(self, connected_session): + """Should return database path.""" + with patch("deriva.adapters.database.DB_PATH", "test/path/db.duckdb"): + result = connected_session.get_database_path() + + assert "db" in result.lower() or "path" in result.lower() + + def test_execute_sql_with_params(self, connected_session): + """Should execute SQL with parameters.""" + connected_session._mock_engine.execute.return_value.fetchall.return_value = [(1, "test")] + + result = connected_session.execute_sql("SELECT * FROM test WHERE id = ?", [1]) + + assert len(result) == 1 + assert result[0] == (1, "test") + + def test_execute_sql_without_params(self, connected_session): + """Should execute SQL without parameters.""" + connected_session._mock_engine.execute.return_value.fetchall.return_value = [(1,), (2,)] + + result = connected_session.execute_sql("SELECT COUNT(*) FROM test") + + assert len(result) == 2 + + def test_workspace_dir_property(self, connected_session): + """Should return workspace directory.""" + result = connected_session.workspace_dir + + assert result is not None + + +class TestPipelineSessionRepositoryInfo: + """Tests for repository info method.""" + + @pytest.fixture + def connected_session(self): + """Create connected session with mocked repo manager.""" + with ( + patch("deriva.services.session.get_connection"), + patch("deriva.services.session.GraphManager"), + patch("deriva.services.session.ArchimateManager"), + patch("deriva.services.session.RepoManager") as mock_repo, + patch("deriva.services.session.Neo4jConnection"), + ): + session = PipelineSession(auto_connect=True) + session._mock_repo = mock_repo.return_value + yield session + + def test_get_repository_info_found(self, connected_session): + """Should return repository info when found.""" + mock_info = MagicMock() + mock_info.to_dict.return_value = {"name": "test-repo", "path": "/path/to/repo"} + connected_session._mock_repo.get_repository_info.return_value = mock_info + + result = connected_session.get_repository_info("test-repo") + + assert result is not None + assert result["name"] == "test-repo" + + def test_get_repository_info_not_found(self, connected_session): + """Should return None when not found.""" + connected_session._mock_repo.get_repository_info.return_value = None + + result = connected_session.get_repository_info("unknown-repo") + + assert result is None + + def test_get_repository_info_error(self, connected_session): + """Should return None on error.""" + connected_session._mock_repo.get_repository_info.side_effect = Exception("Error") + + result = connected_session.get_repository_info("test-repo") + + assert result is None + + def test_delete_repository_error(self, connected_session): + """Should return error on delete failure.""" + connected_session._mock_repo.delete_repository.side_effect = Exception("Delete failed") + + result = connected_session.delete_repository("test-repo") + + assert result["success"] is False + assert "Delete failed" in result["error"] + + +class TestPipelineSessionExtractionStepCount: + """Tests for get_extraction_step_count method.""" + + def test_get_extraction_step_count(self): + """Should return extraction step count.""" + with ( + patch("deriva.services.session.get_connection"), + patch("deriva.services.session.GraphManager"), + patch("deriva.services.session.ArchimateManager"), + patch("deriva.services.session.RepoManager") as mock_repo, + patch("deriva.services.session.Neo4jConnection"), + patch("deriva.services.session.config") as mock_config, + ): + session = PipelineSession(auto_connect=True) + + # 3 configs * 2 repos = 6 steps + mock_config.get_extraction_configs.return_value = [MagicMock(), MagicMock(), MagicMock()] + mock_repo.return_value.list_repositories.return_value = [MagicMock(name="repo1"), MagicMock(name="repo2")] + + count = session.get_extraction_step_count() + + assert count == 6 diff --git a/uv.lock b/uv.lock index 7e93bf5..dbeced7 100644 --- a/uv.lock +++ b/uv.lock @@ -155,6 +155,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/51/321e821856452f7386c4e9df866f196720b1ad0c5ea1623ea7399969ae3b/authlib-1.6.6-py2.py3-none-any.whl", hash = "sha256:7d9e9bc535c13974313a87f53e8430eb6ea3d1cf6ae4f6efcd793f2e949143fd", size = 244005 }, ] +[[package]] +name = "backoff" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/d7/5bbeb12c44d7c4f2fb5b56abce497eb5ed9f34d85701de869acedd602619/backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba", size = 17001 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/df/73/b6e24bd22e6720ca8ee9a85a0c4a2971af8497d8f3193fa05390cbd46e09/backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8", size = 15148 }, +] + [[package]] name = "beartype" version = "0.22.9" @@ -179,30 +188,30 @@ wheels = [ [[package]] name = "boto3" -version = "1.42.28" +version = "1.42.29" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore" }, { name = "jmespath" }, { name = "s3transfer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/83/aa/a44ea8c8ee8239f3f7c32cce966512c846297df5fe48b56db6882f3b7ca0/boto3-1.42.28.tar.gz", hash = "sha256:7d56c298b8d98f5e9b04cf5d6627f68e7792e25614533aef17f815681b5e1096", size = 112846 } +sdist = { url = "https://files.pythonhosted.org/packages/5c/24/1dd85b64004103c2e60476d0fa8d78435f5fed9db1129cd2cd332784037a/boto3-1.42.29.tar.gz", hash = "sha256:247e54f24116ad6792cfc14b274288383af3ec3433b0547da8a14a8bd6e81950", size = 112810 } wheels = [ - { url = "https://files.pythonhosted.org/packages/69/35/5d95169ed145f0c49ebfeb6a5228ab63d54e95a2c7a43f0f0eb893540660/boto3-1.42.28-py3-none-any.whl", hash = "sha256:7994bc2a094c1894f6a4221a1696c5d18af6c9c888191051866f1d05c4fba431", size = 140575 }, + { url = "https://files.pythonhosted.org/packages/51/30/2c25d7be8418e7f137ffece6097c68199dbd6996da645ec9b5a5a9647123/boto3-1.42.29-py3-none-any.whl", hash = "sha256:6c9c4dece67bf72d82ba7dff48e33a56a87cdf9b16c8887f88ca7789a95d3317", size = 140574 }, ] [[package]] name = "botocore" -version = "1.42.28" +version = "1.42.29" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jmespath" }, { name = "python-dateutil" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/85/8d/e0828726aa568e5ab0ec477c7a47a82aa37f00951858d9ad892b6b1d5e32/botocore-1.42.28.tar.gz", hash = "sha256:0c15e78d1accf97df691083331f682e97b1bef73ef12dcdaadcf652abf9c182c", size = 14886029 } +sdist = { url = "https://files.pythonhosted.org/packages/70/08/8a8e0255949845f764c5126f97b1bc09a6484077f124c2177b979ecfbbff/botocore-1.42.29.tar.gz", hash = "sha256:0fe869227a1dfe818f691a31b8c1693e39be8056a6dff5d6d4b3fc5b3a5e7d42", size = 14890916 } wheels = [ - { url = "https://files.pythonhosted.org/packages/8a/ff/72470b92ba96868be1936b8b3c7a70f902b60d36268bdeddb732317bef7a/botocore-1.42.28-py3-none-any.whl", hash = "sha256:d26c7a0851489ce1a18279f9802fe434bd736ea861d4888cc2c7d83fb1f6af8f", size = 14559264 }, + { url = "https://files.pythonhosted.org/packages/94/76/cfa6a934ee5a8a87f626b38275193a046da894d2f9021e001587fc2e8c7d/botocore-1.42.29-py3-none-any.whl", hash = "sha256:b45f8dfc1de5106a9d040c5612f267582e68b2b2c5237477dff85c707c1c5d11", size = 14563947 }, ] [[package]] @@ -459,12 +468,15 @@ name = "deriva" version = "0.6.8" source = { editable = "." } dependencies = [ + { name = "backoff" }, + { name = "diskcache" }, { name = "duckdb" }, { name = "lxml" }, { name = "marimo" }, { name = "neo4j" }, { name = "pydantic" }, { name = "pydantic-ai" }, + { name = "pydantic-settings" }, { name = "pypdf" }, { name = "python-docx" }, { name = "python-dotenv" }, @@ -472,11 +484,13 @@ dependencies = [ { name = "rich" }, { name = "solvor" }, { name = "sqlglot" }, + { name = "structlog" }, { name = "tree-sitter" }, { name = "tree-sitter-c-sharp" }, { name = "tree-sitter-java" }, { name = "tree-sitter-javascript" }, { name = "tree-sitter-python" }, + { name = "typer" }, ] [package.optional-dependencies] @@ -491,13 +505,16 @@ dev = [ [package.metadata] requires-dist = [ + { name = "backoff", specifier = ">=2.2.1" }, + { name = "diskcache", specifier = ">=5.6.3" }, { name = "duckdb", specifier = ">=1.4.3" }, { name = "lxml", specifier = ">=6.0.2" }, - { name = "marimo", specifier = ">=0.19.2" }, + { name = "marimo", specifier = ">=0.19.3" }, { name = "neo4j", specifier = ">=6.1.0" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=4.5.1" }, { name = "pydantic", specifier = ">=2.12.5" }, - { name = "pydantic-ai", specifier = ">=0.1.0" }, + { name = "pydantic-ai", specifier = ">=1.42.0" }, + { name = "pydantic-settings", specifier = ">=2.0" }, { name = "pypdf", specifier = ">=6.6.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=7.0.0" }, @@ -505,15 +522,17 @@ requires-dist = [ { name = "python-dotenv", specifier = ">=1.2.1" }, { name = "pyyaml", specifier = ">=6.0.3" }, { name = "rich", specifier = ">=14.2.0" }, - { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.14.11" }, + { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.14.13" }, { name = "solvor", specifier = ">=0.5.3" }, { name = "sqlglot", specifier = ">=28.6.0" }, + { name = "structlog", specifier = ">=24.0" }, { name = "tree-sitter", specifier = ">=0.24.6" }, { name = "tree-sitter-c-sharp", specifier = ">=0.23.0" }, { name = "tree-sitter-java", specifier = ">=0.23.0" }, { name = "tree-sitter-javascript", specifier = ">=0.23.0" }, { name = "tree-sitter-python", specifier = ">=0.23.0" }, { name = "ty", marker = "extra == 'dev'", specifier = ">=0.0.12" }, + { name = "typer", specifier = ">=0.21.1" }, { name = "types-lxml", marker = "extra == 'dev'", specifier = ">=2025.3.30" }, ] @@ -781,7 +800,7 @@ requests = [ [[package]] name = "google-genai" -version = "1.58.0" +version = "1.59.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -795,9 +814,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b4/49/0c2dd11c50db7ee2c299c63f1795256c96543be8b40e6f139c1a680f92e8/google_genai-1.58.0.tar.gz", hash = "sha256:bbec3abf253c17ad57b68e7f8d87d5cda34d5909c67b7ba726207a2bd10aa9fd", size = 486640 } +sdist = { url = "https://files.pythonhosted.org/packages/40/34/c03bcbc759d67ac3d96077838cdc1eac85417de6ea3b65b313fe53043eee/google_genai-1.59.0.tar.gz", hash = "sha256:0b7a2dc24582850ae57294209d8dfc2c4f5fcfde0a3f11d81dc5aca75fb619e2", size = 487374 } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/61/098a414cc41600036fe3eb24415221f5412f446e1be1ce9595bb32f2ae92/google_genai-1.58.0-py3-none-any.whl", hash = "sha256:2df3fceb92a519d51e266babde26f7c298fb12f84f469605c4ce80c2ec43f796", size = 718352 }, + { url = "https://files.pythonhosted.org/packages/aa/53/6d00692fe50d73409b3406ae90c71bc4499c8ae7fac377ba16e283da917c/google_genai-1.59.0-py3-none-any.whl", hash = "sha256:59fc01a225d074fe9d1e626c3433da292f33249dadce4deb34edea698305a6df", size = 719099 }, ] [[package]] @@ -1261,7 +1280,7 @@ wheels = [ [[package]] name = "marimo" -version = "0.19.2" +version = "0.19.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -1281,9 +1300,9 @@ dependencies = [ { name = "uvicorn" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5a/67/43c96d82f6ad1f4e6e53772e61c8a3ac1bd284aad6ed5a70e8895b3b3ae2/marimo-0.19.2.tar.gz", hash = "sha256:e734b9f6f49943b052be27260befdf48b0f6c627d22fa81259ca48b303df4de6", size = 39373490 } +sdist = { url = "https://files.pythonhosted.org/packages/38/8a/765f509097a7a215436f63ac2211784091730befeb2db2f02d94ab35345d/marimo-0.19.3.tar.gz", hash = "sha256:7fffd553e636064e0cbe8f35264f5a3baa0f2c741c30faf1f6be84ba71643712", size = 39397654 } wheels = [ - { url = "https://files.pythonhosted.org/packages/f6/53/55d149a5244d8f982d493ce7263e014224435df6f4ff37d89f9a92673942/marimo-0.19.2-py3-none-any.whl", hash = "sha256:75b8083bb34fbb4c9908daac5208ad5186d8138a1692c0e937b258c7f6e6cd8a", size = 39912515 }, + { url = "https://files.pythonhosted.org/packages/e5/87/db0965dd4ab1fdef879d7712efe34f327d17e1331dec56d68db19719db69/marimo-0.19.3-py3-none-any.whl", hash = "sha256:3a947a8efa1a479db90be97a1af236348c880b552d7dd7ec5d2af83781d10820", size = 39932636 }, ] [[package]] @@ -2456,28 +2475,28 @@ wheels = [ [[package]] name = "ruff" -version = "0.14.11" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d4/77/9a7fe084d268f8855d493e5031ea03fa0af8cc05887f638bf1c4e3363eb8/ruff-0.14.11.tar.gz", hash = "sha256:f6dc463bfa5c07a59b1ff2c3b9767373e541346ea105503b4c0369c520a66958", size = 5993417 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/a6/a4c40a5aaa7e331f245d2dc1ac8ece306681f52b636b40ef87c88b9f7afd/ruff-0.14.11-py3-none-linux_armv6l.whl", hash = "sha256:f6ff2d95cbd335841a7217bdfd9c1d2e44eac2c584197ab1385579d55ff8830e", size = 12951208 }, - { url = "https://files.pythonhosted.org/packages/5c/5c/360a35cb7204b328b685d3129c08aca24765ff92b5a7efedbdd6c150d555/ruff-0.14.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6f6eb5c1c8033680f4172ea9c8d3706c156223010b8b97b05e82c59bdc774ee6", size = 13330075 }, - { url = "https://files.pythonhosted.org/packages/1b/9e/0cc2f1be7a7d33cae541824cf3f95b4ff40d03557b575912b5b70273c9ec/ruff-0.14.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f2fc34cc896f90080fca01259f96c566f74069a04b25b6205d55379d12a6855e", size = 12257809 }, - { url = "https://files.pythonhosted.org/packages/a7/e5/5faab97c15bb75228d9f74637e775d26ac703cc2b4898564c01ab3637c02/ruff-0.14.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:53386375001773ae812b43205d6064dae49ff0968774e6befe16a994fc233caa", size = 12678447 }, - { url = "https://files.pythonhosted.org/packages/1b/33/e9767f60a2bef779fb5855cab0af76c488e0ce90f7bb7b8a45c8a2ba4178/ruff-0.14.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a697737dce1ca97a0a55b5ff0434ee7205943d4874d638fe3ae66166ff46edbe", size = 12758560 }, - { url = "https://files.pythonhosted.org/packages/eb/84/4c6cf627a21462bb5102f7be2a320b084228ff26e105510cd2255ea868e5/ruff-0.14.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6845ca1da8ab81ab1dce755a32ad13f1db72e7fba27c486d5d90d65e04d17b8f", size = 13599296 }, - { url = "https://files.pythonhosted.org/packages/88/e1/92b5ed7ea66d849f6157e695dc23d5d6d982bd6aa8d077895652c38a7cae/ruff-0.14.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:e36ce2fd31b54065ec6f76cb08d60159e1b32bdf08507862e32f47e6dde8bcbf", size = 15048981 }, - { url = "https://files.pythonhosted.org/packages/61/df/c1bd30992615ac17c2fb64b8a7376ca22c04a70555b5d05b8f717163cf9f/ruff-0.14.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:590bcc0e2097ecf74e62a5c10a6b71f008ad82eb97b0a0079e85defe19fe74d9", size = 14633183 }, - { url = "https://files.pythonhosted.org/packages/04/e9/fe552902f25013dd28a5428a42347d9ad20c4b534834a325a28305747d64/ruff-0.14.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:53fe71125fc158210d57fe4da26e622c9c294022988d08d9347ec1cf782adafe", size = 14050453 }, - { url = "https://files.pythonhosted.org/packages/ae/93/f36d89fa021543187f98991609ce6e47e24f35f008dfe1af01379d248a41/ruff-0.14.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a35c9da08562f1598ded8470fcfef2afb5cf881996e6c0a502ceb61f4bc9c8a3", size = 13757889 }, - { url = "https://files.pythonhosted.org/packages/b7/9f/c7fb6ecf554f28709a6a1f2a7f74750d400979e8cd47ed29feeaa1bd4db8/ruff-0.14.11-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:0f3727189a52179393ecf92ec7057c2210203e6af2676f08d92140d3e1ee72c1", size = 13955832 }, - { url = "https://files.pythonhosted.org/packages/db/a0/153315310f250f76900a98278cf878c64dfb6d044e184491dd3289796734/ruff-0.14.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:eb09f849bd37147a789b85995ff734a6c4a095bed5fd1608c4f56afc3634cde2", size = 12586522 }, - { url = "https://files.pythonhosted.org/packages/2f/2b/a73a2b6e6d2df1d74bf2b78098be1572191e54bec0e59e29382d13c3adc5/ruff-0.14.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:c61782543c1231bf71041461c1f28c64b961d457d0f238ac388e2ab173d7ecb7", size = 12724637 }, - { url = "https://files.pythonhosted.org/packages/f0/41/09100590320394401cd3c48fc718a8ba71c7ddb1ffd07e0ad6576b3a3df2/ruff-0.14.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:82ff352ea68fb6766140381748e1f67f83c39860b6446966cff48a315c3e2491", size = 13145837 }, - { url = "https://files.pythonhosted.org/packages/3b/d8/e035db859d1d3edf909381eb8ff3e89a672d6572e9454093538fe6f164b0/ruff-0.14.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:728e56879df4ca5b62a9dde2dd0eb0edda2a55160c0ea28c4025f18c03f86984", size = 13850469 }, - { url = "https://files.pythonhosted.org/packages/4e/02/bb3ff8b6e6d02ce9e3740f4c17dfbbfb55f34c789c139e9cd91985f356c7/ruff-0.14.11-py3-none-win32.whl", hash = "sha256:337c5dd11f16ee52ae217757d9b82a26400be7efac883e9e852646f1557ed841", size = 12851094 }, - { url = "https://files.pythonhosted.org/packages/58/f1/90ddc533918d3a2ad628bc3044cdfc094949e6d4b929220c3f0eb8a1c998/ruff-0.14.11-py3-none-win_amd64.whl", hash = "sha256:f981cea63d08456b2c070e64b79cb62f951aa1305282974d4d5216e6e0178ae6", size = 14001379 }, - { url = "https://files.pythonhosted.org/packages/c4/1c/1dbe51782c0e1e9cfce1d1004752672d2d4629ea46945d19d731ad772b3b/ruff-0.14.11-py3-none-win_arm64.whl", hash = "sha256:649fb6c9edd7f751db276ef42df1f3df41c38d67d199570ae2a7bd6cbc3590f0", size = 12938644 }, +version = "0.14.13" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/50/0a/1914efb7903174b381ee2ffeebb4253e729de57f114e63595114c8ca451f/ruff-0.14.13.tar.gz", hash = "sha256:83cd6c0763190784b99650a20fec7633c59f6ebe41c5cc9d45ee42749563ad47", size = 6059504 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/ae/0deefbc65ca74b0ab1fd3917f94dc3b398233346a74b8bbb0a916a1a6bf6/ruff-0.14.13-py3-none-linux_armv6l.whl", hash = "sha256:76f62c62cd37c276cb03a275b198c7c15bd1d60c989f944db08a8c1c2dbec18b", size = 13062418 }, + { url = "https://files.pythonhosted.org/packages/47/df/5916604faa530a97a3c154c62a81cb6b735c0cb05d1e26d5ad0f0c8ac48a/ruff-0.14.13-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:914a8023ece0528d5cc33f5a684f5f38199bbb566a04815c2c211d8f40b5d0ed", size = 13442344 }, + { url = "https://files.pythonhosted.org/packages/4c/f3/e0e694dd69163c3a1671e102aa574a50357536f18a33375050334d5cd517/ruff-0.14.13-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d24899478c35ebfa730597a4a775d430ad0d5631b8647a3ab368c29b7e7bd063", size = 12354720 }, + { url = "https://files.pythonhosted.org/packages/c3/e8/67f5fcbbaee25e8fc3b56cc33e9892eca7ffe09f773c8e5907757a7e3bdb/ruff-0.14.13-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9aaf3870f14d925bbaf18b8a2347ee0ae7d95a2e490e4d4aea6813ed15ebc80e", size = 12774493 }, + { url = "https://files.pythonhosted.org/packages/6b/ce/d2e9cb510870b52a9565d885c0d7668cc050e30fa2c8ac3fb1fda15c083d/ruff-0.14.13-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac5b7f63dd3b27cc811850f5ffd8fff845b00ad70e60b043aabf8d6ecc304e09", size = 12815174 }, + { url = "https://files.pythonhosted.org/packages/88/00/c38e5da58beebcf4fa32d0ddd993b63dfacefd02ab7922614231330845bf/ruff-0.14.13-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78d2b1097750d90ba82ce4ba676e85230a0ed694178ca5e61aa9b459970b3eb9", size = 13680909 }, + { url = "https://files.pythonhosted.org/packages/61/61/cd37c9dd5bd0a3099ba79b2a5899ad417d8f3b04038810b0501a80814fd7/ruff-0.14.13-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:7d0bf87705acbbcb8d4c24b2d77fbb73d40210a95c3903b443cd9e30824a5032", size = 15144215 }, + { url = "https://files.pythonhosted.org/packages/56/8a/85502d7edbf98c2df7b8876f316c0157359165e16cdf98507c65c8d07d3d/ruff-0.14.13-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a3eb5da8e2c9e9f13431032fdcbe7681de9ceda5835efee3269417c13f1fed5c", size = 14706067 }, + { url = "https://files.pythonhosted.org/packages/7e/2f/de0df127feb2ee8c1e54354dc1179b4a23798f0866019528c938ba439aca/ruff-0.14.13-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:642442b42957093811cd8d2140dfadd19c7417030a7a68cf8d51fcdd5f217427", size = 14133916 }, + { url = "https://files.pythonhosted.org/packages/0d/77/9b99686bb9fe07a757c82f6f95e555c7a47801a9305576a9c67e0a31d280/ruff-0.14.13-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4acdf009f32b46f6e8864af19cbf6841eaaed8638e65c8dac845aea0d703c841", size = 13859207 }, + { url = "https://files.pythonhosted.org/packages/7d/46/2bdcb34a87a179a4d23022d818c1c236cb40e477faf0d7c9afb6813e5876/ruff-0.14.13-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:591a7f68860ea4e003917d19b5c4f5ac39ff558f162dc753a2c5de897fd5502c", size = 14043686 }, + { url = "https://files.pythonhosted.org/packages/1a/a9/5c6a4f56a0512c691cf143371bcf60505ed0f0860f24a85da8bd123b2bf1/ruff-0.14.13-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:774c77e841cc6e046fc3e91623ce0903d1cd07e3a36b1a9fe79b81dab3de506b", size = 12663837 }, + { url = "https://files.pythonhosted.org/packages/fe/bb/b920016ece7651fa7fcd335d9d199306665486694d4361547ccb19394c44/ruff-0.14.13-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:61f4e40077a1248436772bb6512db5fc4457fe4c49e7a94ea7c5088655dd21ae", size = 12805867 }, + { url = "https://files.pythonhosted.org/packages/7d/b3/0bd909851e5696cd21e32a8fc25727e5f58f1934b3596975503e6e85415c/ruff-0.14.13-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6d02f1428357fae9e98ac7aa94b7e966fd24151088510d32cf6f902d6c09235e", size = 13208528 }, + { url = "https://files.pythonhosted.org/packages/3b/3b/e2d94cb613f6bbd5155a75cbe072813756363eba46a3f2177a1fcd0cd670/ruff-0.14.13-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e399341472ce15237be0c0ae5fbceca4b04cd9bebab1a2b2c979e015455d8f0c", size = 13929242 }, + { url = "https://files.pythonhosted.org/packages/6a/c5/abd840d4132fd51a12f594934af5eba1d5d27298a6f5b5d6c3be45301caf/ruff-0.14.13-py3-none-win32.whl", hash = "sha256:ef720f529aec113968b45dfdb838ac8934e519711da53a0456038a0efecbd680", size = 12919024 }, + { url = "https://files.pythonhosted.org/packages/c2/55/6384b0b8ce731b6e2ade2b5449bf07c0e4c31e8a2e68ea65b3bafadcecc5/ruff-0.14.13-py3-none-win_amd64.whl", hash = "sha256:6070bd026e409734b9257e03e3ef18c6e1a216f0435c6751d7a8ec69cb59abef", size = 14097887 }, + { url = "https://files.pythonhosted.org/packages/4d/e1/7348090988095e4e39560cfc2f7555b1b2a7357deba19167b600fdf5215d/ruff-0.14.13-py3-none-win_arm64.whl", hash = "sha256:7ab819e14f1ad9fe39f246cfcc435880ef7a9390d81a2b6ac7e01039083dd247", size = 13080224 }, ] [[package]] @@ -2593,6 +2612,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/52/1064f510b141bd54025f9b55105e26d1fa970b9be67ad766380a3c9b74b0/starlette-0.50.0-py3-none-any.whl", hash = "sha256:9e5391843ec9b6e472eed1365a78c8098cfceb7a74bfd4d6b1c0c0095efb3bca", size = 74033 }, ] +[[package]] +name = "structlog" +version = "25.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ef/52/9ba0f43b686e7f3ddfeaa78ac3af750292662284b3661e91ad5494f21dbc/structlog-25.5.0.tar.gz", hash = "sha256:098522a3bebed9153d4570c6d0288abf80a031dfdb2048d59a49e9dc2190fc98", size = 1460830 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/45/a132b9074aa18e799b891b91ad72133c98d8042c70f6240e4c5f9dabee2f/structlog-25.5.0-py3-none-any.whl", hash = "sha256:a8453e9b9e636ec59bd9e79bbd4a72f025981b3ba0f5837aebf48f02f37a7f9f", size = 72510 }, +] + [[package]] name = "temporalio" version = "1.20.0"