diff --git a/argus/agents/agent_models.py b/argus/agents/agent_models.py index 7b04b1c..6b5f5d8 100644 --- a/argus/agents/agent_models.py +++ b/argus/agents/agent_models.py @@ -141,7 +141,7 @@ class ActionType(str, Enum): ) # Import from validation_models -# Validation error models; Validation schemas; Validation utilities; +# Validation error models; Validation schemas; Validation utilities; # Custom validators; Validation decorators from .validation_models import ( CodeAnalysisValidationSchema, diff --git a/argus/agents/enhanced_adapter.py b/argus/agents/enhanced_adapter.py index 8a7db5d..cab3d98 100644 --- a/argus/agents/enhanced_adapter.py +++ b/argus/agents/enhanced_adapter.py @@ -15,7 +15,7 @@ """Enhanced Agent Adapters and Migration Helpers.""" import logging -from typing import Any, cast +from typing import Any from ..llm.config import LLMConfig from .enhanced_specialized import ( diff --git a/argus/agents/legacy_adapter.py b/argus/agents/legacy_adapter.py index cf15c07..8fcb051 100644 --- a/argus/agents/legacy_adapter.py +++ b/argus/agents/legacy_adapter.py @@ -24,13 +24,13 @@ import logging from typing import Any -from ..llm.config import LLMConfig -from ..triage_agent import TriageAgent from ..analysis_agent import AnalysisAgent +from ..llm.config import LLMConfig from ..remediation_agent import RemediationAgent -from .enhanced_triage_agent import EnhancedTriageAgent +from ..triage_agent import TriageAgent from .enhanced_analysis_agent import EnhancedAnalysisAgent from .enhanced_remediation_agent import EnhancedRemediationAgent +from .enhanced_triage_agent import EnhancedTriageAgent logger = logging.getLogger(__name__) diff --git a/argus/agents/response_models.py b/argus/agents/response_models.py index df9b1f6..48f402b 100644 --- a/argus/agents/response_models.py +++ b/argus/agents/response_models.py @@ -29,7 +29,6 @@ from uuid import uuid4 from pydantic import BaseModel, Field, field_validator -from argus.core.types.agent import AnalysisResponse, RemediationResponse # ============================================================================ # Common Enums and Base Models diff --git a/argus/config/ingestion_config.py b/argus/config/ingestion_config.py index 3c10274..56b7a4e 100644 --- a/argus/config/ingestion_config.py +++ b/argus/config/ingestion_config.py @@ -21,12 +21,12 @@ pluggable log ingestion architecture. """ -import json -import logging from dataclasses import dataclass, field from enum import Enum +import json +import logging from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any import yaml @@ -75,15 +75,15 @@ class SourceConfig: retry_delay: float = 1.0 timeout: float = 30.0 circuit_breaker_enabled: bool = True - rate_limit_per_second: Optional[int] = None - config: Dict[str, Any] = field(default_factory=dict) + rate_limit_per_second: int | None = None + config: dict[str, Any] = field(default_factory=dict) @dataclass class GCPPubSubConfig(SourceConfig): """Configuration for GCP Pub/Sub source.""" - credentials_path: Optional[str] = None + credentials_path: str | None = None max_messages: int = 100 ack_deadline_seconds: int = 60 flow_control_max_messages: int = 1000 @@ -109,7 +109,7 @@ class GCPLoggingConfig(SourceConfig): """Configuration for GCP Logging source.""" log_filter: str = "severity>=ERROR" - credentials_path: Optional[str] = None + credentials_path: str | None = None poll_interval: int = 30 max_results: int = 1000 project_id: str = "" @@ -153,9 +153,9 @@ def __post_init__(self) -> None: class AWSCloudWatchConfig(SourceConfig): """Configuration for AWS CloudWatch source.""" - log_stream_name: Optional[str] = None + log_stream_name: str | None = None region: str = "us-east-1" - credentials_profile: Optional[str] = None + credentials_profile: str | None = None poll_interval: int = 30 max_events: int = 1000 log_group_name: str = "" @@ -176,10 +176,10 @@ def __post_init__(self) -> None: class KubernetesConfig(SourceConfig): """Configuration for Kubernetes source.""" - namespace: Optional[str] = None - label_selector: Optional[str] = None - container_name: Optional[str] = None - kubeconfig_path: Optional[str] = None + namespace: str | None = None + label_selector: str | None = None + container_name: str | None = None + kubeconfig_path: str | None = None poll_interval: int = 30 max_logs: int = 1000 max_pods: int = 100 @@ -243,11 +243,11 @@ class GlobalConfig: class IngestionConfig: """Complete configuration for the log ingestion system.""" - sources: List[SourceConfig] = field(default_factory=list) + sources: list[SourceConfig] = field(default_factory=list) global_config: GlobalConfig = field(default_factory=GlobalConfig) schema_version: str = "1.0.0" - def validate(self) -> List[str]: + def validate(self) -> list[str]: """Validate the configuration and return any errors.""" errors = [] @@ -303,14 +303,14 @@ def validate(self) -> List[str]: return errors - def get_source_by_name(self, name: str) -> Optional[SourceConfig]: + def get_source_by_name(self, name: str) -> SourceConfig | None: """Get a source configuration by name.""" for source in self.sources: if source.name == name: return source return None - def get_enabled_sources(self) -> List[SourceConfig]: + def get_enabled_sources(self) -> list[SourceConfig]: """Get all enabled sources sorted by priority.""" enabled = [source for source in self.sources if source.enabled] return sorted(enabled, key=lambda x: x.priority) @@ -319,13 +319,13 @@ def get_enabled_sources(self) -> List[SourceConfig]: class IngestionConfigManager: """Manager for ingestion configuration loading and validation.""" - def __init__(self, config_path: Optional[Union[str, Path]] = None) -> None: + def __init__(self, config_path: str | Path | None = None) -> None: """Initialize the config manager.""" self.config_path = Path(config_path) if config_path else None - self._config: Optional[IngestionConfig] = None + self._config: IngestionConfig | None = None def load_config( - self, config_path: Optional[Union[str, Path]] = None + self, config_path: str | Path | None = None ) -> IngestionConfig: """Load configuration from file.""" if config_path: @@ -335,7 +335,7 @@ def load_config( raise ConfigError(f"Configuration file not found: {self.config_path}") try: - with open(self.config_path, "r") as f: + with open(self.config_path) as f: if self.config_path.suffix.lower() in [".yaml", ".yml"]: data = yaml.safe_load(f) elif self.config_path.suffix.lower() == ".json": @@ -351,7 +351,7 @@ def load_config( except Exception as e: raise ConfigError(f"Failed to load configuration: {e}") from e - def _parse_config(self, data: Dict[str, Any]) -> IngestionConfig: + def _parse_config(self, data: dict[str, Any]) -> IngestionConfig: """Parse configuration data into IngestionConfig object.""" # Parse global config global_data = data.get("global_config", {}) @@ -513,14 +513,14 @@ def _parse_config(self, data: Dict[str, Any]) -> IngestionConfig: schema_version=data.get("schema_version", "1.0.0"), ) - def validate_config(self) -> List[str]: + def validate_config(self) -> list[str]: """Validate the current configuration.""" if not self._config: return ["No configuration loaded"] return self._config.validate() def save_config( - self, config: IngestionConfig, output_path: Union[str, Path] + self, config: IngestionConfig, output_path: str | Path ) -> None: """Save configuration to file.""" output_path = Path(output_path) @@ -571,6 +571,6 @@ def save_config( else: raise ConfigError(f"Unsupported output format: {output_path.suffix}") - def get_config(self) -> Optional[IngestionConfig]: + def get_config(self) -> IngestionConfig | None: """Get the current configuration.""" return self._config diff --git a/argus/core/exceptions/__init__.py b/argus/core/exceptions/__init__.py index cdaa496..d3bf544 100644 --- a/argus/core/exceptions/__init__.py +++ b/argus/core/exceptions/__init__.py @@ -35,8 +35,8 @@ from .agent import AgentError as AgentSpecificError from .base import ( AgentError, - ConfigurationError, ArgusAgentError, + ConfigurationError, LLMError, MonitoringError, ProcessingError, diff --git a/argus/core/interfaces/base.py b/argus/core/interfaces/base.py index 0301251..a1db5dd 100644 --- a/argus/core/interfaces/base.py +++ b/argus/core/interfaces/base.py @@ -21,9 +21,9 @@ for all major components in the system. """ -import time from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, List, Optional, TypeVar +import time +from typing import Any, Generic, TypeVar from ..types import ( ConfigDict, @@ -97,7 +97,7 @@ def shutdown(self) -> None: pass @abstractmethod - def get_health_status(self) -> Dict[str, Any]: + def get_health_status(self) -> dict[str, Any]: """ Get the component's health status. @@ -130,7 +130,7 @@ class ConfigurableComponent(BaseComponent): """ def __init__( - self, component_id: str, name: str, config: Optional[ConfigDict] = None + self, component_id: str, name: str, config: ConfigDict | None = None ) -> None: """ Initialize the configurable component. @@ -195,7 +195,7 @@ class StatefulComponent(ConfigurableComponent): """ def __init__( - self, component_id: str, name: str, config: Optional[ConfigDict] = None + self, component_id: str, name: str, config: ConfigDict | None = None ) -> None: """ Initialize the stateful component. @@ -210,7 +210,7 @@ def __init__( self._state_history = [] @property - def state(self) -> Dict[str, Any]: + def state(self) -> dict[str, Any]: """Get the current state.""" return self._state.copy() @@ -226,7 +226,7 @@ def set_state(self, key: str, value: Any) -> None: pass @abstractmethod - def get_state(self, key: str, default: Optional[Any] = None) -> Any: + def get_state(self, key: str, default: Any | None = None) -> Any: """ Get a state value. @@ -240,7 +240,7 @@ def get_state(self, key: str, default: Optional[Any] = None) -> Any: pass @abstractmethod - def clear_state(self, key: Optional[str] = None) -> None: + def clear_state(self, key: str | None = None) -> None: """ Clear state values. @@ -249,7 +249,7 @@ def clear_state(self, key: Optional[str] = None) -> None: """ pass - def get_state_snapshot(self) -> Dict[str, Any]: + def get_state_snapshot(self) -> dict[str, Any]: """ Get a complete state snapshot. @@ -274,7 +274,7 @@ class ProcessableComponent(StatefulComponent, Generic[T, R]): """ def __init__( - self, component_id: str, name: str, config: Optional[ConfigDict] = None + self, component_id: str, name: str, config: ConfigDict | None = None ) -> None: """ Initialize the processable component. @@ -286,7 +286,7 @@ def __init__( """ super().__init__(component_id, name, config) self._processing_count = 0 - self._last_processed_at: Optional[Timestamp] = None + self._last_processed_at: Timestamp | None = None @property def processing_count(self) -> int: @@ -294,7 +294,7 @@ def processing_count(self) -> int: return self._processing_count @property - def last_processed_at(self) -> Optional[Timestamp]: + def last_processed_at(self) -> Timestamp | None: """Get the timestamp of the last processed item.""" return self._last_processed_at @@ -343,7 +343,7 @@ class MonitorableComponent(ProcessableComponent[T, R]): """ def __init__( - self, component_id: str, name: str, config: Optional[ConfigDict] = None + self, component_id: str, name: str, config: ConfigDict | None = None ) -> None: """ Initialize the monitorable component. @@ -358,17 +358,17 @@ def __init__( self._alerts = [] @property - def metrics(self) -> Dict[str, Any]: + def metrics(self) -> dict[str, Any]: """Get current metrics.""" return self._metrics.copy() @property - def alerts(self) -> List[Dict[str, Any]]: + def alerts(self) -> list[dict[str, Any]]: """Get current alerts.""" return self._alerts.copy() @abstractmethod - def collect_metrics(self) -> Dict[str, Any]: + def collect_metrics(self) -> dict[str, Any]: """ Collect component metrics. @@ -415,7 +415,7 @@ def clear_alerts(self) -> None: """Clear all alerts.""" self._alerts.clear() - def get_health_status(self) -> Dict[str, Any]: + def get_health_status(self) -> dict[str, Any]: """ Get comprehensive health status. diff --git a/argus/core/interfaces/protocols.py b/argus/core/interfaces/protocols.py index 2747747..9eac221 100644 --- a/argus/core/interfaces/protocols.py +++ b/argus/core/interfaces/protocols.py @@ -1,4 +1,5 @@ -from typing import Optional, Any, Dict, List, Union, Callable, Awaitable, TypeVar, Tuple +from typing import Any, TypeVar + # Copyright 2026 Divyansh Rawat # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,7 +24,8 @@ rather than inheritance hierarchy. """ -from typing import Any, AsyncIterator, Dict, List, Protocol, TypeVar +from collections.abc import AsyncIterator +from typing import Protocol from ..types import ( AgentContext, @@ -114,7 +116,7 @@ class Stateful(Protocol): """Protocol for objects that maintain state.""" @property - def state(self) -> Dict[str, Any]: + def state(self) -> dict[str, Any]: """Get current state.""" ... @@ -122,7 +124,7 @@ def set_state(self, key: str, value: Any) -> None: """Set new state value.""" ... - def get_state(self, key: str, default: Optional[Any] = None) -> Any: + def get_state(self, key: str, default: Any | None = None) -> Any: """Get state value by key.""" ... @@ -142,7 +144,7 @@ def validate(self) -> bool: """Validate the object.""" ... - def get_validation_errors(self) -> List[str]: + def get_validation_errors(self) -> list[str]: """Get validation errors if any.""" ... @@ -154,7 +156,7 @@ def is_healthy(self) -> bool: """Check if the object is healthy.""" ... - def get_health_status(self) -> Dict[str, Any]: + def get_health_status(self) -> dict[str, Any]: """Get detailed health status.""" ... @@ -162,11 +164,11 @@ def get_health_status(self) -> Dict[str, Any]: class MetricsCollector(Protocol): """Protocol for objects that collect metrics.""" - def collect_metrics(self) -> Dict[str, Any]: + def collect_metrics(self) -> dict[str, Any]: """Collect current metrics.""" ... - def get_metrics_summary(self) -> Dict[str, Any]: + def get_metrics_summary(self) -> dict[str, Any]: """Get metrics summary.""" ... @@ -174,7 +176,7 @@ def get_metrics_summary(self) -> Dict[str, Any]: class Alertable(Protocol): """Protocol for objects that can generate alerts.""" - def check_alerts(self) -> List[Dict[str, Any]]: + def check_alerts(self) -> list[dict[str, Any]]: """Check for alerts.""" ... @@ -240,7 +242,7 @@ def is_rate_limited(self) -> bool: """Check if currently rate limited.""" ... - def get_rate_limit_info(self) -> Dict[str, Any]: + def get_rate_limit_info(self) -> dict[str, Any]: """Get rate limit information.""" ... @@ -300,7 +302,7 @@ def provider_name(self) -> str: """Get provider name.""" ... - def get_models(self) -> List[Any]: + def get_models(self) -> list[Any]: """Get available models.""" ... @@ -364,7 +366,7 @@ def step_name(self) -> str: """Get step name.""" ... - def execute(self, context: Dict[str, Any]) -> Dict[str, Any]: + def execute(self, context: dict[str, Any]) -> dict[str, Any]: """Execute the workflow step.""" ... @@ -376,7 +378,7 @@ def add_step(self, step: WorkflowStep) -> None: """Add a workflow step.""" ... - def execute_workflow(self, initial_context: Dict[str, Any]) -> Dict[str, Any]: + def execute_workflow(self, initial_context: dict[str, Any]) -> dict[str, Any]: """Execute the complete workflow.""" ... @@ -405,7 +407,7 @@ class ResourceManager(Protocol): """Protocol for resource managers.""" def allocate_resource( - self, resource_type: str, requirements: Dict[str, Any] + self, resource_type: str, requirements: dict[str, Any] ) -> str: """Allocate a resource.""" ... @@ -414,7 +416,7 @@ def deallocate_resource(self, resource_id: str) -> None: """Deallocate a resource.""" ... - def get_resource_status(self, resource_id: str) -> Dict[str, Any]: + def get_resource_status(self, resource_id: str) -> dict[str, Any]: """Get resource status.""" ... @@ -422,7 +424,7 @@ def get_resource_status(self, resource_id: str) -> Dict[str, Any]: class LoadBalancer(Protocol): """Protocol for load balancers.""" - def select_target(self, targets: List[Any], context: Dict[str, Any]) -> Any: + def select_target(self, targets: list[Any], context: dict[str, Any]) -> Any: """Select a target for load balancing.""" ... @@ -466,7 +468,7 @@ def should_fallback(self, error: Exception) -> bool: class BatchProcessor(Protocol[T, R]): """Protocol for batch processors.""" - def process_batch(self, items: List[T]) -> List[R]: + def process_batch(self, items: list[T]) -> list[R]: """Process a batch of items.""" ... @@ -478,7 +480,7 @@ def get_batch_size(self) -> int: class AsyncBatchProcessor(Protocol[T, R]): """Protocol for async batch processors.""" - async def process_batch_async(self, items: List[T]) -> List[R]: + async def process_batch_async(self, items: list[T]) -> list[R]: """Process a batch of items asynchronously.""" ... @@ -510,7 +512,7 @@ def can_transform(self, input_type: type) -> bool: class Filter(Protocol[T]): """Protocol for data filters.""" - def filter(self, items: List[T]) -> List[T]: + def filter(self, items: list[T]) -> list[T]: """Filter items.""" ... @@ -522,7 +524,7 @@ def should_include(self, item: T) -> bool: class Aggregator(Protocol[T, R]): """Protocol for data aggregators.""" - def aggregate(self, items: List[T]) -> R: + def aggregate(self, items: list[T]) -> R: """Aggregate items.""" ... @@ -599,7 +601,7 @@ def implements_protocol(obj: Any, protocol: type) -> bool: return False -def get_protocol_methods(protocol: type) -> List[str]: +def get_protocol_methods(protocol: type) -> list[str]: """ Get method names from a protocol. @@ -618,7 +620,7 @@ def get_protocol_methods(protocol: type) -> List[str]: return methods -def validate_protocol_implementation(obj: Any, protocol: type) -> List[str]: +def validate_protocol_implementation(obj: Any, protocol: type) -> list[str]: """ Validate that an object implements a protocol. diff --git a/argus/core/logging/manager.py b/argus/core/logging/manager.py index e6d1194..0ce2c3e 100644 --- a/argus/core/logging/manager.py +++ b/argus/core/logging/manager.py @@ -84,8 +84,8 @@ def _create_handler(self, handler_config) -> logging.Handler | None: if handler_config.destination == OutputDestination.CONSOLE: return create_console_handler( formatter_type=( - handler_config.format.value - if hasattr(handler_config.format, "value") + handler_config.format.value + if hasattr(handler_config.format, "value") else "structured" ), colorize=handler_config.colorize @@ -95,8 +95,8 @@ def _create_handler(self, handler_config) -> logging.Handler | None: return create_file_handler( filename=handler_config.file_path or "app.log", formatter_type=( - handler_config.format.value - if hasattr(handler_config.format, "value") + handler_config.format.value + if hasattr(handler_config.format, "value") else "json" ), max_bytes=handler_config.max_file_size_mb * 1024 * 1024, @@ -109,8 +109,8 @@ def _create_handler(self, handler_config) -> logging.Handler | None: facility=handler_config.syslog_facility, address=handler_config.syslog_address, formatter_type=( - handler_config.format.value - if hasattr(handler_config.format, "value") + handler_config.format.value + if hasattr(handler_config.format, "value") else "json" ) ) @@ -122,8 +122,8 @@ def _create_handler(self, handler_config) -> logging.Handler | None: headers=handler_config.remote_headers, timeout=handler_config.remote_timeout, formatter_type=( - handler_config.format.value - if hasattr(handler_config.format, "value") + handler_config.format.value + if hasattr(handler_config.format, "value") else "json" ) ) diff --git a/argus/core/logging/metrics.py b/argus/core/logging/metrics.py index c293617..71b6f76 100644 --- a/argus/core/logging/metrics.py +++ b/argus/core/logging/metrics.py @@ -156,8 +156,8 @@ def to_dict(self) -> dict[str, Any]: "average_processing_time": self.average_processing_time, "max_processing_time": self.max_processing_time, "min_processing_time": ( - self.min_processing_time - if self.min_processing_time != float("inf") + self.min_processing_time + if self.min_processing_time != float("inf") else 0.0 ), "formatting_errors": self.formatting_errors, diff --git a/argus/core/performance/profiler.py b/argus/core/performance/profiler.py index 6e0129c..29610e5 100644 --- a/argus/core/performance/profiler.py +++ b/argus/core/performance/profiler.py @@ -376,7 +376,7 @@ def profile_operation( start_time = time.time() memory_before = ( - self._async_profiler._get_memory_usage() + self._async_profiler._get_memory_usage() if self._config.enable_memory_profiling else 0 ) @@ -387,7 +387,7 @@ def profile_operation( end_time = time.time() duration = end_time - start_time memory_after = ( - self._async_profiler._get_memory_usage() + self._async_profiler._get_memory_usage() if self._config.enable_memory_profiling else 0 ) memory_delta = memory_after - memory_before diff --git a/argus/core/quality/cli.py b/argus/core/quality/cli.py index 9dd95be..6e8f3ea 100644 --- a/argus/core/quality/cli.py +++ b/argus/core/quality/cli.py @@ -58,8 +58,8 @@ def cli(ctx, verbose: bool, config: str | None): @click.option("--coverage/--no-coverage", default=True, help="Enable/disable coverage checks") @click.option("--security/--no-security", default=True, help="Enable/disable security checks") @click.option( - "--performance/--no-performance", - default=True, + "--performance/--no-performance", + default=True, help="Enable/disable performance checks" ) @click.option("--docs/--no-docs", default=True, help="Enable/disable documentation checks") @@ -71,10 +71,10 @@ def cli(ctx, verbose: bool, config: str | None): @click.option("--timeout", default=300, help="Timeout in seconds") @click.option("--output", "-o", type=click.Path(), help="Output file for report") @click.option( - "--format", - "output_format", + "--format", + "output_format", type=click.Choice(["json", "html", "markdown", "console"]), - default="console", + default="console", help="Output format" ) @click.option("--gates", help="Comma-separated list of gates to run (default: all)") @@ -165,7 +165,7 @@ async def _run_quality_gates( if gates: gate_names = [name.strip() for name in gates.split(",")] manager.gates = { - name: manager.gates[name] + name: manager.gates[name] for name in gate_names if name in manager.gates } @@ -268,10 +268,10 @@ def init_config(output: str | None): @cli.command() @click.option( - "--format", - "output_format", + "--format", + "output_format", type=click.Choice(["json", "html", "markdown", "console"]), - default="console", + default="console", help="Output format" ) def list_gates(output_format: str): diff --git a/argus/core/quality/exceptions.py b/argus/core/quality/exceptions.py index aa66d06..ab7751c 100644 --- a/argus/core/quality/exceptions.py +++ b/argus/core/quality/exceptions.py @@ -85,10 +85,10 @@ class ToolExecutionError(QualityGateError): """Exception raised when a quality tool execution fails.""" def __init__( - self, - tool_name: str, - message: str, - exit_code: int | None = None, + self, + tool_name: str, + message: str, + exit_code: int | None = None, details: dict | None = None ): """Initialize the tool execution error. diff --git a/argus/core/quality/reports.py b/argus/core/quality/reports.py index 829006a..c400ff2 100644 --- a/argus/core/quality/reports.py +++ b/argus/core/quality/reports.py @@ -76,8 +76,8 @@ def __init__(self): self.logger = get_logger(__name__) def generate_report( - self, - results: list[QualityGateResult], + self, + results: list[QualityGateResult], duration: float = 0.0 ) -> QualityReport: """Generate a comprehensive quality report. @@ -398,9 +398,9 @@ def _format_console(self, report: QualityReport) -> str: return "\n".join(output) def save_report( - self, - report: QualityReport, - file_path: Path, + self, + report: QualityReport, + file_path: Path, format_type: ReportFormat ) -> None: """Save a quality report to a file. diff --git a/argus/core/quality/validators.py b/argus/core/quality/validators.py index 9e210e0..a50430d 100644 --- a/argus/core/quality/validators.py +++ b/argus/core/quality/validators.py @@ -103,11 +103,11 @@ async def run_pyright(self, config: QualityGateConfig) -> ValidationResult: "errors": errors }, errors=[ - e.get("message", "") + e.get("message", "") for e in errors if e.get("severity") == "error" ], warnings=[ - e.get("message", "") + e.get("message", "") for e in errors if e.get("severity") == "warning" ] ) diff --git a/argus/core/resilience/bulkhead_isolator.py b/argus/core/resilience/bulkhead_isolator.py index e60c9bb..7353d89 100644 --- a/argus/core/resilience/bulkhead_isolator.py +++ b/argus/core/resilience/bulkhead_isolator.py @@ -170,7 +170,7 @@ def get_stats(self) -> dict[str, Any]: failed_operations = total_operations - successful_operations success_rate = ( - (successful_operations / total_operations * 100) + (successful_operations / total_operations * 100) if total_operations > 0 else 0.0 ) @@ -181,7 +181,7 @@ def get_stats(self) -> dict[str, Any]: # Calculate current utilization current_usage = len(self._active_operations) utilization_rate = ( - (current_usage / self._config.max_concurrency * 100) + (current_usage / self._config.max_concurrency * 100) if self._config.max_concurrency > 0 else 0.0 ) diff --git a/argus/core/resilience/health_checker.py b/argus/core/resilience/health_checker.py index 8d6f2fb..ff05e34 100644 --- a/argus/core/resilience/health_checker.py +++ b/argus/core/resilience/health_checker.py @@ -363,7 +363,7 @@ def _get_overall_status(self) -> HealthStatus: return HealthStatus.UNHEALTHY unhealthy_count = sum( - 1 for status in self._health_status.values() + 1 for status in self._health_status.values() if status == HealthStatus.UNHEALTHY ) if unhealthy_count > 0: diff --git a/argus/core/resilience/rate_limiter.py b/argus/core/resilience/rate_limiter.py index 4533037..b98adae 100644 --- a/argus/core/resilience/rate_limiter.py +++ b/argus/core/resilience/rate_limiter.py @@ -190,13 +190,13 @@ def get_stats(self) -> dict[str, Any]: failed_requests = sum(1 for req in self._request_history if req["success"] is False) success_rate = ( - (successful_requests / total_requests * 100) + (successful_requests / total_requests * 100) if total_requests > 0 else 0.0 ) # Calculate requests per second requests_per_second = ( - current_requests / self._config.window_seconds + current_requests / self._config.window_seconds if self._config.window_seconds > 0 else 0.0 ) @@ -211,7 +211,7 @@ def get_stats(self) -> dict[str, Any]: "failed_requests": failed_requests, "success_rate": success_rate, "utilization_rate": ( - (current_requests / self._config.limit * 100) + (current_requests / self._config.limit * 100) if self._config.limit > 0 else 0.0 ), "config": { diff --git a/argus/core/resilience/retry_handler.py b/argus/core/resilience/retry_handler.py index 8208840..81bf482 100644 --- a/argus/core/resilience/retry_handler.py +++ b/argus/core/resilience/retry_handler.py @@ -192,8 +192,8 @@ def _record_attempt(self, success: bool, error: str | None) -> None: "success": success, "error": error, "delay": ( - self._calculate_delay() - if not success and self._attempt_count < self._config.max_attempts + self._calculate_delay() + if not success and self._attempt_count < self._config.max_attempts else 0.0 ) } @@ -216,7 +216,7 @@ def get_stats(self) -> dict[str, Any]: failed_attempts = total_attempts - successful_attempts success_rate = ( - (successful_attempts / total_attempts * 100) + (successful_attempts / total_attempts * 100) if total_attempts > 0 else 0.0 ) diff --git a/argus/core/resilience/timeout_manager.py b/argus/core/resilience/timeout_manager.py index 1e59b32..ea1da40 100644 --- a/argus/core/resilience/timeout_manager.py +++ b/argus/core/resilience/timeout_manager.py @@ -221,7 +221,7 @@ def get_stats(self) -> dict[str, Any]: failed_operations = total_operations - successful_operations success_rate = ( - (successful_operations / total_operations * 100) + (successful_operations / total_operations * 100) if total_operations > 0 else 0.0 ) @@ -231,11 +231,11 @@ def get_stats(self) -> dict[str, Any]: # Calculate timeout rate timeout_operations = sum( - 1 for op in self._timeout_history + 1 for op in self._timeout_history if not op["success"] and "timeout" in (op["error"] or "").lower() ) timeout_rate = ( - (timeout_operations / total_operations * 100) + (timeout_operations / total_operations * 100) if total_operations > 0 else 0.0 ) diff --git a/argus/ingestion/interfaces/resilience.py b/argus/ingestion/interfaces/resilience.py index 8871a12..216be16 100644 --- a/argus/ingestion/interfaces/resilience.py +++ b/argus/ingestion/interfaces/resilience.py @@ -277,12 +277,11 @@ async def execute(self, operation: Callable[[], Awaitable[T]]) -> T: try: # Apply rate limiting and bulkhead - async with self.rate_limiter: - async with self.bulkhead: - # Simple timeout implementation - result = await asyncio.wait_for( - operation(), timeout=self.timeout.timeout - ) + async with self.rate_limiter, self.bulkhead: + # Simple timeout implementation + result = await asyncio.wait_for( + operation(), timeout=self.timeout.timeout + ) self._stats["successful_operations"] += 1 return result diff --git a/argus/ingestion/manager/log_manager.py b/argus/ingestion/manager/log_manager.py index 4df257b..0abe123 100644 --- a/argus/ingestion/manager/log_manager.py +++ b/argus/ingestion/manager/log_manager.py @@ -180,7 +180,7 @@ async def _process_source_logs(self, source_name: str) -> None: # For file system sources, wait a bit before checking again try: source_config = source.get_config() - if (hasattr(source_config, "source_type") and + if (hasattr(source_config, "source_type") and source_config.source_type.value == "file_system"): await asyncio.sleep(1) except Exception: diff --git a/argus/llm/base.py b/argus/llm/base.py index 3d9007e..ea637cb 100644 --- a/argus/llm/base.py +++ b/argus/llm/base.py @@ -277,4 +277,3 @@ def provider_name(self) -> str: """Get the provider name.""" return self.provider_type -from .common.enums import ProviderType diff --git a/argus/llm/capabilities/discovery.py b/argus/llm/capabilities/discovery.py index 25bae64..d652937 100644 --- a/argus/llm/capabilities/discovery.py +++ b/argus/llm/capabilities/discovery.py @@ -16,8 +16,7 @@ import logging import time -from functools import lru_cache -from typing import Any, Dict, List, Optional, Set +from typing import Any from argus.llm.base import LLMProvider from argus.llm.capabilities.config import get_capability_config @@ -31,7 +30,7 @@ class CapabilityDiscovery: Discovers and catalogs capabilities of LLM models across various providers. """ - def __init__(self, providers: Dict[str, LLMProvider], cache_ttl: int = 3600) -> None: + def __init__(self, providers: dict[str, LLMProvider], cache_ttl: int = 3600) -> None: """ Initialize the CapabilityDiscovery. @@ -40,14 +39,14 @@ def __init__(self, providers: Dict[str, LLMProvider], cache_ttl: int = 3600) -> cache_ttl: Cache time-to-live in seconds (default: 1 hour). """ self.providers = providers - self.model_capabilities: Dict[str, ModelCapabilities] = {} + self.model_capabilities: dict[str, ModelCapabilities] = {} self.cache_ttl = cache_ttl - self._cache_timestamps: Dict[str, float] = {} - self._discovery_lock: Set[str] = set() # Prevent concurrent discovery of same model - + self._cache_timestamps: dict[str, float] = {} + self._discovery_lock: set[str] = set() # Prevent concurrent discovery of same model + # Load capability configuration self.capability_config = get_capability_config() - + # Metrics tracking self._metrics = { "discovery_attempts": 0, @@ -77,7 +76,7 @@ def clear_cache(self) -> None: async def discover_capabilities( self, force_refresh: bool = False - ) -> Dict[str, ModelCapabilities]: + ) -> dict[str, ModelCapabilities]: """ Discover capabilities for all configured models across all providers. @@ -89,7 +88,7 @@ async def discover_capabilities( """ start_time = time.time() self._metrics["discovery_attempts"] += 1 - + try: for provider_name, provider_instance in self.providers.items(): logger.info(f"Discovering capabilities for provider: {provider_name}") @@ -98,7 +97,7 @@ async def discover_capabilities( available_models = provider_instance.get_available_models() # Handle both dict and list return types - model_items: List[tuple] = [] + model_items: list[tuple] = [] if isinstance(available_models, dict): model_items = list(available_models.items()) elif isinstance(available_models, list): @@ -113,7 +112,7 @@ async def discover_capabilities( logger.debug(f"Using cached capabilities for {model_id}") self._metrics["cache_hits"] += 1 continue - + self._metrics["cache_misses"] += 1 # Prevent concurrent discovery of the same model @@ -124,12 +123,12 @@ async def discover_capabilities( self._discovery_lock.add(model_id) try: capabilities = [] - + # Safely check provider capabilities try: # Check if the provider has these methods using getattr supports_streaming_method = getattr( - provider_instance, 'supports_streaming', None + provider_instance, "supports_streaming", None ) if supports_streaming_method: supports_streaming = await self._safe_check_capability( @@ -137,9 +136,9 @@ async def discover_capabilities( ) else: supports_streaming = False - + supports_tools_method = getattr( - provider_instance, 'supports_tools', None + provider_instance, "supports_tools", None ) if supports_tools_method: supports_tools = await self._safe_check_capability( @@ -174,7 +173,7 @@ async def discover_capabilities( provider_capabilities = self.capability_config.get_provider_capabilities( provider_name_lower, model_name ) - + # Add configured capabilities for cap_name in provider_capabilities: cap_def = self.capability_config.get_capability(cap_name) @@ -192,13 +191,13 @@ async def discover_capabilities( ), ) capabilities.append(capability) - + # Add dynamic capabilities based on provider features if supports_streaming: streaming_cap = self.capability_config.get_capability("streaming") if streaming_cap: capabilities.append(streaming_cap) - + if supports_tools: tool_cap = self.capability_config.get_capability("tool_calling") if tool_cap: @@ -217,24 +216,24 @@ async def discover_capabilities( discovery_time = end_time - start_time self._metrics["last_discovery_time"] = end_time self._metrics["discovery_successes"] += 1 - + # Update average discovery time if self._metrics["discovery_attempts"] == 1: self._metrics["average_discovery_time"] = discovery_time else: # Running average self._metrics["average_discovery_time"] = ( - (self._metrics["average_discovery_time"] * + (self._metrics["average_discovery_time"] * (self._metrics["discovery_attempts"] - 1) + discovery_time) / self._metrics["discovery_attempts"] ) - + logger.info( f"Discovered capabilities for {len(self.model_capabilities)} models " f"in {discovery_time:.2f}s." ) return self.model_capabilities - + except Exception as e: self._metrics["discovery_failures"] += 1 logger.error(f"Capability discovery failed: {e}") @@ -252,10 +251,10 @@ async def _safe_check_capability(self, capability_method) -> bool: """ try: # Try calling as async first - if hasattr(capability_method, '__call__'): + if hasattr(capability_method, "__call__"): result = capability_method() # If it returns a coroutine, await it - if hasattr(result, '__await__'): + if hasattr(result, "__await__"): return await result # Otherwise it's synchronous return bool(result) @@ -266,7 +265,7 @@ async def _safe_check_capability(self, capability_method) -> bool: def get_model_capabilities( self, model_id: str, auto_refresh: bool = True - ) -> Optional[ModelCapabilities]: + ) -> ModelCapabilities | None: """ Retrieve capabilities for a specific model. @@ -284,7 +283,7 @@ def get_model_capabilities( return None return self.model_capabilities.get(model_id) - def find_models_by_capability(self, capability_name: str) -> List[str]: + def find_models_by_capability(self, capability_name: str) -> list[str]: """ Find models that support a specific capability. @@ -303,8 +302,8 @@ def find_models_by_capability(self, capability_name: str) -> List[str]: return matching_models def find_models_by_capabilities( - self, capability_names: List[str], require_all: bool = False - ) -> List[str]: + self, capability_names: list[str], require_all: bool = False + ) -> list[str]: """ Find models that support multiple capabilities. @@ -318,17 +317,17 @@ def find_models_by_capabilities( matching_models = [] for model_id, model_caps in self.model_capabilities.items(): model_capability_names = {cap.name for cap in model_caps.capabilities} - + if require_all: if all(cap_name in model_capability_names for cap_name in capability_names): matching_models.append(model_id) else: if any(cap_name in model_capability_names for cap_name in capability_names): matching_models.append(model_id) - + return matching_models - def get_capability_summary(self) -> Dict[str, int]: + def get_capability_summary(self) -> dict[str, int]: """ Get a summary of how many models support each capability. @@ -341,7 +340,7 @@ def get_capability_summary(self) -> Dict[str, int]: capability_counts[cap.name] = capability_counts.get(cap.name, 0) + 1 return capability_counts - def get_metrics(self) -> Dict[str, Any]: + def get_metrics(self) -> dict[str, Any]: """ Get current metrics for the capability discovery system. @@ -350,7 +349,7 @@ def get_metrics(self) -> Dict[str, Any]: """ return self._metrics.copy() - def get_health_status(self) -> Dict[str, Any]: + def get_health_status(self) -> dict[str, Any]: """ Get health status of the capability discovery system. @@ -359,21 +358,21 @@ def get_health_status(self) -> Dict[str, Any]: """ total_attempts = self._metrics["discovery_attempts"] success_rate = ( - self._metrics["discovery_successes"] / total_attempts + self._metrics["discovery_successes"] / total_attempts if total_attempts > 0 else 0.0 ) - + cache_hit_rate = ( - self._metrics["cache_hits"] / + self._metrics["cache_hits"] / (self._metrics["cache_hits"] + self._metrics["cache_misses"]) - if (self._metrics["cache_hits"] + self._metrics["cache_misses"]) > 0 + if (self._metrics["cache_hits"] + self._metrics["cache_misses"]) > 0 else 0.0 ) - + return { "status": ( - "healthy" if success_rate > 0.8 - else "degraded" if success_rate > 0.5 + "healthy" if success_rate > 0.8 + else "degraded" if success_rate > 0.5 else "unhealthy" ), "success_rate": success_rate, @@ -396,7 +395,7 @@ def reset_metrics(self) -> None: } logger.info("Capability discovery metrics reset") - def validate_task_requirements(self, task_type: str, model_id: str) -> Dict[str, Any]: + def validate_task_requirements(self, task_type: str, model_id: str) -> dict[str, Any]: """ Validate if a model meets the requirements for a specific task type. @@ -413,7 +412,7 @@ def validate_task_requirements(self, task_type: str, model_id: str) -> Dict[str, "meets_requirements": False, "error": f"Model {model_id} not found in capabilities database" } - + available_capabilities = [cap.name for cap in model_caps.capabilities] return self.capability_config.validate_capability_requirements( task_type, available_capabilities @@ -421,7 +420,7 @@ def validate_task_requirements(self, task_type: str, model_id: str) -> Dict[str, def find_models_for_task( self, task_type: str, min_coverage: float = 0.8 - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ Find models that meet the requirements for a specific task type. @@ -433,10 +432,10 @@ def find_models_for_task( List of dictionaries with model information and validation results. """ suitable_models = [] - + for model_id, model_caps in self.model_capabilities.items(): validation_result = self.validate_task_requirements(task_type, model_id) - + if validation_result.get("meets_requirements", False): coverage_score = validation_result.get("coverage_score", 0.0) if coverage_score >= min_coverage: @@ -446,7 +445,7 @@ def find_models_for_task( "validation_result": validation_result, "coverage_score": coverage_score }) - + # Sort by coverage score (descending) suitable_models.sort(key=lambda x: x["coverage_score"], reverse=True) return suitable_models diff --git a/argus/llm/capabilities/testing.py b/argus/llm/capabilities/testing.py index 585907f..28aa7c0 100644 --- a/argus/llm/capabilities/testing.py +++ b/argus/llm/capabilities/testing.py @@ -14,9 +14,8 @@ # argus/llm/capabilities/testing.py -import logging from abc import ABC, abstractmethod -from typing import Dict, List +import logging from argus.llm.base import LLMProvider from argus.llm.capabilities.models import ModelCapability @@ -79,7 +78,7 @@ async def run_test(self, provider: LLMProvider, model_name: str) -> bool: temperature=0.1 ) response = await provider._generate(request) - response_content = response.content if hasattr(response, 'content') else str(response) + response_content = response.content if hasattr(response, "content") else str(response) return expected_substring.lower() in response_content.lower() except Exception as e: logger.error(f"Text generation test failed for {model_name}: {e}") @@ -118,7 +117,7 @@ async def run_test(self, provider: LLMProvider, model_name: str) -> bool: temperature=0.1 ) response = await provider._generate(request) - response_content = response.content if hasattr(response, 'content') else str(response) + response_content = response.content if hasattr(response, "content") else str(response) return expected_substring.lower() in response_content.lower() except Exception as e: logger.error(f"Code generation test failed for {model_name}: {e}") @@ -130,11 +129,11 @@ class CapabilityTester: Runs a suite of capability tests against LLM models. """ - def __init__(self, providers: Dict[str, LLMProvider], tests: List[CapabilityTest]) -> None: + def __init__(self, providers: dict[str, LLMProvider], tests: list[CapabilityTest]) -> None: self.providers = providers self.tests = tests - async def run_all_tests(self) -> Dict[str, Dict[str, bool]]: + async def run_all_tests(self) -> dict[str, dict[str, bool]]: """ Run all configured tests against all models. @@ -145,13 +144,13 @@ async def run_all_tests(self) -> Dict[str, Dict[str, bool]]: for provider_name, provider_instance in self.providers.items(): available_models = provider_instance.get_available_models() # Handle both dict and list return types - model_items: List[tuple] = [] + model_items: list[tuple] = [] if isinstance(available_models, dict): model_items = list(available_models.items()) elif isinstance(available_models, list): # If it's a list, create tuples with model names model_items = [("default", model_name) for model_name in available_models] # type: ignore - + for _model_type, model_name in model_items: model_id = f"{provider_name}/{model_name}" results[model_id] = {} diff --git a/argus/llm/config.py b/argus/llm/config.py index 43852d6..e38b975 100644 --- a/argus/llm/config.py +++ b/argus/llm/config.py @@ -102,7 +102,7 @@ class LLMProviderConfig(BaseModel): models: dict[str, ModelConfig] = Field( default_factory=dict, description="Available models" ) - # ModelType to model name mappings - allows users to configure which models + # ModelType to model name mappings - allows users to configure which models # are used for each semantic type model_type_mappings: dict[ModelType, str] = Field( default_factory=dict, diff --git a/argus/llm/config_loaders.py b/argus/llm/config_loaders.py index bbf0f07..c8d5033 100644 --- a/argus/llm/config_loaders.py +++ b/argus/llm/config_loaders.py @@ -21,12 +21,13 @@ including environment variables, files, and programmatic sources. """ +from collections.abc import Callable +from dataclasses import dataclass import json import logging import os -from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Union +from typing import Any import yaml @@ -37,10 +38,10 @@ class LoaderResult: """Result from a configuration loader.""" - data: Dict[str, Any] + data: dict[str, Any] source: str - metadata: Dict[str, Any] - errors: List[str] + metadata: dict[str, Any] + errors: list[str] class BaseConfigLoader: @@ -56,13 +57,13 @@ def __init__(self, source: str, priority: int = 0) -> None: """ self.source = source self.priority = priority - self._validators: List[Callable[[Dict[str, Any]], List[str]]] = [] + self._validators: list[Callable[[dict[str, Any]], list[str]]] = [] - def add_validator(self, validator: Callable[[Dict[str, Any]], List[str]]) -> None: + def add_validator(self, validator: Callable[[dict[str, Any]], list[str]]) -> None: """Add a validation function.""" self._validators.append(validator) - def validate(self, data: Dict[str, Any]) -> List[str]: + def validate(self, data: dict[str, Any]) -> list[str]: """Validate configuration data.""" errors = [] for validator in self._validators: @@ -145,7 +146,7 @@ def load(self) -> LoaderResult: ) def _set_nested_value( - self, data: Dict[str, Any], key_path: str, value: Any + self, data: dict[str, Any], key_path: str, value: Any ) -> None: """Set a nested value in the data dictionary.""" keys = key_path.split(".") @@ -185,7 +186,7 @@ def _set_nested_value( else: current[keys[-1]] = value - def _load_provider_env_vars(self) -> Dict[str, Any]: + def _load_provider_env_vars(self) -> dict[str, Any]: """Load provider-specific environment variables.""" providers = {} @@ -233,7 +234,7 @@ def _load_provider_env_vars(self) -> Dict[str, Any]: return providers - def _load_agent_env_vars(self) -> Dict[str, Any]: + def _load_agent_env_vars(self) -> dict[str, Any]: """Load agent-specific environment variables.""" agents = {} @@ -264,7 +265,7 @@ def _load_agent_env_vars(self) -> Dict[str, Any]: class FileConfigLoader(BaseConfigLoader): """Loader for file-based configuration (YAML/JSON).""" - def __init__(self, file_path: Union[str, Path], priority: int = 1) -> None: + def __init__(self, file_path: str | Path, priority: int = 1) -> None: """ Initialize file loader. @@ -290,7 +291,7 @@ def load(self) -> LoaderResult: errors=errors, ) - with open(self.file_path, "r") as f: + with open(self.file_path) as f: if self.file_path.suffix.lower() in [".yaml", ".yml"]: data = yaml.safe_load(f) or {} elif self.file_path.suffix.lower() == ".json": @@ -337,7 +338,7 @@ def load(self) -> LoaderResult: class ProgrammaticConfigLoader(BaseConfigLoader): """Loader for programmatically provided configuration.""" - def __init__(self, config_data: Dict[str, Any], priority: int = 3) -> None: + def __init__(self, config_data: dict[str, Any], priority: int = 3) -> None: """ Initialize programmatic loader. @@ -374,14 +375,14 @@ class ConfigLoaderManager: def __init__(self) -> None: """Initialize the loader manager.""" - self.loaders: List[BaseConfigLoader] = [] - self._results: List[LoaderResult] = [] + self.loaders: list[BaseConfigLoader] = [] + self._results: list[LoaderResult] = [] def add_loader(self, loader: BaseConfigLoader) -> None: """Add a configuration loader.""" self.loaders.append(loader) - def load_all(self) -> Dict[str, Any]: + def load_all(self) -> dict[str, Any]: """Load configuration from all loaders and merge results.""" self._results = [] merged_data = {} @@ -413,8 +414,8 @@ def load_all(self) -> Dict[str, Any]: return merged_data def _merge_config_data( - self, base: Dict[str, Any], updates: Dict[str, Any] - ) -> Dict[str, Any]: + self, base: dict[str, Any], updates: dict[str, Any] + ) -> dict[str, Any]: """Deep merge configuration data.""" result = base.copy() @@ -430,18 +431,18 @@ def _merge_config_data( return result - def get_loader_results(self) -> List[LoaderResult]: + def get_loader_results(self) -> list[LoaderResult]: """Get results from all loaders.""" return self._results.copy() - def get_all_errors(self) -> List[str]: + def get_all_errors(self) -> list[str]: """Get all errors from all loaders.""" errors = [] for result in self._results: errors.extend(result.errors) return errors - def get_loader_summary(self) -> Dict[str, Any]: + def get_loader_summary(self) -> dict[str, Any]: """Get a summary of loader results.""" return { "total_loaders": len(self.loaders), diff --git a/argus/llm/config_manager.py b/argus/llm/config_manager.py index 07010d7..2a7239a 100644 --- a/argus/llm/config_manager.py +++ b/argus/llm/config_manager.py @@ -21,12 +21,12 @@ multiple LLM providers, models, resilience patterns, and cost management. """ +from dataclasses import dataclass, field import json import logging import os -from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any import yaml @@ -45,9 +45,9 @@ class ConfigSource: """Represents a configuration source with metadata.""" source_type: str # 'env', 'file', 'programmatic' - path: Optional[str] = None + path: str | None = None priority: int = 0 # Higher number = higher priority - metadata: Dict[str, Any] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) class ConfigManager: @@ -58,7 +58,7 @@ class ConfigManager: from multiple sources with proper precedence rules. """ - def __init__(self, config_path: Optional[Union[str, Path]] = None) -> None: + def __init__(self, config_path: str | Path | None = None) -> None: """ Initialize the configuration manager. @@ -66,10 +66,10 @@ def __init__(self, config_path: Optional[Union[str, Path]] = None) -> None: config_path: Optional path to configuration file """ self.config_path = Path(config_path) if config_path else None - self._config: Optional[LLMConfig] = None - self._sources: List[ConfigSource] = [] - self._watchers: List[Any] = [] # File watchers for hot-reload - self._callbacks: List[Any] = [] + self._config: LLMConfig | None = None + self._sources: list[ConfigSource] = [] + self._watchers: list[Any] = [] # File watchers for hot-reload + self._callbacks: list[Any] = [] # Load initial configuration self._load_configuration() @@ -120,10 +120,10 @@ def _load_configuration(self) -> None: agents={}, ) - def _load_from_file(self, path: Path) -> Dict[str, Any]: + def _load_from_file(self, path: Path) -> dict[str, Any]: """Load configuration from a file (YAML or JSON).""" try: - with open(path, "r") as f: + with open(path) as f: if path.suffix.lower() in [".yaml", ".yml"]: return yaml.safe_load(f) or {} elif path.suffix.lower() == ".json": @@ -135,7 +135,7 @@ def _load_from_file(self, path: Path) -> Dict[str, Any]: logger.error(f"Failed to load configuration from {path}: {e}") return {} - def _load_from_environment(self) -> Dict[str, Any]: + def _load_from_environment(self) -> dict[str, Any]: """Load configuration from environment variables.""" config_data = {} @@ -184,17 +184,17 @@ def get_config(self) -> LLMConfig: assert self._config is not None return self._config - def get_provider_config(self, provider_name: str) -> Optional[LLMProviderConfig]: + def get_provider_config(self, provider_name: str) -> LLMProviderConfig | None: """Get configuration for a specific provider.""" config = self.get_config() return config.providers.get(provider_name) - def get_agent_config(self, agent_name: str) -> Optional[AgentLLMConfig]: + def get_agent_config(self, agent_name: str) -> AgentLLMConfig | None: """Get configuration for a specific agent.""" config = self.get_config() return config.agents.get(agent_name) - def update_config(self, updates: Dict[str, Any]) -> None: + def update_config(self, updates: dict[str, Any]) -> None: """Update configuration programmatically.""" try: current_config = self.get_config() @@ -216,7 +216,7 @@ def update_config(self, updates: Dict[str, Any]) -> None: logger.error(f"Failed to update configuration: {e}") raise - def _deep_merge(self, base: Dict[str, Any], updates: Dict[str, Any]) -> None: + def _deep_merge(self, base: dict[str, Any], updates: dict[str, Any]) -> None: """Deep merge updates into base dictionary.""" for key, value in updates.items(): if key in base and isinstance(base[key], dict) and isinstance(value, dict): @@ -241,7 +241,7 @@ def _notify_callbacks(self) -> None: except Exception as e: logger.error(f"Error in configuration change callback: {e}") - def validate_config(self) -> List[str]: + def validate_config(self) -> list[str]: """Validate the current configuration and return any errors.""" errors = [] @@ -292,7 +292,7 @@ def validate_config(self) -> List[str]: return errors - def get_config_summary(self) -> Dict[str, Any]: + def get_config_summary(self) -> dict[str, Any]: """Get a summary of the current configuration.""" config = self.get_config() @@ -319,7 +319,7 @@ def get_config_summary(self) -> Dict[str, Any]: ], } - def export_config(self, path: Union[str, Path], format: str = "yaml") -> None: + def export_config(self, path: str | Path, format: str = "yaml") -> None: """Export current configuration to a file.""" config = self.get_config() config_dict = config.model_dump() @@ -343,7 +343,7 @@ def export_config(self, path: Union[str, Path], format: str = "yaml") -> None: # Global configuration manager instance -_config_manager: Optional[ConfigManager] = None +_config_manager: ConfigManager | None = None def get_config_manager() -> ConfigManager: @@ -354,7 +354,7 @@ def get_config_manager() -> ConfigManager: return _config_manager -def initialize_config(config_path: Optional[Union[str, Path]] = None) -> ConfigManager: +def initialize_config(config_path: str | Path | None = None) -> ConfigManager: """Initialize the global configuration manager.""" global _config_manager _config_manager = ConfigManager(config_path) diff --git a/argus/llm/enhanced_service.py b/argus/llm/enhanced_service.py index 79b518d..e74d9bf 100644 --- a/argus/llm/enhanced_service.py +++ b/argus/llm/enhanced_service.py @@ -277,8 +277,8 @@ async def generate_structured( elif response_model.__name__ == "RemediationPlan": # Truncate the prompt to avoid overwhelming the model truncated_prompt = prompt[:500] + "..." if len(prompt) > 500 else prompt - structured_prompt = f"""Return only this JSON, no other text: -{{ + structured_prompt = """Return only this JSON, no other text: +{ "agent_id": "remediation-agent-1", "agent_type": "remediation", "status": "success", @@ -288,7 +288,7 @@ async def generate_structured( "estimated_total_duration": "1 hour", "estimated_total_effort": "medium", "steps": [ - {{ + { "step_id": "step-1", "order": 1, "title": "Fix Error", @@ -305,7 +305,7 @@ async def generate_structured( "affected_systems": ["main system"], "requires_approval": false, "automated": true - }} + } ], "success_criteria": ["Error fixed", "Service running"], "risk_assessment": "Low risk", @@ -315,7 +315,7 @@ async def generate_structured( "approval_required": false, "automated_steps": 1, "manual_steps": 0 -}}""" +}""" else: # Truncate the prompt to avoid overwhelming the model truncated_prompt = prompt[:2000] + "..." if len(prompt) > 2000 else prompt @@ -380,7 +380,7 @@ async def generate_structured( # Extract JSON from response (in case there's extra text) self.logger.debug(f"Raw LLM response: {response.content}") - + # Try multiple patterns to extract JSON json_str = None patterns = [ @@ -388,14 +388,14 @@ async def generate_structured( r"```\s*(\{.*?\})\s*```", # Generic code blocks r"(\{.*\})", # Plain JSON ] - + for pattern in patterns: json_match = re.search(pattern, response.content, re.DOTALL) if json_match: json_str = json_match.group(1) self.logger.debug(f"Extracted JSON with pattern {pattern}: {json_str}") break - + if json_str: try: parsed_data = json.loads(json_str) diff --git a/argus/llm/factory.py b/argus/llm/factory.py index 080cf80..6fcfe45 100644 --- a/argus/llm/factory.py +++ b/argus/llm/factory.py @@ -20,7 +20,7 @@ """ import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any from .base import LLMProvider from .config import LLMConfig, LLMProviderConfig @@ -43,7 +43,7 @@ class LLMProviderFactory: cached instances to ensure efficient resource reuse. Supports both class-level logic and instance-level logic for backward compatibility. """ - + _providers_registry = { "gemini": GeminiProvider, "openai": OpenAIProvider, @@ -78,10 +78,10 @@ def create_provider(self, config: LLMProviderConfig, force_recreate: bool = Fals provider_type = config.provider if provider_type not in self._provider_types: raise ValueError(f"Unsupported provider type: {provider_type}") - + if not force_recreate and provider_type in self._providers: return self._providers[provider_type] - + provider_class = self._provider_types[provider_type] try: provider = provider_class(config) @@ -101,7 +101,7 @@ def create_provider(self, config: LLMProviderConfig, force_recreate: bool = Fals except Exception as e: raise RuntimeError(f"Provider creation failed: {e}") - def get_provider(self, name: Union[str, LLMProviderConfig]) -> Any: + def get_provider(self, name: str | LLMProviderConfig) -> Any: """Hybrid get_provider: supports both string name and config object.""" if isinstance(name, str): return self._providers.get(name) diff --git a/argus/llm/mirascope_response.py b/argus/llm/mirascope_response.py index 2ca834a..4856f15 100644 --- a/argus/llm/mirascope_response.py +++ b/argus/llm/mirascope_response.py @@ -21,14 +21,14 @@ transformation capabilities for Mirascope provider responses. """ -import json -import logging -import re from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Type, TypeVar +import json +import logging +import re +from typing import Any, TypeVar from pydantic import BaseModel, ValidationError @@ -65,9 +65,9 @@ class ResponseMetadata: validation_passed: bool = True quality_score: float = 0.0 quality_assessment: ResponseQuality = ResponseQuality.UNKNOWN - warnings: List[str] = field(default_factory=list) - errors: List[str] = field(default_factory=list) - transformations_applied: List[str] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + errors: list[str] = field(default_factory=list) + transformations_applied: list[str] = field(default_factory=list) confidence_score: float = 0.0 @@ -79,15 +79,15 @@ class ProcessedResponse: processed_content: str status: ResponseStatus metadata: ResponseMetadata - structured_data: Optional[Any] = None - validation_results: Optional[Dict[str, Any]] = None + structured_data: Any | None = None + validation_results: dict[str, Any] | None = None class ResponseValidator(ABC): """Abstract base class for response validators.""" @abstractmethod - def validate(self, response: ClientResponse) -> Dict[str, Any]: + def validate(self, response: ClientResponse) -> dict[str, Any]: """Validate a response and return validation results.""" pass @@ -99,7 +99,7 @@ def __init__(self, min_length: int = 1, max_length: int = 10000) -> None: self.min_length = min_length self.max_length = max_length - def validate(self, response: ClientResponse) -> Dict[str, Any]: + def validate(self, response: ClientResponse) -> dict[str, Any]: """Validate content length.""" content_length = len(response.content) @@ -119,10 +119,10 @@ def validate(self, response: ClientResponse) -> Dict[str, Any]: class JSONStructureValidator(ResponseValidator): """Validates JSON structure in responses.""" - def __init__(self, required_fields: Optional[List[str]] = None) -> None: + def __init__(self, required_fields: list[str] | None = None) -> None: self.required_fields = required_fields or [] - def validate(self, response: ClientResponse) -> Dict[str, Any]: + def validate(self, response: ClientResponse) -> dict[str, Any]: """Validate JSON structure.""" try: data = json.loads(response.content) @@ -153,7 +153,7 @@ def validate(self, response: ClientResponse) -> Dict[str, Any]: except json.JSONDecodeError as e: return { "valid": False, - "error": f"Invalid JSON: {str(e)}", + "error": f"Invalid JSON: {e!s}", "parsed_data": None, } @@ -161,12 +161,12 @@ def validate(self, response: ClientResponse) -> Dict[str, Any]: class RegexPatternValidator(ResponseValidator): """Validates response against regex patterns.""" - def __init__(self, patterns: Dict[str, str]) -> None: + def __init__(self, patterns: dict[str, str]) -> None: self.patterns = { name: re.compile(pattern) for name, pattern in patterns.items() } - def validate(self, response: ClientResponse) -> Dict[str, Any]: + def validate(self, response: ClientResponse) -> dict[str, Any]: """Validate against regex patterns.""" results = {} all_valid = True @@ -383,8 +383,8 @@ class ResponseProcessor: """Main response processor with validation, transformation, and quality assessment.""" def __init__(self) -> None: - self.validators: List[ResponseValidator] = [] - self.transformers: List[ResponseTransformer] = [] + self.validators: list[ResponseValidator] = [] + self.transformers: list[ResponseTransformer] = [] self.quality_assessor = QualityAssessor() self.logger = logging.getLogger(__name__) @@ -425,7 +425,7 @@ def process_response( status = ResponseStatus.VALIDATION_ERROR except Exception as e: error_msg = ( - f"Validator {validator.__class__.__name__} failed: {str(e)}" + f"Validator {validator.__class__.__name__} failed: {e!s}" ) errors.append(error_msg) self.logger.warning(error_msg) @@ -451,7 +451,7 @@ def process_response( transformations_applied.append(transformer.__class__.__name__) except Exception as e: error_msg = ( - f"Transformer {transformer.__class__.__name__} failed: {str(e)}" + f"Transformer {transformer.__class__.__name__} failed: {e!s}" ) errors.append(error_msg) status = ResponseStatus.TRANSFORMATION_ERROR @@ -464,7 +464,7 @@ def process_response( try: quality, quality_score = self.quality_assessor.assess_quality(response) except Exception as e: - warnings.append(f"Quality assessment failed: {str(e)}") + warnings.append(f"Quality assessment failed: {e!s}") self.logger.warning(f"Quality assessment failed: {e}") # Calculate processing time @@ -491,8 +491,8 @@ def process_response( ) def process_structured_response( - self, response: ClientResponse, response_model: Type[T], validate: bool = True - ) -> tuple[ProcessedResponse, Optional[T]]: + self, response: ClientResponse, response_model: type[T], validate: bool = True + ) -> tuple[ProcessedResponse, T | None]: """Process a response and attempt to parse it into a structured model.""" processed = self.process_response(response, validate=validate, transform=False) @@ -504,14 +504,14 @@ def process_structured_response( structured_data = response_model(**data) except (json.JSONDecodeError, ValidationError) as e: processed.metadata.errors.append( - f"Failed to parse structured data: {str(e)}" + f"Failed to parse structured data: {e!s}" ) processed.status = ResponseStatus.PARSING_ERROR self.logger.warning(f"Failed to parse structured response: {e}") return processed, structured_data - def get_processing_stats(self) -> Dict[str, Any]: + def get_processing_stats(self) -> dict[str, Any]: """Get statistics about response processing.""" return { "validators_count": len(self.validators), @@ -563,8 +563,8 @@ def create_text_processor() -> ResponseProcessor: @staticmethod def create_custom_processor( - validators: Optional[List[ResponseValidator]] = None, - transformers: Optional[List[ResponseTransformer]] = None, + validators: list[ResponseValidator] | None = None, + transformers: list[ResponseTransformer] | None = None, ) -> ResponseProcessor: """Create a processor with custom validators and transformers.""" processor = ResponseProcessor() @@ -581,7 +581,7 @@ def create_custom_processor( # Global response processor instance -_response_processor: Optional[ResponseProcessor] = None +_response_processor: ResponseProcessor | None = None def get_response_processor() -> ResponseProcessor: diff --git a/argus/llm/provider_framework/base_template.py b/argus/llm/provider_framework/base_template.py index 6b3b2d3..3d4d330 100644 --- a/argus/llm/provider_framework/base_template.py +++ b/argus/llm/provider_framework/base_template.py @@ -21,10 +21,10 @@ making it easy to implement new providers with minimal code. """ +from abc import abstractmethod import asyncio import logging -from abc import abstractmethod -from typing import Any, Dict, List +from typing import Any from ..base import LLMProvider, LLMRequest, LLMResponse, ModelType from ..config import LLMProviderConfig @@ -66,17 +66,17 @@ def _initialize_provider(self) -> None: # Abstract methods that must be implemented by subclasses @abstractmethod - async def _make_api_request(self, request: LLMRequest) -> Dict[str, Any]: + async def _make_api_request(self, request: LLMRequest) -> dict[str, Any]: """Make the actual API request to the provider. Must be implemented.""" pass @abstractmethod - def _parse_response(self, response_data: Dict[str, Any]) -> LLMResponse: + def _parse_response(self, response_data: dict[str, Any]) -> LLMResponse: """Parse the API response into LLMResponse format. Must be implemented.""" pass @abstractmethod - def _get_model_mapping(self) -> Dict[ModelType, str]: + def _get_model_mapping(self) -> dict[ModelType, str]: """Get the mapping of semantic types to actual model names. Must be implemented.""" pass @@ -138,11 +138,11 @@ def supports_tools(self) -> bool: """Check if provider supports tool calling. Override in subclasses.""" return False - def get_available_models(self) -> Dict[ModelType, str]: + def get_available_models(self) -> dict[ModelType, str]: """Get available models mapped to semantic types.""" return self._get_model_mapping() - async def embeddings(self, text: str) -> List[float]: + async def embeddings(self, text: str) -> list[float]: """Generate embeddings for the given text. Override if supported.""" raise NotImplementedError(f"Embeddings not supported by {self.provider_name}") @@ -168,7 +168,7 @@ def validate_config(cls, config: LLMProviderConfig) -> None: raise ValueError("Base URL must start with http:// or https://") # Helper methods for common functionality - def _get_headers(self) -> Dict[str, str]: + def _get_headers(self) -> dict[str, str]: """Get common HTTP headers for API requests.""" return { "Authorization": f"Bearer {self.api_key}", @@ -176,7 +176,7 @@ def _get_headers(self) -> Dict[str, str]: "User-Agent": f"argus/{self.provider_name}", } - def _get_request_payload(self, request: LLMRequest) -> Dict[str, Any]: + def _get_request_payload(self, request: LLMRequest) -> dict[str, Any]: """Convert LLMRequest to provider-specific payload format.""" # Default OpenAI-compatible format return { @@ -187,7 +187,7 @@ def _get_request_payload(self, request: LLMRequest) -> Dict[str, Any]: "stream": False, } - def _parse_openai_response(self, response_data: Dict[str, Any]) -> LLMResponse: + def _parse_openai_response(self, response_data: dict[str, Any]) -> LLMResponse: """Parse OpenAI-compatible response format.""" choices = response_data.get("choices", []) if not choices: diff --git a/argus/llm/providers/groq_provider.py b/argus/llm/providers/groq_provider.py index 95a1be0..de713b6 100644 --- a/argus/llm/providers/groq_provider.py +++ b/argus/llm/providers/groq_provider.py @@ -19,9 +19,9 @@ interface for models hosted on the Groq platform (e.g., Llama 3, Mixtral). """ -import json +from collections.abc import AsyncGenerator import logging -from typing import Any, AsyncGenerator +from typing import Any import httpx diff --git a/argus/llm/providers/litellm_provider.py b/argus/llm/providers/litellm_provider.py index f47e6d7..760902b 100644 --- a/argus/llm/providers/litellm_provider.py +++ b/argus/llm/providers/litellm_provider.py @@ -19,8 +19,9 @@ LLM providers (Azure, Mistral, Anthropic, etc.) via the LiteLLM library. """ +from collections.abc import AsyncGenerator import logging -from typing import Any, AsyncGenerator +from typing import Any import litellm diff --git a/argus/llm/providers/ollama_provider.py b/argus/llm/providers/ollama_provider.py index 69ab81f..d93ce56 100644 --- a/argus/llm/providers/ollama_provider.py +++ b/argus/llm/providers/ollama_provider.py @@ -28,8 +28,8 @@ import ollama from ..base import LLMProvider, LLMRequest, LLMResponse, ModelType -from ..config import LLMProviderConfig from ..capabilities.models import ModelCapability +from ..config import LLMProviderConfig logger = logging.getLogger(__name__) diff --git a/argus/llm/service_base.py b/argus/llm/service_base.py index ad6bb3f..3dff47c 100644 --- a/argus/llm/service_base.py +++ b/argus/llm/service_base.py @@ -321,7 +321,6 @@ def health_check(self) -> dict[str, Any]: # Additional classes needed for service management from enum import Enum -from typing import Optional class ServiceStatus(Enum): @@ -348,8 +347,8 @@ class ServiceHealth: status: ServiceStatus score: float message: str - last_check: Optional[float] = None - details: Optional[dict] = None + last_check: float | None = None + details: dict | None = None # Type alias for backward compatibility diff --git a/argus/llm/service_manager.py b/argus/llm/service_manager.py index 9e90ac9..8425f2c 100644 --- a/argus/llm/service_manager.py +++ b/argus/llm/service_manager.py @@ -30,6 +30,7 @@ from ..core.exceptions import ServiceError from .service_base import BaseService, ServiceConfig, ServiceHealth, ServiceStatus + # from .service_implementations import ( # CacheService, # ContextService, diff --git a/argus/llm/strategy_manager.py b/argus/llm/strategy_manager.py index c5319f8..fc609e4 100644 --- a/argus/llm/strategy_manager.py +++ b/argus/llm/strategy_manager.py @@ -61,8 +61,8 @@ ModelScore, ModelSelectionStrategy, OptimizationGoal, - StrategyContext, ScoringWeights, + StrategyContext, StrategyResult, ) from .strategy_implementations import ( diff --git a/argus/metrics/alerting.py b/argus/metrics/alerting.py index 28ae292..4c18321 100644 --- a/argus/metrics/alerting.py +++ b/argus/metrics/alerting.py @@ -15,7 +15,7 @@ # argus/metrics/alerting.py import asyncio -from typing import Any, Dict, List +from typing import Any from .metrics_manager import MetricsManager from .models import Alert @@ -26,7 +26,7 @@ class AlertManager: Manages alerts based on metrics and thresholds. """ - def __init__(self, config: Dict[str, Any]) -> None: + def __init__(self, config: dict[str, Any]) -> None: """ Initialize the AlertManager. @@ -35,9 +35,9 @@ def __init__(self, config: Dict[str, Any]) -> None: """ self.thresholds = config.get("alert_thresholds", {}) self.notification_channels = config.get("notification_channels", []) - self.alert_history: List[Alert] = [] + self.alert_history: list[Alert] = [] - def check_metrics(self, metrics_manager: MetricsManager) -> List[Alert]: + def check_metrics(self, metrics_manager: MetricsManager) -> list[Alert]: """ Check metrics for threshold violations and generate alerts. @@ -68,7 +68,7 @@ def check_metrics(self, metrics_manager: MetricsManager) -> List[Alert]: self.alert_history.extend(alerts) return alerts - async def send_alerts(self, alerts: List[Alert]) -> None: + async def send_alerts(self, alerts: list[Alert]) -> None: """ Send alerts to configured notification channels. diff --git a/argus/ml/code_analysis_models.py b/argus/ml/code_analysis_models.py index 7cb868e..dcb59e1 100644 --- a/argus/ml/code_analysis_models.py +++ b/argus/ml/code_analysis_models.py @@ -16,10 +16,10 @@ Data models representing configuration and metrics for codebase analysis. """ -import os -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, List, Optional +import os +from typing import Any @dataclass @@ -50,7 +50,7 @@ class CodeChange: timestamp: datetime author: str message: str - files_changed: List[str] + files_changed: list[str] lines_added: int lines_deleted: int is_rollback: bool = False @@ -86,9 +86,9 @@ class StaticAnalysisResult: """Findings from static analysis tools.""" tool_name: str - findings: List[Dict[str, Any]] + findings: list[dict[str, Any]] scan_duration_seconds: float - files_analyzed: List[str] + files_analyzed: list[str] error_count: int = 0 warning_count: int = 0 info_count: int = 0 @@ -103,7 +103,7 @@ def has_errors(self) -> bool: """Returns True if error count is greater than zero.""" return self.error_count > 0 - def get_findings_by_severity(self, severity: str) -> List[Dict[str, Any]]: + def get_findings_by_severity(self, severity: str) -> list[dict[str, Any]]: """Filters findings case-insensitively by severity value.""" return [ f @@ -163,8 +163,8 @@ class DependencyVulnerability: vulnerability_id: str severity: str description: str - fixed_version: Optional[str] = None - cve_id: Optional[str] = None + fixed_version: str | None = None + cve_id: str | None = None def __post_init__(self) -> None: if self.severity.upper() not in ("CRITICAL", "HIGH", "MEDIUM", "LOW"): diff --git a/argus/ml/code_context_extractor.py b/argus/ml/code_context_extractor.py index 6b8f2ac..9628c3c 100644 --- a/argus/ml/code_context_extractor.py +++ b/argus/ml/code_context_extractor.py @@ -17,12 +17,13 @@ """ import asyncio -import re from datetime import datetime, timedelta from pathlib import Path -from typing import Any, Dict, List, Optional +import re +from typing import Any + +from argus.pattern_detector.models import TimeWindow -from argus.pattern_detector.models import LogEntry, TimeWindow from .code_analysis_models import CodeAnalysisConfig, CodeChange @@ -35,7 +36,7 @@ def __init__(self, config: CodeAnalysisConfig) -> None: if not self.repo_path.exists(): raise ValueError("Repository path does not exist") - def _parse_git_log_output(self, output: str) -> List[CodeChange]: + def _parse_git_log_output(self, output: str) -> list[CodeChange]: """Parses raw git log output into a list of CodeChange objects.""" commits = [] for line in output.strip().split("\n"): @@ -65,7 +66,7 @@ def _parse_git_log_output(self, output: str) -> List[CodeChange]: return commits def _generate_code_changes_summary( - self, commits: List[CodeChange], time_window: TimeWindow + self, commits: list[CodeChange], time_window: TimeWindow ) -> str: """Generates a summary string for commits within the incident window.""" # Calculate window end time based on start_time and duration_minutes @@ -83,7 +84,7 @@ def _generate_code_changes_summary( f"{rollbacks} rollback/revert commits detected." ) - async def _extract_git_context(self, time_window: TimeWindow) -> Dict[str, Any]: + async def _extract_git_context(self, time_window: TimeWindow) -> dict[str, Any]: """Runs git log command and extracts recent commits and change summaries.""" try: process = await asyncio.create_subprocess_exec( @@ -98,8 +99,8 @@ async def _extract_git_context(self, time_window: TimeWindow) -> Dict[str, Any]: ) # Bypass Python 3.11/3.12 AsyncMock limitation with raw coroutine side_effects - from unittest.mock import Mock import inspect + from unittest.mock import Mock communicate_callable = process.communicate if isinstance(communicate_callable, Mock) and getattr(communicate_callable, "side_effect", None): se = communicate_callable.side_effect @@ -136,10 +137,10 @@ async def _extract_git_context(self, time_window: TimeWindow) -> Dict[str, Any]: except Exception as e: return { "recent_commits": [], - "code_changes_summary": f"Git analysis failed: {str(e)}", + "code_changes_summary": f"Git analysis failed: {e!s}", } - async def _extract_error_related_files(self, time_window: TimeWindow) -> List[str]: + async def _extract_error_related_files(self, time_window: TimeWindow) -> list[str]: """Scans logs for file paths and line numbers linked to error context.""" related = [] for log in time_window.logs: @@ -154,19 +155,19 @@ async def _extract_error_related_files(self, time_window: TimeWindow) -> List[st related.append(m) return related - async def _empty_static_analysis(self) -> Dict[str, Any]: + async def _empty_static_analysis(self) -> dict[str, Any]: """Returns default disabled static analysis results.""" return {"enabled": False} - async def _empty_complexity_analysis(self) -> Dict[str, Any]: + async def _empty_complexity_analysis(self) -> dict[str, Any]: """Returns default disabled complexity metrics.""" return {"enabled": False} - async def _empty_dependency_scan(self) -> List[Any]: + async def _empty_dependency_scan(self) -> list[Any]: """Returns default empty dependency vulnerabilities list.""" return [] - def _empty_context(self) -> Dict[str, Any]: + def _empty_context(self) -> dict[str, Any]: """Returns a default empty codebase context dictionary.""" return { "changes_summary": "Code context extraction failed", @@ -178,8 +179,8 @@ def _empty_context(self) -> Dict[str, Any]: } async def extract_code_context( - self, time_window: TimeWindow, services: List[str] - ) -> Dict[str, Any]: + self, time_window: TimeWindow, services: list[str] + ) -> dict[str, Any]: """Aggregates all analysis tasks with a timeout limit.""" try: git_task = self._extract_git_context(time_window) @@ -199,7 +200,7 @@ async def extract_code_context( ), timeout=self.config.analysis_timeout_seconds, ) - except (asyncio.TimeoutError, Exception): + except (TimeoutError, Exception): return self._empty_context() git_res = results[0] diff --git a/argus/ml/drift_detector.py b/argus/ml/drift_detector.py index b188804..8ad2b07 100644 --- a/argus/ml/drift_detector.py +++ b/argus/ml/drift_detector.py @@ -17,7 +17,7 @@ """ from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any from .performance_config import DriftAlert, PerformanceConfig, PerformanceMetrics @@ -26,7 +26,7 @@ class MetricsCalculator: """Helper class for calculations on rolling performance history windows.""" @staticmethod - def calculate_recent_metrics(history: List[float], window: int) -> float: + def calculate_recent_metrics(history: list[float], window: int) -> float: """Returns the mean value of the last window items in history.""" if not history: return 0.0 @@ -35,8 +35,8 @@ def calculate_recent_metrics(history: List[float], window: int) -> float: @staticmethod def calculate_pattern_accuracy( - accuracies: List[float], recent_samples: int - ) -> Dict[str, Any]: + accuracies: list[float], recent_samples: int + ) -> dict[str, Any]: """Calculates accuracy metric summary for specific pattern histories.""" total_samples = len(accuracies) if not accuracies: @@ -50,12 +50,12 @@ def calculate_pattern_accuracy( } @staticmethod - def trim_history(history: List[Any], max_size: int) -> List[Any]: + def trim_history(history: list[Any], max_size: int) -> list[Any]: """Trims history list to keep only the last max_size items.""" return history[-max_size:] @staticmethod - def analyze_drift_alerts(alerts: List[DriftAlert]) -> Dict[str, Any]: + def analyze_drift_alerts(alerts: list[DriftAlert]) -> dict[str, Any]: """Summarizes severity counts and presence of drift alerts.""" if not alerts: return { @@ -95,7 +95,7 @@ def _determine_drift_severity(self, drift_amount: float, high_threshold: float) return PerformanceMetrics.categorize_drift_severity(drift_amount, high_threshold) async def check_accuracy_drift( - self, current: float, baseline: Optional[float], alerts: List[DriftAlert] + self, current: float, baseline: float | None, alerts: list[DriftAlert] ) -> None: """Checks accuracy metrics against baseline and generates DriftAlert if drift exceeds threshold.""" if baseline is None: @@ -120,7 +120,7 @@ async def check_accuracy_drift( ) async def check_confidence_drift( - self, current: float, baseline: Optional[float], alerts: List[DriftAlert] + self, current: float, baseline: float | None, alerts: list[DriftAlert] ) -> None: """Checks confidence metrics against baseline and generates DriftAlert if drift exceeds threshold.""" if baseline is None: @@ -145,7 +145,7 @@ async def check_confidence_drift( ) async def check_latency_drift( - self, current: float, baseline: Optional[float], alerts: List[DriftAlert] + self, current: float, baseline: float | None, alerts: list[DriftAlert] ) -> None: """Checks latency metrics against baseline and generates DriftAlert if latency degrades beyond multiplier.""" if baseline is None: diff --git a/argus/ml/gemini_pattern_classifier.py b/argus/ml/gemini_pattern_classifier.py index a8812e5..1eee05e 100644 --- a/argus/ml/gemini_pattern_classifier.py +++ b/argus/ml/gemini_pattern_classifier.py @@ -21,8 +21,8 @@ import json from typing import Any -from argus.ml.gemini_api_client import GeminiAPIClient, GeminiResponse -from argus.pattern_detector.models import LogEntry, PatternMatch, PatternType, TimeWindow +from argus.ml.gemini_api_client import GeminiAPIClient +from argus.pattern_detector.models import PatternMatch, PatternType, TimeWindow class GeminiPatternClassifier: diff --git a/argus/ml/log_quality_validator.py b/argus/ml/log_quality_validator.py index 0dce695..1dbc00a 100644 --- a/argus/ml/log_quality_validator.py +++ b/argus/ml/log_quality_validator.py @@ -16,9 +16,9 @@ Log quality validator checking log completeness, noise, consistency, and duplicates. """ -import re from collections import Counter -from typing import Any, Dict, List, Optional +import re +from typing import Any from .validation_config import ( LogEntry, @@ -32,7 +32,7 @@ class LogQualityValidator: """Pre-processing validator ensuring high-quality input for AI analysis.""" - def __init__(self, thresholds: Optional[QualityThresholds] = None) -> None: + def __init__(self, thresholds: QualityThresholds | None = None) -> None: self.thresholds = thresholds or QualityThresholds() def _extract_message_pattern(self, message: str) -> str: @@ -54,7 +54,7 @@ def _is_noisy_log(self, log: LogEntry) -> bool: or ValidationRules.is_message_too_short(log.error_message) ) - def assess_log_quality(self, window: TimeWindow) -> Dict[str, Any]: + def assess_log_quality(self, window: TimeWindow) -> dict[str, Any]: """Calculates completeness, noise ratio, consistency, and duplicates.""" if not window.logs: return ValidationMetrics.empty_metrics() @@ -124,28 +124,28 @@ def validate_for_processing(self, window: TimeWindow) -> bool: """Returns True if the window passes all quality thresholds.""" return self.assess_log_quality(window)["passes_threshold"] - def get_quality_recommendations(self, metrics: Dict[str, Any]) -> List[str]: + def get_quality_recommendations(self, metrics: dict[str, Any]) -> list[str]: """Provides actionable recommendations based on quality metrics.""" recommendations = [] if metrics["completeness"] < self.thresholds.min_completeness: recommendations.append( - f"Ensure essential fields service_name, error_message, and severity are populated to improve completeness." + "Ensure essential fields service_name, error_message, and severity are populated to improve completeness." ) if metrics["noise_ratio"] > self.thresholds.max_noise_ratio: recommendations.append( - f"Filter out DEBUG/TRACE level logs and extremely short/empty messages to reduce noise." + "Filter out DEBUG/TRACE level logs and extremely short/empty messages to reduce noise." ) if metrics["consistency"] < self.thresholds.min_consistency: recommendations.append( - f"standardize log formats across services to improve pattern consistency." + "standardize log formats across services to improve pattern consistency." ) if metrics["duplicate_ratio"] > self.thresholds.max_duplicate_ratio: recommendations.append( - f"Apply deduplication or rate limiting to reduce duplicate error messages." + "Apply deduplication or rate limiting to reduce duplicate error messages." ) if not recommendations: diff --git a/argus/ml/model_performance_monitor.py b/argus/ml/model_performance_monitor.py index 3c5be0d..dac33fb 100644 --- a/argus/ml/model_performance_monitor.py +++ b/argus/ml/model_performance_monitor.py @@ -18,7 +18,7 @@ from collections import defaultdict from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any from .drift_detector import DriftDetector, MetricsCalculator from .performance_config import DriftAlert, PerformanceConfig, PerformanceMetrics @@ -27,18 +27,18 @@ class ModelPerformanceMonitor: """Monitors model prediction accuracy, confidence, and latency over time to detect drift.""" - def __init__(self, config: Optional[PerformanceConfig] = None) -> None: + def __init__(self, config: PerformanceConfig | None = None) -> None: self.config = config or PerformanceConfig() - self.accuracy_history: List[float] = [] - self.confidence_history: List[float] = [] - self.latency_history: List[float] = [] - self.pattern_type_accuracy: Dict[str, List[float]] = defaultdict(list) + self.accuracy_history: list[float] = [] + self.confidence_history: list[float] = [] + self.latency_history: list[float] = [] + self.pattern_type_accuracy: dict[str, list[float]] = defaultdict(list) - self.baseline_accuracy: Optional[float] = None - self.baseline_confidence: Optional[float] = None - self.baseline_latency: Optional[float] = None + self.baseline_accuracy: float | None = None + self.baseline_confidence: float | None = None + self.baseline_latency: float | None = None - self.drift_alerts: List[DriftAlert] = [] + self.drift_alerts: list[DriftAlert] = [] self.drift_detector = DriftDetector(self.config) self.last_drift_check = datetime.now() @@ -48,7 +48,7 @@ def _trim_pattern_history(self, pattern: str) -> None: self.pattern_type_accuracy[pattern], self.config.max_pattern_history ) - def _calculate_pattern_accuracy(self) -> Dict[str, Dict[str, Any]]: + def _calculate_pattern_accuracy(self) -> dict[str, dict[str, Any]]: """Calculates accuracy summary for all tracked pattern types.""" result = {} for pattern, accuracies in self.pattern_type_accuracy.items(): @@ -134,7 +134,7 @@ async def track_prediction_accuracy( self.last_drift_check = datetime.now() - def get_performance_metrics(self) -> Dict[str, Any]: + def get_performance_metrics(self) -> dict[str, Any]: """Returns current performance metrics dashboard dict.""" if not self.accuracy_history: return PerformanceMetrics.empty_metrics() @@ -153,7 +153,7 @@ def get_performance_metrics(self) -> Dict[str, Any]: "pattern_accuracy": self._calculate_pattern_accuracy(), } - def get_drift_summary(self) -> Dict[str, Any]: + def get_drift_summary(self) -> dict[str, Any]: """Returns summarized drift statistics and severity counts.""" return MetricsCalculator.analyze_drift_alerts(self.drift_alerts) diff --git a/argus/ml/performance_config.py b/argus/ml/performance_config.py index 931dc31..32a8294 100644 --- a/argus/ml/performance_config.py +++ b/argus/ml/performance_config.py @@ -16,9 +16,9 @@ Performance configuration, drift alert, and metrics definitions. """ -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any @dataclass @@ -62,7 +62,7 @@ class PerformanceMetrics: """Utility methods for calculating performance metrics and categorizing drift.""" @staticmethod - def empty_metrics() -> Dict[str, Any]: + def empty_metrics() -> dict[str, Any]: """Returns default empty performance metrics.""" return { "overall_accuracy": 0.0, diff --git a/argus/ml/prompt_utils.py b/argus/ml/prompt_utils.py index 1383a16..76de698 100644 --- a/argus/ml/prompt_utils.py +++ b/argus/ml/prompt_utils.py @@ -23,7 +23,7 @@ import json import re -from typing import Any, Dict, List +from typing import Any def format_prompt(template: str, **kwargs: Any) -> str: @@ -40,7 +40,7 @@ def format_prompt(template: str, **kwargs: Any) -> str: return template.format(**kwargs) -def extract_variables(template: str) -> List[str]: +def extract_variables(template: str) -> list[str]: """ Extract variable names from a prompt template. @@ -124,7 +124,7 @@ def count_tokens_estimate(text: str) -> int: return len(text) // 4 -def merge_prompts(prompts: List[str], separator: str = "\n\n") -> str: +def merge_prompts(prompts: list[str], separator: str = "\n\n") -> str: """ Merge multiple prompts into a single prompt. @@ -138,7 +138,7 @@ def merge_prompts(prompts: List[str], separator: str = "\n\n") -> str: return separator.join(filter(None, prompts)) -def build_context_kwargs(context_data: Any) -> Dict[str, Any]: +def build_context_kwargs(context_data: Any) -> dict[str, Any]: """ Build keyword arguments for context-based prompts. @@ -149,22 +149,22 @@ def build_context_kwargs(context_data: Any) -> Dict[str, Any]: Formatted keyword arguments with normalized dict keys for JSON compatibility """ # Handle both PatternContext objects and dictionaries - if hasattr(context_data, '__dict__'): + if hasattr(context_data, "__dict__"): data = context_data.__dict__ else: data = context_data if isinstance(context_data, dict) else {} - + def normalize_dict_keys(obj: Any) -> Any: """Recursively normalize dict keys to strings for JSON compatibility.""" if isinstance(obj, dict): return {str(k): normalize_dict_keys(v) for k, v in obj.items()} elif isinstance(obj, list): return [normalize_dict_keys(item) for item in obj] - elif hasattr(obj, 'isoformat'): # datetime objects + elif hasattr(obj, "isoformat"): # datetime objects return obj.isoformat() else: return obj - + # Convert complex objects to JSON strings for template compatibility def safe_json_serialize(obj: Any) -> str: """Safely serialize objects to JSON strings.""" @@ -173,14 +173,14 @@ def safe_json_serialize(obj: Any) -> str: return json.dumps(normalized, ensure_ascii=False) except (TypeError, ValueError): return str(obj) - + # Handle list fields by joining with commas def safe_list_join(obj: Any, default: str = "Unknown") -> str: """Safely join list items or return default.""" if isinstance(obj, list) and obj: return ", ".join(str(item) for item in obj) return default - + return { "primary_service": data.get("primary_service", "Unknown"), "affected_services": safe_list_join(data.get("affected_services"), "No services identified"), @@ -198,7 +198,7 @@ def safe_list_join(obj: Any, default: str = "Unknown") -> str: } -def build_evidence_kwargs(evidence_data: Dict[str, Any]) -> Dict[str, Any]: +def build_evidence_kwargs(evidence_data: dict[str, Any]) -> dict[str, Any]: """ Build keyword arguments for evidence-based prompts. @@ -216,7 +216,7 @@ def normalize_dict_keys(obj: Any) -> Any: return [normalize_dict_keys(item) for item in obj] else: return obj - + # Convert complex objects to JSON strings for template compatibility def safe_json_serialize(obj: Any) -> str: """Safely serialize objects to JSON strings.""" @@ -225,7 +225,7 @@ def safe_json_serialize(obj: Any) -> str: return json.dumps(normalized, ensure_ascii=False) except (TypeError, ValueError): return str(obj) - + return { "log_completeness": evidence_data.get("log_completeness", 0.0), "timestamp_consistency": evidence_data.get("timestamp_consistency", "unknown"), @@ -251,6 +251,6 @@ def safe_json_serialize(obj: Any) -> str: class PatternContext: """Context for pattern-based prompt generation.""" - def __init__(self, pattern_type: str, data: Dict[str, Any]) -> None: + def __init__(self, pattern_type: str, data: dict[str, Any]) -> None: self.pattern_type = pattern_type self.data = data diff --git a/argus/ml/validation_config.py b/argus/ml/validation_config.py index 87938c5..6bdd83f 100644 --- a/argus/ml/validation_config.py +++ b/argus/ml/validation_config.py @@ -18,7 +18,7 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any @dataclass @@ -26,12 +26,12 @@ class LogEntry: """Represents a single parsed log entry for quality validation.""" timestamp: datetime - service_name: Optional[str] = None - error_message: Optional[str] = None - severity: Optional[str] = None - trace_id: Optional[str] = None - span_id: Optional[str] = None - metadata: Dict[str, Any] = field(default_factory=dict) + service_name: str | None = None + error_message: str | None = None + severity: str | None = None + trace_id: str | None = None + span_id: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: if self.metadata is None: @@ -55,7 +55,7 @@ class TimeWindow: start_time: datetime end_time: datetime - logs: List[LogEntry] + logs: list[LogEntry] def __post_init__(self) -> None: if self.start_time >= self.end_time: @@ -76,7 +76,7 @@ class ValidationMetrics: """Helper utilities for formatting and generating empty validation metrics.""" @staticmethod - def empty_metrics() -> Dict[str, Any]: + def empty_metrics() -> dict[str, Any]: """Returns default empty metrics dictionary.""" return { "completeness": 0.0, @@ -123,14 +123,14 @@ def is_essential_field_complete(log: LogEntry) -> bool: return bool(log.service_name and log.error_message and log.severity) @staticmethod - def is_noisy_severity(severity: Optional[str]) -> bool: + def is_noisy_severity(severity: str | None) -> bool: """Checks if severity is noisy (e.g. DEBUG, TRACE).""" if not severity: return False return severity.upper() in ("DEBUG", "TRACE") @staticmethod - def is_message_too_short(message: Optional[str]) -> bool: + def is_message_too_short(message: str | None) -> bool: """Checks if message is too short (less than 10 characters).""" if not message: return True diff --git a/argus/ml/workflow/workflow_analysis.py b/argus/ml/workflow/workflow_analysis.py index 7117f5b..65dbb3f 100644 --- a/argus/ml/workflow/workflow_analysis.py +++ b/argus/ml/workflow/workflow_analysis.py @@ -21,10 +21,10 @@ issue analysis, pattern detection, and context analysis. """ +from dataclasses import dataclass import logging import time -from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any from ...core.interfaces import ProcessableComponent from ...core.types import ConfigDict, Timestamp @@ -52,18 +52,18 @@ class AnalysisResult: timestamp: Timestamp # Analysis results - detected_patterns: List[Dict[str, Any]] - insights: List[str] - recommendations: List[str] + detected_patterns: list[dict[str, Any]] + insights: list[str] + recommendations: list[str] confidence_score: float # Performance metrics analysis_duration: float patterns_analyzed: int success: bool - error_message: Optional[str] = None + error_message: str | None = None - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert analysis result to dictionary.""" return { "analysis_id": self.analysis_id, @@ -81,7 +81,7 @@ def to_dict(self) -> Dict[str, Any]: } -class WorkflowAnalysisEngine(ProcessableComponent[Dict[str, Any], AnalysisResult]): +class WorkflowAnalysisEngine(ProcessableComponent[dict[str, Any], AnalysisResult]): """ Analysis engine for workflow operations. @@ -93,7 +93,7 @@ def __init__( self, component_id: str = "workflow_analysis_engine", name: str = "Workflow Analysis Engine", - config: Optional[ConfigDict] = None, + config: ConfigDict | None = None, ) -> None: """ Initialize the workflow analysis engine. @@ -207,7 +207,7 @@ async def analyze_issue( except Exception as e: analysis_duration = time.time() - start_time - error_msg = f"Analysis failed: {str(e)}" + error_msg = f"Analysis failed: {e!s}" logger.error(f"Analysis {analysis_id} failed: {error_msg}") @@ -235,9 +235,9 @@ async def analyze_issue( async def analyze_patterns( self, - patterns: List[Dict[str, Any]], + patterns: list[dict[str, Any]], workflow_id: str, - context: Optional[Dict[str, Any]] = None, + context: dict[str, Any] | None = None, ) -> AnalysisResult: """ Analyze patterns for insights and recommendations. @@ -311,7 +311,7 @@ async def analyze_patterns( except Exception as e: analysis_duration = time.time() - start_time - error_msg = f"Pattern analysis failed: {str(e)}" + error_msg = f"Pattern analysis failed: {e!s}" logger.error(f"Pattern analysis {analysis_id} failed: {error_msg}") @@ -337,7 +337,7 @@ async def analyze_patterns( return result - def process(self, input_data: Dict[str, Any]) -> AnalysisResult: + def process(self, input_data: dict[str, Any]) -> AnalysisResult: """ Process analysis request (synchronous wrapper). @@ -406,7 +406,7 @@ def get_state(self, key: str, default: Any = None) -> Any: """ return self._state.get(key, default) - def clear_state(self, key: Optional[str] = None) -> None: + def clear_state(self, key: str | None = None) -> None: """ Clear state values. @@ -418,7 +418,7 @@ def clear_state(self, key: Optional[str] = None) -> None: else: self._state.pop(key, None) - def collect_metrics(self) -> Dict[str, Any]: + def collect_metrics(self) -> dict[str, Any]: """ Collect component metrics. @@ -427,7 +427,7 @@ def collect_metrics(self) -> Dict[str, Any]: """ return self.get_analysis_metrics() - def _extract_patterns(self, analysis_response: Any) -> List[Dict[str, Any]]: + def _extract_patterns(self, analysis_response: Any) -> list[dict[str, Any]]: """ Extract patterns from analysis response. @@ -452,7 +452,7 @@ def _extract_patterns(self, analysis_response: Any) -> List[Dict[str, Any]]: logger.warning(f"Failed to extract patterns: {e}") return [] - def _extract_insights(self, analysis_response: Any) -> List[str]: + def _extract_insights(self, analysis_response: Any) -> list[str]: """ Extract insights from analysis response. @@ -477,7 +477,7 @@ def _extract_insights(self, analysis_response: Any) -> List[str]: logger.warning(f"Failed to extract insights: {e}") return [] - def _extract_recommendations(self, analysis_response: Any) -> List[str]: + def _extract_recommendations(self, analysis_response: Any) -> list[str]: """ Extract recommendations from analysis response. @@ -529,7 +529,7 @@ def _calculate_confidence(self, analysis_response: Any) -> float: logger.warning(f"Failed to calculate confidence: {e}") return 0.5 - def get_analysis_metrics(self) -> Dict[str, Any]: + def get_analysis_metrics(self) -> dict[str, Any]: """ Get analysis performance metrics. @@ -553,7 +553,7 @@ def get_analysis_metrics(self) -> Dict[str, Any]: ), } - def get_health_status(self) -> Dict[str, Any]: + def get_health_status(self) -> dict[str, Any]: """ Get the component's health status. diff --git a/argus/ml/workflow/workflow_generation.py b/argus/ml/workflow/workflow_generation.py index aa17b78..7609025 100644 --- a/argus/ml/workflow/workflow_generation.py +++ b/argus/ml/workflow/workflow_generation.py @@ -21,10 +21,10 @@ code generation, prompt generation, and solution generation. """ +from dataclasses import dataclass import logging import time -from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any from ...core.interfaces import ProcessableComponent from ...core.types import ConfigDict, Timestamp @@ -53,18 +53,18 @@ class GenerationResult: # Generation results generated_content: str - code_patches: List[Dict[str, Any]] - prompts: List[str] - solutions: List[str] + code_patches: list[dict[str, Any]] + prompts: list[str] + solutions: list[str] confidence_score: float # Performance metrics generation_duration: float content_length: int success: bool - error_message: Optional[str] = None + error_message: str | None = None - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert generation result to dictionary.""" return { "generation_id": self.generation_id, @@ -83,7 +83,7 @@ def to_dict(self) -> Dict[str, Any]: } -class WorkflowGenerationEngine(ProcessableComponent[Dict[str, Any], GenerationResult]): +class WorkflowGenerationEngine(ProcessableComponent[dict[str, Any], GenerationResult]): """ Generation engine for workflow operations. @@ -95,7 +95,7 @@ def __init__( self, component_id: str = "workflow_generation_engine", name: str = "Workflow Generation Engine", - config: Optional[ConfigDict] = None, + config: ConfigDict | None = None, ) -> None: """ Initialize the workflow generation engine. @@ -215,7 +215,7 @@ async def generate_code( except Exception as e: generation_duration = time.time() - start_time - error_msg = f"Code generation failed: {str(e)}" + error_msg = f"Code generation failed: {e!s}" logger.error(f"Code generation {generation_id} failed: {error_msg}") @@ -243,7 +243,7 @@ async def generate_code( return result async def generate_prompts( - self, context: Dict[str, Any], workflow_id: str, prompt_type: str = "analysis" + self, context: dict[str, Any], workflow_id: str, prompt_type: str = "analysis" ) -> GenerationResult: """ Generate prompts for analysis or other operations. @@ -321,7 +321,7 @@ async def generate_prompts( except Exception as e: generation_duration = time.time() - start_time - error_msg = f"Prompt generation failed: {str(e)}" + error_msg = f"Prompt generation failed: {e!s}" logger.error(f"Prompt generation {generation_id} failed: {error_msg}") @@ -348,7 +348,7 @@ async def generate_prompts( return result - def process(self, input_data: Dict[str, Any]) -> GenerationResult: + def process(self, input_data: dict[str, Any]) -> GenerationResult: """ Process generation request (synchronous wrapper). @@ -417,7 +417,7 @@ def get_state(self, key: str, default: Any = None) -> Any: """ return self._state.get(key, default) - def clear_state(self, key: Optional[str] = None) -> None: + def clear_state(self, key: str | None = None) -> None: """ Clear state values. @@ -429,7 +429,7 @@ def clear_state(self, key: Optional[str] = None) -> None: else: self._state.pop(key, None) - def collect_metrics(self) -> Dict[str, Any]: + def collect_metrics(self) -> dict[str, Any]: """ Collect component metrics. @@ -466,7 +466,7 @@ def _extract_generated_content(self, remediation_response: Any) -> str: logger.warning(f"Failed to extract generated content: {e}") return "" - def _extract_code_patches(self, remediation_response: Any) -> List[Dict[str, Any]]: + def _extract_code_patches(self, remediation_response: Any) -> list[dict[str, Any]]: """ Extract code patches from remediation response. @@ -492,7 +492,7 @@ def _extract_code_patches(self, remediation_response: Any) -> List[Dict[str, Any logger.warning(f"Failed to extract code patches: {e}") return [] - def _extract_prompts(self, remediation_response: Any) -> List[str]: + def _extract_prompts(self, remediation_response: Any) -> list[str]: """ Extract prompts from remediation response. @@ -518,7 +518,7 @@ def _extract_prompts(self, remediation_response: Any) -> List[str]: logger.warning(f"Failed to extract prompts: {e}") return [] - def _extract_solutions(self, remediation_response: Any) -> List[str]: + def _extract_solutions(self, remediation_response: Any) -> list[str]: """ Extract solutions from remediation response. @@ -570,7 +570,7 @@ def _calculate_confidence(self, remediation_response: Any) -> float: logger.warning(f"Failed to calculate confidence: {e}") return 0.5 - def get_generation_metrics(self) -> Dict[str, Any]: + def get_generation_metrics(self) -> dict[str, Any]: """ Get generation performance metrics. @@ -594,7 +594,7 @@ def get_generation_metrics(self) -> Dict[str, Any]: ), } - def get_health_status(self) -> Dict[str, Any]: + def get_health_status(self) -> dict[str, Any]: """ Get the component's health status. diff --git a/argus/ml/workflow/workflow_metrics.py b/argus/ml/workflow/workflow_metrics.py index 32494e8..a881828 100644 --- a/argus/ml/workflow/workflow_metrics.py +++ b/argus/ml/workflow/workflow_metrics.py @@ -21,10 +21,10 @@ including performance metrics, health metrics, and operational metrics. """ +from dataclasses import dataclass, field import logging import time -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any from ...core.interfaces import ProcessableComponent from ...core.types import ConfigDict, Timestamp @@ -44,11 +44,11 @@ class MetricData: name: str value: float timestamp: Timestamp - tags: Dict[str, str] = field(default_factory=dict) + tags: dict[str, str] = field(default_factory=dict) unit: str = "count" description: str = "" - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert metric data to dictionary.""" return { "name": self.name, @@ -72,8 +72,8 @@ class WorkflowMetrics: # Workflow metadata workflow_id: str start_time: Timestamp - end_time: Optional[Timestamp] = None - duration: Optional[float] = None + end_time: Timestamp | None = None + duration: float | None = None # Performance metrics total_operations: int = 0 @@ -99,9 +99,9 @@ class WorkflowMetrics: cache_misses: int = 0 # Custom metrics - custom_metrics: List[MetricData] = field(default_factory=list) + custom_metrics: list[MetricData] = field(default_factory=list) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert workflow metrics to dictionary.""" return { "workflow_id": self.workflow_id, @@ -127,7 +127,7 @@ def to_dict(self) -> Dict[str, Any]: } -class WorkflowMetricsCollector(ProcessableComponent[Dict[str, Any], WorkflowMetrics]): +class WorkflowMetricsCollector(ProcessableComponent[dict[str, Any], WorkflowMetrics]): """ Metrics collector for workflow operations. @@ -139,7 +139,7 @@ def __init__( self, component_id: str = "workflow_metrics_collector", name: str = "Workflow Metrics Collector", - config: Optional[ConfigDict] = None, + config: ConfigDict | None = None, ) -> None: """ Initialize the workflow metrics collector. @@ -152,9 +152,9 @@ def __init__( super().__init__(component_id, name, config) # Metrics tracking - self.active_workflows: Dict[str, WorkflowMetrics] = {} - self.completed_workflows: List[WorkflowMetrics] = [] - self.metrics_history: List[MetricData] = [] + self.active_workflows: dict[str, WorkflowMetrics] = {} + self.completed_workflows: list[WorkflowMetrics] = [] + self.metrics_history: list[MetricData] = [] # Collection settings self.metrics_retention_days = ( @@ -186,7 +186,7 @@ def start_workflow_metrics(self, workflow_id: str) -> WorkflowMetrics: self.active_workflows[workflow_id] = metrics return metrics - def end_workflow_metrics(self, workflow_id: str) -> Optional[WorkflowMetrics]: + def end_workflow_metrics(self, workflow_id: str) -> WorkflowMetrics | None: """ End metrics collection for a workflow. @@ -315,7 +315,7 @@ def add_custom_metric( value: float, unit: str = "count", description: str = "", - tags: Optional[Dict[str, str]] = None, + tags: dict[str, str] | None = None, ) -> None: """ Add a custom metric to a workflow. @@ -345,7 +345,7 @@ def add_custom_metric( metrics.custom_metrics.append(metric_data) - def get_workflow_metrics(self, workflow_id: str) -> Optional[WorkflowMetrics]: + def get_workflow_metrics(self, workflow_id: str) -> WorkflowMetrics | None: """ Get metrics for a specific workflow. @@ -365,7 +365,7 @@ def get_workflow_metrics(self, workflow_id: str) -> Optional[WorkflowMetrics]: return None - def get_all_metrics(self) -> Dict[str, Any]: + def get_all_metrics(self) -> dict[str, Any]: """ Get all collected metrics. @@ -385,7 +385,7 @@ def get_all_metrics(self) -> Dict[str, Any]: ), } - def process(self, input_data: Dict[str, Any]) -> WorkflowMetrics: + def process(self, input_data: dict[str, Any]) -> WorkflowMetrics: """ Process metrics collection request (synchronous wrapper). @@ -454,7 +454,7 @@ def get_state(self, key: str, default: Any = None) -> Any: """ return self._state.get(key, default) - def clear_state(self, key: Optional[str] = None) -> None: + def clear_state(self, key: str | None = None) -> None: """ Clear state values. @@ -466,7 +466,7 @@ def clear_state(self, key: Optional[str] = None) -> None: else: self._state.pop(key, None) - def collect_metrics(self) -> Dict[str, Any]: + def collect_metrics(self) -> dict[str, Any]: """ Collect component metrics. @@ -493,7 +493,7 @@ def _calculate_derived_metrics(self, metrics: WorkflowMetrics) -> None: if metrics.total_operations > 0 and metrics.duration: metrics.average_operation_time = metrics.duration / metrics.total_operations - def get_health_status(self) -> Dict[str, Any]: + def get_health_status(self) -> dict[str, Any]: """ Get the component's health status. diff --git a/argus/ml/workflow/workflow_validation.py b/argus/ml/workflow/workflow_validation.py index 39259ec..b64f2b9 100644 --- a/argus/ml/workflow/workflow_validation.py +++ b/argus/ml/workflow/workflow_validation.py @@ -21,10 +21,10 @@ code validation, prompt validation, and solution validation. """ +from dataclasses import dataclass import logging import time -from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any from ...core.interfaces import ProcessableComponent from ...core.types import ConfigDict, Timestamp @@ -49,18 +49,18 @@ class ValidationResult: # Validation results is_valid: bool - errors: List[Dict[str, Any]] - warnings: List[Dict[str, Any]] - recommendations: List[str] + errors: list[dict[str, Any]] + warnings: list[dict[str, Any]] + recommendations: list[str] confidence_score: float # Performance metrics validation_duration: float items_validated: int success: bool - error_message: Optional[str] = None + error_message: str | None = None - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert validation result to dictionary.""" return { "validation_id": self.validation_id, @@ -79,7 +79,7 @@ def to_dict(self) -> Dict[str, Any]: } -class WorkflowValidationEngine(ProcessableComponent[Dict[str, Any], ValidationResult]): +class WorkflowValidationEngine(ProcessableComponent[dict[str, Any], ValidationResult]): """ Validation engine for workflow operations. @@ -91,7 +91,7 @@ def __init__( self, component_id: str = "workflow_validation_engine", name: str = "Workflow Validation Engine", - config: Optional[ConfigDict] = None, + config: ConfigDict | None = None, ) -> None: """ Initialize the workflow validation engine. @@ -201,7 +201,7 @@ async def validate_code( except Exception as e: validation_duration = time.time() - start_time - error_msg = f"Code validation failed: {str(e)}" + error_msg = f"Code validation failed: {e!s}" logger.error(f"Code validation {validation_id} failed: {error_msg}") @@ -230,7 +230,7 @@ async def validate_code( async def validate_prompts( self, - prompts: List[str], + prompts: list[str], workflow_id: str, validation_type: str = "completeness", ) -> ValidationResult: @@ -314,7 +314,7 @@ async def validate_prompts( except Exception as e: validation_duration = time.time() - start_time - error_msg = f"Prompt validation failed: {str(e)}" + error_msg = f"Prompt validation failed: {e!s}" logger.error(f"Prompt validation {validation_id} failed: {error_msg}") @@ -341,7 +341,7 @@ async def validate_prompts( return result - def process(self, input_data: Dict[str, Any]) -> ValidationResult: + def process(self, input_data: dict[str, Any]) -> ValidationResult: """ Process validation request (synchronous wrapper). @@ -410,7 +410,7 @@ def get_state(self, key: str, default: Any = None) -> Any: """ return self._state.get(key, default) - def clear_state(self, key: Optional[str] = None) -> None: + def clear_state(self, key: str | None = None) -> None: """ Clear state values. @@ -422,7 +422,7 @@ def clear_state(self, key: Optional[str] = None) -> None: else: self._state.pop(key, None) - def collect_metrics(self) -> Dict[str, Any]: + def collect_metrics(self) -> dict[str, Any]: """ Collect component metrics. @@ -431,7 +431,7 @@ def collect_metrics(self) -> Dict[str, Any]: """ return self.get_validation_metrics() - def _initialize_validation_rules(self) -> Dict[str, Any]: + def _initialize_validation_rules(self) -> dict[str, Any]: """ Initialize validation rules for different types of content. @@ -473,7 +473,7 @@ def _initialize_validation_rules(self) -> Dict[str, Any]: def _validate_syntax( self, code_content: str, file_path: str - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ Validate code syntax. @@ -539,7 +539,7 @@ def _validate_syntax( def _validate_style( self, code_content: str, file_path: str - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ Validate code style. @@ -570,7 +570,7 @@ def _validate_style( def _validate_best_practices( self, code_content: str, file_path: str - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ Validate code best practices. @@ -631,7 +631,7 @@ def _validate_best_practices( def _validate_prompt_completeness( self, prompt: str, index: int - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ Validate prompt completeness. @@ -657,7 +657,7 @@ def _validate_prompt_completeness( return errors - def _validate_prompt_quality(self, prompt: str, index: int) -> List[Dict[str, Any]]: + def _validate_prompt_quality(self, prompt: str, index: int) -> list[dict[str, Any]]: """ Validate prompt quality. @@ -684,7 +684,7 @@ def _validate_prompt_quality(self, prompt: str, index: int) -> List[Dict[str, An return warnings - def _validate_prompt_clarity(self, prompt: str, index: int) -> List[Dict[str, Any]]: + def _validate_prompt_clarity(self, prompt: str, index: int) -> list[dict[str, Any]]: """ Validate prompt clarity. @@ -716,9 +716,9 @@ def _validate_prompt_clarity(self, prompt: str, index: int) -> List[Dict[str, An def _generate_recommendations( self, code_content: str, - errors: List[Dict[str, Any]], - warnings: List[Dict[str, Any]], - ) -> List[str]: + errors: list[dict[str, Any]], + warnings: list[dict[str, Any]], + ) -> list[str]: """ Generate recommendations based on validation results. @@ -747,10 +747,10 @@ def _generate_recommendations( def _generate_prompt_recommendations( self, - prompts: List[str], - errors: List[Dict[str, Any]], - warnings: List[Dict[str, Any]], - ) -> List[str]: + prompts: list[str], + errors: list[dict[str, Any]], + warnings: list[dict[str, Any]], + ) -> list[str]: """ Generate recommendations for prompts based on validation results. @@ -776,7 +776,7 @@ def _generate_prompt_recommendations( return recommendations def _calculate_confidence( - self, errors: List[Dict[str, Any]], warnings: List[Dict[str, Any]] + self, errors: list[dict[str, Any]], warnings: list[dict[str, Any]] ) -> float: """ Calculate confidence score based on validation results. @@ -796,7 +796,7 @@ def _calculate_confidence( return 1.0 - def get_validation_metrics(self) -> Dict[str, Any]: + def get_validation_metrics(self) -> dict[str, Any]: """ Get validation performance metrics. @@ -820,7 +820,7 @@ def get_validation_metrics(self) -> Dict[str, Any]: ), } - def get_health_status(self) -> Dict[str, Any]: + def get_health_status(self) -> dict[str, Any]: """ Get the component's health status. diff --git a/argus/source_control/base.py b/argus/source_control/base.py index f113f1a..b69a2fd 100644 --- a/argus/source_control/base.py +++ b/argus/source_control/base.py @@ -18,11 +18,11 @@ Abstract base class for source control providers with async context manager support. """ -import asyncio -import logging from abc import ABC, abstractmethod +import asyncio from datetime import datetime -from typing import Any, Dict, List, Optional +import logging +from typing import Any from .models import ( BatchOperation, @@ -40,7 +40,7 @@ class SourceControlProvider(ABC): """Abstract base class defining the interface for source control providers.""" - def __init__(self, config: Dict[str, Any]) -> None: + def __init__(self, config: dict[str, Any]) -> None: """Initialize with configuration.""" self.config = config self._initialized = False @@ -80,23 +80,23 @@ async def get_health_status(self) -> ProviderHealth: # File operations @abstractmethod - async def get_file_content(self, path: str, ref: Optional[str] = None) -> str: + async def get_file_content(self, path: str, ref: str | None = None) -> str: """Retrieve content of a file at a specific path and reference.""" pass @abstractmethod - async def get_file_info(self, path: str, ref: Optional[str] = None) -> FileInfo: + async def get_file_info(self, path: str, ref: str | None = None) -> FileInfo: """Get information about a file.""" pass @abstractmethod - async def file_exists(self, path: str, ref: Optional[str] = None) -> bool: + async def file_exists(self, path: str, ref: str | None = None) -> bool: """Check if a file exists at the given path.""" pass # Branch operations @abstractmethod - async def create_branch(self, name: str, base_ref: Optional[str] = None) -> bool: + async def create_branch(self, name: str, base_ref: str | None = None) -> bool: """Create a new branch from the specified reference.""" pass @@ -106,19 +106,19 @@ async def delete_branch(self, name: str) -> bool: pass @abstractmethod - async def list_branches(self) -> List[BranchInfo]: + async def list_branches(self) -> list[BranchInfo]: """List all branches in the repository.""" pass @abstractmethod - async def get_branch_info(self, name: str) -> Optional[BranchInfo]: + async def get_branch_info(self, name: str) -> BranchInfo | None: """Get information about a specific branch.""" pass # Remediation operations @abstractmethod async def apply_remediation( - self, path: str, content: str, message: str, branch: Optional[str] = None + self, path: str, content: str, message: str, branch: str | None = None ) -> RemediationResult: """Apply a remediation to a file and commit the changes.""" pass @@ -140,7 +140,7 @@ async def create_merge_request( # Conflict resolution @abstractmethod async def check_conflicts( - self, path: str, content: str, branch: Optional[str] = None + self, path: str, content: str, branch: str | None = None ) -> bool: """Check if applying content to a file would cause conflicts.""" pass @@ -153,15 +153,15 @@ async def resolve_conflicts( pass @abstractmethod - async def get_conflict_info(self, path: str) -> Optional[ConflictInfo]: + async def get_conflict_info(self, path: str) -> ConflictInfo | None: """Get detailed information about conflicts in a file.""" pass # Batch operations @abstractmethod async def batch_operations( - self, operations: List[BatchOperation] - ) -> List[OperationResult]: + self, operations: list[BatchOperation] + ) -> list[OperationResult]: """Execute multiple operations as a batch.""" pass @@ -230,7 +230,7 @@ def get_config_value(self, key: str, default: Any = None) -> Any: """Get a configuration value with optional default.""" return self.config.get(key, default) - def update_config(self, updates: Dict[str, Any]) -> None: + def update_config(self, updates: dict[str, Any]) -> None: """Update configuration values.""" self.config.update(updates) diff --git a/argus/source_control/base_implementation.py b/argus/source_control/base_implementation.py index 76a1eef..5e1fb77 100644 --- a/argus/source_control/base_implementation.py +++ b/argus/source_control/base_implementation.py @@ -19,9 +19,10 @@ """ import asyncio -import logging +from collections.abc import Callable from datetime import datetime -from typing import Any, Callable, Dict, List, Optional +import logging +from typing import Any from .base import SourceControlProvider from .error_handling import ( @@ -45,7 +46,7 @@ class BaseSourceControlProvider(SourceControlProvider): """Base implementation of SourceControlProvider with common functionality.""" - def __init__(self, config: Dict[str, Any]) -> None: + def __init__(self, config: dict[str, Any]) -> None: """Initialize with configuration.""" super().__init__(config) self._client = None @@ -120,7 +121,7 @@ async def _execute_resilient_operation( operation_name, func, *args, **kwargs ) - def _initialize_error_handling(self, provider_name: str, config: Optional[Dict[str, Any]] = None) -> None: + def _initialize_error_handling(self, provider_name: str, config: dict[str, Any] | None = None) -> None: """Initialize the advanced error handling system for a specific provider.""" try: self._error_handling_components = create_provider_error_handling( @@ -149,7 +150,7 @@ async def _execute_with_error_handling( async def handle_operation_failure(self, operation: str, error: Exception) -> bool: """Default implementation for handling operation failures.""" - self.logger.error(f"Operation {operation} failed: {str(error)}") + self.logger.error(f"Operation {operation} failed: {error!s}") # Check if this is a retryable error if self._is_retryable_error(error): @@ -158,7 +159,7 @@ async def handle_operation_failure(self, operation: str, error: Exception) -> bo return False async def retry_operation( - self, operation: str, max_retries: Optional[int] = None + self, operation: str, max_retries: int | None = None ) -> bool: """Retry a failed operation with exponential backoff.""" if max_retries is None: @@ -202,8 +203,8 @@ def _is_retryable_error(self, error: Exception) -> bool: return error_type in retryable_errors async def batch_operations( - self, operations: List[BatchOperation] - ) -> List[OperationResult]: + self, operations: list[BatchOperation] + ) -> list[OperationResult]: """Default implementation for batch operations.""" results = [] @@ -228,7 +229,7 @@ async def batch_operations( OperationResult( operation_id=f"batch_{i}", success=False, - message=f"Operation failed: {str(e)}", + message=f"Operation failed: {e!s}", file_path=operation.file_path, error_details=str(e), additional_info={"operation_type": operation.operation_type}, @@ -297,7 +298,7 @@ async def get_health_status(self) -> ProviderHealth: return resilient_health async def check_conflicts( - self, path: str, content: str, branch: Optional[str] = None + self, path: str, content: str, branch: str | None = None ) -> bool: """Default implementation for conflict checking.""" # This is a simplified implementation @@ -329,7 +330,7 @@ async def resolve_conflicts( else: raise ValueError(f"Unknown conflict resolution strategy: {strategy}") - async def get_conflict_info(self, path: str) -> Optional[ConflictInfo]: + async def get_conflict_info(self, path: str) -> ConflictInfo | None: """Get detailed information about conflicts in a file.""" # This is a placeholder implementation # Real implementations would analyze the file for conflict markers @@ -341,10 +342,10 @@ def _create_remediation_result( message: str, file_path: str, operation_type: str, - commit_sha: Optional[str] = None, - pull_request_url: Optional[str] = None, - error_details: Optional[str] = None, - additional_info: Optional[Dict[str, Any]] = None, + commit_sha: str | None = None, + pull_request_url: str | None = None, + error_details: str | None = None, + additional_info: dict[str, Any] | None = None, ) -> RemediationResult: """Helper method to create a RemediationResult.""" return RemediationResult( @@ -363,9 +364,9 @@ def _create_operation_result( operation_id: str, success: bool, message: str, - file_path: Optional[str] = None, - error_details: Optional[str] = None, - additional_info: Optional[Dict[str, Any]] = None, + file_path: str | None = None, + error_details: str | None = None, + additional_info: dict[str, Any] | None = None, ) -> OperationResult: """Helper method to create an OperationResult.""" return OperationResult( @@ -378,7 +379,7 @@ def _create_operation_result( ) def _log_operation( - self, operation: str, success: bool, details: Optional[Dict[str, Any]] = None + self, operation: str, success: bool, details: dict[str, Any] | None = None ): """Log an operation with details.""" level = logging.INFO if success else logging.ERROR @@ -389,7 +390,7 @@ def _log_operation( self.logger.log(level, message) - async def get_comprehensive_health_status(self) -> Dict[str, Any]: + async def get_comprehensive_health_status(self) -> dict[str, Any]: """Get comprehensive health status including monitoring data.""" if not self.monitoring_manager: return {"error": "Monitoring not enabled"} @@ -420,7 +421,7 @@ async def get_comprehensive_health_status(self) -> Dict[str, Any]: "timestamp": datetime.now().isoformat(), } - async def get_metrics_summary(self) -> Dict[str, Any]: + async def get_metrics_summary(self) -> dict[str, Any]: """Get metrics summary for this provider.""" if not self.metrics_collector: return {"error": "Metrics collection not enabled"} @@ -428,8 +429,8 @@ async def get_metrics_summary(self) -> Dict[str, Any]: return await self.metrics_collector.get_metrics_summary() async def get_operation_statistics( - self, operation_name: Optional[str] = None, window_minutes: int = 60 - ) -> Dict[str, Any]: + self, operation_name: str | None = None, window_minutes: int = 60 + ) -> dict[str, Any]: """Get operation statistics for this provider.""" if not self.operation_metrics: return {"error": "Operation metrics not enabled"} diff --git a/argus/source_control/configured_provider.py b/argus/source_control/configured_provider.py index 560ef9a..7bd079b 100644 --- a/argus/source_control/configured_provider.py +++ b/argus/source_control/configured_provider.py @@ -18,7 +18,7 @@ Provider that uses Pydantic configuration models. """ -from typing import Any, Dict, Optional +from typing import Any from ..config.source_control_global import SourceControlGlobalConfig from ..config.source_control_repositories import RepositoryConfig @@ -31,7 +31,7 @@ class ConfiguredSourceControlProvider(BaseSourceControlProvider): def __init__( self, repository_config: RepositoryConfig, - global_config: Optional[SourceControlGlobalConfig] = None, + global_config: SourceControlGlobalConfig | None = None, ): """Initialize with validated Pydantic configs.""" # Convert Pydantic models to dictionary for base class @@ -104,7 +104,7 @@ def get_operation_timeout(self) -> int: return self.global_config.operation_timeout_seconds - def get_retry_config(self) -> Dict[str, Any]: + def get_retry_config(self) -> dict[str, Any]: """Get retry configuration from global config.""" if self.global_config is None: return {"max_retries": 3, "base_delay": 1.0, "max_delay": 60.0} @@ -115,7 +115,7 @@ def get_retry_config(self) -> Dict[str, Any]: "max_delay": 60.0, } - def get_rate_limit_config(self) -> Dict[str, Any]: + def get_rate_limit_config(self) -> dict[str, Any]: """Get rate limit configuration from global config.""" if self.global_config is None: return {"requests_per_minute": 60, "burst_size": 10} @@ -125,7 +125,7 @@ def get_rate_limit_config(self) -> Dict[str, Any]: "burst_size": self.global_config.rate_limit_burst_size, } - def get_cache_config(self) -> Dict[str, Any]: + def get_cache_config(self) -> dict[str, Any]: """Get cache configuration from global config.""" if self.global_config is None: return {"enabled": False, "ttl_seconds": 3600, "max_size_mb": 100} @@ -136,7 +136,7 @@ def get_cache_config(self) -> Dict[str, Any]: "max_size_mb": self.global_config.max_cache_size_mb, } - def get_security_config(self) -> Dict[str, Any]: + def get_security_config(self) -> dict[str, Any]: """Get security configuration from global config.""" if self.global_config is None: return { @@ -151,7 +151,7 @@ def get_security_config(self) -> Dict[str, Any]: "credential_rotation_interval_days": self.global_config.credential_rotation_interval_days, } - def get_monitoring_config(self) -> Dict[str, Any]: + def get_monitoring_config(self) -> dict[str, Any]: """Get monitoring configuration from global config.""" if self.global_config is None: return {"enable_metrics": True, "audit_logging": True} diff --git a/argus/source_control/enhanced_base_implementation.py b/argus/source_control/enhanced_base_implementation.py index 6e39e35..3d2e1b7 100644 --- a/argus/source_control/enhanced_base_implementation.py +++ b/argus/source_control/enhanced_base_implementation.py @@ -22,8 +22,9 @@ health checks, and metrics collection. """ +from collections.abc import Callable import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any from .base import SourceControlProvider from .error_handling import ( @@ -49,7 +50,7 @@ class EnhancedBaseSourceControlProvider(SourceControlProvider): """Enhanced base implementation with comprehensive error handling.""" - def __init__(self, config: Dict[str, Any]) -> None: + def __init__(self, config: dict[str, Any]) -> None: """Initialize with enhanced error handling configuration.""" super().__init__(config) self._client = None @@ -62,7 +63,7 @@ def __init__(self, config: Dict[str, Any]) -> None: # Initialize monitoring self._initialize_monitoring(config) - def _initialize_error_handling(self, config: Dict[str, Any]) -> None: + def _initialize_error_handling(self, config: dict[str, Any]) -> None: """Initialize the comprehensive error handling system.""" # Get error handling configuration error_handling_config = self.get_config_value("error_handling", {}) @@ -104,7 +105,7 @@ def _initialize_error_handling(self, config: Dict[str, Any]) -> None: ) def _create_circuit_breaker_config( - self, error_handling_config: Dict[str, Any] + self, error_handling_config: dict[str, Any] ) -> CircuitBreakerConfig: """Create circuit breaker configuration from config.""" circuit_config = error_handling_config.get("circuit_breaker", {}) @@ -116,13 +117,13 @@ def _create_circuit_breaker_config( ) def _create_operation_circuit_breaker_config( - self, error_handling_config: Dict[str, Any] + self, error_handling_config: dict[str, Any] ) -> OperationCircuitBreakerConfig: """Create operation-specific circuit breaker configuration.""" return OperationCircuitBreakerConfig() def _create_retry_config( - self, error_handling_config: Dict[str, Any] + self, error_handling_config: dict[str, Any] ) -> RetryConfig: """Create retry configuration from config.""" retry_config = error_handling_config.get("retry", {}) @@ -134,7 +135,7 @@ def _create_retry_config( jitter=retry_config.get("jitter", True), ) - def _initialize_monitoring(self, config: Dict[str, Any]) -> None: + def _initialize_monitoring(self, config: dict[str, Any]) -> None: """Initialize monitoring components.""" enable_monitoring = self.get_config_value("monitoring", {}).get("enabled", True) if enable_monitoring: @@ -260,7 +261,7 @@ async def get_health_status(self) -> ProviderHealth: except Exception as e: return ProviderHealth( status="unhealthy", - message=f"Health check failed: {str(e)}", + message=f"Health check failed: {e!s}", additional_info={"error": str(e)}, ) @@ -272,7 +273,7 @@ async def _get_basic_health_status(self) -> ProviderHealth: async def handle_operation_failure(self, operation: str, error: Exception) -> bool: """Enhanced operation failure handling with error classification.""" - self.logger.error(f"Operation {operation} failed: {str(error)}") + self.logger.error(f"Operation {operation} failed: {error!s}") # Classify the error error_classification = self.error_classifier.classify_error(error) @@ -296,8 +297,8 @@ async def handle_operation_failure(self, operation: str, error: Exception) -> bo return False async def batch_operations( - self, operations: List[BatchOperation] - ) -> List[OperationResult]: + self, operations: list[BatchOperation] + ) -> list[OperationResult]: """Enhanced batch operations with comprehensive error handling.""" results = [] @@ -337,7 +338,7 @@ async def batch_operations( OperationResult( operation_id=f"batch_{i}", success=False, - message=f"Operation failed: {str(e)}", + message=f"Operation failed: {e!s}", file_path=operation.file_path, error_details=str(e), additional_info={ @@ -403,25 +404,25 @@ async def get_capabilities(self) -> Any: """Get provider capabilities. Must be implemented by subclasses.""" raise NotImplementedError("Subclasses must implement get_capabilities") - async def get_file_content(self, path: str, ref: Optional[str] = None) -> str: + async def get_file_content(self, path: str, ref: str | None = None) -> str: """Get file content. Must be implemented by subclasses.""" raise NotImplementedError("Subclasses must implement get_file_content") async def apply_remediation( - self, path: str, content: str, message: str, branch: Optional[str] = None + self, path: str, content: str, message: str, branch: str | None = None ) -> RemediationResult: """Apply remediation. Must be implemented by subclasses.""" raise NotImplementedError("Subclasses must implement apply_remediation") - async def file_exists(self, path: str, ref: Optional[str] = None) -> bool: + async def file_exists(self, path: str, ref: str | None = None) -> bool: """Check if file exists. Must be implemented by subclasses.""" raise NotImplementedError("Subclasses must implement file_exists") - async def get_file_info(self, path: str, ref: Optional[str] = None) -> Any: + async def get_file_info(self, path: str, ref: str | None = None) -> Any: """Get file information. Must be implemented by subclasses.""" raise NotImplementedError("Subclasses must implement get_file_info") - async def list_files(self, path: str = "", ref: Optional[str] = None) -> List[Any]: + async def list_files(self, path: str = "", ref: str | None = None) -> list[Any]: """List files. Must be implemented by subclasses.""" raise NotImplementedError("Subclasses must implement list_files") @@ -434,12 +435,12 @@ async def apply_patch(self, patch: str, file_path: str) -> bool: raise NotImplementedError("Subclasses must implement apply_patch") async def commit_changes( - self, file_path: str, content: str, message: str, branch: Optional[str] = None - ) -> Optional[str]: + self, file_path: str, content: str, message: str, branch: str | None = None + ) -> str | None: """Commit changes. Must be implemented by subclasses.""" raise NotImplementedError("Subclasses must implement commit_changes") - async def create_branch(self, name: str, base_ref: Optional[str] = None) -> bool: + async def create_branch(self, name: str, base_ref: str | None = None) -> bool: """Create branch. Must be implemented by subclasses.""" raise NotImplementedError("Subclasses must implement create_branch") @@ -447,11 +448,11 @@ async def delete_branch(self, name: str) -> bool: """Delete branch. Must be implemented by subclasses.""" raise NotImplementedError("Subclasses must implement delete_branch") - async def list_branches(self) -> List[Any]: + async def list_branches(self) -> list[Any]: """List branches. Must be implemented by subclasses.""" raise NotImplementedError("Subclasses must implement list_branches") - async def get_branch_info(self, name: str) -> Optional[Any]: + async def get_branch_info(self, name: str) -> Any | None: """Get branch info. Must be implemented by subclasses.""" raise NotImplementedError("Subclasses must implement get_branch_info") @@ -464,7 +465,7 @@ async def get_repository_info(self) -> Any: raise NotImplementedError("Subclasses must implement get_repository_info") async def check_conflicts( - self, path: str, content: str, branch: Optional[str] = None + self, path: str, content: str, branch: str | None = None ) -> bool: """Check conflicts. Must be implemented by subclasses.""" raise NotImplementedError("Subclasses must implement check_conflicts") @@ -475,7 +476,7 @@ async def resolve_conflicts( """Resolve conflicts. Must be implemented by subclasses.""" raise NotImplementedError("Subclasses must implement resolve_conflicts") - async def get_file_history(self, path: str, limit: int = 10) -> List[Any]: + async def get_file_history(self, path: str, limit: int = 10) -> list[Any]: """Get file history. Must be implemented by subclasses.""" raise NotImplementedError("Subclasses must implement get_file_history") diff --git a/argus/source_control/error_handling/factory.py b/argus/source_control/error_handling/factory.py index c2c7d0e..53352ec 100644 --- a/argus/source_control/error_handling/factory.py +++ b/argus/source_control/error_handling/factory.py @@ -22,7 +22,7 @@ """ import logging -from typing import Any, Dict, Optional +from typing import Any from .core import ( CircuitBreakerConfig, @@ -40,15 +40,15 @@ class ErrorHandlingFactory: """Factory for creating error handling components.""" - def __init__(self, config: Optional[Dict[str, Any]] = None) -> None: + def __init__(self, config: dict[str, Any] | None = None) -> None: """Initialize the factory with configuration.""" self.config = config or {} self.validator = ErrorHandlingConfigValidator() self.logger = logging.getLogger("ErrorHandlingFactory") def create_error_handling_system( - self, provider_name: str, config: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: + self, provider_name: str, config: dict[str, Any] | None = None + ) -> dict[str, Any]: """ Create a complete error handling system for a provider. @@ -112,7 +112,7 @@ def create_error_handling_system( return components def _create_circuit_breaker_config( - self, config: Dict[str, Any] + self, config: dict[str, Any] ) -> CircuitBreakerConfig: """Create circuit breaker configuration.""" circuit_config = config.get("circuit_breaker", {}) @@ -124,12 +124,12 @@ def _create_circuit_breaker_config( ) def _create_operation_circuit_breaker_config( - self, config: Dict[str, Any] + self, config: dict[str, Any] ) -> OperationCircuitBreakerConfig: """Create operation-specific circuit breaker configuration.""" return OperationCircuitBreakerConfig() - def _create_retry_config(self, config: Dict[str, Any]) -> RetryConfig: + def _create_retry_config(self, config: dict[str, Any]) -> RetryConfig: """Create retry configuration.""" retry_config = config.get("retry", {}) return RetryConfig( @@ -140,25 +140,25 @@ def _create_retry_config(self, config: Dict[str, Any]) -> RetryConfig: jitter=retry_config.get("jitter", True), ) - def get_default_config(self) -> Dict[str, Any]: + def get_default_config(self) -> dict[str, Any]: """Get default error handling configuration.""" return self.validator.get_default_config() - def validate_config(self, config: Dict[str, Any]) -> tuple[bool, list[str]]: + def validate_config(self, config: dict[str, Any]) -> tuple[bool, list[str]]: """Validate error handling configuration.""" return self.validator.validate_error_handling_config(config) def create_error_handling_factory( - config: Optional[Dict[str, Any]] = None, + config: dict[str, Any] | None = None, ) -> ErrorHandlingFactory: """Create an error handling factory with the given configuration.""" return ErrorHandlingFactory(config) def create_provider_error_handling( - provider_name: str, config: Optional[Dict[str, Any]] = None -) -> Dict[str, Any]: + provider_name: str, config: dict[str, Any] | None = None +) -> dict[str, Any]: """ Create error handling components for a specific provider. @@ -280,14 +280,14 @@ def create_provider_error_handling( } -def get_provider_config(provider_name: str) -> Dict[str, Any]: +def get_provider_config(provider_name: str) -> dict[str, Any]: """Get provider-specific error handling configuration.""" return PROVIDER_CONFIGS.get(provider_name, PROVIDER_CONFIGS["github"]) def create_provider_error_handling_with_preset( - provider_name: str, custom_config: Optional[Dict[str, Any]] = None -) -> Dict[str, Any]: + provider_name: str, custom_config: dict[str, Any] | None = None +) -> dict[str, Any]: """ Create error handling components using provider-specific presets. diff --git a/argus/source_control/health_checks.py b/argus/source_control/health_checks.py index 42802e8..0fdd30b 100644 --- a/argus/source_control/health_checks.py +++ b/argus/source_control/health_checks.py @@ -22,10 +22,11 @@ """ import asyncio +from collections.abc import Awaitable, Callable +from datetime import datetime import logging import time -from datetime import datetime -from typing import Any, Awaitable, Callable, Dict, List, Optional +from typing import Any from .base import SourceControlProvider from .monitoring import HealthCheck, HealthStatus @@ -35,7 +36,7 @@ class HealthCheckRegistry: """Registry for health check implementations.""" def __init__(self) -> None: - self.checks: Dict[ + self.checks: dict[ str, Callable[[SourceControlProvider], Awaitable[HealthCheck]] ] = {} self.logger = logging.getLogger("HealthCheckRegistry") @@ -51,11 +52,11 @@ def register( def get_check( self, name: str - ) -> Optional[Callable[[SourceControlProvider], Awaitable[HealthCheck]]]: + ) -> Callable[[SourceControlProvider], Awaitable[HealthCheck]] | None: """Get a health check function by name.""" return self.checks.get(name) - def list_checks(self) -> List[str]: + def list_checks(self) -> list[str]: """List all registered health check names.""" return list(self.checks.keys()) @@ -396,7 +397,7 @@ def __init__(self) -> None: async def run_all_checks( self, provider: SourceControlProvider - ) -> List[HealthCheck]: + ) -> list[HealthCheck]: """Run all registered health checks on a provider.""" checks = [] check_names = self.registry.list_checks() @@ -430,8 +431,8 @@ async def run_all_checks( return checks async def run_specific_checks( - self, provider: SourceControlProvider, check_names: List[str] - ) -> List[HealthCheck]: + self, provider: SourceControlProvider, check_names: list[str] + ) -> list[HealthCheck]: """Run specific health checks on a provider.""" checks = [] @@ -467,11 +468,11 @@ async def run_specific_checks( return checks - def get_available_checks(self) -> List[str]: + def get_available_checks(self) -> list[str]: """Get list of available health check names.""" return self.registry.list_checks() - def get_check_summary(self, checks: List[HealthCheck]) -> Dict[str, Any]: + def get_check_summary(self, checks: list[HealthCheck]) -> dict[str, Any]: """Get a summary of health check results.""" if not checks: return { @@ -482,7 +483,7 @@ def get_check_summary(self, checks: List[HealthCheck]) -> Dict[str, Any]: "unknown": 0, } - summary: Dict[str, Any] = { + summary: dict[str, Any] = { "total": len(checks), "healthy": sum(1 for c in checks if c.status == HealthStatus.HEALTHY), "degraded": sum(1 for c in checks if c.status == HealthStatus.DEGRADED), diff --git a/argus/source_control/monitoring.py b/argus/source_control/monitoring.py index f0e7985..64ef159 100644 --- a/argus/source_control/monitoring.py +++ b/argus/source_control/monitoring.py @@ -21,13 +21,14 @@ health checks, performance monitoring, and alerting for source control operations. """ -import logging -import time from collections import defaultdict, deque +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any, Awaitable, Callable, Dict, List, Optional +import logging +import time +from typing import Any from .base import SourceControlProvider @@ -58,8 +59,8 @@ class Metric: value: float metric_type: MetricType timestamp: datetime - tags: Dict[str, str] = field(default_factory=dict) - unit: Optional[str] = None + tags: dict[str, str] = field(default_factory=dict) + unit: str | None = None @dataclass @@ -71,7 +72,7 @@ class HealthCheck: message: str timestamp: datetime duration_ms: float - details: Dict[str, Any] = field(default_factory=dict) + details: dict[str, Any] = field(default_factory=dict) @dataclass @@ -83,8 +84,8 @@ class Alert: message: str timestamp: datetime resolved: bool = False - resolved_at: Optional[datetime] = None - metadata: Dict[str, Any] = field(default_factory=dict) + resolved_at: datetime | None = None + metadata: dict[str, Any] = field(default_factory=dict) class MetricsCollector: @@ -92,10 +93,10 @@ class MetricsCollector: def __init__(self, max_metrics: int = 10000) -> None: self.metrics: deque = deque(maxlen=max_metrics) - self.counters: Dict[str, float] = defaultdict(float) - self.gauges: Dict[str, float] = defaultdict(float) - self.histograms: Dict[str, List[float]] = defaultdict(list) - self.timers: Dict[str, List[float]] = defaultdict(list) + self.counters: dict[str, float] = defaultdict(float) + self.gauges: dict[str, float] = defaultdict(float) + self.histograms: dict[str, list[float]] = defaultdict(list) + self.timers: dict[str, list[float]] = defaultdict(list) self.logger = logging.getLogger("MetricsCollector") def record_metric(self, metric: Metric) -> None: @@ -119,7 +120,7 @@ def get_gauge(self, name: str) -> float: """Get current gauge value.""" return self.gauges.get(name, 0.0) - def get_histogram_stats(self, name: str) -> Dict[str, float]: + def get_histogram_stats(self, name: str) -> dict[str, float]: """Get histogram statistics.""" values = self.histograms.get(name, []) if not values: @@ -136,11 +137,11 @@ def get_histogram_stats(self, name: str) -> Dict[str, float]: "p99": sorted_values[int(count * 0.99)] if count > 0 else 0, } - def get_timer_stats(self, name: str) -> Dict[str, float]: + def get_timer_stats(self, name: str) -> dict[str, float]: """Get timer statistics.""" return self.get_histogram_stats(name) - def get_metrics_summary(self) -> Dict[str, Any]: + def get_metrics_summary(self) -> dict[str, Any]: """Get a summary of all metrics.""" return { "counters": dict(self.counters), @@ -157,7 +158,7 @@ class HealthChecker: """Performs comprehensive health checks on source control providers.""" def __init__(self) -> None: - self.health_checks: Dict[ + self.health_checks: dict[ str, Callable[[SourceControlProvider], Awaitable[HealthCheck]] ] = {} self.logger = logging.getLogger("HealthChecker") @@ -172,7 +173,7 @@ def register_health_check( async def run_health_checks( self, provider: SourceControlProvider - ) -> List[HealthCheck]: + ) -> list[HealthCheck]: """Run all registered health checks on a provider.""" results = [] @@ -304,12 +305,12 @@ class AlertManager: """Manages alerts and notifications.""" def __init__(self) -> None: - self.alerts: List[Alert] = [] - self.alert_rules: Dict[str, Callable[[Dict[str, Any]], bool]] = {} - self.notification_handlers: List[Callable[[Alert], None]] = [] + self.alerts: list[Alert] = [] + self.alert_rules: dict[str, Callable[[dict[str, Any]], bool]] = {} + self.notification_handlers: list[Callable[[Alert], None]] = [] self.logger = logging.getLogger("AlertManager") - def add_alert_rule(self, name: str, rule_func: Callable[[Dict[str, Any]], bool]) -> None: + def add_alert_rule(self, name: str, rule_func: Callable[[dict[str, Any]], bool]) -> None: """Add an alert rule.""" self.alert_rules[name] = rule_func @@ -317,7 +318,7 @@ def add_notification_handler(self, handler: Callable[[Alert], None]) -> None: """Add a notification handler.""" self.notification_handlers.append(handler) - def check_alerts(self, metrics: Dict[str, Any]) -> List[Alert]: + def check_alerts(self, metrics: dict[str, Any]) -> list[Alert]: """Check for alert conditions based on metrics.""" new_alerts = [] @@ -346,7 +347,7 @@ def check_alerts(self, metrics: Dict[str, Any]) -> List[Alert]: return new_alerts - def resolve_alert(self, alert_name: str, resolved_at: Optional[datetime] = None) -> None: + def resolve_alert(self, alert_name: str, resolved_at: datetime | None = None) -> None: """Mark an alert as resolved.""" for alert in self.alerts: if alert.name == alert_name and not alert.resolved: @@ -354,7 +355,7 @@ def resolve_alert(self, alert_name: str, resolved_at: Optional[datetime] = None) alert.resolved_at = resolved_at or datetime.now() break - def get_active_alerts(self) -> List[Alert]: + def get_active_alerts(self) -> list[Alert]: """Get all active (unresolved) alerts.""" return [alert for alert in self.alerts if not alert.resolved] @@ -403,7 +404,7 @@ async def monitor_operation( operation_name: str, operation_func: Callable[[], Awaitable[Any]], provider_name: str, - tags: Optional[Dict[str, str]] = None, + tags: dict[str, str] | None = None, ) -> Any: """Monitor a source control operation.""" if not self.metrics_collector: @@ -465,8 +466,8 @@ async def monitor_operation( ) async def run_health_checks( - self, providers: List[SourceControlProvider] - ) -> Dict[str, List[HealthCheck]]: + self, providers: list[SourceControlProvider] + ) -> dict[str, list[HealthCheck]]: """Run health checks on multiple providers.""" if not self.health_checker: return {} @@ -491,7 +492,7 @@ async def run_health_checks( return results - async def check_alerts(self) -> List[Alert]: + async def check_alerts(self) -> list[Alert]: """Check for alert conditions.""" if not self.alert_manager or not self.metrics_collector: return [] @@ -505,8 +506,8 @@ async def check_alerts(self) -> List[Alert]: return self.alert_manager.check_alerts(derived_metrics) def _calculate_derived_metrics( - self, metrics_summary: Dict[str, Any] - ) -> Dict[str, Any]: + self, metrics_summary: dict[str, Any] + ) -> dict[str, Any]: """Calculate derived metrics from raw metrics.""" derived = {} @@ -535,7 +536,7 @@ def _calculate_derived_metrics( return derived - def get_monitoring_summary(self) -> Dict[str, Any]: + def get_monitoring_summary(self) -> dict[str, Any]: """Get a comprehensive monitoring summary.""" summary = { "timestamp": datetime.now().isoformat(), diff --git a/argus/source_control/providers/github/enhanced_github_provider.py b/argus/source_control/providers/github/enhanced_github_provider.py index 86fe54b..4bbe29c 100644 --- a/argus/source_control/providers/github/enhanced_github_provider.py +++ b/argus/source_control/providers/github/enhanced_github_provider.py @@ -22,7 +22,7 @@ degradation, health checks, and metrics collection. """ -from typing import Any, Dict, List, Optional +from typing import Any from github import Github from github.Repository import Repository @@ -48,7 +48,7 @@ class EnhancedGitHubProvider(EnhancedBaseSourceControlProvider): """Enhanced GitHub provider with comprehensive error handling.""" - def __init__(self, config: Dict[str, Any]) -> None: + def __init__(self, config: dict[str, Any]) -> None: """Initialize the enhanced GitHub provider.""" super().__init__(config) @@ -57,13 +57,13 @@ def __init__(self, config: Dict[str, Any]) -> None: self.credentials = ( None # Will be set later when credential management is integrated ) - self.client: Optional[Github] = None - self.repo: Optional[Repository] = None + self.client: Github | None = None + self.repo: Repository | None = None # Initialize component modules - self.operations: Optional[GitHubOperations] = None - self.pull_requests: Optional[GitHubPullRequests] = None - self.utils: Optional[GitHubUtils] = None + self.operations: GitHubOperations | None = None + self.pull_requests: GitHubPullRequests | None = None + self.utils: GitHubUtils | None = None async def initialize(self) -> None: """Initialize the GitHub provider with error handling.""" @@ -161,7 +161,7 @@ async def get_capabilities(self) -> ProviderCapabilities: ) # File operations with error handling - async def get_file_content(self, path: str, ref: Optional[str] = None) -> str: + async def get_file_content(self, path: str, ref: str | None = None) -> str: """Get file content with error handling.""" if not self.operations: raise RuntimeError("Provider not initialized") @@ -171,7 +171,7 @@ async def get_file_content(self, path: str, ref: Optional[str] = None) -> str: ) async def apply_remediation( - self, path: str, content: str, message: str, branch: Optional[str] = None + self, path: str, content: str, message: str, branch: str | None = None ) -> RemediationResult: """Apply remediation with error handling.""" if not self.operations: @@ -186,7 +186,7 @@ async def apply_remediation( branch, ) - async def file_exists(self, path: str, ref: Optional[str] = None) -> bool: + async def file_exists(self, path: str, ref: str | None = None) -> bool: """Check if file exists with error handling.""" if not self.operations: raise RuntimeError("Provider not initialized") @@ -195,7 +195,7 @@ async def file_exists(self, path: str, ref: Optional[str] = None) -> bool: "file_exists", self.operations.file_exists, path, ref ) - async def get_file_info(self, path: str, ref: Optional[str] = None) -> FileInfo: + async def get_file_info(self, path: str, ref: str | None = None) -> FileInfo: """Get file information with error handling.""" if not self.operations: raise RuntimeError("Provider not initialized") @@ -205,8 +205,8 @@ async def get_file_info(self, path: str, ref: Optional[str] = None) -> FileInfo: ) async def list_files( - self, path: str = "", ref: Optional[str] = None - ) -> List[FileInfo]: + self, path: str = "", ref: str | None = None + ) -> list[FileInfo]: """List files with error handling.""" if not self.operations: raise RuntimeError("Provider not initialized") @@ -234,8 +234,8 @@ async def apply_patch(self, patch: str, file_path: str) -> bool: ) async def commit_changes( - self, file_path: str, content: str, message: str, branch: Optional[str] = None - ) -> Optional[str]: + self, file_path: str, content: str, message: str, branch: str | None = None + ) -> str | None: """Commit changes with error handling.""" if not self.operations: raise RuntimeError("Provider not initialized") @@ -250,7 +250,7 @@ async def commit_changes( ) # Branch operations with error handling - async def create_branch(self, name: str, base_ref: Optional[str] = None) -> bool: + async def create_branch(self, name: str, base_ref: str | None = None) -> bool: """Create branch with error handling.""" if not self.operations: raise RuntimeError("Provider not initialized") @@ -268,7 +268,7 @@ async def delete_branch(self, name: str) -> bool: "delete_branch", self.operations.delete_branch, name ) - async def list_branches(self) -> List[BranchInfo]: + async def list_branches(self) -> list[BranchInfo]: """List branches with error handling.""" if not self.operations: raise RuntimeError("Provider not initialized") @@ -277,7 +277,7 @@ async def list_branches(self) -> List[BranchInfo]: "list_branches", self.operations.list_branches ) - async def get_branch_info(self, name: str) -> Optional[BranchInfo]: + async def get_branch_info(self, name: str) -> BranchInfo | None: """Get branch info with error handling.""" if not self.operations: raise RuntimeError("Provider not initialized") @@ -305,7 +305,7 @@ async def get_repository_info(self) -> RepositoryInfo: ) async def check_conflicts( - self, path: str, content: str, branch: Optional[str] = None + self, path: str, content: str, branch: str | None = None ) -> bool: """Check conflicts with error handling.""" if not self.operations: @@ -331,7 +331,7 @@ async def resolve_conflicts( ) # Git operations with error handling - async def get_file_history(self, path: str, limit: int = 10) -> List[CommitInfo]: + async def get_file_history(self, path: str, limit: int = 10) -> list[CommitInfo]: """Get file history with error handling.""" if not self.operations: raise RuntimeError("Provider not initialized") @@ -398,8 +398,8 @@ async def create_merge_request( # Enhanced batch operations with error handling async def batch_operations( - self, operations: List[BatchOperation] - ) -> List[OperationResult]: + self, operations: list[BatchOperation] + ) -> list[OperationResult]: """Execute batch operations with comprehensive error handling.""" if not self.operations: raise RuntimeError("Provider not initialized") diff --git a/argus/source_control/providers/github/github_provider.py b/argus/source_control/providers/github/github_provider.py index 9cf4c79..ea977a8 100644 --- a/argus/source_control/providers/github/github_provider.py +++ b/argus/source_control/providers/github/github_provider.py @@ -22,7 +22,7 @@ """ import logging -from typing import Any, Dict, List, Optional +from typing import Any from github import Github from github.Repository import Repository @@ -52,7 +52,7 @@ class GitHubProvider(BaseSourceControlProvider): """GitHub implementation of the SourceControlProvider interface.""" - def __init__(self, config: Dict[str, Any]) -> None: + def __init__(self, config: dict[str, Any]) -> None: """Initialize the GitHub provider with configuration.""" super().__init__(config) # Convert config dict back to GitHubRepositoryConfig for type safety @@ -61,17 +61,17 @@ def __init__(self, config: Dict[str, Any]) -> None: None # Will be set later when credential management is integrated ) self.logger = logging.getLogger("GitHubProvider") - self.client: Optional[Github] = None - self.repo: Optional[Repository] = None + self.client: Github | None = None + self.repo: Repository | None = None # Initialize component modules - self.operations: Optional[GitHubOperations] = None - self.pull_requests: Optional[GitHubPullRequests] = None - self.utils: Optional[GitHubUtils] = None + self.operations: GitHubOperations | None = None + self.pull_requests: GitHubPullRequests | None = None + self.utils: GitHubUtils | None = None # Initialize error handling system self.error_handling_factory = ErrorHandlingFactory() - self.error_handling_components: Optional[Dict[str, Any]] = None + self.error_handling_components: dict[str, Any] | None = None async def _setup_client(self) -> None: """Set up GitHub client and repository.""" @@ -133,7 +133,7 @@ async def _with_retry(self, operation_func, *args, **kwargs): return await self.utils._with_retry(operation_func, *args, **kwargs) # Delegate operations to component modules - async def get_file_content(self, path: str, ref: Optional[str] = None) -> str: + async def get_file_content(self, path: str, ref: str | None = None) -> str: """Get file content from GitHub repository.""" if not self.operations: raise RuntimeError("GitHub operations not initialized") @@ -146,7 +146,7 @@ async def apply_remediation( path: str, content: str, message: str, - branch: Optional[str] = None, + branch: str | None = None, ) -> RemediationResult: """Apply a remediation to a file.""" if not self.operations: @@ -160,7 +160,7 @@ async def apply_remediation( branch, ) - async def create_branch(self, name: str, base_ref: Optional[str] = None) -> bool: + async def create_branch(self, name: str, base_ref: str | None = None) -> bool: """Create a new branch.""" if not self.operations: raise RuntimeError("GitHub operations not initialized") @@ -176,7 +176,7 @@ async def delete_branch(self, name: str) -> bool: "delete_branch", self.operations.delete_branch, name ) - async def list_branches(self) -> List[BranchInfo]: + async def list_branches(self) -> list[BranchInfo]: """List all branches in the repository.""" if not self.operations: raise RuntimeError("GitHub operations not initialized") @@ -189,7 +189,7 @@ async def get_repository_info(self) -> RepositoryInfo: return await self.operations.get_repository_info() async def check_conflicts( - self, path: str, content: str, branch: Optional[str] = None + self, path: str, content: str, branch: str | None = None ) -> bool: """Check for merge conflicts between branches.""" if not self.operations: @@ -217,8 +217,8 @@ async def resolve_conflicts( return await self.operations.resolve_conflicts(path, content, strategy) async def batch_operations( - self, operations: List[BatchOperation] - ) -> List[OperationResult]: + self, operations: list[BatchOperation] + ) -> list[OperationResult]: """Execute multiple operations in batch.""" if not self.operations: raise RuntimeError("GitHub operations not initialized") @@ -333,19 +333,19 @@ async def create_merge_request( additional_info={}, ) - async def file_exists(self, path: str, ref: Optional[str] = None) -> bool: + async def file_exists(self, path: str, ref: str | None = None) -> bool: """Check if a file exists in the repository.""" if not self.operations: raise RuntimeError("GitHub operations not initialized") return await self.operations.file_exists(path, ref) - async def get_branch_info(self, name: str) -> Optional[BranchInfo]: + async def get_branch_info(self, name: str) -> BranchInfo | None: """Get information about a specific branch.""" if not self.operations: raise RuntimeError("GitHub operations not initialized") return await self.operations.get_branch_info(name) - async def get_file_info(self, path: str, ref: Optional[str] = None) -> FileInfo: + async def get_file_info(self, path: str, ref: str | None = None) -> FileInfo: """Get detailed information about a file with error handling.""" if not self.operations: raise RuntimeError("GitHub operations not initialized") @@ -413,8 +413,8 @@ async def get_health_status(self) -> ProviderHealth: return await self.utils.get_health_status() async def list_files( - self, path: str = "", recursive: bool = True, ref: Optional[str] = None - ) -> List[FileInfo]: + self, path: str = "", recursive: bool = True, ref: str | None = None + ) -> list[FileInfo]: """List files in the repository.""" if not self.operations: raise RuntimeError("GitHub operations not initialized") @@ -439,14 +439,14 @@ async def commit_changes( file_path: str, content: str, message: str, - branch: Optional[str] = None, - ) -> Optional[str]: + branch: str | None = None, + ) -> str | None: """Commit changes to a file.""" if not self.operations: raise RuntimeError("GitHub operations not initialized") return await self.operations.commit_changes(file_path, content, message, branch) - async def get_file_history(self, path: str, limit: int = 10) -> List[CommitInfo]: + async def get_file_history(self, path: str, limit: int = 10) -> list[CommitInfo]: """Get commit history for a file.""" if not self.operations: raise RuntimeError("GitHub operations not initialized") @@ -464,19 +464,19 @@ async def get_current_branch(self) -> str: raise RuntimeError("GitHub operations not initialized") return await self.operations.get_current_branch() - async def get_status(self) -> Dict[str, Any]: + async def get_status(self) -> dict[str, Any]: """Get current status of the GitHub provider.""" if not self.utils: return {"error": "GitHub utils not initialized"} return await self.utils.get_status() - async def execute_git_command(self, command: str, **kwargs) -> Dict[str, Any]: + async def execute_git_command(self, command: str, **kwargs) -> dict[str, Any]: """Execute a git command (placeholder for GitHub API operations).""" if not self.utils: return {"error": "GitHub utils not initialized"} return await self.utils.execute_git_command(command, **kwargs) - async def get_conflict_info(self, path: str) -> Optional[ConflictInfo]: + async def get_conflict_info(self, path: str) -> ConflictInfo | None: """Get conflict information for a file.""" if not self.operations: raise RuntimeError("GitHub operations not initialized") diff --git a/argus/source_control/providers/gitlab/enhanced_gitlab_provider.py b/argus/source_control/providers/gitlab/enhanced_gitlab_provider.py index f18a87a..6b6c22e 100644 --- a/argus/source_control/providers/gitlab/enhanced_gitlab_provider.py +++ b/argus/source_control/providers/gitlab/enhanced_gitlab_provider.py @@ -22,7 +22,7 @@ degradation, health checks, and metrics collection. """ -from typing import Any, Dict, List, Optional +from typing import Any import gitlab @@ -48,19 +48,19 @@ class EnhancedGitLabProvider(EnhancedBaseSourceControlProvider): """Enhanced GitLab provider with comprehensive error handling.""" - def __init__(self, config: Dict[str, Any]) -> None: + def __init__(self, config: dict[str, Any]) -> None: """Initialize the enhanced GitLab provider.""" super().__init__(config) self.repo_config = GitLabRepositoryConfig(**config) - self.credentials: Optional[GitLabCredentials] = None - self.gl: Optional[gitlab.Gitlab] = None - self.project: Optional[Any] = None + self.credentials: GitLabCredentials | None = None + self.gl: gitlab.Gitlab | None = None + self.project: Any | None = None # Initialize sub-modules (will be set after initialization) - self.file_ops: Optional[GitLabFileOperations] = None - self.branch_ops: Optional[GitLabBranchOperations] = None - self.mr_ops: Optional[GitLabMergeRequestOperations] = None + self.file_ops: GitLabFileOperations | None = None + self.branch_ops: GitLabBranchOperations | None = None + self.mr_ops: GitLabMergeRequestOperations | None = None async def initialize(self) -> None: """Initialize the GitLab provider with error handling.""" @@ -168,7 +168,7 @@ async def get_capabilities(self) -> ProviderCapabilities: return await self.mr_ops.get_capabilities() # File operations with error handling - async def get_file_content(self, path: str, ref: Optional[str] = None) -> str: + async def get_file_content(self, path: str, ref: str | None = None) -> str: """Get file content with error handling.""" if not self.file_ops: raise RuntimeError("Provider not initialized") @@ -178,7 +178,7 @@ async def get_file_content(self, path: str, ref: Optional[str] = None) -> str: ) async def apply_remediation( - self, path: str, content: str, message: str, branch: Optional[str] = None + self, path: str, content: str, message: str, branch: str | None = None ) -> RemediationResult: """Apply remediation with error handling.""" if not self.file_ops: @@ -193,7 +193,7 @@ async def apply_remediation( branch, ) - async def file_exists(self, path: str, ref: Optional[str] = None) -> bool: + async def file_exists(self, path: str, ref: str | None = None) -> bool: """Check if file exists with error handling.""" if not self.file_ops: raise RuntimeError("Provider not initialized") @@ -202,7 +202,7 @@ async def file_exists(self, path: str, ref: Optional[str] = None) -> bool: "file_exists", self.file_ops.file_exists, path, ref ) - async def get_file_info(self, path: str, ref: Optional[str] = None) -> FileInfo: + async def get_file_info(self, path: str, ref: str | None = None) -> FileInfo: """Get file information with error handling.""" if not self.file_ops: raise RuntimeError("Provider not initialized") @@ -212,8 +212,8 @@ async def get_file_info(self, path: str, ref: Optional[str] = None) -> FileInfo: ) async def list_files( - self, path: str = "", ref: Optional[str] = None - ) -> List[FileInfo]: + self, path: str = "", ref: str | None = None + ) -> list[FileInfo]: """List files with error handling.""" if not self.file_ops: raise RuntimeError("Provider not initialized") @@ -241,8 +241,8 @@ async def apply_patch(self, patch: str, file_path: str) -> bool: ) async def commit_changes( - self, file_path: str, content: str, message: str, branch: Optional[str] = None - ) -> Optional[str]: + self, file_path: str, content: str, message: str, branch: str | None = None + ) -> str | None: """Commit changes with error handling.""" if not self.file_ops: raise RuntimeError("Provider not initialized") @@ -257,7 +257,7 @@ async def commit_changes( ) # Branch operations with error handling - async def create_branch(self, name: str, base_ref: Optional[str] = None) -> bool: + async def create_branch(self, name: str, base_ref: str | None = None) -> bool: """Create branch with error handling.""" if not self.branch_ops: raise RuntimeError("Provider not initialized") @@ -275,7 +275,7 @@ async def delete_branch(self, name: str) -> bool: "delete_branch", self.branch_ops.delete_branch, name ) - async def list_branches(self) -> List[BranchInfo]: + async def list_branches(self) -> list[BranchInfo]: """List branches with error handling.""" if not self.branch_ops: raise RuntimeError("Provider not initialized") @@ -284,7 +284,7 @@ async def list_branches(self) -> List[BranchInfo]: "list_branches", self.branch_ops.list_branches ) - async def get_branch_info(self, name: str) -> Optional[BranchInfo]: + async def get_branch_info(self, name: str) -> BranchInfo | None: """Get branch info with error handling.""" if not self.branch_ops: raise RuntimeError("Provider not initialized") @@ -312,7 +312,7 @@ async def get_repository_info(self) -> RepositoryInfo: ) async def check_conflicts( - self, path: str, content: str, branch: Optional[str] = None + self, path: str, content: str, branch: str | None = None ) -> bool: """Check conflicts with error handling.""" if not self.branch_ops: @@ -338,7 +338,7 @@ async def resolve_conflicts( ) # Git operations with error handling - async def get_file_history(self, path: str, limit: int = 10) -> List[CommitInfo]: + async def get_file_history(self, path: str, limit: int = 10) -> list[CommitInfo]: """Get file history with error handling.""" if not self.file_ops: raise RuntimeError("Provider not initialized") @@ -418,8 +418,8 @@ async def create_merge_request( # Enhanced batch operations with error handling async def batch_operations( - self, operations: List[BatchOperation] - ) -> List[OperationResult]: + self, operations: list[BatchOperation] + ) -> list[OperationResult]: """Execute batch operations with comprehensive error handling.""" if not self.file_ops: raise RuntimeError("Provider not initialized") diff --git a/argus/source_control/providers/gitlab/gitlab_provider.py b/argus/source_control/providers/gitlab/gitlab_provider.py index 296686d..968e3d5 100644 --- a/argus/source_control/providers/gitlab/gitlab_provider.py +++ b/argus/source_control/providers/gitlab/gitlab_provider.py @@ -17,7 +17,7 @@ """GitLab provider implementation for source control operations.""" import logging -from typing import Any, Dict, List, Optional +from typing import Any import gitlab @@ -46,23 +46,23 @@ class GitLabProvider(BaseSourceControlProvider): """GitLab provider implementation.""" - def __init__(self, config: Dict[str, Any]) -> None: + def __init__(self, config: dict[str, Any]) -> None: """Initialize the GitLab provider.""" super().__init__(config) self.repo_config = GitLabRepositoryConfig(**config) - self.credentials: Optional[GitLabCredentials] = None - self.gl: Optional[gitlab.Gitlab] = None - self.project: Optional[Any] = None + self.credentials: GitLabCredentials | None = None + self.gl: gitlab.Gitlab | None = None + self.project: Any | None = None self.logger = logging.getLogger(__name__) # Initialize sub-modules (will be set after initialization) - self.file_ops: Optional[GitLabFileOperations] = None - self.branch_ops: Optional[GitLabBranchOperations] = None - self.mr_ops: Optional[GitLabMergeRequestOperations] = None + self.file_ops: GitLabFileOperations | None = None + self.branch_ops: GitLabBranchOperations | None = None + self.mr_ops: GitLabMergeRequestOperations | None = None # Initialize error handling system self.error_handling_factory = ErrorHandlingFactory() - self.error_handling_components: Optional[Dict[str, Any]] = None + self.error_handling_components: dict[str, Any] | None = None async def __aenter__(self): """Async context manager entry.""" @@ -177,7 +177,7 @@ async def get_health_status(self) -> ProviderHealth: ) # File operations - delegate to file_ops - async def get_file_content(self, path: str, ref: Optional[str] = None) -> str: + async def get_file_content(self, path: str, ref: str | None = None) -> str: """Get file content from GitLab repository.""" if not self.file_ops: raise RuntimeError("Provider not initialized") @@ -190,7 +190,7 @@ async def apply_remediation( path: str, content: str, message: str, - branch: Optional[str] = None, + branch: str | None = None, ) -> RemediationResult: """Apply remediation to a file.""" if not self.file_ops: @@ -204,21 +204,21 @@ async def apply_remediation( branch, ) - async def file_exists(self, path: str, ref: Optional[str] = None) -> bool: + async def file_exists(self, path: str, ref: str | None = None) -> bool: """Check if a file exists in the repository.""" if not self.file_ops: raise RuntimeError("Provider not initialized") return await self.file_ops.file_exists(path, ref) - async def get_file_info(self, path: str, ref: Optional[str] = None) -> FileInfo: + async def get_file_info(self, path: str, ref: str | None = None) -> FileInfo: """Get file information.""" if not self.file_ops: raise RuntimeError("Provider not initialized") return await self.file_ops.get_file_info(path, ref) async def list_files( - self, path: str = "", ref: Optional[str] = None - ) -> List[FileInfo]: + self, path: str = "", ref: str | None = None + ) -> list[FileInfo]: """List files in a directory.""" if not self.file_ops: raise RuntimeError("Provider not initialized") @@ -241,15 +241,15 @@ async def commit_changes( file_path: str, content: str, message: str, - branch: Optional[str] = None, - ) -> Optional[str]: + branch: str | None = None, + ) -> str | None: """Commit changes to a file.""" if not self.file_ops: raise RuntimeError("Provider not initialized") return await self.file_ops.commit_changes(file_path, content, message, branch) # Branch operations - delegate to branch_ops - async def create_branch(self, name: str, base_ref: Optional[str] = None) -> bool: + async def create_branch(self, name: str, base_ref: str | None = None) -> bool: """Create a new branch.""" if not self.branch_ops: raise RuntimeError("Provider not initialized") @@ -261,13 +261,13 @@ async def delete_branch(self, name: str) -> bool: raise RuntimeError("Provider not initialized") return await self.branch_ops.delete_branch(name) - async def list_branches(self) -> List[BranchInfo]: + async def list_branches(self) -> list[BranchInfo]: """List all branches.""" if not self.branch_ops: raise RuntimeError("Provider not initialized") return await self.branch_ops.list_branches() - async def get_branch_info(self, name: str) -> Optional[BranchInfo]: + async def get_branch_info(self, name: str) -> BranchInfo | None: """Get information about a specific branch.""" if not self.branch_ops: raise RuntimeError("Provider not initialized") @@ -289,7 +289,7 @@ async def check_conflicts( self, path: str, content: str, - branch: Optional[str] = None, + branch: str | None = None, ) -> bool: """Check for conflicts between branches.""" if not self.branch_ops: @@ -319,7 +319,7 @@ async def resolve_conflicts( return await self.branch_ops.resolve_conflicts(path, content, strategy) # Git operations - delegate to file_ops - async def get_file_history(self, path: str, limit: int = 10) -> List[CommitInfo]: + async def get_file_history(self, path: str, limit: int = 10) -> list[CommitInfo]: """Get file commit history.""" if not self.file_ops: raise RuntimeError("Provider not initialized") @@ -383,8 +383,8 @@ async def create_pull_request( # Batch operations (simplified implementation) async def batch_operations( - self, operations: List[BatchOperation] - ) -> List[OperationResult]: + self, operations: list[BatchOperation] + ) -> list[OperationResult]: """Execute multiple operations in batch.""" results = [] for operation in operations: diff --git a/argus/source_control/providers/local/enhanced_local_provider.py b/argus/source_control/providers/local/enhanced_local_provider.py index ccc02c2..a1d9458 100644 --- a/argus/source_control/providers/local/enhanced_local_provider.py +++ b/argus/source_control/providers/local/enhanced_local_provider.py @@ -23,7 +23,7 @@ """ from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any from ....config.source_control_repositories import LocalRepositoryConfig from ...enhanced_base_implementation import EnhancedBaseSourceControlProvider @@ -46,7 +46,7 @@ class EnhancedLocalProvider(EnhancedBaseSourceControlProvider): """Enhanced Local provider with comprehensive error handling.""" - def __init__(self, config: Dict[str, Any]) -> None: + def __init__(self, config: dict[str, Any]) -> None: """Initialize the enhanced Local provider.""" super().__init__(config) @@ -167,7 +167,7 @@ async def get_capabilities(self) -> ProviderCapabilities: ) # File operations with error handling - async def get_file_content(self, path: str, ref: Optional[str] = None) -> str: + async def get_file_content(self, path: str, ref: str | None = None) -> str: """Get file content with error handling.""" if not self.file_ops: raise RuntimeError("File operations not initialized") @@ -176,7 +176,7 @@ async def get_file_content(self, path: str, ref: Optional[str] = None) -> str: ) async def apply_remediation( - self, path: str, content: str, message: str, branch: Optional[str] = None + self, path: str, content: str, message: str, branch: str | None = None ) -> RemediationResult: """Apply remediation with error handling.""" if not self.file_ops: @@ -185,7 +185,7 @@ async def apply_remediation( "apply_remediation", self.file_ops.apply_remediation, path, content, message ) - async def file_exists(self, path: str, ref: Optional[str] = None) -> bool: + async def file_exists(self, path: str, ref: str | None = None) -> bool: """Check if file exists with error handling.""" if not self.file_ops: raise RuntimeError("File operations not initialized") @@ -193,7 +193,7 @@ async def file_exists(self, path: str, ref: Optional[str] = None) -> bool: "file_exists", self.file_ops.file_exists, path ) - async def get_file_info(self, path: str, ref: Optional[str] = None) -> FileInfo: + async def get_file_info(self, path: str, ref: str | None = None) -> FileInfo: """Get file information with error handling.""" if not self.file_ops: raise RuntimeError("File operations not initialized") @@ -202,8 +202,8 @@ async def get_file_info(self, path: str, ref: Optional[str] = None) -> FileInfo: ) async def list_files( - self, path: str = "", ref: Optional[str] = None - ) -> List[FileInfo]: + self, path: str = "", ref: str | None = None + ) -> list[FileInfo]: """List files with error handling.""" if not self.file_ops: raise RuntimeError("File operations not initialized") @@ -228,8 +228,8 @@ async def apply_patch(self, patch: str, file_path: str) -> bool: ) async def commit_changes( - self, file_path: str, content: str, message: str, branch: Optional[str] = None - ) -> Optional[str]: + self, file_path: str, content: str, message: str, branch: str | None = None + ) -> str | None: """Commit changes with error handling.""" if not self.git_enabled: # For non-Git local operations, just write the file @@ -251,7 +251,7 @@ async def commit_changes( ) # Branch operations with error handling (only if Git is enabled) - async def create_branch(self, name: str, base_ref: Optional[str] = None) -> bool: + async def create_branch(self, name: str, base_ref: str | None = None) -> bool: """Create branch with error handling.""" if not self.git_enabled: raise RuntimeError("Git operations are not enabled") @@ -273,7 +273,7 @@ async def delete_branch(self, name: str) -> bool: "delete_branch", self.git_ops.delete_branch, name ) - async def list_branches(self) -> List[BranchInfo]: + async def list_branches(self) -> list[BranchInfo]: """List branches with error handling.""" if not self.git_enabled: return [] @@ -284,7 +284,7 @@ async def list_branches(self) -> List[BranchInfo]: "list_branches", self.git_ops.list_branches ) - async def get_branch_info(self, name: str) -> Optional[BranchInfo]: + async def get_branch_info(self, name: str) -> BranchInfo | None: """Get branch info with error handling.""" if not self.git_enabled: return None @@ -323,7 +323,7 @@ async def get_repository_info(self) -> RepositoryInfo: ) async def check_conflicts( - self, path: str, content: str, branch: Optional[str] = None + self, path: str, content: str, branch: str | None = None ) -> bool: """Check conflicts with error handling.""" if not self.git_enabled: @@ -349,7 +349,7 @@ async def resolve_conflicts( ) # Git operations with error handling (only if Git is enabled) - async def get_file_history(self, path: str, limit: int = 10) -> List[CommitInfo]: + async def get_file_history(self, path: str, limit: int = 10) -> list[CommitInfo]: """Get file history with error handling.""" if not self.git_enabled: return [] # No history for non-Git operations @@ -418,8 +418,8 @@ async def create_merge_request( # Enhanced batch operations with error handling async def batch_operations( - self, operations: List[BatchOperation] - ) -> List[OperationResult]: + self, operations: list[BatchOperation] + ) -> list[OperationResult]: """Execute batch operations with comprehensive error handling.""" if not self.batch_ops: raise RuntimeError("Batch operations not initialized") diff --git a/argus/source_control/providers/local/local_provider.py b/argus/source_control/providers/local/local_provider.py index feda446..36df06c 100644 --- a/argus/source_control/providers/local/local_provider.py +++ b/argus/source_control/providers/local/local_provider.py @@ -17,7 +17,7 @@ """Local filesystem provider with Git integration and patch generation capabilities.""" from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any from argus.config.source_control_repositories import LocalRepositoryConfig from argus.source_control.base_implementation import ( @@ -45,7 +45,7 @@ class LocalProvider(BaseSourceControlProvider): """Provider for local filesystem operations with Git integration and patch generation capabilities.""" - def __init__(self, config: Dict[str, Any]) -> None: + def __init__(self, config: dict[str, Any]) -> None: """Initialize the local provider with configuration.""" super().__init__(config) # Convert config dict back to LocalRepositoryConfig for type safety @@ -139,7 +139,7 @@ async def get_health_status(self) -> ProviderHealth: ) # File operations - delegate to file_ops - async def get_file_content(self, path: str, ref: Optional[str] = None) -> str: + async def get_file_content(self, path: str, ref: str | None = None) -> str: """Get file content from local filesystem with error handling.""" # Use error handling if available if self._error_handling_components: @@ -164,7 +164,7 @@ async def apply_remediation( path: str, content: str, message: str, - branch: Optional[str] = None, + branch: str | None = None, ) -> RemediationResult: """Apply remediation to a file with error handling.""" # Use error handling if available @@ -198,17 +198,17 @@ async def apply_remediation( else: return await self.file_ops.apply_remediation(path, content, message) - async def file_exists(self, path: str, ref: Optional[str] = None) -> bool: + async def file_exists(self, path: str, ref: str | None = None) -> bool: """Check if a file exists.""" return await self.file_ops.file_exists(path) - async def get_file_info(self, path: str, ref: Optional[str] = None) -> FileInfo: + async def get_file_info(self, path: str, ref: str | None = None) -> FileInfo: """Get file information.""" return await self.file_ops.get_file_info(path) async def list_files( - self, path: str = "", ref: Optional[str] = None - ) -> List[FileInfo]: + self, path: str = "", ref: str | None = None + ) -> list[FileInfo]: """List files in a directory.""" return await self.file_ops.list_files(path) @@ -225,14 +225,14 @@ async def commit_changes( file_path: str, content: str, message: str, - branch: Optional[str] = None, - ) -> Optional[str]: + branch: str | None = None, + ) -> str | None: """Commit changes to a file.""" success = await self.file_ops.commit_changes(file_path, content, message) return "local_commit" if success else None # Branch operations - delegate to git_ops - async def create_branch(self, name: str, base_ref: Optional[str] = None) -> bool: + async def create_branch(self, name: str, base_ref: str | None = None) -> bool: """Create a new branch.""" return await self.git_ops.create_branch(name, base_ref) @@ -240,11 +240,11 @@ async def delete_branch(self, name: str) -> bool: """Delete a branch.""" return await self.git_ops.delete_branch(name) - async def list_branches(self) -> List[BranchInfo]: + async def list_branches(self) -> list[BranchInfo]: """List all branches.""" return await self.git_ops.list_branches() - async def get_branch_info(self, name: str) -> Optional[BranchInfo]: + async def get_branch_info(self, name: str) -> BranchInfo | None: """Get information about a specific branch.""" return await self.git_ops.get_branch_info(name) @@ -260,7 +260,7 @@ async def check_conflicts( self, path: str, content: str, - branch: Optional[str] = None, + branch: str | None = None, ) -> bool: """Check for conflicts between branches.""" # Use the default branch if no branch is specified @@ -286,7 +286,7 @@ async def resolve_conflicts( return await self.git_ops.resolve_conflicts(path, content, strategy) # Git operations - delegate to git_ops - async def get_file_history(self, path: str, limit: int = 10) -> List[CommitInfo]: + async def get_file_history(self, path: str, limit: int = 10) -> list[CommitInfo]: """Get file commit history.""" return await self.git_ops.get_file_history(path, limit) @@ -294,14 +294,14 @@ async def diff_between_commits(self, base_sha: str, head_sha: str) -> str: """Get diff between two commits.""" return await self.git_ops.diff_between_commits(base_sha, head_sha) - async def execute_git_command(self, command: List[str]) -> str: + async def execute_git_command(self, command: list[str]) -> str: """Execute a Git command and return output.""" return await self.git_ops.execute_git_command(command) # Batch operations - delegate to batch_ops async def batch_operations( - self, operations: List[BatchOperation] - ) -> List[OperationResult]: + self, operations: list[BatchOperation] + ) -> list[OperationResult]: """Execute multiple operations in batch.""" return await self.batch_ops.batch_operations(operations) diff --git a/tests/ingestion/adapters/test_aws_cloudwatch.py b/tests/ingestion/adapters/test_aws_cloudwatch.py index 1e2904b..a8a2e63 100644 --- a/tests/ingestion/adapters/test_aws_cloudwatch.py +++ b/tests/ingestion/adapters/test_aws_cloudwatch.py @@ -240,7 +240,7 @@ async def test_get_health_metrics(self, mock_boto3, adapter): assert "last_check_time" in metrics assert "region" in metrics assert "log_group_name" in metrics - + def test_get_config(self, adapter: str, config: str) -> None: """Test getting configuration.""" returned_config = adapter.get_config() diff --git a/tests/ingestion_config/simple_config_test.py b/tests/ingestion_config/simple_config_test.py index 7b3d6be..83593c9 100644 --- a/tests/ingestion_config/simple_config_test.py +++ b/tests/ingestion_config/simple_config_test.py @@ -518,4 +518,4 @@ def main() -> None: if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/ingestion_config/test_config_system.py b/tests/ingestion_config/test_config_system.py index 88520d2..924d53e 100644 --- a/tests/ingestion_config/test_config_system.py +++ b/tests/ingestion_config/test_config_system.py @@ -350,4 +350,4 @@ def main() -> None: if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/llm/test_factory_core.py b/tests/llm/test_factory_core.py index b11d3ca..15d3736 100644 --- a/tests/llm/test_factory_core.py +++ b/tests/llm/test_factory_core.py @@ -19,7 +19,7 @@ from unittest.mock import MagicMock, patch import pytest -from typing import Optional + # Mock the dependencies before importing the factory with patch.dict( "sys.modules", @@ -48,7 +48,7 @@ async def generate_structured(self, prompt, response_model, model=None, **kwargs return response_model() def generate_stream( - self, prompt: str, model: Optional[str] = None, **kwargs: str + self, prompt: str, model: str | None = None, **kwargs: str ) -> None: """ Generate Stream. @@ -74,7 +74,7 @@ def get_available_models(self) -> None: """ return ["mock-model"] - def estimate_cost(self, prompt: str, model: Optional[str] = None) -> None: + def estimate_cost(self, prompt: str, model: str | None = None) -> None: """ Estimate Cost. diff --git a/tests/llm/test_provider_interface.py b/tests/llm/test_provider_interface.py index 46a6a86..a1d3f55 100644 --- a/tests/llm/test_provider_interface.py +++ b/tests/llm/test_provider_interface.py @@ -19,7 +19,7 @@ from unittest.mock import MagicMock, patch from pydantic import BaseModel -from typing import Optional + # Mock the dependencies before importing the provider mock_prompt_class = MagicMock() with patch.dict( @@ -59,7 +59,7 @@ async def generate_structured(self, prompt, response_model, model=None, **kwargs return response_model(message="Test", confidence=0.95) def generate_stream( - self, prompt: str, model: Optional[str] = None, **kwargs: str + self, prompt: str, model: str | None = None, **kwargs: str ) -> None: """ Generate Stream. diff --git a/tests/llm/test_service.py b/tests/llm/test_service.py index b236724..f3df655 100644 --- a/tests/llm/test_service.py +++ b/tests/llm/test_service.py @@ -234,8 +234,7 @@ def test_missing_dependencies(self) -> None: """Test behavior when required dependencies are missing.""" with patch.dict( "sys.modules", {"instructor": None, "litellm": None, "mirascope": None} + ), pytest.raises( + ImportError, match="Required dependencies not installed" ): - with pytest.raises( - ImportError, match="Required dependencies not installed" - ): - pass + pass diff --git a/tests/source_control/test_base_sub_operation.py b/tests/source_control/test_base_sub_operation.py index fd3e93b..469ab45 100644 --- a/tests/source_control/test_base_sub_operation.py +++ b/tests/source_control/test_base_sub_operation.py @@ -26,7 +26,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from typing import Optional + from argus.source_control.providers.base_sub_operation import ( BaseSubOperation, ) @@ -41,8 +41,8 @@ class MockSubOperation(BaseSubOperation): def __init__( self, logger: str, - error_handling_components: Optional[str] = None, - config: Optional[str] = None, + error_handling_components: str | None = None, + config: str | None = None, ) -> None: super().__init__( logger=logger, diff --git a/tests/source_control/test_github_provider.py b/tests/source_control/test_github_provider.py index b3ef9de..dd5665a 100644 --- a/tests/source_control/test_github_provider.py +++ b/tests/source_control/test_github_provider.py @@ -416,25 +416,24 @@ async def test_batch_operations_success(self, github_provider): # Mock the apply_remediation method with patch.object( github_provider, "apply_remediation", return_value=MagicMock(success=True) - ): - with patch.object(github_provider, "create_branch", return_value=True): - operations = [ - BatchOperation( - operation_type="update_file", - path="file1.py", - content="content1", - message="Update file1", - ), - BatchOperation( - operation_type="create_branch", - parameters={"name": "new-branch"}, - ), - ] - - results = await github_provider.batch_operations(operations) - - assert len(results) == 2 - assert all(status == OperationStatus.SUCCESS for status in results) + ), patch.object(github_provider, "create_branch", return_value=True): + operations = [ + BatchOperation( + operation_type="update_file", + path="file1.py", + content="content1", + message="Update file1", + ), + BatchOperation( + operation_type="create_branch", + parameters={"name": "new-branch"}, + ), + ] + + results = await github_provider.batch_operations(operations) + + assert len(results) == 2 + assert all(status == OperationStatus.SUCCESS for status in results) @pytest.mark.asyncio async def test_batch_operations_with_failures(self, github_provider): diff --git a/tests/source_control/test_repository_access_security.py b/tests/source_control/test_repository_access_security.py index 67a3649..ef7e5aa 100644 --- a/tests/source_control/test_repository_access_security.py +++ b/tests/source_control/test_repository_access_security.py @@ -101,12 +101,11 @@ def test_credential_validation_for_repository_access( assert token == "valid_token" # Test that invalid credentials are rejected - with patch.dict("os.environ", {}, clear=True): - with pytest.raises( - ValueError, - match="At least one authentication method must be provided", - ): - CredentialConfig() + with patch.dict("os.environ", {}, clear=True), pytest.raises( + ValueError, + match="At least one authentication method must be provided", + ): + CredentialConfig() def test_repository_url_security_validation( self, mock_github_provider: str @@ -224,9 +223,8 @@ async def test_repository_access_error_handling(self, mock_github_provider): mock_github_provider, "test_connection", side_effect=Exception("Access denied"), - ): - with pytest.raises(Exception, match="Access denied"): - await mock_github_provider.test_connection() + ), pytest.raises(Exception, match="Access denied"): + await mock_github_provider.test_connection() def test_repository_access_timeout_handling( self, mock_github_provider: str diff --git a/tests/source_control/test_repository_permission_tests.py b/tests/source_control/test_repository_permission_tests.py index 9ab566c..a19c3f5 100644 --- a/tests/source_control/test_repository_permission_tests.py +++ b/tests/source_control/test_repository_permission_tests.py @@ -545,12 +545,11 @@ def test_credential_validation_for_permissions( assert len(token) > 0 # Test invalid credentials - with patch.dict("os.environ", {}, clear=True): - with pytest.raises( - ValueError, - match="At least one authentication method must be provided", - ): - CredentialConfig() + with patch.dict("os.environ", {}, clear=True), pytest.raises( + ValueError, + match="At least one authentication method must be provided", + ): + CredentialConfig() # Test expired credentials with patch.dict("os.environ", {"GITHUB_TOKEN": "expired_token"}): diff --git a/tests/source_control/test_setup.py b/tests/source_control/test_setup.py index 383b80f..cf377d6 100644 --- a/tests/source_control/test_setup.py +++ b/tests/source_control/test_setup.py @@ -107,25 +107,24 @@ async def test_setup_credential_manager_configuration(self, mock_config): @pytest.mark.asyncio async def test_setup_provider_registration(self, mock_config): """Test that providers are properly registered.""" - with patch("argus.source_control.setup.CredentialManager"): - with patch( - "argus.source_control.setup.ProviderFactory" - ) as mock_factory_class: - mock_factory = MagicMock() - mock_factory_class.return_value = mock_factory + with patch("argus.source_control.setup.CredentialManager"), patch( + "argus.source_control.setup.ProviderFactory" + ) as mock_factory_class: + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory - with patch( - "argus.source_control.setup.RepositoryManager" - ) as mock_repo_manager_class: - mock_repo_manager = AsyncMock() - mock_repo_manager_class.return_value = mock_repo_manager - await setup_repository_system(mock_config) - - # Verify providers were registered - assert mock_factory.register_provider.call_count == 1 - mock_factory.register_provider.assert_any_call( - "github", mock_factory.register_provider.call_args_list[0][0][1] - ) + with patch( + "argus.source_control.setup.RepositoryManager" + ) as mock_repo_manager_class: + mock_repo_manager = AsyncMock() + mock_repo_manager_class.return_value = mock_repo_manager + await setup_repository_system(mock_config) + + # Verify providers were registered + assert mock_factory.register_provider.call_count == 1 + mock_factory.register_provider.assert_any_call( + "github", mock_factory.register_provider.call_args_list[0][0][1] + ) class TestCreateDefaultConfig: diff --git a/tests/templates/test_core_template.py b/tests/templates/test_core_template.py index 116401b..b4e8ddd 100644 --- a/tests/templates/test_core_template.py +++ b/tests/templates/test_core_template.py @@ -14,8 +14,8 @@ """Core module tests.""" + from argus.core.exceptions import * -from typing import Optional from argus.core.interfaces import * from argus.core.types import * diff --git a/tests/test_base_prompt_template.py b/tests/test_base_prompt_template.py index 39f7067..f9911ef 100644 --- a/tests/test_base_prompt_template.py +++ b/tests/test_base_prompt_template.py @@ -164,9 +164,8 @@ def test_generate_prompt_error_handling(self) -> None: # Mock _get_context_variables to raise an exception with patch.object( template, "_get_context_variables", side_effect=Exception("Test error") - ): - with pytest.raises(ValueError, match="Failed to generate prompt"): - template.generate_prompt(Mock()) + ), pytest.raises(ValueError, match="Failed to generate prompt"): + template.generate_prompt(Mock()) def test_validate_context_success(self) -> None: """Test successful context validation.""" diff --git a/tests/test_capability_integration.py b/tests/test_capability_integration.py index e3274a1..05bf8e8 100644 --- a/tests/test_capability_integration.py +++ b/tests/test_capability_integration.py @@ -19,7 +19,7 @@ configuration loading, capability discovery, caching, and task-based selection. """ -from typing import Any, Dict +from typing import Any from unittest.mock import Mock, patch import pytest @@ -33,7 +33,7 @@ class MockLLMProvider(LLMProvider): """Mock LLM provider for testing.""" - def __init__(self, name: str, models: Dict[str, Any]) -> None: + def __init__(self, name: str, models: dict[str, Any]) -> None: self.name = name self.models = models self.config = Mock() diff --git a/tests/test_enhanced_analysis_agent.py b/tests/test_enhanced_analysis_agent.py index a0a1577..e2f2aca 100644 --- a/tests/test_enhanced_analysis_agent.py +++ b/tests/test_enhanced_analysis_agent.py @@ -85,31 +85,27 @@ def config(self) -> None: @pytest.fixture def agent(self, config: str) -> None: """Create test agent with mocked dependencies.""" - with patch("argus.ml.enhanced_analysis_agent.GenerativeModel"): - with patch( - "argus.ml.enhanced_analysis_agent.AdaptivePromptStrategy" - ): - with patch( - "argus.ml.enhanced_analysis_agent.MetaPromptGenerator" - ): - return EnhancedAnalysisAgent(config) + with patch("argus.ml.enhanced_analysis_agent.GenerativeModel"), patch( + "argus.ml.enhanced_analysis_agent.AdaptivePromptStrategy" + ), patch( + "argus.ml.enhanced_analysis_agent.MetaPromptGenerator" + ): + return EnhancedAnalysisAgent(config) def test_agent_initialization(self, config: str) -> None: """Test agent initialization.""" - with patch("argus.ml.enhanced_analysis_agent.GenerativeModel"): - with patch( - "argus.ml.enhanced_analysis_agent.AdaptivePromptStrategy" - ): - with patch( - "argus.ml.enhanced_analysis_agent.MetaPromptGenerator" - ): - agent = EnhancedAnalysisAgent(config) - - assert agent.config == config - assert agent.main_model is not None - assert agent.meta_model is not None - assert agent.adaptive_strategy is not None - assert agent.meta_prompt_generator is not None + with patch("argus.ml.enhanced_analysis_agent.GenerativeModel"), patch( + "argus.ml.enhanced_analysis_agent.AdaptivePromptStrategy" + ), patch( + "argus.ml.enhanced_analysis_agent.MetaPromptGenerator" + ): + agent = EnhancedAnalysisAgent(config) + + assert agent.config == config + assert agent.main_model is not None + assert agent.meta_model is not None + assert agent.adaptive_strategy is not None + assert agent.meta_prompt_generator is not None def test_classify_issue_type_database(self, agent: str) -> None: """Test issue type classification for database errors."""