diff --git a/.gitignore b/.gitignore index a132e287..abb66c31 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ *.swp *.bak real-values.yaml - +.cursor # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/AGENT.md b/AGENT.md new file mode 100644 index 00000000..3101a8f2 --- /dev/null +++ b/AGENT.md @@ -0,0 +1,64 @@ +# Agent Constraints + +**Read this file before making any changes.** These constraints apply to all AI-assisted work in this repository. This file is the permanent **constraints** layer; task-specific **Goal / Constraints / Output Format / Failure Conditions** belong in copies of the templates under `[docs/specs/templates/](docs/specs/templates/)`. + +## Project Identity + +This repository is the **chutes.ai** platform **API and validator** code: HTTP API, related services, Docker/deployment assets, and integration with the broader Chutes ecosystem (see [README.md](README.md)). + +## Stack (Non-Negotiable) + +- **Language**: Python `>=3.10,<3.13` (3.12 is typical for local work) +- **Package manager / runs**: **[uv](https://github.com/astral-sh/uv)** — install with `uv sync`; run tools with `uv run …` (e.g. `uv run pytest`) +- **Build**: **hatchling** ([pyproject.toml](pyproject.toml)); installable packages: `**api`**, `**metasync**` +- **HTTP API**: **FastAPI** + **Uvicorn**, default **ORJSONResponse** ([api/main.py](api/main.py)) +- **Data**: **SQLAlchemy 2.x** + **asyncpg**, **Pydantic** / **pydantic-settings**; SQL migrations under `[api/migrations/](api/migrations/)` (timestamped `.sql`) +- **Ops / deps** (non-exhaustive): **Redis**, **loguru**, **httpx** / **aiohttp**, **Socket.IO** client, **aioboto3**, domain packages (`chutes`, Bittensor-related libs, attestation tooling as pulled in by pyproject) + +Do not replace this stack with alternate frameworks or ORMs unless explicitly agreed. Do not introduce extra dependencies without discussion. + +## Hard Rules + +- **Never add a new dependency** without explicit approval +- **Configuration**: use `**api.config.settings`** (pydantic-settings) and environment variables — **no hardcoded secrets**, connection strings, or API keys +- **Database schema changes**: add a new file under `[api/migrations/](api/migrations/)` **and** keep ORM models in `[api/database/orms.py](api/database/orms.py)` in sync; describe the migration plan in PRs/specs +- **Lint/format**: **Ruff** only — `make lint` and `make reformat` ([makefiles/lint.mk](makefiles/lint.mk)); there is **no enforced coverage percentage** in CI — still add or update tests when behavior changes +- `**nv-attest/`**: excluded from Ruff in pyproject — do not “fix” it via repo-wide lint refactors unless scoped to that subtree +- **Crypto / attestation-sensitive code**: follow existing patterns in the relevant modules (e.g. server/attestation paths); never hardcode keys or measurements + +## Patterns + +- **Async-first** in request paths: `**async def`** handlers, **async** SQLAlchemy sessions (`**get_session`** and related helpers in `[api/database/](api/database/)`); avoid blocking I/O in handlers +- **Domain layout**: routes in `**api//router.py`**, shared logic often in `**api//util.py**` (match neighboring domains) +- **Models and settings**: ORM in `**api/database/orms.py`**; app settings via `**api.config**` +- **Tests**: under `**tests/unit/`** and `**tests/integration/**`; use `**uv run pytest**`. **Match test style** to the file you edit (this repo uses both plain `**def test_*`** and `**class Test***` — stay consistent with surrounding tests) +- **Naming and structure**: follow existing modules in the same package; prefer small, focused changes over large unsolicited refactors + +## Architecture Overview + + +| Area | Purpose | +| ---------------------------------------------------- | ------------------------------------------------------------------------------------------------------------- | +| **[api/main.py](api/main.py)** | Main FastAPI app: lifespan, mounts domain routers (users, chutes, instances, invocations, payments, miner, …) | +| `**api//`** | Feature modules: `router.py`, `util.py`, and related schemas/helpers per domain | +| **[api/database/](api/database/)** | Async engine/session helpers; **[api/database/orms.py](api/database/orms.py)** ORM models | +| **[api/migrations/](api/migrations/)** | Ordered SQL migrations consumed at startup (see `lifespan` / `tasks.py`) | +| **[metasync/](metasync/)** | Metagraph / sync utilities (separate package in the same repo) | +| `**api/payment/`**, `**api/socket_server.py**`, etc. | Auxiliary services or entrypoints alongside the main app — follow local patterns when touching them | +| **[nv-attest/](nv-attest/)** | Attestation-related subtree (own tooling; Ruff-excluded at repo root) | +| `**docker/`**, **[dev/dev.md](dev/dev.md)** | Local Docker Compose and dev bootstrap | + + +## Development Commands + +```bash +uv sync --extra dev # Install project + dev dependencies (from repo root) +uv run pytest # Run tests (add paths or -k as needed) +make lint # Ruff check + format check +make reformat # Ruff format (line length per makefile) +make infrastructure # Docker compose up for test infra (see makefiles/development.mk) +``` + +Local full stack: see **[dev/dev.md](dev/dev.md)** (Docker network, `docker compose`, optional GPU compose files). + +**Session handshake**: Before starting substantive work, confirm you have read this file and will follow it; for non-trivial features, bugfixes, or refactors, consider filling a copy of the appropriate template under `[docs/specs/templates/](docs/specs/templates/)`. \ No newline at end of file diff --git a/api/instance/util.py b/api/instance/util.py index 333ed546..ff40edac 100644 --- a/api/instance/util.py +++ b/api/instance/util.py @@ -31,7 +31,7 @@ from api.config import settings from api.job.schemas import Job from api.database import get_session -from api.util import has_legacy_private_billing, notify_deleted, semcomp +from api.util import has_legacy_private_billing, notify_deleted, notify_job_deleted, semcomp from api.user.service import chutes_user_id from api.bounty.util import create_bounty_if_not_exists, get_bounty_amount, send_bounty_notification from sqlalchemy.future import select @@ -43,7 +43,7 @@ from api.server.client import TeeServerClient from api.server.schemas import Server from api.server.exceptions import GetEvidenceError -from api.server.service import verify_quote, verify_gpu_evidence +from api.server.util import verify_quote, verify_gpu_evidence from api.server.util import get_public_key_hash # Define an alias for the Instance model to use in a subquery @@ -1191,3 +1191,48 @@ async def is_instance_in_thrash_penalty( instance_created_at = instance_created_at.replace(tzinfo=None) return await is_thrashing_miner(db, miner_hotkey, chute_id, instance_created_at) + + +async def purge(target, reason, valid_termination=False): + """Delete an instance from the database and clean up associated state.""" + async with get_session() as session: + await session.execute( + text("DELETE FROM instances WHERE instance_id = :instance_id"), + {"instance_id": target.instance_id}, + ) + await session.execute( + text( + "UPDATE instance_audit SET deletion_reason = :reason, valid_termination = :valid_termination WHERE instance_id = :instance_id" + ), + { + "instance_id": target.instance_id, + "reason": reason, + "valid_termination": valid_termination, + }, + ) + + job = ( + (await session.execute(select(Job).where(Job.instance_id == target.instance_id))) + .unique() + .scalar_one_or_none() + ) + if job and not job.finished_at: + job.status = "error" + job.error_detail = f"Instance failed monitoring probes: {reason=}" + job.miner_terminated = True + job.finished_at = func.now() + await notify_job_deleted(job) + + await session.commit() + + await cleanup_instance_conn_tracking(target.chute_id, target.instance_id) + + +async def purge_and_notify(target, reason, valid_termination=False): + """Purge an instance and broadcast a deletion notification.""" + await purge(target, reason=reason, valid_termination=valid_termination) + await notify_deleted( + target, + message=f"Instance {target.instance_id} of miner {target.miner_hotkey} deleted: {reason}", + ) + await invalidate_instance_cache(target.chute_id, instance_id=target.instance_id) diff --git a/api/migrations/20260403120000_server_maintenance.sql b/api/migrations/20260403120000_server_maintenance.sql new file mode 100644 index 00000000..364f3919 --- /dev/null +++ b/api/migrations/20260403120000_server_maintenance.sql @@ -0,0 +1,49 @@ +-- migrate:up + +CREATE TABLE IF NOT EXISTS tee_upgrade_windows ( + id VARCHAR PRIMARY KEY, + upgrade_window_start TIMESTAMPTZ NOT NULL, + upgrade_window_end TIMESTAMPTZ NOT NULL, + target_measurement_version TEXT NOT NULL, + max_concurrent_per_miner INTEGER NOT NULL DEFAULT 1, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT uq_tee_upgrade_target UNIQUE (target_measurement_version), + CONSTRAINT chk_window_bounds CHECK (upgrade_window_end > upgrade_window_start) +); + +CREATE INDEX IF NOT EXISTS idx_tee_upgrade_window_bounds + ON tee_upgrade_windows (upgrade_window_start, upgrade_window_end); + +ALTER TABLE servers + ADD COLUMN IF NOT EXISTS maintenance_pending_window_id VARCHAR + REFERENCES tee_upgrade_windows(id) ON DELETE SET NULL; + +ALTER TABLE servers + ADD COLUMN IF NOT EXISTS version TEXT; + +CREATE INDEX IF NOT EXISTS idx_servers_maintenance_pending + ON servers (miner_hotkey) + WHERE maintenance_pending_window_id IS NOT NULL; + +ALTER TABLE boot_attestations + ADD COLUMN IF NOT EXISTS miner_hotkey VARCHAR; + +ALTER TABLE boot_attestations + ADD COLUMN IF NOT EXISTS vm_name VARCHAR; + +CREATE INDEX IF NOT EXISTS idx_boot_miner_vm + ON boot_attestations (miner_hotkey, vm_name); + +-- migrate:down + +DROP INDEX IF EXISTS idx_boot_miner_vm; + +ALTER TABLE boot_attestations DROP COLUMN IF EXISTS vm_name; +ALTER TABLE boot_attestations DROP COLUMN IF EXISTS miner_hotkey; + +DROP INDEX IF EXISTS idx_servers_maintenance_pending; + +ALTER TABLE servers DROP COLUMN IF EXISTS version; +ALTER TABLE servers DROP COLUMN IF EXISTS maintenance_pending_window_id; + +DROP TABLE IF EXISTS tee_upgrade_windows; diff --git a/api/miner/schemas.py b/api/miner/schemas.py index dc98478d..81037f9d 100644 --- a/api/miner/schemas.py +++ b/api/miner/schemas.py @@ -24,6 +24,8 @@ class MinerServer(BaseModel): name: str ip: str is_tee: bool + version: str | None = None + maintenance_pending: bool = False created_at: str | None = None updated_at: str | None = None gpus: list[MinerServerGpu] = Field(default_factory=list) @@ -38,6 +40,8 @@ def from_server(cls, server: Server) -> "MinerServer": name=server.name, ip=server.ip, is_tee=server.is_tee, + version=server.version, + maintenance_pending=server.maintenance_pending_window_id is not None, created_at=server.created_at.isoformat() if server.created_at else None, updated_at=server.updated_at.isoformat() if server.updated_at else None, gpus=[ diff --git a/api/server/router.py b/api/server/router.py index ceba57b1..f961be7b 100644 --- a/api/server/router.py +++ b/api/server/router.py @@ -25,6 +25,12 @@ BootAttestationResponse, RuntimeAttestationResponse, LuksPassphraseRequest, + PreflightResult, + ConfirmMaintenanceResult, + MaintenancePolicyResponse, + PendingServerInfo, + TeeUpgradeWindow, + UpgradeWindowInfo, ) from api.server.service import ( create_nonce, @@ -38,6 +44,10 @@ delete_server, validate_request_nonce, process_luks_passphrase_request, + get_active_upgrade_window, + preflight_maintenance, + confirm_maintenance, + _count_active_maintenance_slots, ) from api.server.util import ( decrypt_passphrase, @@ -332,6 +342,55 @@ async def create_server( ) +@router.get("/maintenance/policy", response_model=MaintenancePolicyResponse) +async def get_maintenance_policy( + db: AsyncSession = Depends(get_db_session), + hotkey: str | None = Header(None, alias=HOTKEY_HEADER), + _: User = Depends( + get_current_user(purpose="tee", raise_not_found=False, registered_to=settings.netuid) + ), +): + """Return the active upgrade window, concurrency limits, and the miner's pending servers.""" + if not hotkey: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Hotkey header required") + + active_window = await get_active_upgrade_window(db) + window_info: UpgradeWindowInfo | None = None + current_slots = 0 + pending_servers: list[PendingServerInfo] = [] + + if active_window is not None: + window_info = UpgradeWindowInfo( + id=active_window.id, + target_measurement_version=active_window.target_measurement_version, + upgrade_window_start=str(active_window.upgrade_window_start), + upgrade_window_end=str(active_window.upgrade_window_end), + max_concurrent_per_miner=active_window.max_concurrent_per_miner, + ) + current_slots = await _count_active_maintenance_slots(db, hotkey, active_window) + + query = select(Server).where( + Server.miner_hotkey == hotkey, + Server.maintenance_pending_window_id.isnot(None), + ) + result = await db.execute(query) + for srv in result.scalars().all(): + pending_servers.append( + PendingServerInfo( + server_id=srv.server_id, + name=srv.name, + version=srv.version, + target_version=active_window.target_measurement_version, + ) + ) + + return MaintenancePolicyResponse( + active_window=window_info, + current_slots=current_slots, + pending_servers=pending_servers, + ) + + @router.patch("/{server_id}", response_model=Dict[str, Any]) async def patch_server_name( server_id: str, @@ -386,13 +445,21 @@ async def get_server_details( try: server = await check_server_ownership(db, server_id, hotkey) - return { + response: dict = { "server_id": server.server_id, "name": server.name, "ip": server.ip, + "version": server.version, + "maintenance_pending_window_id": server.maintenance_pending_window_id, "created_at": server.created_at.isoformat(), "updated_at": server.updated_at.isoformat() if server.updated_at else None, } + if server.maintenance_pending_window_id is not None: + window = await db.get(TeeUpgradeWindow, server.maintenance_pending_window_id) + if window is not None: + response["target_version"] = window.target_measurement_version + + return response except ServerNotFoundError as e: raise e @@ -405,6 +472,34 @@ async def get_server_details( ) +@router.get("/{server_name_or_id}/maintenance/preflight", response_model=PreflightResult) +async def get_maintenance_preflight( + server_name_or_id: str, + db: AsyncSession = Depends(get_db_session), + hotkey: str | None = Header(None, alias=HOTKEY_HEADER), + _: User = Depends( + get_current_user(purpose="tee", raise_not_found=False, registered_to=settings.netuid) + ), +): + """Check maintenance eligibility for a server without entering maintenance.""" + server = await get_server_by_name_or_id(db, hotkey, server_name_or_id) + return await preflight_maintenance(db, server, hotkey) + + +@router.put("/{server_name_or_id}/maintenance", response_model=ConfirmMaintenanceResult) +async def put_confirm_maintenance( + server_name_or_id: str, + db: AsyncSession = Depends(get_db_session), + hotkey: str | None = Header(None, alias=HOTKEY_HEADER), + _: User = Depends( + get_current_user(purpose="tee", raise_not_found=False, registered_to=settings.netuid) + ), +): + """Enter maintenance: purge instances and mark server for upgrade.""" + server = await get_server_by_name_or_id(db, hotkey, server_name_or_id) + return await confirm_maintenance(db, server, hotkey) + + @router.delete("/{server_name_or_id}", response_model=Dict[str, str]) async def remove_server( server_name_or_id: str, diff --git a/api/server/schemas.py b/api/server/schemas.py index ff47d684..5dfe7cfe 100644 --- a/api/server/schemas.py +++ b/api/server/schemas.py @@ -11,10 +11,12 @@ String, DateTime, Boolean, + CheckConstraint, ForeignKey, Text, Index, ForeignKeyConstraint, + UniqueConstraint, ) from sqlalchemy.dialects.postgresql import JSONB from typing import Dict, Any, List, Optional @@ -117,6 +119,70 @@ class TeeChuteEvidence(BaseModel): ) +class MaintenanceReason(BaseModel): + """A single reason why maintenance eligibility was denied.""" + + reason: str + current_version: Optional[str] = None + target_version: Optional[str] = None + window_id: Optional[str] = None + current_slots: Optional[int] = None + limit: Optional[int] = None + blocking: Optional[List[dict]] = None + + +class SoleSurvivorBlock(BaseModel): + """An instance that is the sole active instance for its chute.""" + + chute_id: str + instance_id: str + + +class PreflightResult(BaseModel): + """Result of a maintenance preflight eligibility check.""" + + eligible: bool + denial_reasons: List[MaintenanceReason] = Field(default_factory=list) + blocking_chute_ids: List[SoleSurvivorBlock] = Field(default_factory=list) + current_slots: int = 0 + limit: int = 1 + + +class UpgradeWindowInfo(BaseModel): + """Summary of an upgrade window for API responses.""" + + id: str + target_measurement_version: str + upgrade_window_start: str + upgrade_window_end: str + max_concurrent_per_miner: int = 1 + + +class ConfirmMaintenanceResult(BaseModel): + """Result of confirming maintenance on a server.""" + + server_id: str + purged_instance_ids: List[str] = Field(default_factory=list) + window: UpgradeWindowInfo + + +class PendingServerInfo(BaseModel): + """A server with pending maintenance, shown in the policy response.""" + + server_id: str + name: Optional[str] = None + version: Optional[str] = None + target_version: str + + +class MaintenancePolicyResponse(BaseModel): + """Response for GET /servers/maintenance/policy.""" + + active_window: Optional[UpgradeWindowInfo] = None + current_slots: int = 0 + pending_servers: List[PendingServerInfo] = Field(default_factory=list) + + class BootAttestation(Base): """Track anonymous boot attestations (pre-registration).""" @@ -125,6 +191,8 @@ class BootAttestation(Base): attestation_id = Column(String, primary_key=True, default=generate_uuid) quote_data = Column(Text, nullable=False) # Base64 encoded quote server_ip = Column(String, nullable=True) # For later linking to server + miner_hotkey = Column(String, nullable=True) + vm_name = Column(String, nullable=True) verification_error = Column(String, nullable=True) measurement_version = Column( String, nullable=True @@ -136,6 +204,32 @@ class BootAttestation(Base): Index("idx_boot_server_id", "server_ip"), Index("idx_boot_created", "created_at"), Index("idx_boot_verified", "verified_at"), + Index("idx_boot_miner_vm", "miner_hotkey", "vm_name"), + ) + + +class TeeUpgradeWindow(Base): + """Validator-managed maintenance window: one row per coordinated TEE image cutover.""" + + __tablename__ = "tee_upgrade_windows" + + id = Column(String, primary_key=True, default=generate_uuid) + upgrade_window_start = Column(DateTime(timezone=True), nullable=False) + upgrade_window_end = Column(DateTime(timezone=True), nullable=False) + target_measurement_version = Column(Text, nullable=False) + max_concurrent_per_miner = Column(Integer, nullable=False, default=1, server_default="1") + created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + + pending_servers = relationship( + "Server", + back_populates="pending_upgrade_window", + foreign_keys="Server.maintenance_pending_window_id", + ) + + __table_args__ = ( + UniqueConstraint("target_measurement_version", name="uq_tee_upgrade_target"), + CheckConstraint("upgrade_window_end > upgrade_window_start", name="chk_window_bounds"), + Index("idx_tee_upgrade_window_bounds", "upgrade_window_start", "upgrade_window_end"), ) @@ -156,16 +250,35 @@ class Server(Base): is_tee = Column(Boolean, default=False, server_default="false") + # Maintenance: set at confirm, cleared on successful boot completion or lazily when window closes. + maintenance_pending_window_id = Column( + String, + ForeignKey("tee_upgrade_windows.id", ondelete="SET NULL"), + nullable=True, + ) + # Current attested measurement version, updated on every successful boot attestation. + version = Column(Text, nullable=True) + # Relationships nodes = relationship("Node", back_populates="server", cascade="all, delete-orphan") runtime_attestations = relationship( "ServerAttestation", back_populates="server", cascade="all, delete-orphan" ) miner = relationship("MetagraphNode", back_populates="servers") + pending_upgrade_window = relationship( + "TeeUpgradeWindow", + back_populates="pending_servers", + foreign_keys=[maintenance_pending_window_id], + ) __table_args__ = ( Index("idx_server_miner", "miner_hotkey"), Index("idx_servers_miner_name", "miner_hotkey", "name", unique=True), + Index( + "idx_servers_maintenance_pending", + "miner_hotkey", + postgresql_where=maintenance_pending_window_id.isnot(None), + ), ForeignKeyConstraint( ["netuid", "miner_hotkey"], ["metagraph_nodes.netuid", "metagraph_nodes.hotkey"] ), diff --git a/api/server/service.py b/api/server/service.py index ed13873f..9fc7302a 100644 --- a/api/server/service.py +++ b/api/server/service.py @@ -2,11 +2,9 @@ Core server management and TDX attestation logic. """ -import asyncio import pybase64 as base64 from datetime import datetime, timezone, timedelta import json -import tempfile from typing import Dict, Any, Optional from fastapi import HTTPException, Header, Request, status from loguru import logger @@ -19,7 +17,7 @@ from api.gpu import SUPPORTED_GPUS from api.node.util import _track_nodes from api.server.client import TeeServerClient -from api.server.quote import BootTdxQuote, RuntimeTdxQuote, TdxQuote, TdxVerificationResult +from api.server.quote import BootTdxQuote, RuntimeTdxQuote, TdxQuote from api.server.schemas import ( Server, ServerAttestation, @@ -27,12 +25,17 @@ BootAttestationArgs, RuntimeAttestationArgs, ServerArgs, + TeeUpgradeWindow, + MaintenanceReason, + SoleSurvivorBlock, + PreflightResult, + UpgradeWindowInfo, + ConfirmMaintenanceResult, ) from api.server.exceptions import ( AttestationError, GetEvidenceError, GpuEvidenceError, - InvalidClientCertError, InvalidGpuEvidenceError, InvalidQuoteError, MeasurementMismatchError, @@ -44,19 +47,18 @@ ) from api.server.util import ( _track_server, - extract_report_data, - verify_measurements, get_matching_measurement_config, generate_nonce, get_nonce_expiry_seconds, - verify_quote_signature, - verify_result, + verify_quote, + verify_gpu_evidence, sync_server_luks_passphrases, get_public_key_hash, cert_to_base64_der, validate_user_nonce, ) -from api.instance.schemas import Instance +from api.instance.schemas import Instance, instance_nodes +from api.instance.util import purge_and_notify from api.chute.schemas import Chute from api.node.schemas import Node from sqlalchemy.orm import joinedload @@ -179,29 +181,6 @@ async def _validate_request_nonce( return _validate_request_nonce -async def verify_quote( - quote: TdxQuote, expected_nonce: str, expected_cert_hash: str -) -> TdxVerificationResult: - # Validate nonce - nonce, cert_hash = extract_report_data(quote) - - if nonce != expected_nonce: - logger.info(f"Nonce error: {nonce} =/= {expected_nonce}") - raise NonceError("Quote nonce does not match expected nonce.") - - if cert_hash != expected_cert_hash: - raise InvalidClientCertError() - - # Verify the quote using DCAP - result = await verify_quote_signature(quote) - # Verify the quote against the result to ensure it was parsed properly - verify_result(quote, result) - # Verify the quote against configured MRTD/RMTRs - verify_measurements(quote) - - return result - - def validate_gpus_for_measurements(quote: TdxQuote, gpus: list[NodeArgs]) -> None: """ Validate that the provided GPUs match the expected GPUs for this measurement configuration. @@ -243,31 +222,6 @@ def validate_gpus_for_measurements(quote: TdxQuote, gpus: list[NodeArgs]) -> Non ) -async def verify_gpu_evidence(evidence: list[Dict[str, str]], expected_nonce: str) -> None: - try: - with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as fp: - json.dump(evidence, fp) - fp.flush() - - verify_gpus_cmd = ["chutes-nvattest", "--nonce", expected_nonce, "--evidence", fp.name] - - process = await asyncio.create_subprocess_exec(*verify_gpus_cmd) - - await asyncio.gather(process.wait()) - - if process.returncode != 0: - raise InvalidGpuEvidenceError() - - logger.info("GPU evidence verified successfully.") - - except FileNotFoundError as e: - logger.error(f"Failed to verify GPU evidence. chutes-nvattest command not found?:\n{e}") - raise GpuEvidenceError("Failed to verify GPU evidence.") - except Exception as e: - logger.error(f"Unexepected exception encoutnered verifying GPU evidence:\n{e}") - raise GpuEvidenceError("Encountered an unexpected exception verifying GPU evidence.") - - async def generate_and_store_boot_token(miner_hotkey: str, vm_name: str) -> str: """ Generate and store a boot token for a verified VM. @@ -328,6 +282,8 @@ async def process_boot_attestation( boot_attestation = BootAttestation( quote_data=args.quote, server_ip=server_ip, + miner_hotkey=args.miner_hotkey, + vm_name=args.vm_name, measurement_version=measurement_config.version, created_at=func.now(), verified_at=func.now(), @@ -339,6 +295,10 @@ async def process_boot_attestation( logger.success(f"Boot attestation successful: {boot_attestation.attestation_id}") + await _handle_boot_version_update( + db, args.miner_hotkey, args.vm_name, measurement_config.version + ) + # Generate boot token for this verified VM boot_token = await generate_and_store_boot_token(args.miner_hotkey, args.vm_name) @@ -360,6 +320,8 @@ async def process_boot_attestation( boot_attestation = BootAttestation( quote_data=args.quote, server_ip=server_ip, + miner_hotkey=args.miner_hotkey, + vm_name=args.vm_name, verification_error=str(e.detail), measurement_version=measurement_version, created_at=func.now(), @@ -372,6 +334,43 @@ async def process_boot_attestation( raise +async def _handle_boot_version_update( + db: AsyncSession, miner_hotkey: str, vm_name: str, measurement_version: str +) -> None: + """Update server.version on every successful boot; clear maintenance slot if target met.""" + try: + server = await get_server_by_name(db, miner_hotkey, vm_name) + except ServerNotFoundError: + return + + server.version = measurement_version + + if server.maintenance_pending_window_id is not None: + window = await db.get(TeeUpgradeWindow, server.maintenance_pending_window_id) + if ( + window is not None + and semcomp(measurement_version, window.target_measurement_version) >= 0 + ): + logger.info( + f"Maintenance complete for server {server.server_id}: " + f"version {measurement_version} meets target {window.target_measurement_version}" + ) + server.maintenance_pending_window_id = None + elif window is not None: + logger.warning( + f"Boot attestation for server {server.server_id} has version {measurement_version} " + f"but target is {window.target_measurement_version}; maintenance not complete" + ) + else: + logger.warning( + f"Server {server.server_id} has stale maintenance_pending_window_id " + f"pointing to missing window; clearing" + ) + server.maintenance_pending_window_id = None + + await db.commit() + + async def register_server(db: AsyncSession, args: ServerArgs, miner_hotkey: str): """ Register a TEE server: create Server, verify attestation (creating a ServerAttestation @@ -390,7 +389,11 @@ async def register_server(db: AsyncSession, args: ServerArgs, miner_hotkey: str) setattr(gpu, key, gpu_info.get(key)) # Start verification process (pass GPUs for validation) - await verify_server(db, server, miner_hotkey, gpus=args.gpus) + measurement_version = await verify_server(db, server, miner_hotkey, gpus=args.gpus) + + if measurement_version is not None: + server.version = measurement_version + await db.commit() # Track nodes once verified await _track_nodes(db, miner_hotkey, server.server_id, args.gpus, "0", func.now()) @@ -433,15 +436,11 @@ async def register_server(db: AsyncSession, args: ServerArgs, miner_hotkey: str) async def verify_server( db: AsyncSession, server: Server, miner_hotkey: str, gpus: list[NodeArgs] -) -> None: +) -> Optional[str]: """ Verify server attestation and validate GPUs match measurement configuration. - Args: - db: Database session - server: Server to verify - miner_hotkey: Miner hotkey - gpus: List of GPUs to validate against measurement configuration + Returns the measurement_version string on success, None on failure. """ failure_reason = "" quote = None @@ -484,6 +483,8 @@ async def verify_server( await db.commit() await db.refresh(server_attestation) + return measurement_config.version + except GetEvidenceError as e: failure_reason = "Failed to get attestation evidence." logger.error( @@ -1028,3 +1029,216 @@ async def get_chute_instances_evidence( failed_instance_ids.append(instance.instance_id) return (evidence_list, failed_instance_ids) + + +# --------------------------------------------------------------------------- +# TEE Maintenance Window +# --------------------------------------------------------------------------- + + +async def get_active_upgrade_window( + db: AsyncSession, +) -> Optional[TeeUpgradeWindow]: + """Return the active tee_upgrade_windows row (start <= now <= end), or None.""" + now = func.now() + query = ( + select(TeeUpgradeWindow) + .where( + TeeUpgradeWindow.upgrade_window_start <= now, + TeeUpgradeWindow.upgrade_window_end >= now, + ) + .order_by(TeeUpgradeWindow.created_at.desc()) + ) + result = await db.execute(query) + rows = result.scalars().all() + if not rows: + return None + if len(rows) > 1: + logger.warning( + f"Multiple overlapping tee_upgrade_windows rows active ({len(rows)}); " + f"using most recently created: {rows[0].id}" + ) + return rows[0] + + +async def _get_instances_on_server(db: AsyncSession, server_id: str) -> list[Instance]: + """Return all instances hosted on a server via Instance → instance_nodes → Node → Server.""" + query = ( + select(Instance) + .join(instance_nodes, Instance.instance_id == instance_nodes.c.instance_id) + .join(Node, instance_nodes.c.node_id == Node.uuid) + .where(Node.server_id == server_id) + .distinct() + ) + result = await db.execute(query) + return list(result.scalars().all()) + + +async def _find_sole_survivor_chutes( + db: AsyncSession, instances: list[Instance] +) -> list[SoleSurvivorBlock]: + """For each instance, check if it is the only active instance globally for its chute. + + Returns a list of SoleSurvivorBlock for blocking sole survivors. + """ + blocking: list[SoleSurvivorBlock] = [] + seen_chutes: set[str] = set() + for inst in instances: + if inst.chute_id in seen_chutes: + continue + seen_chutes.add(inst.chute_id) + count_query = ( + select(func.count()) + .select_from(Instance) + .where( + Instance.chute_id == inst.chute_id, + Instance.active.is_(True), + Instance.instance_id != inst.instance_id, + ) + ) + result = await db.execute(count_query) + other_active = result.scalar() or 0 + if other_active == 0: + blocking.append(SoleSurvivorBlock(chute_id=inst.chute_id, instance_id=inst.instance_id)) + return blocking + + +async def _count_active_maintenance_slots( + db: AsyncSession, miner_hotkey: str, active_window: TeeUpgradeWindow +) -> int: + """Count servers for this miner that are in maintenance for the given active window.""" + query = ( + select(func.count()) + .select_from(Server) + .where( + Server.miner_hotkey == miner_hotkey, + Server.maintenance_pending_window_id == active_window.id, + ) + ) + result = await db.execute(query) + return result.scalar() or 0 + + +async def preflight_maintenance( + db: AsyncSession, server: Server, miner_hotkey: str +) -> PreflightResult: + """Read-only eligibility check for entering maintenance on a server.""" + denial_reasons: list[MaintenanceReason] = [] + blocking: list[SoleSurvivorBlock] = [] + limit = 1 + current_slots = 0 + active_window: Optional[TeeUpgradeWindow] = None + + if not server.is_tee: + denial_reasons.append(MaintenanceReason(reason="not_tee")) + + if not denial_reasons: + active_window = await get_active_upgrade_window(db) + if active_window is None: + denial_reasons.append(MaintenanceReason(reason="no_active_window")) + + if active_window is not None: + limit = active_window.max_concurrent_per_miner + + if ( + server.version is not None + and semcomp(server.version, active_window.target_measurement_version) >= 0 + ): + denial_reasons.append( + MaintenanceReason( + reason="already_at_target", + current_version=server.version, + target_version=active_window.target_measurement_version, + ) + ) + + if server.maintenance_pending_window_id is not None: + if server.maintenance_pending_window_id == active_window.id: + denial_reasons.append( + MaintenanceReason( + reason="maintenance_pending", + current_version=server.version, + target_version=active_window.target_measurement_version, + window_id=active_window.id, + ) + ) + else: + server.maintenance_pending_window_id = None + + current_slots = await _count_active_maintenance_slots(db, miner_hotkey, active_window) + if current_slots >= limit: + denial_reasons.append( + MaintenanceReason( + reason="concurrency_cap", + current_slots=current_slots, + limit=limit, + ) + ) + + instances = await _get_instances_on_server(db, server.server_id) + blocking = await _find_sole_survivor_chutes(db, instances) + if blocking: + denial_reasons.append( + MaintenanceReason( + reason="sole_survivor", + blocking=[b.model_dump() for b in blocking], + ) + ) + + return PreflightResult( + eligible=len(denial_reasons) == 0, + denial_reasons=denial_reasons, + blocking_chute_ids=blocking, + current_slots=current_slots, + limit=limit, + ) + + +async def confirm_maintenance( + db: AsyncSession, server: Server, miner_hotkey: str +) -> ConfirmMaintenanceResult: + """Enter maintenance: re-validate, set pending window, and auto-purge instances. + + Raises HTTPException (409/403) on failure. + """ + preflight = await preflight_maintenance(db, server, miner_hotkey) + if not preflight.eligible: + reason_codes = {r.reason for r in preflight.denial_reasons} + conflict_reasons = {"sole_survivor", "concurrency_cap", "maintenance_pending"} + if reason_codes & conflict_reasons: + status_code = status.HTTP_409_CONFLICT + else: + status_code = status.HTTP_403_FORBIDDEN + raise HTTPException(status_code=status_code, detail=preflight.model_dump()) + + active_window = await get_active_upgrade_window(db) + server.maintenance_pending_window_id = active_window.id + await db.commit() + await db.refresh(server) + + instances = await _get_instances_on_server(db, server.server_id) + purged_ids: list[str] = [] + for inst in instances: + try: + await purge_and_notify( + inst, + reason="maintenance - server entering TEE upgrade window", + valid_termination=True, + ) + purged_ids.append(inst.instance_id) + except Exception: + logger.error( + f"Failed to purge instance {inst.instance_id} during maintenance", exc_info=True + ) + + return ConfirmMaintenanceResult( + server_id=server.server_id, + purged_instance_ids=purged_ids, + window=UpgradeWindowInfo( + id=active_window.id, + target_measurement_version=active_window.target_measurement_version, + upgrade_window_start=str(active_window.upgrade_window_start), + upgrade_window_end=str(active_window.upgrade_window_end), + max_concurrent_per_miner=active_window.max_concurrent_per_miner, + ), + ) diff --git a/api/server/util.py b/api/server/util.py index d9d98a3f..c6f519eb 100644 --- a/api/server/util.py +++ b/api/server/util.py @@ -2,8 +2,11 @@ TDX quote parsing, crypto operations, and server helper functions. """ +import asyncio import secrets import base64 +import json +import tempfile from typing import Dict, List, Optional from sqlalchemy import select from sqlalchemy.sql import func @@ -20,8 +23,11 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.backends import default_backend from api.server.exceptions import ( - InvalidQuoteError, AttestationError, + GpuEvidenceError, + InvalidClientCertError, + InvalidGpuEvidenceError, + InvalidQuoteError, InvalidSignatureError, InvalidTdxConfiguration, MeasurementMismatchError, @@ -530,3 +536,47 @@ async def _track_server( await db.refresh(server) return server + + +async def verify_quote( + quote: TdxQuote, expected_nonce: str, expected_cert_hash: str +) -> TdxVerificationResult: + nonce, cert_hash = extract_report_data(quote) + + if nonce != expected_nonce: + logger.info(f"Nonce error: {nonce} =/= {expected_nonce}") + raise NonceError("Quote nonce does not match expected nonce.") + + if cert_hash != expected_cert_hash: + raise InvalidClientCertError() + + result = await verify_quote_signature(quote) + verify_result(quote, result) + verify_measurements(quote) + + return result + + +async def verify_gpu_evidence(evidence: list[Dict[str, str]], expected_nonce: str) -> None: + try: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as fp: + json.dump(evidence, fp) + fp.flush() + + verify_gpus_cmd = ["chutes-nvattest", "--nonce", expected_nonce, "--evidence", fp.name] + + process = await asyncio.create_subprocess_exec(*verify_gpus_cmd) + + await asyncio.gather(process.wait()) + + if process.returncode != 0: + raise InvalidGpuEvidenceError() + + logger.info("GPU evidence verified successfully.") + + except FileNotFoundError as e: + logger.error(f"Failed to verify GPU evidence. chutes-nvattest command not found?:\n{e}") + raise GpuEvidenceError("Failed to verify GPU evidence.") + except Exception as e: + logger.error(f"Unexepected exception encoutnered verifying GPU evidence:\n{e}") + raise GpuEvidenceError("Encountered an unexpected exception verifying GPU evidence.") diff --git a/chute_autoscaler.py b/chute_autoscaler.py index 34e7a24f..fb4119af 100644 --- a/chute_autoscaler.py +++ b/chute_autoscaler.py @@ -46,7 +46,7 @@ from api.instance.util import invalidate_instance_cache from api.metrics.util import reconcile_connection_counts from api.capacity_log.schemas import CapacityLog -from watchtower import purge, purge_and_notify # noqa +from api.instance.util import purge, purge_and_notify # noqa from api.constants import ( UNDERUTILIZED_CAP, UTILIZATION_SCALE_UP, @@ -409,7 +409,8 @@ async def instance_cleanup(): ) logger.warning(f" {instance.verified=} {instance.active=}") await purge_and_notify( - instance, reason="Instance failed to verify within a reasonable amount of time" + instance, + reason="autoscaler - instance failed to verify within a reasonable amount of time", ) total += 1 if total: @@ -3107,7 +3108,7 @@ async def execute_downsizing(to_downsize: List[Tuple[str, int, Set[str]]], db_no valid_candidates = [] for inst in active_instances: if len(inst.nodes) != (chute.node_selector.get("gpu_count") or 1): - await purge_and_notify(inst, "Instance node count mismatch") + await purge_and_notify(inst, "autoscaler - instance node count mismatch") num_to_remove -= 1 instances_removed += 1 elif db_now.replace(tzinfo=None) - inst.activated_at.replace( @@ -3140,7 +3141,7 @@ async def execute_downsizing(to_downsize: List[Tuple[str, int, Set[str]]], db_no f"Downscaling {chute_id}: removing {targeted_instance.instance_id} ({targeted_instance.nodes[0].gpu_identifier if targeted_instance.nodes else 'unknown'})" ) await purge_and_notify( - targeted_instance, "Autoscaler adjustment", valid_termination=True + targeted_instance, "autoscaler - downscale adjustment", valid_termination=True ) await invalidate_instance_cache(chute_id, targeted_instance.instance_id) instances_removed += 1 diff --git a/docs/specs/server-maintenance.md b/docs/specs/server-maintenance.md new file mode 100644 index 00000000..f7f59c75 --- /dev/null +++ b/docs/specs/server-maintenance.md @@ -0,0 +1,212 @@ +# Feature Spec: TEE server maintenance window + +Use the sections **Goal**, **Constraints**, **Output Format**, and **Failure Conditions** as a **Prompt Contract** for this task (see [AGENT.md](../../AGENT.md) at repo root). + +**Date**: 2026-04-02 (revised 2026-04-03) +**Status**: draft (refined) + +**Scope note (this iteration):** Implement **maintenance window + preflight/confirm + auto-purge + boot completion** (`Server.version` updated, `maintenance_pending_window_id` cleared on success). **Out of scope:** `scoring_penalty_multiplier`, cron-based penalties for outdated `measurement_version`, changes to `metasync` / `INSTANCES_QUERY`, and deferred `valid_termination` (see **Acknowledged abuse vector** below). + +--- + +## Context + +Miners upgrading a TEE host need to **remove instances** from routing before reboot. Today, **miner-initiated** `delete_instance` can trigger **thrash** (no `valid_termination`) and **last-instance** scoring penalties (`api/instance/router.py`). This spec adds a **validator-controlled global window**, a **preflight** endpoint (see below), and a **confirm** step. On **successful confirm**, the platform **automatically terminates all instances** on that server (validator-driven eviction) using the same **purge** machinery as **watchtower** (`watchtower.purge` / `purge_and_notify`), with **`valid_termination = true`** and a dedicated **`deletion_reason`**, so routing and caches update **proactively**—the miner does not rely on manual deletes to drain the box. **Policy** must align **last-instance** / bounty handling on these purges with the maintenance story (see below). + +- **Packages affected**: `api` (primary) +- **Key files**: + - `api/server/router.py`, `api/server/service.py`, `api/server/schemas.py` — server model, preflight + confirm maintenance, boot attestation **completion** handling (`Server.version` updated, `maintenance_pending_window_id` cleared) + - `api/config.py` (or equivalent settings module) — `tee_maintenance_max_miner_concurrency`; **not** primary store for window bounds / target (those live in **`tee_upgrade_windows`**) + - `watchtower.py` — `purge` / `purge_and_notify` (or factored shared helper): reuse for maintenance-initiated teardown (`valid_termination=True`, notifications, cache invalidation) + - `api/instance/router.py` — **Optional** fallback: if any code path still allows miner `delete_instance` while maintenance is active, keep protected behavior; primary drain is **confirm → auto-purge** + - `api/instance/util.py`, `api/node/schemas.py` — correlate `Instance` → `Node` → `Server` (forward-compatible with **planned** shared-IP TEE; see **TEE addressing: today vs planned** below) + - `api/migrations/*.sql` — new table **`tee_upgrade_windows`**; new columns on **`servers`** (`maintenance_pending_window_id`, `version`) + - `api/constants.py` — thrash constants (read-only context; no change required unless tests need it) +- **Dependencies**: Existing FastAPI, SQLAlchemy async sessions, `get_current_user` / miner auth patterns for server routes. + +### TEE addressing: today vs planned + +- **Today (production TEE):** The platform often has **one public IP per server**, but **IP alone is not sufficient** for maintenance completion once **multiple logical servers** or **IP changes** exist. **`Instance` → `Node` → `Server`** remains the primary path for instance teardown. +- **Planned / shared-IP:** **Multiple TEE servers may share one public IP** (e.g. NAT / shared egress). **Boot attestation must not rely on IP alone** to find the `servers` row for maintenance completion. +- **This spec:** Use **`(miner_hotkey, server name)`** for server identity wherever the API must pick a single **`Server`** row (aligned with **`servers` unique `(miner_hotkey, name)`** in [`api/server/schemas.py`](../../api/server/schemas.py)). +- **Boot attestation params (already present):** [`BootAttestationArgs`](../../api/server/schemas.py) includes **`vm_name`** and **`miner_hotkey`**. **Contract:** **`vm_name` must match `Server.name`** for that miner's registered server (same string the miner used at registration / LUKS linkage). **Maintenance completion** on successful boot must **resolve the `Server` row by `(args.miner_hotkey, args.vm_name)` → `servers.name`** (primary). Optionally **cross-check** request IP against `servers.ip` when present (log or soft-validate); do **not** use IP as the only key when **`vm_name`** is available. +- **First boot / no `Server` row yet:** Boot attestation can run **before** server registration. If **no** row matches **`(miner_hotkey, vm_name)`**, **skip** maintenance completion and `version` update (no-op): there is no **`servers`** row to update, and **preflight/confirm** only apply to **existing** servers anyway. After **`register_server`** (or equivalent), subsequent boots find the row and can clear maintenance if a slot was set. +- **Implementation:** Centralise in **`resolve_server_for_maintenance_boot_completion(db, miner_hotkey, vm_name, request_ip?)`** (or fold into `process_boot_attestation`): lookup by **hotkey + name**, return **`None`** if absent, then run `version` update + maintenance fields only when a **`Server`** exists. + +--- + +## Design Decisions + +- **No human admin (miner-facing):** Allow/deny is **fully automatic** inside the API for miners. **Validators** control windows by **rows in the database** (see **`tee_upgrade_windows`** below)—no miner-facing approve/deny routes. **This iteration:** no HTTP admin API for inserting windows (use migration seed, SQL runbook, or internal tooling); follow-up can add operator routes. +- **Global window = admission window (DB-backed):** The **active** upgrade window is the row in **`tee_upgrade_windows`** for which **`upgrade_window_start <= now() <= upgrade_window_end`** (see **Resolving the active window**). If **no** row qualifies, **deny** new **preflight** / **confirm** (**403** / **404**—document the choice). Servers that **confirmed** before the window closes may **finish** (reboot, boot attestation) **after** `upgrade_window_end`—we do not revoke an in-flight slot. **No new admits after close:** until validators **insert** a new window row (or extend an existing row's end time). Past rows remain for **history and audit**. +- **Why a table instead of env-only:** Environment variables give **no durable history** and encourage **silent overwrites**. A table yields **one row per coordinated target** (`target_measurement_version`), explicit **`upgrade_window_start` / `upgrade_window_end`**, **`created_at`**, and a full **audit trail** of past cutovers. **`GET …/policy`** (and internal checks) **read the active row from the DB**, optionally **cached** (short TTL Redis/in-process) to avoid hitting the DB on every preflight; cache must **invalidate or expire** quickly enough that window changes take effect promptly (or invalidate on write when an admin API exists). +- **Rollout identity = `tee_upgrade_windows` row:** **`target_measurement_version`** on the row is the **logical "which upgrade"** string (normalised once at insert). **`id`** distinguishes rows; enforce **`UNIQUE (target_measurement_version)`** so there is **at most one window definition per target** (adjust if you ever need a rare re-run for the same target—then drop uniqueness and key completion by **`id`** only). +- **Already at or above target (no pointless purge / anti-re-entry):** **Preflight** / **confirm** **deny** if `semcomp(server.version, active_window.target_measurement_version) >= 0`. **Rationale:** Maintenance exists to move hosts **onto** the mandated image; purging when already compliant would only create churn and scoring noise. **Source of truth:** **`Server.version`** — updated on **every successful boot attestation** (see boot completion below). **If `version` is `None`** (no boot recorded yet), **do not** treat as "already at target." This check also **prevents re-entry** after a completed maintenance cycle: once a server boots with `version >= target`, the version check blocks subsequent preflight/confirm for the same window's target. +- **No per-server deadline / grace period:** There is **no** `maintenance_deadline_at` or `grace_hours`. The server stays "in maintenance" (`maintenance_pending_window_id` set) until boot completion succeeds or the window closes. This is simpler and avoids false precision—the window's own `upgrade_window_end` provides the outer bound for the rollout. +- **Per-miner concurrency (validator capacity knob):** A miner may have at most **`tee_maintenance_max_miner_concurrency`** servers in an **active** maintenance state at once (**default `1`**). **Active** means **`maintenance_pending_window_id IS NOT NULL`** and the referenced window is the **currently active** one (stale slots for closed windows are lazily cleared—see below). Before accepting **confirm**, **count** distinct `servers` rows for that `miner_hotkey` meeting that condition; if **count ≥ limit**, return **409** with a clear message (and `current_slots` / `limit` in JSON). An active slot **ends** when **successful boot** clears `maintenance_pending_window_id`, or is **lazily cleared** when the referenced window is no longer active. **No miner `DELETE`** to clear maintenance. **Operator-only** override (if ever needed) is out of scope here. +- **Two-step flow (preflight + confirm):** + - **Preflight:** `GET /servers/{server_id}/maintenance/preflight` — read-only check whether **confirm** would succeed **right now**—same auth as confirm (miner + `server_id`). Returns **`eligible: true`** or **`eligible: false`** with structured reasons (no active DB window, **already at or above active target** via `Server.version`, sole-survivor `chute_id`s, concurrency cap, already **active** slot on this server, not TEE, etc.). When the server has a **pending maintenance** (`maintenance_pending_window_id` set), includes a structured entry with `current_version` and `target_version` so the miner can see the gap. **No DB writes**, no purges (reads **active window** from DB or cache). + - **Confirm:** `PUT /servers/{server_id}/maintenance` — **re-runs** all checks (must match preflight outcome unless state changed between calls), then **commits** maintenance slot and runs **auto-purge**. Idempotent **enter maintenance** semantics are acceptable for **PUT**. If preflight was **eligible** but state changed before confirm, confirm may **409**; clients should **re-preflight**. +- **TEE + ownership:** **Preflight / confirm** only for `Server.is_tee == true` and `server_id` + `HOTKEY_HEADER` passing existing ownership checks (`check_server_ownership`). +- **Sole-survivor rule (fixed policy):** If **any** active instance on that server is the **only** `active` instance globally for its `chute_id`, **`PUT …/maintenance` fails** with **409** and a JSON body listing blocking `{ chute_id, instance_id? }`. **Never** auto-terminate the globally last instance via maintenance: the network keeps that copy until another miner scales up. **Preflight** surfaces the same blocking set. **Auto-bounty on deny** remains **deferred** (follow-up). +- **Instance → server at delete time (forward-compatible):** **Do not** use `instances.host == servers.ip` as the **primary** link. That join is **consistent with today's one-IP-per-server TEE rule** but will become **ambiguous** when **planned** multi-server-per-IP TEE exists. **Primary** resolution: **`Instance` → `instance_nodes` → `Node` → `Server`** (`nodes.server_id` → `servers.server_id`). **The platform does not support** an instance attached to nodes belonging to **more than one** `server_id`; no cross-server merge or precedence rules are required. For logging, policy, and any "logical server" checks, treat **`(servers.miner_hotkey, servers.name)`** as the stable human-facing identity (unique per miner today). +- **Boot completion → `Server` row:** Use **`BootAttestationArgs.miner_hotkey`** + **`BootAttestationArgs.vm_name`** → **`Server`** where **`servers.name = vm_name`** (and **`servers.miner_hotkey`** matches). **If no row:** first-boot / pre-registration path—**no** maintenance completion, **no** `version` update. **Always** set **`Server.version = measurement_version`** from the matched measurement config on successful boot (regardless of maintenance state). If **`maintenance_pending_window_id`** is set and `semcomp(measurement_version, pending_window.target_measurement_version) >= 0`: **clear** `maintenance_pending_window_id = None`. If measurement is **below** target: log warning, leave `maintenance_pending_window_id` set (miner must try again). The updated `version` prevents re-entry via the "already at target" check. +- **Auto-terminate on confirm (primary path):** After **successful** confirm (DB commit sets `maintenance_pending_window_id`), **enumerate all instances** on that server via **Instance → Node → Server** (not host/IP as primary), then for each instance invoke shared **purge** logic (as watchtower does): delete `instances` row, update `instance_audit` with `valid_termination = true`, `deletion_reason` e.g. `tee maintenance`, fail jobs, `notify_deleted`, `invalidate_instance_cache`, etc. **Order:** persist slot **before** purges so concurrent logic can see maintenance. **Performance:** many instances may require **sequential async** purges or a **background task**; document whether confirm **HTTP** waits for all purges to finish or returns after scheduling (prefer **wait** for small N, **task** for large N with idempotent retry on failure). +- **Last-instance / bounty on auto-purge:** Because **confirm** is **denied** when a sole survivor exists on the server, the maintenance purge batch **must not** include a globally last instance for any `chute_id`. For each instance purged in this flow, use **`valid_termination = true`**, maintenance **`deletion_reason`**, and shared **`purge`** machinery; last-instance bounty / multiplier slash **does not apply** to these rows (they are not last-global by construction). Implement via a shared helper or **`purge()`** flags—avoid duplicating `delete_instance` penalty logic. +- **Miner-initiated delete while slot active:** Rare if auto-purge drained the server; if any instance remains (partial failure, race, or future edge case), **`delete_instance`** should still treat **active maintenance slot** (`maintenance_pending_window_id IS NOT NULL`) with `valid_termination`, same last-instance policy. +- **Stale pending slot cleanup:** If the upgrade window closes and `maintenance_pending_window_id` still points to it (miner never completed), the slot is stale. Lazy clear: when preflight/confirm encounters `maintenance_pending_window_id` referencing a window that is no longer active, treat as cleared (or clear it on read). This allows the miner to enter a new window without manual intervention. +- **Pending upgrade visibility (miner UX):** When a server has `maintenance_pending_window_id` set (especially after a boot that didn't reach the target), surface the version gap clearly: **preflight** includes a structured `maintenance_pending` reason with `current_version` and `target_version`; **policy** endpoint includes a `pending_servers` list for the calling miner; **`GET /servers/{server_id}`** includes `version` and maintenance status. +- **Explicit non-goals (this spec):** No `scoring_penalty_multiplier` column; no cron adjusting scores for outdated VMs; no `INSTANCES_QUERY` / miner stats query changes for penalties. + +### Acknowledged abuse vector (v1 accept) + +Auto-purge grants `valid_termination=True` at confirm time, **before** the upgrade is proven. A miner could confirm maintenance purely to get a penalty-free purge of all their instances, then never actually upgrade. **Mitigations:** economic cost (zero earnings while offline), sole-survivor check (can't kill last global instance), validator-controlled windows (can't enter at will), per-miner concurrency cap. **Closed fully by** the planned **scoring penalty follow-up** (`scoring_penalty_multiplier` + cron for outdated VMs post-window). **Deferred option (stronger):** purge with `valid_termination=False` at confirm, upgrade to `True` on successful boot completion. + +### Identity, per-window limits, and abuse model + +- **"One maintenance per server per global window"** is enforced by the **`Server.version >= target`** check: once a server completes maintenance (boots with new version), `version` is updated and the server cannot re-enter for the same target. `server_id` is often **ephemeral** (e.g. new Kubernetes node UID after reprovision). If the miner **deletes** the `servers` row and re-registers, or wipes storage and gets a **new** `server_id` / **new** `name`, the API sees a **new** server (with `version = NULL`): we **cannot** infer they already consumed a slot on a logically "same" machine unless we add **durable tracking** outside `servers`. +- **Rename / reprovision loop:** A miner could enter maintenance, tear down the VM, re-register under a **new** `server_id` (and possibly a **new** `name`), and be eligible again for the **same** configured target. **Why this may be weak abuse:** + - Each cycle implies **real downtime** and **lost compute / earnings** during reprovision and redeploy. + - Maintenance protection only affects **how deletes are classified** (thrash + last-instance treatment); it does not mint extra rewards. The "profit" is avoiding scoring penalties on churn, which is bounded by how much they actually delete and redeploy. +- **Residual risk:** A miner could seek **more** `valid_termination`-style deletes than intended by policy if they can cheaply rotate server identities. Mitigation is **economic** (outage cost), **per-miner concurrency**, **no miner cancel** after purge, plus optional **product** mitigations below. +- **Optional hardening (follow-up, if needed):** Append-only **`server_maintenance_events`** with columns like `(miner_hotkey, server_name, upgrade_window_id, confirmed_at)` and a **unique** constraint on `(upgrade_window_id, miner_hotkey, server_name)` — stops **reuse of the same name** in one rollout after row delete, but **does not** stop a miner who picks a **new name** each time. Stronger binding would need an **immutable** hardware or enrollment identifier (out of scope unless another feature provides it). + +### Table `tee_upgrade_windows` (historical record, one row per target) + +| Column | Type | Notes | +|--------|------|--------| +| **`id`** | bigserial PK | Stable row identity for FKs from **`servers`**. | +| **`upgrade_window_start`** | timestamptz | Admission opens (new **preflight** / **confirm** allowed). | +| **`upgrade_window_end`** | timestamptz | Admission closes for **new** entries; in-flight slots may still finish. | +| **`target_measurement_version`** | text | Minimum attested measurement for this cutover; **normalised** on insert. **`UNIQUE`** recommended (one row per target version). | +| **`created_at`** | timestamptz | When the row was inserted (audit). Default `now()`. | + +**Operational pattern:** `INSERT` a new row when shipping a **new** mandated image line; **UPDATE** `upgrade_window_end` to **now** (or past) to **end** admits for that cutover before opening the next. Old rows **stay** in the table as history. + +### `servers` table — new columns + +| Column | Type | Notes | +|--------|------|--------| +| **`maintenance_pending_window_id`** | bigint, nullable, FK → `tee_upgrade_windows.id` | Set at **confirm**; "in maintenance" signal. Cleared on successful boot completion (version >= target) or lazily when the referenced window is no longer active. | +| **`version`** | text, nullable | Current attested measurement version. Updated on **every** successful boot attestation (regardless of maintenance). Used by the "already at target" check and exposed in server metadata for miners. | + +### Resolving the "active" window + +**Definition:** The **active** window is the single row (if any) such that **`upgrade_window_start <= now() <= upgrade_window_end`**. If **multiple** rows overlap (operator error), implementation must pick a **deterministic** rule (e.g. **highest `id`**, or **latest `created_at`**) and **log a warning**; validators should avoid overlaps. + +**No active row:** No new **preflight** / **confirm**; feature is "closed" until a new row qualifies. + +**Caching:** Load active row via a small helper used by preflight, confirm, policy, and boot completion; cache the result for a **short TTL** (or invalidate on writes) so policy GETs do not hammer the DB. + +### Rollout identity, single artifact, and what belongs in maintenance + +**Single published VM artifact:** The release pipeline exposes **only the latest** VM image. Once **0.3.1** is published, miners **cannot** fetch **0.3.0**. Any miner who **starts** an upgrade after that point is on the **current** image line. + +**What maintenance windows are for (policy):** Use **coordinated admission windows** primarily for **major / minor** (or **breaking / validator-mandated**) TEE image moves. **Patch** releases: miners upgrade **on their own schedule** **without** this API—**no** maintenance-scoped protections for that path in this spec. + +**Problem (minor bump mid-window):** A row exists with target **T0**; the pipeline publishes a newer image and the old one is **gone**. + +**Operator response (recommended — end and replace, no overlap):** + +1. **End** the current row's admission: set **`upgrade_window_end`** to **now** (or past) on that row so it is no longer "active." +2. **`INSERT`** a **new** row with **`target_measurement_version = T1`**, new **`upgrade_window_start` / `upgrade_window_end`**, and **`created_at`**. +3. Miners who **completed** the old cutover have `version >= T0` → if `T1 > T0`, the "already at target" check allows them to enter the **new** window (version < T1). Miners **in flight** (confirmed under old row, not yet booted): at boot, `Server.version` is updated; if it meets the **new** target too, the stale `maintenance_pending_window_id` (pointing at the old row) is lazily cleared on next preflight. If not, they remain pending until they boot with a sufficient version. + +**Overlapping concurrent windows:** **Out of scope** in v1—operators should **not** insert overlapping `[start, end]` ranges; if they do, deterministic resolution + warning applies. + +**Implementation note:** **`GET …/policy`** returns the **active** row's **`id`**, bounds, **`target_measurement_version`**, **`tee_maintenance_max_miner_concurrency`** from settings, the miner's active slot count, and a **`pending_servers`** list for the calling miner. + +--- + +## API Changes + +- **New endpoints** (names illustrative; align with existing `/servers` prefix): + - `GET /servers/{server_id}/maintenance/preflight` — miner auth, **no side effects**. Response e.g. `{ "eligible": bool, "reasons": [...], "blocking_chute_ids": [...], "current_slots": n, "limit": m, ... }`. When the server has pending maintenance, includes structured `maintenance_pending` reason with `current_version` and `target_version`. + - `PUT /servers/{server_id}/maintenance` — miner auth; body optional `{}` or `{ "ack": true }` if you want an explicit client ack. **Re-validates** all rules, then sets `maintenance_pending_window_id` and **auto-purges**. Returns server id, **list of `instance_id`s purged** (or async job id), echo of window if desired. On failure returns 403/409 with structured error matching preflight reasons. + - **No** `DELETE /servers/.../maintenance` for miners. + - `GET /servers/maintenance/policy` — **read-only** global JSON: **active** window **`id`**, **`upgrade_window_start` / `upgrade_window_end`**, **`target_measurement_version`**, **`tee_maintenance_max_miner_concurrency`** (from settings), miner's **current active slot count**, and **`pending_servers`** list (servers with `maintenance_pending_window_id` set, each with `server_id`, `name`, `version`, `target_version`). Served from DB (via cache). No secrets. +- **Updated endpoints:** + - `GET /servers/{server_id}` — include **`version`** and **maintenance status** (`maintenance_pending_window_id`, resolved `target_version` when pending) in the response. +- **Schema changes — new table `tee_upgrade_windows`:** As in the table above; add **`UNIQUE (target_measurement_version)`** if policy is strictly one row per target. +- **Schema changes (`servers` table):** Add 2 nullable columns: + - **`maintenance_pending_window_id` (FK → `tee_upgrade_windows.id`, nullable)** — set to the **active** window's **`id`** at **confirm**; cleared on successful boot completion or lazily when the referenced window is no longer active. + - **`version` (text, nullable)** — current attested measurement version; updated on every successful boot attestation; used by "already at target" check and exposed in server metadata. +- **Migrations:** New timestamped SQL under `api/migrations/` creating **`tee_upgrade_windows`** and altering **`servers`**; keep `api/server/schemas.py` models in sync (this repo holds `Server` in that module, not `orms.py`—follow local convention). + +--- + +## Goal + +Success = a miner can **preflight** then **confirm** only when a **DB-backed active `tee_upgrade_windows` row** exists and `now` is inside **`[upgrade_window_start, upgrade_window_end]`**, and **not** when **already at or above** that row's **`target_measurement_version`** (checked via `Server.version`). Subject to **sole-survivor rule** (deny **409**—never purge the globally last instance) and **per-miner concurrent-slot limit** (default **one** server at a time); on successful **confirm** the API **auto-purges** with **`valid_termination`** (no globally last instances in batch—see sole-survivor rule); **successful boot** updates **`Server.version`** and clears **`maintenance_pending_window_id`** on the **correct** server row (resolved by `miner_hotkey` + `vm_name`); stale pending slots (window closed, never completed) are **lazily cleared**. Validators add **new table rows** (and end old rows) for each coordinated cutover—**history** remains in **`tee_upgrade_windows`**. **No** miner **DELETE**. **No** scoring / metasync changes for outdated versions in this iteration. + +Testable criteria: + +- Migration applies cleanly; `Server` ORM matches DB. +- Preflight returns **`eligible: false`** when **no active window row**, **already >= active target** (via `Server.version`), outside **`[start, end]`**, **sole-survivor** blocking any instance, at concurrency cap, or already in maintenance; confirm returns **403/409** consistently. +- After successful **confirm**, **all** targeted instances are **gone** from `instances` (or async job completes reliably), `instance_audit` shows **`valid_termination`** and maintenance reason, caches invalidated; **no** spurious thrash on miner redeploy after upgrade. +- Optional: miner `delete_instance` under active slot still correct if any instance left. +- Boot attestation success updates **`Server.version`** and clears **`maintenance_pending_window_id`** when measurement >= target; subsequent preflight returns **ineligible** (version >= target). +- Boot attestation with measurement **below** target still updates `Server.version` but leaves `maintenance_pending_window_id` set; preflight/policy surfaces the version gap. +- `GET` policy endpoint returns expected shape when window open/closed, includes `pending_servers`. +- **Preflight** and **confirm** agree when state is unchanged; after state change, confirm may fail until re-preflight. + +--- + +## Constraints + +- Follow [AGENT.md](../../AGENT.md): **no new dependencies**; **window bounds and targets** live in **`tee_upgrade_windows`** (not env); **settings** for concurrency limit; **async** handlers; **Ruff** clean; add **tests** where behavior is non-trivial. +- **Do not** add `scoring_penalty_multiplier`, penalty cron, or `INSTANCES_QUERY` edits in this task. +- **Do not** hardcode window times or target versions in application code—load from the **DB** (active row) or documented migration seeds. +- Keep changes **focused**: prefer small helpers in `api/server/` (e.g. `util.py` or `service.py`) over cross-cutting refactors. + +--- + +## Output Format + +1. `api/migrations/YYYYMMDDHHMMSS_server_maintenance.sql` — `CREATE TABLE tee_upgrade_windows (...)`; indexes to resolve **active** row quickly (e.g. on `(upgrade_window_start, upgrade_window_end)` or as justified by queries); `ALTER TABLE servers ADD COLUMN …` / FKs. +2. `api/server/schemas.py` — `TeeUpgradeWindow` (or equivalent) model + new columns (`maintenance_pending_window_id`, `version`) on **`Server`**. +3. `api/config.py` (or settings model) — **`tee_maintenance_max_miner_concurrency`**; **not** window start/end/target (those are DB rows). +4. `api/server/router.py` — New routes (policy, preflight, confirm); update `GET /servers/{server_id}` to include `version` and maintenance status; reuse `get_current_user` / hotkey patterns from existing server routes. +5. `api/server/service.py` (or new helper module) — **`get_active_upgrade_window()`** (DB + cache); `preflight_maintenance` / `confirm_maintenance` (already-at-target via `Server.version`, set **`maintenance_pending_window_id`** at confirm); **stale slot lazy clear** when referenced window is no longer active; **`resolve_server_for_maintenance_boot_completion(db, miner_hotkey, vm_name, …)`** using **`BootAttestationArgs`**; extend **`process_boot_attestation`** (or call hook after success) to set **`Server.version`** and clear **`maintenance_pending_window_id`** when measurement OK; **no-op** when **no `Server`** row (**first boot**). +6. `api/instance/util.py` or `watchtower.py` or small `api/instance/maintenance_purge.py` — **Shared** "maintenance purge one instance" used by **confirm** batch; wraps or extends `purge` with **`valid_termination=True`** and maintenance **`deletion_reason`** (batch excludes globally last instances—see sole-survivor rule). +7. `api/instance/router.py` — Keep protected `delete_instance` branch for edge cases (optional if auto-purge is exhaustive). +8. `tests/unit/` (and/or integration) — Preflight **denied** when **no row / outside window / already >= target**; **allowed** when eligible; confirm sets `maintenance_pending_window_id` and purges; boot completion updates `version` and clears pending slot; stale slot lazy clear; sole-survivor blocks confirm; **Instance → Node → Server** resolution. + +--- + +## Failure Conditions + +- Maintenance protection applies **outside** the global window (preflight/confirm must check window bounds). +- **Confirm** returns **success** when **any** blocking **sole-survivor** `chute_id` exists, when miner already holds **≥ limit** concurrent active slots, or when outside the active DB window (**confirm** must **re-validate** every check; outcomes must match preflight unless state legitimately changed). +- A miner-facing **`DELETE`** exists that **clears** maintenance (must **not** ship). +- Auto-purge uses **`valid_termination = false`** or omits last-instance protection → **thrash** or **wrong scoring** on redeploy. +- Instances **remain routable** after successful **confirm** (purge incomplete / wrong server scope). +- Boot success **omits** `Server.version` update or **wipes** `maintenance_pending_window_id` without checking measurement >= target. +- **Wrongly** allow preflight/confirm when the server is **already >= active target** (pointless purge — `Server.version` check must deny). +- **Confirm** or purge **succeeds** while a **globally sole-survivor** instance for any `chute_id` would be terminated (**must** remain **409** / no purge—see fixed sole-survivor rule). +- Boot completion runs on **wrong** `servers` row (must use **`(miner_hotkey, vm_name)` → `servers.name`**, not **IP-only**) or runs without measurement check. +- Maintenance completion runs when **no** `servers` row exists (**first boot**) and incorrectly mutates state (should **no-op**). +- Delete / enumerate path uses **IP-only** as **primary** correlation instead of **Instance → Node → Server** (breaks **planned** shared-IP TEE; weak even today). +- Schema drift: migration applied but `Server` model missing columns (or reverse). +- **Any** dependency added without explicit approval. +- Scoring / metasync penalty code added despite scope. + +--- + +## Rollout Notes + +- **Database:** Document **`tee_upgrade_windows`** and the runbook: **`INSERT`** a row to open a cutover (**`target_measurement_version`**, **`upgrade_window_start` / `upgrade_window_end`**, **`created_at`**); **`UPDATE`** `upgrade_window_end` to end admits; **never** delete old rows if you want history (or archive separately). Optionally document in `dev/dev.md`. +- **Settings:** `TEE_MAINTENANCE_MAX_MINER_CONCURRENCY`—final name follows `settings` naming. +- **Deploy order:** Migrate DB (table + server columns) → deploy API → **`INSERT`** first window row when ready; if **no** row is active for `now`, preflight/confirm deny new entry. +- **Operational:** Miners **`GET …/preflight` → `PUT …/maintenance`**; **`GET …/policy`** reflects the **active DB row** (cached) and shows `pending_servers`. Validators manage **rows**, not env window clocks. + +--- + +## Follow-ups (not this spec) + +- Auto-bounty when confirm is blocked (sole survivor). +- `scoring_penalty_multiplier` + cron + `INSTANCES_QUERY` / miner stats alignment for outdated VMs post-window. +- **Deferred `valid_termination`**: purge with `valid_termination=False` at confirm, upgrade to `True` on successful boot completion (closes the "free purge" abuse vector). +- **`server_maintenance_events`** (or similar) if per-rollout limits must survive server row deletion or **reuse of `server_name`**. +- **Optional:** **`server_id` in `BootAttestationArgs`** if product wants an explicit key beyond **`vm_name`**; **`vm_name` is already required** today ([`BootAttestationArgs`](../../api/server/schemas.py)). +- **Audit view:** `CREATE VIEW v_server_maintenance` joining `servers` → `tee_upgrade_windows` to expose derived columns (`maintenance_declared_at`, `last_maintenance_completed_at`, `target_measurement_version`) without widening the `servers` table. diff --git a/docs/specs/templates/bugfix.md b/docs/specs/templates/bugfix.md new file mode 100644 index 00000000..93eb2a54 --- /dev/null +++ b/docs/specs/templates/bugfix.md @@ -0,0 +1,78 @@ +# Bugfix Spec: [Short Description] + +Use the sections **Goal**, **Constraints**, **Output Format**, and **Failure Conditions** as a **Prompt Contract** for this task (see [AGENT.md](../../../AGENT.md) at repo root). + +**Date**: YYYY-MM-DD +**Status**: draft | investigating | in progress | done + +--- + +## Symptoms + + + +- +- + +--- + +## Reproduction Steps + + + +1. +2. +3. + +--- + +## Root Cause + + + +- + +--- + +## Goal + + + + +Success = + +--- + +## Constraints + + + +- +- + +--- + +## Output Format + + + + +1. +2. + +--- + +## Failure Conditions + + + +- +- + +--- + +## Regression Prevention + + + +- diff --git a/docs/specs/templates/feature.md b/docs/specs/templates/feature.md new file mode 100644 index 00000000..2f7b2b34 --- /dev/null +++ b/docs/specs/templates/feature.md @@ -0,0 +1,85 @@ +# Feature Spec: [Feature Name] + +Use the sections **Goal**, **Constraints**, **Output Format**, and **Failure Conditions** as a **Prompt Contract** for this task (see [AGENT.md](../../../AGENT.md) at repo root). + +**Date**: YYYY-MM-DD +**Status**: draft | in progress | done + +--- + +## Context + + + + +- **Packages affected**: +- **Key files**: +- **Dependencies**: + +--- + +## Design Decisions + + + +- +- + +--- + +## API Changes + + + +- **New endpoints**: +- **Schema changes**: +- **Migrations**: + +--- + +## Goal + + + + +Success = + +--- + +## Constraints + + + + +- +- + +--- + +## Output Format + + + + +1. +2. +3. + +--- + +## Failure Conditions + + + +- +- +- + +--- + +## Rollout Notes + + + +- +- diff --git a/docs/specs/templates/refactor.md b/docs/specs/templates/refactor.md new file mode 100644 index 00000000..5c97ce6b --- /dev/null +++ b/docs/specs/templates/refactor.md @@ -0,0 +1,83 @@ +# Refactor Spec: [Short Description] + +Use the sections **Goal**, **Constraints**, **Output Format**, and **Failure Conditions** as a **Prompt Contract** for this task (see [AGENT.md](../../../AGENT.md) at repo root). + +**Date**: YYYY-MM-DD +**Status**: draft | in progress | done + +--- + +## Motivation + + + +- +- + +--- + +## Scope + + + +- **Components**: +- **Key files**: + +--- + +## Before / After + + + +| Before | After | +|--------|-------| +| | | +| | | + +--- + +## Goal + + + + +Success = + +--- + +## Constraints + + + +- +- + +--- + +## Output Format + + + + +1. +2. +3. + +--- + +## Failure Conditions + + + +- +- + +--- + +## Migration Strategy + + + + +- +- diff --git a/tests/unit/test_server_maintenance_boot.py b/tests/unit/test_server_maintenance_boot.py new file mode 100644 index 00000000..0882cd28 --- /dev/null +++ b/tests/unit/test_server_maintenance_boot.py @@ -0,0 +1,222 @@ +""" +Unit tests for Phase 4: boot version update hook and registration version population. +""" + +import pytest +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +from api.server.schemas import Server, TeeUpgradeWindow +from api.server.exceptions import ServerNotFoundError +from api.server.service import _handle_boot_version_update + +TEST_SERVER_ID = "server-abc-123" +TEST_HOTKEY = "5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY" +TEST_VM_NAME = "my-tee-vm" +TEST_WINDOW_ID = "window-abc-123" +TEST_VERSION_OLD = "0.2.0" +TEST_VERSION_TARGET = "0.3.1" +TEST_VERSION_ABOVE = "0.4.0" +TEST_WINDOW_START = datetime(2026, 4, 1, tzinfo=timezone.utc) +TEST_WINDOW_END = datetime(2026, 4, 7, tzinfo=timezone.utc) + + +def _make_server(**overrides): + defaults = dict( + server_id=TEST_SERVER_ID, + ip="10.0.0.1", + miner_hotkey=TEST_HOTKEY, + name=TEST_VM_NAME, + netuid=64, + is_tee=True, + version=None, + maintenance_pending_window_id=None, + ) + defaults.update(overrides) + return Server(**defaults) + + +def _make_window(**overrides): + defaults = dict( + id=TEST_WINDOW_ID, + upgrade_window_start=TEST_WINDOW_START, + upgrade_window_end=TEST_WINDOW_END, + target_measurement_version=TEST_VERSION_TARGET, + max_concurrent_per_miner=1, + created_at=datetime(2026, 4, 1, tzinfo=timezone.utc), + ) + defaults.update(overrides) + return TeeUpgradeWindow(**defaults) + + +# --------------------------------------------------------------------------- +# _handle_boot_version_update +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("api.server.service.get_server_by_name", new_callable=AsyncMock) +async def test_boot_no_server_row_is_noop(mock_get): + """First boot before registration: no server found, no error.""" + mock_get.side_effect = ServerNotFoundError("not found") + db = AsyncMock() + await _handle_boot_version_update(db, TEST_HOTKEY, TEST_VM_NAME, TEST_VERSION_TARGET) + db.commit.assert_not_awaited() + + +@pytest.mark.asyncio +@patch("api.server.service.get_server_by_name", new_callable=AsyncMock) +async def test_boot_updates_version_no_maintenance(mock_get): + """Server exists, no pending maintenance: version updated, commit called.""" + server = _make_server() + mock_get.return_value = server + db = AsyncMock() + + await _handle_boot_version_update(db, TEST_HOTKEY, TEST_VM_NAME, TEST_VERSION_TARGET) + + assert server.version == TEST_VERSION_TARGET + assert server.maintenance_pending_window_id is None + db.commit.assert_awaited_once() + + +@pytest.mark.asyncio +@patch("api.server.service.get_server_by_name", new_callable=AsyncMock) +async def test_boot_meets_target_clears_maintenance(mock_get): + """Boot version >= target: version updated, maintenance_pending_window_id cleared.""" + server = _make_server(maintenance_pending_window_id=TEST_WINDOW_ID) + mock_get.return_value = server + window = _make_window() + + db = AsyncMock() + db.get = AsyncMock(return_value=window) + + await _handle_boot_version_update(db, TEST_HOTKEY, TEST_VM_NAME, TEST_VERSION_TARGET) + + assert server.version == TEST_VERSION_TARGET + assert server.maintenance_pending_window_id is None + db.get.assert_awaited_once_with(TeeUpgradeWindow, TEST_WINDOW_ID) + db.commit.assert_awaited_once() + + +@pytest.mark.asyncio +@patch("api.server.service.get_server_by_name", new_callable=AsyncMock) +async def test_boot_above_target_clears_maintenance(mock_get): + """Boot version > target: still clears maintenance.""" + server = _make_server(maintenance_pending_window_id=TEST_WINDOW_ID) + mock_get.return_value = server + window = _make_window() + + db = AsyncMock() + db.get = AsyncMock(return_value=window) + + await _handle_boot_version_update(db, TEST_HOTKEY, TEST_VM_NAME, TEST_VERSION_ABOVE) + + assert server.version == TEST_VERSION_ABOVE + assert server.maintenance_pending_window_id is None + + +@pytest.mark.asyncio +@patch("api.server.service.get_server_by_name", new_callable=AsyncMock) +async def test_boot_below_target_keeps_maintenance(mock_get): + """Boot version < target: version updated, but maintenance slot stays.""" + server = _make_server(maintenance_pending_window_id=TEST_WINDOW_ID) + mock_get.return_value = server + window = _make_window() + + db = AsyncMock() + db.get = AsyncMock(return_value=window) + + await _handle_boot_version_update(db, TEST_HOTKEY, TEST_VM_NAME, TEST_VERSION_OLD) + + assert server.version == TEST_VERSION_OLD + assert server.maintenance_pending_window_id == TEST_WINDOW_ID + db.commit.assert_awaited_once() + + +@pytest.mark.asyncio +@patch("api.server.service.get_server_by_name", new_callable=AsyncMock) +async def test_boot_stale_window_cleared(mock_get): + """maintenance_pending_window_id points to missing window: cleared.""" + server = _make_server(maintenance_pending_window_id="missing-window-id") + mock_get.return_value = server + + db = AsyncMock() + db.get = AsyncMock(return_value=None) + + await _handle_boot_version_update(db, TEST_HOTKEY, TEST_VM_NAME, TEST_VERSION_TARGET) + + assert server.version == TEST_VERSION_TARGET + assert server.maintenance_pending_window_id is None + db.commit.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# register_server sets version from verify_server return +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("api.server.service._track_nodes", new_callable=AsyncMock) +@patch("api.server.service.verify_server", new_callable=AsyncMock) +@patch("api.server.service._track_server", new_callable=AsyncMock) +@patch( + "api.server.service.SUPPORTED_GPUS", + {"gpu-a100": {"processors": 1, "max_threads_per_processor": 1}}, +) +async def test_register_server_sets_version(mock_track_server, mock_verify, mock_track_nodes): + from api.server.service import register_server + from api.server.schemas import ServerArgs + from api.node.schemas import NodeArgs + + server = _make_server() + mock_track_server.return_value = server + mock_verify.return_value = TEST_VERSION_TARGET + + db = AsyncMock() + + gpu = MagicMock(spec=NodeArgs) + gpu.gpu_identifier = "gpu-a100" + args = MagicMock(spec=ServerArgs) + args.id = TEST_SERVER_ID + args.name = TEST_VM_NAME + args.host = "10.0.0.1" + args.gpus = [gpu] + + await register_server(db, args, TEST_HOTKEY) + + assert server.version == TEST_VERSION_TARGET + db.commit.assert_awaited() + + +@pytest.mark.asyncio +@patch("api.server.service._track_nodes", new_callable=AsyncMock) +@patch("api.server.service.verify_server", new_callable=AsyncMock) +@patch("api.server.service._track_server", new_callable=AsyncMock) +@patch( + "api.server.service.SUPPORTED_GPUS", + {"gpu-a100": {"processors": 1, "max_threads_per_processor": 1}}, +) +async def test_register_server_version_none_when_verify_returns_none( + mock_track_server, mock_verify, mock_track_nodes +): + from api.server.service import register_server + from api.server.schemas import ServerArgs + from api.node.schemas import NodeArgs + + server = _make_server() + mock_track_server.return_value = server + mock_verify.return_value = None + + db = AsyncMock() + + gpu = MagicMock(spec=NodeArgs) + gpu.gpu_identifier = "gpu-a100" + args = MagicMock(spec=ServerArgs) + args.id = TEST_SERVER_ID + args.name = TEST_VM_NAME + args.host = "10.0.0.1" + args.gpus = [gpu] + + await register_server(db, args, TEST_HOTKEY) + + assert server.version is None diff --git a/tests/unit/test_server_maintenance_models.py b/tests/unit/test_server_maintenance_models.py new file mode 100644 index 00000000..998508a5 --- /dev/null +++ b/tests/unit/test_server_maintenance_models.py @@ -0,0 +1,93 @@ +""" +Unit tests for TEE maintenance ORM models and config settings (Phase 1). +""" + +from datetime import datetime, timezone + +from api.server.schemas import TeeUpgradeWindow, Server + +TEST_SERVER_ID = "node-abc-123" +TEST_IP = "10.0.0.1" +TEST_HOTKEY = "5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY" +TEST_VM_NAME = "my-tee-vm" +TEST_WINDOW_ID = "window-abc-123" +TEST_VERSION = "0.3.1" +TEST_WINDOW_START = datetime(2026, 4, 1, tzinfo=timezone.utc) +TEST_WINDOW_END = datetime(2026, 4, 7, tzinfo=timezone.utc) + + +def _make_server(**overrides): + defaults = dict( + server_id=TEST_SERVER_ID, + ip=TEST_IP, + miner_hotkey=TEST_HOTKEY, + name=TEST_VM_NAME, + netuid=64, + is_tee=True, + ) + defaults.update(overrides) + return Server(**defaults) + + +def test_tee_upgrade_window_columns(): + window = TeeUpgradeWindow( + id=TEST_WINDOW_ID, + upgrade_window_start=TEST_WINDOW_START, + upgrade_window_end=TEST_WINDOW_END, + target_measurement_version=TEST_VERSION, + ) + assert window.id == TEST_WINDOW_ID + assert window.upgrade_window_start == TEST_WINDOW_START + assert window.upgrade_window_end == TEST_WINDOW_END + assert window.target_measurement_version == TEST_VERSION + + +def test_server_maintenance_pending_window_id_defaults_to_none(): + assert _make_server().maintenance_pending_window_id is None + + +def test_server_version_defaults_to_none(): + assert _make_server().version is None + + +def test_server_maintenance_pending_window_id_can_be_set(): + server = _make_server(maintenance_pending_window_id=TEST_WINDOW_ID) + assert server.maintenance_pending_window_id == TEST_WINDOW_ID + + +def test_server_version_can_be_set(): + server = _make_server(version=TEST_VERSION) + assert server.version == TEST_VERSION + + +def test_server_existing_columns_unaffected(): + server = _make_server() + assert server.server_id == TEST_SERVER_ID + assert server.ip == TEST_IP + assert server.miner_hotkey == TEST_HOTKEY + assert server.name == TEST_VM_NAME + assert server.is_tee is True + + +def test_server_has_pending_upgrade_window_relationship(): + assert hasattr(Server, "pending_upgrade_window") + + +def test_window_has_pending_servers_relationship(): + assert hasattr(TeeUpgradeWindow, "pending_servers") + + +def test_window_max_concurrent_per_miner_has_column_default(): + col = TeeUpgradeWindow.__table__.c.max_concurrent_per_miner + assert col.server_default.arg == "1" + assert col.nullable is False + + +def test_window_max_concurrent_per_miner_can_be_set(): + w = TeeUpgradeWindow( + upgrade_window_start=datetime(2026, 4, 1, tzinfo=timezone.utc), + upgrade_window_end=datetime(2026, 4, 7, tzinfo=timezone.utc), + target_measurement_version="0.3.0", + max_concurrent_per_miner=3, + ) + assert w.max_concurrent_per_miner == 3 diff --git a/tests/unit/test_server_maintenance_routes.py b/tests/unit/test_server_maintenance_routes.py new file mode 100644 index 00000000..8917c75a --- /dev/null +++ b/tests/unit/test_server_maintenance_routes.py @@ -0,0 +1,290 @@ +""" +Unit tests for TEE maintenance route handler functions (Phase 3). + +These tests call the route handler functions directly with mocked dependencies, +matching the test style used elsewhere in this project. +""" + +import pytest +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi import HTTPException + +from api.server.schemas import ( + Server, + TeeUpgradeWindow, + MaintenanceReason, + PreflightResult, + ConfirmMaintenanceResult, + UpgradeWindowInfo, + MaintenancePolicyResponse, +) +from api.server.router import ( + get_maintenance_policy, + get_maintenance_preflight, + put_confirm_maintenance, + get_server_details, +) + +TEST_SERVER_NAME_OR_ID = "my-tee-vm" + +TEST_SERVER_ID = "server-abc-123" +TEST_HOTKEY = "5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY" +TEST_VM_NAME = "my-tee-vm" +TEST_WINDOW_ID = "window-abc-123" +TEST_VERSION_OLD = "0.2.0" +TEST_VERSION_TARGET = "0.3.1" +TEST_WINDOW_START = datetime(2026, 4, 1, tzinfo=timezone.utc) +TEST_WINDOW_END = datetime(2026, 4, 7, tzinfo=timezone.utc) +DEFAULT_CONCURRENCY_LIMIT = 1 + + +def _make_window(**overrides): + defaults = dict( + id=TEST_WINDOW_ID, + upgrade_window_start=TEST_WINDOW_START, + upgrade_window_end=TEST_WINDOW_END, + target_measurement_version=TEST_VERSION_TARGET, + max_concurrent_per_miner=DEFAULT_CONCURRENCY_LIMIT, + created_at=datetime(2026, 4, 1, tzinfo=timezone.utc), + ) + defaults.update(overrides) + return TeeUpgradeWindow(**defaults) + + +def _make_server(**overrides): + defaults = dict( + server_id=TEST_SERVER_ID, + ip="10.0.0.1", + miner_hotkey=TEST_HOTKEY, + name=TEST_VM_NAME, + netuid=64, + is_tee=True, + version=None, + maintenance_pending_window_id=None, + ) + defaults.update(overrides) + return Server(**defaults) + + +def _mock_scalars_result(rows): + scalars_mock = MagicMock() + scalars_mock.all.return_value = rows + result_mock = MagicMock() + result_mock.scalars.return_value = scalars_mock + return result_mock + + +# --------------------------------------------------------------------------- +# GET /servers/maintenance/policy +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("api.server.router._count_active_maintenance_slots", new_callable=AsyncMock, return_value=0) +@patch("api.server.router.get_active_upgrade_window", new_callable=AsyncMock) +async def test_policy_returns_active_window(mock_get_window, _mock_slots): + window = _make_window() + mock_get_window.return_value = window + + db = AsyncMock() + db.execute.return_value = _mock_scalars_result([]) + + result = await get_maintenance_policy(db=db, hotkey=TEST_HOTKEY, _=None) + assert isinstance(result, MaintenancePolicyResponse) + assert result.active_window is not None + assert result.active_window.id == TEST_WINDOW_ID + assert result.active_window.max_concurrent_per_miner == DEFAULT_CONCURRENCY_LIMIT + + +@pytest.mark.asyncio +@patch("api.server.router.get_active_upgrade_window", new_callable=AsyncMock, return_value=None) +async def test_policy_returns_null_when_no_window(mock_get_window): + db = AsyncMock() + + result = await get_maintenance_policy(db=db, hotkey=TEST_HOTKEY, _=None) + assert result.active_window is None + assert result.current_slots == 0 + assert result.pending_servers == [] + + +@pytest.mark.asyncio +async def test_policy_rejects_missing_hotkey(): + db = AsyncMock() + with pytest.raises(HTTPException) as exc_info: + await get_maintenance_policy(db=db, hotkey=None, _=None) + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +@patch("api.server.router._count_active_maintenance_slots", new_callable=AsyncMock, return_value=1) +@patch("api.server.router.get_active_upgrade_window", new_callable=AsyncMock) +async def test_policy_includes_pending_servers(mock_get_window, _mock_slots): + window = _make_window() + mock_get_window.return_value = window + + pending_server = _make_server( + maintenance_pending_window_id=TEST_WINDOW_ID, + version=TEST_VERSION_OLD, + ) + db = AsyncMock() + db.execute.return_value = _mock_scalars_result([pending_server]) + + result = await get_maintenance_policy(db=db, hotkey=TEST_HOTKEY, _=None) + assert result.current_slots == 1 + assert len(result.pending_servers) == 1 + assert result.pending_servers[0].server_id == TEST_SERVER_ID + assert result.pending_servers[0].version == TEST_VERSION_OLD + assert result.pending_servers[0].target_version == TEST_VERSION_TARGET + + +# --------------------------------------------------------------------------- +# GET /servers/{server_id}/maintenance/preflight +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("api.server.router.preflight_maintenance", new_callable=AsyncMock) +@patch("api.server.router.get_server_by_name_or_id", new_callable=AsyncMock) +async def test_preflight_route_returns_result(mock_lookup, mock_preflight): + server = _make_server() + mock_lookup.return_value = server + expected = PreflightResult(eligible=True, current_slots=0, limit=1) + mock_preflight.return_value = expected + db = AsyncMock() + + result = await get_maintenance_preflight( + server_name_or_id=TEST_SERVER_NAME_OR_ID, db=db, hotkey=TEST_HOTKEY, _=None + ) + assert result is expected + mock_lookup.assert_awaited_once_with(db, TEST_HOTKEY, TEST_SERVER_NAME_OR_ID) + mock_preflight.assert_awaited_once_with(db, server, TEST_HOTKEY) + + +@pytest.mark.asyncio +@patch("api.server.router.preflight_maintenance", new_callable=AsyncMock) +@patch("api.server.router.get_server_by_name_or_id", new_callable=AsyncMock) +async def test_preflight_route_returns_ineligible(mock_lookup, mock_preflight): + server = _make_server() + mock_lookup.return_value = server + expected = PreflightResult( + eligible=False, + denial_reasons=[MaintenanceReason(reason="concurrency_cap", current_slots=1, limit=1)], + current_slots=1, + limit=1, + ) + mock_preflight.return_value = expected + db = AsyncMock() + + result = await get_maintenance_preflight( + server_name_or_id=TEST_SERVER_NAME_OR_ID, db=db, hotkey=TEST_HOTKEY, _=None + ) + assert result.eligible is False + assert len(result.denial_reasons) == 1 + + +# --------------------------------------------------------------------------- +# PUT /servers/{server_id}/maintenance +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("api.server.router.confirm_maintenance", new_callable=AsyncMock) +@patch("api.server.router.get_server_by_name_or_id", new_callable=AsyncMock) +async def test_confirm_route_returns_result(mock_lookup, mock_confirm): + server = _make_server() + mock_lookup.return_value = server + expected = ConfirmMaintenanceResult( + server_id=TEST_SERVER_ID, + purged_instance_ids=["inst-1"], + window=UpgradeWindowInfo( + id=TEST_WINDOW_ID, + target_measurement_version=TEST_VERSION_TARGET, + upgrade_window_start=str(TEST_WINDOW_START), + upgrade_window_end=str(TEST_WINDOW_END), + ), + ) + mock_confirm.return_value = expected + db = AsyncMock() + + result = await put_confirm_maintenance( + server_name_or_id=TEST_SERVER_NAME_OR_ID, db=db, hotkey=TEST_HOTKEY, _=None + ) + assert result is expected + mock_lookup.assert_awaited_once_with(db, TEST_HOTKEY, TEST_SERVER_NAME_OR_ID) + mock_confirm.assert_awaited_once_with(db, server, TEST_HOTKEY) + + +@pytest.mark.asyncio +@patch("api.server.router.confirm_maintenance", new_callable=AsyncMock) +@patch("api.server.router.get_server_by_name_or_id", new_callable=AsyncMock) +async def test_confirm_route_propagates_409(mock_lookup, mock_confirm): + server = _make_server() + mock_lookup.return_value = server + mock_confirm.side_effect = HTTPException(status_code=409, detail="conflict") + db = AsyncMock() + + with pytest.raises(HTTPException) as exc_info: + await put_confirm_maintenance( + server_name_or_id=TEST_SERVER_NAME_OR_ID, db=db, hotkey=TEST_HOTKEY, _=None + ) + assert exc_info.value.status_code == 409 + + +@pytest.mark.asyncio +@patch("api.server.router.confirm_maintenance", new_callable=AsyncMock) +@patch("api.server.router.get_server_by_name_or_id", new_callable=AsyncMock) +async def test_confirm_route_propagates_403(mock_lookup, mock_confirm): + server = _make_server() + mock_lookup.return_value = server + mock_confirm.side_effect = HTTPException(status_code=403, detail="forbidden") + db = AsyncMock() + + with pytest.raises(HTTPException) as exc_info: + await put_confirm_maintenance( + server_name_or_id=TEST_SERVER_NAME_OR_ID, db=db, hotkey=TEST_HOTKEY, _=None + ) + assert exc_info.value.status_code == 403 + + +# --------------------------------------------------------------------------- +# GET /servers/{server_id} — includes version + maintenance info +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("api.server.router.check_server_ownership", new_callable=AsyncMock) +async def test_get_server_details_includes_version(mock_ownership): + server = _make_server(version=TEST_VERSION_OLD) + server.created_at = TEST_WINDOW_START + server.updated_at = None + mock_ownership.return_value = server + db = AsyncMock() + + result = await get_server_details(server_id=TEST_SERVER_ID, db=db, hotkey=TEST_HOTKEY, _=None) + assert result["version"] == TEST_VERSION_OLD + assert result["maintenance_pending_window_id"] is None + assert "target_version" not in result + + +@pytest.mark.asyncio +@patch("api.server.router.check_server_ownership", new_callable=AsyncMock) +async def test_get_server_details_includes_target_when_pending(mock_ownership): + server = _make_server( + version=TEST_VERSION_OLD, + maintenance_pending_window_id=TEST_WINDOW_ID, + ) + server.created_at = TEST_WINDOW_START + server.updated_at = None + mock_ownership.return_value = server + + window = _make_window() + db = AsyncMock() + db.get = AsyncMock(return_value=window) + + result = await get_server_details(server_id=TEST_SERVER_ID, db=db, hotkey=TEST_HOTKEY, _=None) + assert result["version"] == TEST_VERSION_OLD + assert result["maintenance_pending_window_id"] == TEST_WINDOW_ID + assert result["target_version"] == TEST_VERSION_TARGET + db.get.assert_awaited_once_with(TeeUpgradeWindow, TEST_WINDOW_ID) diff --git a/tests/unit/test_server_maintenance_service.py b/tests/unit/test_server_maintenance_service.py new file mode 100644 index 00000000..025f34d5 --- /dev/null +++ b/tests/unit/test_server_maintenance_service.py @@ -0,0 +1,459 @@ +""" +Unit tests for TEE maintenance service functions (Phase 2). +""" + +import pytest +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi import HTTPException + +from api.server.schemas import ( + Server, + TeeUpgradeWindow, + MaintenanceReason, + SoleSurvivorBlock, + PreflightResult, +) +from api.server.service import ( + get_active_upgrade_window, + preflight_maintenance, + confirm_maintenance, + _get_instances_on_server, + _find_sole_survivor_chutes, + _count_active_maintenance_slots, +) + +TEST_SERVER_ID = "server-abc-123" +TEST_HOTKEY = "5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY" +TEST_VM_NAME = "my-tee-vm" +TEST_WINDOW_ID = "window-abc-123" +TEST_WINDOW_ID_2 = "window-def-456" +TEST_VERSION_OLD = "0.2.0" +TEST_VERSION_TARGET = "0.3.1" +TEST_VERSION_ABOVE = "0.4.0" +TEST_WINDOW_START = datetime(2026, 4, 1, tzinfo=timezone.utc) +TEST_WINDOW_END = datetime(2026, 4, 7, tzinfo=timezone.utc) +TEST_CHUTE_ID = "chute-abc-123" +TEST_INSTANCE_ID = "inst-abc-123" +DEFAULT_CONCURRENCY_LIMIT = 1 + + +def _make_window(**overrides): + defaults = dict( + id=TEST_WINDOW_ID, + upgrade_window_start=TEST_WINDOW_START, + upgrade_window_end=TEST_WINDOW_END, + target_measurement_version=TEST_VERSION_TARGET, + max_concurrent_per_miner=DEFAULT_CONCURRENCY_LIMIT, + created_at=datetime(2026, 4, 1, tzinfo=timezone.utc), + ) + defaults.update(overrides) + return TeeUpgradeWindow(**defaults) + + +def _make_server(**overrides): + defaults = dict( + server_id=TEST_SERVER_ID, + ip="10.0.0.1", + miner_hotkey=TEST_HOTKEY, + name=TEST_VM_NAME, + netuid=64, + is_tee=True, + version=None, + maintenance_pending_window_id=None, + ) + defaults.update(overrides) + return Server(**defaults) + + +def _make_instance(**overrides): + from api.instance.schemas import Instance + + defaults = dict( + instance_id=TEST_INSTANCE_ID, + chute_id=TEST_CHUTE_ID, + miner_hotkey=TEST_HOTKEY, + active=True, + ) + defaults.update(overrides) + return Instance(**defaults) + + +def _make_preflight(eligible=True, denial_reasons=None, blocking=None, current_slots=0, limit=1): + return PreflightResult( + eligible=eligible, + denial_reasons=denial_reasons or [], + blocking_chute_ids=blocking or [], + current_slots=current_slots, + limit=limit, + ) + + +def _mock_scalars_result(rows): + """Build a mock result whose .scalars().all() returns the given rows.""" + scalars_mock = MagicMock() + scalars_mock.all.return_value = rows + result_mock = MagicMock() + result_mock.scalars.return_value = scalars_mock + return result_mock + + +def _mock_scalar_result(value): + """Build a mock result whose .scalar() returns the given value.""" + result_mock = MagicMock() + result_mock.scalar.return_value = value + return result_mock + + +def _mock_scalar_one_or_none_result(value): + """Build a mock result whose .scalar_one_or_none() returns the given value.""" + result_mock = MagicMock() + result_mock.scalar_one_or_none.return_value = value + return result_mock + + +# --------------------------------------------------------------------------- +# get_active_upgrade_window +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_active_upgrade_window_returns_none_when_no_rows(): + db = AsyncMock() + db.execute.return_value = _mock_scalars_result([]) + result = await get_active_upgrade_window(db) + assert result is None + + +@pytest.mark.asyncio +async def test_get_active_upgrade_window_returns_single_active_row(): + window = _make_window() + db = AsyncMock() + db.execute.return_value = _mock_scalars_result([window]) + result = await get_active_upgrade_window(db) + assert result is window + + +@pytest.mark.asyncio +async def test_get_active_upgrade_window_picks_most_recent_when_overlapping(): + newer = _make_window(id=TEST_WINDOW_ID, created_at=datetime(2026, 4, 2, tzinfo=timezone.utc)) + older = _make_window(id=TEST_WINDOW_ID_2, created_at=datetime(2026, 4, 1, tzinfo=timezone.utc)) + db = AsyncMock() + db.execute.return_value = _mock_scalars_result([newer, older]) + result = await get_active_upgrade_window(db) + assert result is newer + + +# --------------------------------------------------------------------------- +# _get_instances_on_server +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_instances_on_server_returns_instances(): + inst = _make_instance() + db = AsyncMock() + db.execute.return_value = _mock_scalars_result([inst]) + result = await _get_instances_on_server(db, TEST_SERVER_ID) + assert result == [inst] + + +@pytest.mark.asyncio +async def test_get_instances_on_server_returns_empty_list(): + db = AsyncMock() + db.execute.return_value = _mock_scalars_result([]) + result = await _get_instances_on_server(db, TEST_SERVER_ID) + assert result == [] + + +# --------------------------------------------------------------------------- +# _find_sole_survivor_chutes +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_find_sole_survivor_chutes_no_blocking(): + inst = _make_instance() + db = AsyncMock() + db.execute.return_value = _mock_scalar_result(1) + result = await _find_sole_survivor_chutes(db, [inst]) + assert result == [] + + +@pytest.mark.asyncio +async def test_find_sole_survivor_chutes_blocks_sole_instance(): + inst = _make_instance() + db = AsyncMock() + db.execute.return_value = _mock_scalar_result(0) + result = await _find_sole_survivor_chutes(db, [inst]) + assert len(result) == 1 + assert isinstance(result[0], SoleSurvivorBlock) + assert result[0].chute_id == TEST_CHUTE_ID + assert result[0].instance_id == TEST_INSTANCE_ID + + +@pytest.mark.asyncio +async def test_find_sole_survivor_chutes_deduplicates_by_chute(): + inst_a = _make_instance(instance_id="inst-1", chute_id=TEST_CHUTE_ID) + inst_b = _make_instance(instance_id="inst-2", chute_id=TEST_CHUTE_ID) + db = AsyncMock() + db.execute.return_value = _mock_scalar_result(0) + result = await _find_sole_survivor_chutes(db, [inst_a, inst_b]) + assert len(result) == 1 + + +# --------------------------------------------------------------------------- +# _count_active_maintenance_slots +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_count_active_maintenance_slots(): + window = _make_window() + db = AsyncMock() + db.execute.return_value = _mock_scalar_result(2) + result = await _count_active_maintenance_slots(db, TEST_HOTKEY, window) + assert result == 2 + + +@pytest.mark.asyncio +async def test_count_active_maintenance_slots_returns_zero_when_null(): + window = _make_window() + db = AsyncMock() + db.execute.return_value = _mock_scalar_result(None) + result = await _count_active_maintenance_slots(db, TEST_HOTKEY, window) + assert result == 0 + + +# --------------------------------------------------------------------------- +# preflight_maintenance +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("api.server.service.get_active_upgrade_window", new_callable=AsyncMock) +async def test_preflight_not_tee(_mock_window): + server = _make_server(is_tee=False) + db = AsyncMock() + result = await preflight_maintenance(db, server, TEST_HOTKEY) + assert isinstance(result, PreflightResult) + assert result.eligible is False + assert any(r.reason == "not_tee" for r in result.denial_reasons) + + +@pytest.mark.asyncio +@patch("api.server.service.get_active_upgrade_window", new_callable=AsyncMock, return_value=None) +async def test_preflight_no_active_window(_mock_window): + server = _make_server() + db = AsyncMock() + result = await preflight_maintenance(db, server, TEST_HOTKEY) + assert result.eligible is False + assert any(r.reason == "no_active_window" for r in result.denial_reasons) + + +@pytest.mark.asyncio +@patch("api.server.service._find_sole_survivor_chutes", new_callable=AsyncMock, return_value=[]) +@patch("api.server.service._get_instances_on_server", new_callable=AsyncMock, return_value=[]) +@patch("api.server.service._count_active_maintenance_slots", new_callable=AsyncMock, return_value=0) +@patch("api.server.service.get_active_upgrade_window", new_callable=AsyncMock) +async def test_preflight_already_at_target( + mock_window, _mock_slots, _mock_instances, _mock_survivors +): + window = _make_window() + mock_window.return_value = window + server = _make_server(version=TEST_VERSION_ABOVE) + db = AsyncMock() + result = await preflight_maintenance(db, server, TEST_HOTKEY) + assert result.eligible is False + assert any(r.reason == "already_at_target" for r in result.denial_reasons) + + +@pytest.mark.asyncio +@patch("api.server.service._find_sole_survivor_chutes", new_callable=AsyncMock, return_value=[]) +@patch("api.server.service._get_instances_on_server", new_callable=AsyncMock, return_value=[]) +@patch("api.server.service._count_active_maintenance_slots", new_callable=AsyncMock, return_value=0) +@patch("api.server.service.get_active_upgrade_window", new_callable=AsyncMock) +async def test_preflight_maintenance_pending( + mock_window, _mock_slots, _mock_instances, _mock_survivors +): + window = _make_window() + mock_window.return_value = window + server = _make_server(maintenance_pending_window_id=TEST_WINDOW_ID) + db = AsyncMock() + result = await preflight_maintenance(db, server, TEST_HOTKEY) + assert result.eligible is False + assert any(r.reason == "maintenance_pending" for r in result.denial_reasons) + + +@pytest.mark.asyncio +@patch("api.server.service._find_sole_survivor_chutes", new_callable=AsyncMock, return_value=[]) +@patch("api.server.service._get_instances_on_server", new_callable=AsyncMock, return_value=[]) +@patch("api.server.service._count_active_maintenance_slots", new_callable=AsyncMock, return_value=0) +@patch("api.server.service.get_active_upgrade_window", new_callable=AsyncMock) +async def test_preflight_stale_window_gets_cleared( + mock_window, _mock_slots, _mock_instances, _mock_survivors +): + window = _make_window(id="new-window-id") + mock_window.return_value = window + server = _make_server(maintenance_pending_window_id="old-stale-window-id") + db = AsyncMock() + result = await preflight_maintenance(db, server, TEST_HOTKEY) + assert result.eligible is True + assert server.maintenance_pending_window_id is None + + +@pytest.mark.asyncio +@patch("api.server.service._find_sole_survivor_chutes", new_callable=AsyncMock, return_value=[]) +@patch("api.server.service._get_instances_on_server", new_callable=AsyncMock, return_value=[]) +@patch("api.server.service._count_active_maintenance_slots", new_callable=AsyncMock) +@patch("api.server.service.get_active_upgrade_window", new_callable=AsyncMock) +async def test_preflight_concurrency_cap(mock_window, mock_slots, _mock_instances, _mock_survivors): + window = _make_window() + mock_window.return_value = window + mock_slots.return_value = DEFAULT_CONCURRENCY_LIMIT + server = _make_server() + db = AsyncMock() + result = await preflight_maintenance(db, server, TEST_HOTKEY) + assert result.eligible is False + assert any(r.reason == "concurrency_cap" for r in result.denial_reasons) + + +@pytest.mark.asyncio +@patch("api.server.service._find_sole_survivor_chutes", new_callable=AsyncMock) +@patch("api.server.service._get_instances_on_server", new_callable=AsyncMock) +@patch("api.server.service._count_active_maintenance_slots", new_callable=AsyncMock, return_value=0) +@patch("api.server.service.get_active_upgrade_window", new_callable=AsyncMock) +async def test_preflight_sole_survivor_blocks( + mock_window, _mock_slots, mock_instances, mock_survivors +): + window = _make_window() + mock_window.return_value = window + inst = _make_instance() + mock_instances.return_value = [inst] + blocking = [SoleSurvivorBlock(chute_id=TEST_CHUTE_ID, instance_id=TEST_INSTANCE_ID)] + mock_survivors.return_value = blocking + server = _make_server() + db = AsyncMock() + result = await preflight_maintenance(db, server, TEST_HOTKEY) + assert result.eligible is False + assert any(r.reason == "sole_survivor" for r in result.denial_reasons) + assert result.blocking_chute_ids == blocking + + +@pytest.mark.asyncio +@patch("api.server.service._find_sole_survivor_chutes", new_callable=AsyncMock, return_value=[]) +@patch("api.server.service._get_instances_on_server", new_callable=AsyncMock, return_value=[]) +@patch("api.server.service._count_active_maintenance_slots", new_callable=AsyncMock, return_value=0) +@patch("api.server.service.get_active_upgrade_window", new_callable=AsyncMock) +async def test_preflight_eligible_version_none( + mock_window, _mock_slots, _mock_instances, _mock_survivors +): + """Server with version=None should not be denied as 'already at target'.""" + + window = _make_window() + mock_window.return_value = window + server = _make_server(version=None) + db = AsyncMock() + result = await preflight_maintenance(db, server, TEST_HOTKEY) + assert result.eligible is True + assert result.denial_reasons == [] + + +@pytest.mark.asyncio +@patch("api.server.service._find_sole_survivor_chutes", new_callable=AsyncMock, return_value=[]) +@patch("api.server.service._get_instances_on_server", new_callable=AsyncMock, return_value=[]) +@patch("api.server.service._count_active_maintenance_slots", new_callable=AsyncMock, return_value=0) +@patch("api.server.service.get_active_upgrade_window", new_callable=AsyncMock) +async def test_preflight_eligible_old_version( + mock_window, _mock_slots, _mock_instances, _mock_survivors +): + window = _make_window() + mock_window.return_value = window + server = _make_server(version=TEST_VERSION_OLD) + db = AsyncMock() + result = await preflight_maintenance(db, server, TEST_HOTKEY) + assert result.eligible is True + assert result.denial_reasons == [] + + +# --------------------------------------------------------------------------- +# confirm_maintenance +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("api.server.service.get_active_upgrade_window", new_callable=AsyncMock) +@patch("api.server.service.preflight_maintenance", new_callable=AsyncMock) +async def test_confirm_raises_on_ineligible(mock_preflight, mock_window): + mock_preflight.return_value = _make_preflight( + eligible=False, + denial_reasons=[MaintenanceReason(reason="concurrency_cap", current_slots=1, limit=1)], + current_slots=1, + ) + server = _make_server() + db = AsyncMock() + with pytest.raises(HTTPException) as exc_info: + await confirm_maintenance(db, server, TEST_HOTKEY) + assert exc_info.value.status_code == 409 + + +@pytest.mark.asyncio +@patch("api.server.service.get_active_upgrade_window", new_callable=AsyncMock) +@patch("api.server.service.preflight_maintenance", new_callable=AsyncMock) +async def test_confirm_raises_403_for_non_tee(mock_preflight, mock_window): + mock_preflight.return_value = _make_preflight( + eligible=False, + denial_reasons=[MaintenanceReason(reason="not_tee")], + ) + server = _make_server(is_tee=False) + db = AsyncMock() + with pytest.raises(HTTPException) as exc_info: + await confirm_maintenance(db, server, TEST_HOTKEY) + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +@patch("api.server.service.purge_and_notify", new_callable=AsyncMock) +@patch("api.server.service._get_instances_on_server", new_callable=AsyncMock) +@patch("api.server.service.get_active_upgrade_window", new_callable=AsyncMock) +@patch("api.server.service.preflight_maintenance", new_callable=AsyncMock) +async def test_confirm_success(mock_preflight, mock_window, mock_instances, mock_purge): + window = _make_window() + mock_preflight.return_value = _make_preflight(eligible=True) + mock_window.return_value = window + inst = _make_instance() + mock_instances.return_value = [inst] + server = _make_server() + db = AsyncMock() + result = await confirm_maintenance(db, server, TEST_HOTKEY) + + assert server.maintenance_pending_window_id == TEST_WINDOW_ID + assert result.server_id == TEST_SERVER_ID + assert inst.instance_id in result.purged_instance_ids + assert result.window.id == TEST_WINDOW_ID + mock_purge.assert_awaited_once_with( + inst, + reason="tee maintenance", + valid_termination=True, + ) + db.commit.assert_awaited_once() + + +@pytest.mark.asyncio +@patch("api.server.service.purge_and_notify", new_callable=AsyncMock) +@patch("api.server.service._get_instances_on_server", new_callable=AsyncMock) +@patch("api.server.service.get_active_upgrade_window", new_callable=AsyncMock) +@patch("api.server.service.preflight_maintenance", new_callable=AsyncMock) +async def test_confirm_purge_failure_does_not_crash( + mock_preflight, mock_window, mock_instances, mock_purge +): + window = _make_window() + mock_preflight.return_value = _make_preflight(eligible=True) + mock_window.return_value = window + mock_instances.return_value = [_make_instance()] + mock_purge.side_effect = RuntimeError("purge failed") + server = _make_server() + db = AsyncMock() + result = await confirm_maintenance(db, server, TEST_HOTKEY) + assert result.purged_instance_ids == [] diff --git a/tests/unit/test_server_service.py b/tests/unit/test_server_service.py index 75c02d04..2545fd7b 100644 --- a/tests/unit/test_server_service.py +++ b/tests/unit/test_server_service.py @@ -18,7 +18,6 @@ process_boot_attestation, process_runtime_attestation, register_server, - verify_server, check_server_ownership, get_server_by_name, update_server_name, @@ -107,7 +106,7 @@ def mock_util_functions(): patch("api.server.service.generate_nonce", return_value=TEST_GPU_NONCE) as mock_gen, patch("api.server.service.get_nonce_expiry_seconds", return_value=600) as mock_exp, patch( - "api.server.service.extract_report_data", + "api.server.util.extract_report_data", return_value=(TEST_GPU_NONCE, TEST_CERT_HASH), ) as mock_extract, patch("api.server.service.verify_gpu_evidence") as mock_verify_gpu, @@ -238,6 +237,7 @@ def _sample_node_args(): def server_args(): """Sample ServerArgs for testing.""" return ServerArgs( + id="test-server-123", host=TEST_SERVER_IP, name="test-vm-name", gpus=[_sample_node_args()], @@ -279,7 +279,7 @@ def sample_server_attestation(): def mock_verify_quote_signature(sample_verification_result): """Mock verify_quote_signature function.""" with patch( - "api.server.service.verify_quote_signature", return_value=sample_verification_result + "api.server.util.verify_quote_signature", return_value=sample_verification_result ) as mock: yield mock @@ -287,7 +287,7 @@ def mock_verify_quote_signature(sample_verification_result): @pytest.fixture def mock_verify_measurements(): """Mock verify_measurements function.""" - with patch("api.server.service.verify_measurements", return_value=True) as mock: + with patch("api.server.util.verify_measurements", return_value=True) as mock: yield mock @@ -453,9 +453,15 @@ def mock_refresh(obj): mock_db_session.refresh.side_effect = mock_refresh - with patch( - "api.server.service.generate_and_store_boot_token", - return_value="test-boot-token", + with ( + patch( + "api.server.service.generate_and_store_boot_token", + return_value="test-boot-token", + ), + patch( + "api.server.service._handle_boot_version_update", + new_callable=AsyncMock, + ), ): result = await process_boot_attestation( mock_db_session, @@ -594,80 +600,37 @@ async def test_process_runtime_attestation_server_not_found( @pytest.mark.asyncio -async def test_register_server_success( - mock_db_session, server_args, sample_server, sample_runtime_quote -): +async def test_register_server_success(mock_db_session, server_args, sample_server): """Test successful server registration.""" miner_hotkey = "5FTestHotkey123" - def mock_refresh(obj): - obj.server_id = "test-server-123" - - mock_db_session.refresh.side_effect = mock_refresh - - with patch("api.server.service.TeeServerClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.get_evidence.return_value = ( - sample_runtime_quote, - {}, - TEST_CERT_HASH, - ) - mock_client_class.return_value = mock_client - with patch("api.server.service.verify_quote") as mock_verify_quote: - mock_verify_quote.return_value = TdxVerificationResult( - mrtd="a" * 96, - rtmr0="d" * 96, - rtmr1="e" * 96, - rtmr2="f" * 96, - rtmr3="0" * 96, - user_data="test", - parsed_at=datetime.now(timezone.utc), - status="UpToDate", - advisory_ids=[], - td_attributes="0000001000000000", - ) - await verify_server(mock_db_session, sample_server, miner_hotkey, server_args.gpus) + with patch("api.server.service._track_server", return_value=sample_server): + with patch("api.server.service._track_nodes", new_callable=AsyncMock): + with patch( + "api.server.service.verify_server", new_callable=AsyncMock, return_value="1.0.0" + ): + await register_server(mock_db_session, server_args, miner_hotkey) - # Verify database operations - mock_db_session.add.assert_called_once() - mock_db_session.commit.assert_called_once() - mock_db_session.refresh.assert_called_once() + assert sample_server.version == "1.0.0" + mock_db_session.commit.assert_called() @pytest.mark.asyncio -async def test_register_server_integrity_error( - mock_db_session, server_args, sample_server, sample_runtime_quote -): - """Test server registration with database integrity error.""" +async def test_register_server_integrity_error(mock_db_session, server_args, sample_server): + """Test server registration handles IntegrityError from _track_nodes.""" miner_hotkey = "5FTestHotkey123" - mock_db_session.commit.side_effect = IntegrityError("Duplicate key", None, None) - with patch("api.server.service._track_server", return_value=sample_server): - with patch("api.server.service._track_nodes", new_callable=AsyncMock): - with patch("api.server.service.TeeServerClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.get_evidence.return_value = ( - sample_runtime_quote, - {}, - TEST_CERT_HASH, - ) - mock_client_class.return_value = mock_client - with patch("api.server.service.verify_quote") as mock_verify_quote: - mock_verify_quote.return_value = TdxVerificationResult( - mrtd="a" * 96, - rtmr0="d" * 96, - rtmr1="e" * 96, - rtmr2="f" * 96, - rtmr3="0" * 96, - user_data="test", - parsed_at=datetime.now(timezone.utc), - status="UpToDate", - advisory_ids=[], - td_attributes="0000001000000000", - ) - with pytest.raises(ServerRegistrationError): - await register_server(mock_db_session, server_args, miner_hotkey) + with patch( + "api.server.service._track_nodes", + new_callable=AsyncMock, + side_effect=IntegrityError("Duplicate key", None, None), + ): + with patch( + "api.server.service.verify_server", new_callable=AsyncMock, return_value="1.0.0" + ): + with pytest.raises(ServerRegistrationError): + await register_server(mock_db_session, server_args, miner_hotkey) mock_db_session.rollback.assert_called_once() @@ -927,39 +890,21 @@ async def test_validate_nonce_invalid_format(mock_settings): @pytest.mark.asyncio -async def test_register_server_general_exception( - mock_db_session, server_args, sample_server, sample_runtime_quote -): - """Test server verification with general exception on commit.""" +async def test_register_server_general_exception(mock_db_session, server_args, sample_server): + """Test server registration handles unexpected exceptions.""" miner_hotkey = "5FTestHotkey123" - mock_db_session.commit.side_effect = Exception("Database error") - with patch("api.server.service._track_server", return_value=sample_server): - with patch("api.server.service._track_nodes", new_callable=AsyncMock): - with patch("api.server.service.TeeServerClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.get_evidence.return_value = ( - sample_runtime_quote, - {}, - TEST_CERT_HASH, - ) - mock_client_class.return_value = mock_client - with patch("api.server.service.verify_quote") as mock_verify_quote: - mock_verify_quote.return_value = TdxVerificationResult( - mrtd="a" * 96, - rtmr0="d" * 96, - rtmr1="e" * 96, - rtmr2="f" * 96, - rtmr3="0" * 96, - user_data="test", - parsed_at=datetime.now(timezone.utc), - status="UpToDate", - advisory_ids=[], - td_attributes="0000001000000000", - ) - with pytest.raises(ServerRegistrationError): - await register_server(mock_db_session, server_args, miner_hotkey) + with patch( + "api.server.service._track_nodes", + new_callable=AsyncMock, + side_effect=Exception("Database error"), + ): + with patch( + "api.server.service.verify_server", new_callable=AsyncMock, return_value="1.0.0" + ): + with pytest.raises(ServerRegistrationError): + await register_server(mock_db_session, server_args, miner_hotkey) mock_db_session.rollback.assert_called_once() @@ -1022,7 +967,7 @@ async def test_full_boot_flow_end_to_end(mock_db_session, mock_settings, mock_ve ) with patch("api.server.service.BootTdxQuote.from_base64", return_value=boot_quote): - with patch("api.server.service.verify_quote_signature") as mock_verify: + with patch("api.server.util.verify_quote_signature") as mock_verify: mock_verify.return_value = TdxVerificationResult( mrtd="a" * 96, rtmr0="b" * 96, @@ -1042,9 +987,15 @@ def mock_refresh(obj): mock_db_session.refresh.side_effect = mock_refresh - with patch( - "api.server.service.generate_and_store_boot_token", - return_value="test-boot-token", + with ( + patch( + "api.server.service.generate_and_store_boot_token", + return_value="test-boot-token", + ), + patch( + "api.server.service._handle_boot_version_update", + new_callable=AsyncMock, + ), ): result = await process_boot_attestation( mock_db_session, @@ -1095,7 +1046,7 @@ async def test_full_runtime_flow_end_to_end( with patch("api.server.service.check_server_ownership", return_value=sample_server): with patch("api.server.service.RuntimeTdxQuote.from_base64", return_value=runtime_quote): - with patch("api.server.service.verify_quote_signature") as mock_verify: + with patch("api.server.util.verify_quote_signature") as mock_verify: mock_verify.return_value = TdxVerificationResult( mrtd="a" * 96, rtmr0="d" * 96, @@ -1130,46 +1081,16 @@ def mock_refresh(obj): @pytest.mark.asyncio -async def test_server_lifecycle_flow( - mock_db_session, sample_server, server_args, sample_runtime_quote -): +async def test_server_lifecycle_flow(mock_db_session, sample_server, server_args): """Test complete server lifecycle: register -> check ownership -> delete.""" miner_hotkey = "5FTestHotkey123" - def mock_refresh(obj): - obj.server_id = "test-server-123" - if hasattr(obj, "attestation_id"): - obj.attestation_id = "runtime-attest-123" - if hasattr(obj, "verified_at"): - obj.verified_at = datetime.now(timezone.utc) - - mock_db_session.refresh.side_effect = mock_refresh - with patch("api.server.service._track_server", return_value=sample_server): with patch("api.server.service._track_nodes", new_callable=AsyncMock): - with patch("api.server.service.TeeServerClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.get_evidence.return_value = ( - sample_runtime_quote, - {}, - TEST_CERT_HASH, - ) - mock_client_class.return_value = mock_client - with patch("api.server.service.verify_quote") as mock_verify_quote: - mock_verify_quote.return_value = TdxVerificationResult( - mrtd="a" * 96, - rtmr0="d" * 96, - rtmr1="e" * 96, - rtmr2="f" * 96, - rtmr3="0" * 96, - user_data="test", - parsed_at=datetime.now(timezone.utc), - status="UpToDate", - advisory_ids=[], - td_attributes="0000001000000000", - ) - await register_server(mock_db_session, server_args, miner_hotkey) - mock_db_session.add.assert_called() + with patch( + "api.server.service.verify_server", new_callable=AsyncMock, return_value="1.0.0" + ): + await register_server(mock_db_session, server_args, miner_hotkey) mock_db_session.commit.assert_called() # Step 2: Check ownership @@ -1366,7 +1287,7 @@ async def test_verify_quote_boot_vs_runtime_different_settings(mock_settings): td_attributes="0000001000000000", ) - with patch("api.server.service.verify_quote_signature") as mock_sig: + with patch("api.server.util.verify_quote_signature") as mock_sig: mock_sig.side_effect = [boot_dcap_result, runtime_dcap_result] await verify_quote(boot_quote, TEST_NONCE, TEST_CERT_HASH) await verify_quote(runtime_quote, TEST_NONCE, TEST_CERT_HASH) @@ -1559,7 +1480,7 @@ async def test_verify_quote_with_different_quote_types(mock_verify_measurements) raw_bytes=b"runtime", ) - with patch("api.server.service.verify_quote_signature") as mock_sig: + with patch("api.server.util.verify_quote_signature") as mock_sig: mock_sig.side_effect = [boot_result, runtime_result] boot_verify_result = await verify_quote(boot_quote, TEST_NONCE, TEST_CERT_HASH) runtime_verify_result = await verify_quote(runtime_quote, TEST_NONCE, TEST_CERT_HASH) diff --git a/watchtower.py b/watchtower.py index d96b8cba..6c49cb93 100644 --- a/watchtower.py +++ b/watchtower.py @@ -22,20 +22,17 @@ decrypt_instance_response, encrypt_instance_request, semcomp, - notify_deleted, - notify_job_deleted, ) from api.database import get_session from api.chute.schemas import Chute from api.image.schemas import Image -from api.job.schemas import Job from api.exceptions import EnvdumpMissing from sqlalchemy import text, update, func, select from sqlalchemy.orm import joinedload, selectinload import api.database.orms # noqa import api.miner_client as miner_client from api.instance.schemas import Instance, LaunchConfig -from api.instance.util import invalidate_instance_cache, cleanup_instance_conn_tracking +from api.instance.util import purge, purge_and_notify # noqa: F401 from api.chute.codecheck import is_bad_code @@ -99,58 +96,6 @@ async def load_chute_instances(chute_id): return instances -async def purge(target, reason="miner failed watchtower probes", valid_termination=False): - """ - Purge an instance. - """ - async with get_session() as session: - await session.execute( - text("DELETE FROM instances WHERE instance_id = :instance_id"), - {"instance_id": target.instance_id}, - ) - await session.execute( - text( - "UPDATE instance_audit SET deletion_reason = :reason, valid_termination = :valid_termination WHERE instance_id = :instance_id" - ), - { - "instance_id": target.instance_id, - "reason": reason, - "valid_termination": valid_termination, - }, - ) - - # Fail associated jobs. - job = ( - (await session.execute(select(Job).where(Job.instance_id == target.instance_id))) - .unique() - .scalar_one_or_none() - ) - if job and not job.finished_at: - job.status = "error" - job.error_detail = f"Instance failed monitoring probes: {reason=}" - job.miner_terminated = True - job.finished_at = func.now() - await notify_job_deleted(job) - - await session.commit() - - await cleanup_instance_conn_tracking(target.chute_id, target.instance_id) - - -async def purge_and_notify( - target, reason="miner failed watchtower probes", valid_termination=False -): - """ - Purge an instance and send a notification with the reason. - """ - await purge(target, reason=reason, valid_termination=valid_termination) - await notify_deleted( - target, - message=f"Instance {target.instance_id} of miner {target.miner_hotkey} deleted by watchtower {reason=}", - ) - await invalidate_instance_cache(target.chute_id, instance_id=target.instance_id) - - async def do_slurp(instance, payload, encrypted_slurp): """ Slurp a remote file. @@ -574,7 +519,9 @@ async def increment_soft_fail(instance, chute): f"miner {instance.miner_hotkey} " f"chute {chute.name} reached max soft fails: {fail_count}" ) - await purge_and_notify(instance) + await purge_and_notify( + instance, reason=f"watchtower - max consecutive soft fails ({fail_count})" + ) def get_expected_command(chute, miner_hotkey: str, seed: int = None): @@ -835,7 +782,8 @@ async def check_chute(chute_id): # Delete failed checks. if failed_envdump: await purge_and_notify( - instance, reason="Instance failed env dump signature or process checks." + instance, + reason="watchtower - failed env dump signature or process checks", ) bad_env.add(instance.instance_id) failed_count = await settings.redis_client.incr( @@ -886,7 +834,7 @@ async def check_chute(chute_id): f"miner {instance.miner_hotkey} " f"chute {chute.name} due to hard fail" ) - await purge_and_notify(instance) + await purge_and_notify(instance, reason="watchtower - hard probe failure") # Limit "soft" fails to max consecutive failures, allowing some downtime but not much. for instance in soft_failed: @@ -1100,7 +1048,7 @@ async def procs_check(): if reason: logger.warning(reason) await purge_and_notify( - instance, reason="miner failed watchtower probes" + instance, reason="watchtower - miner failed probes" ) else: logger.success(