diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0af7291f5..87444d60b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -58,7 +58,7 @@ repos: rev: 1.7.0 hooks: - id: interrogate - exclude: ^(setup.py|test/|earth2studio/models/nn/) + exclude: ^(setup.py|test/|earth2studio/models/nn/|serve/server/example_workflows/) args: [--config=pyproject.toml] - repo: https://github.com/igorshubovych/markdownlint-cli diff --git a/earth2studio/data/__init__.py b/earth2studio/data/__init__.py index a27049741..65b78a6c2 100644 --- a/earth2studio/data/__init__.py +++ b/earth2studio/data/__init__.py @@ -31,6 +31,7 @@ from .mrms import MRMS from .ncar import NCAR_ERA5 from .planetary_computer import ( + GeoCatalogClient, PlanetaryComputerECMWFOpenDataIFS, PlanetaryComputerGOES, PlanetaryComputerMODISFire, diff --git a/earth2studio/data/planetary_computer.py b/earth2studio/data/planetary_computer.py index 219208af5..2eedf643d 100644 --- a/earth2studio/data/planetary_computer.py +++ b/earth2studio/data/planetary_computer.py @@ -18,19 +18,25 @@ import asyncio import hashlib +import json +import logging import os import pathlib import shutil +import time as _time_module from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass from datetime import datetime, timedelta, timezone +from time import perf_counter from typing import Any, TypeVar from urllib.parse import urlparse +from uuid import uuid4 import nest_asyncio import netCDF4 import numpy as np import pygrib +import requests import xarray as xr from loguru import logger from tqdm import tqdm @@ -1223,3 +1229,270 @@ def _validate_time(self, times: list[datetime]) -> None: raise ValueError( f"Requested date time {time} is after {self._satellite} was retired ({end_date})" ) + + +class GeoCatalogClient: + """Client for ingesting STAC features into a Planetary Computer GeoCatalog. + + Workflow-specific behavior (templates, tile settings, render options, step size, + start-time parameter key) is loaded from JSON files in the caller-provided config + directory. This class is independent of concrete workflows. + + Expected files per workflow (workflow_name is passed by the caller): + - parameters-{workflow_name}.json: step_size_hours, start_time_parameter_key + - template-collection-{suffix}.json, template-feature-{suffix}.json + - tile-settings-{suffix}.json, render-options-{suffix}.json + + Parameters + ---------- + workflow_name : str + Workflow name used in filenames (e.g. "fcn3", "fcn3-stormscope-goes"). + config_dir : str | pathlib.Path + Directory containing the GeoCatalog JSON config and template files. + """ + + APPLICATION_URL = "https://geocatalog.spatio.azure.com/" + REQUESTS_TIMEOUT = 30 + CREATION_TIMEOUT = 300 + + def __init__( + self, + workflow_name: str, + config_dir: str | pathlib.Path, + ) -> None: + try: + from azure.identity import DefaultAzureCredential as _DefaultAzureCredential + except ImportError as e: + raise ImportError( + "GeoCatalogClient requires 'azure-identity'. " + "Install with the serve extra or pip install azure-identity." + ) from e + self._DefaultAzureCredential = _DefaultAzureCredential + self._workflow_name = workflow_name + self._config_dir = pathlib.Path(config_dir) + self._parameters: dict[str, Any] = {} + self._load_parameters() + self.headers: dict | None = None + + def _load_parameters(self) -> None: + path = self._config_dir / f"parameters-{self._workflow_name}.json" + with open(path) as f: + self._parameters = json.load(f) + + def update_headers(self) -> None: + """Refresh the Authorization header using a new Azure credential token.""" + credential = self._DefaultAzureCredential() + token = credential.get_token(self.APPLICATION_URL) + self.headers = {"Authorization": f"Bearer {token.token}"} + + def _get(self, url: str) -> Any: + return requests.get( + url, + headers=self.headers, + params={"api-version": "2025-04-30-preview"}, + timeout=self.REQUESTS_TIMEOUT, + ) + + def _post(self, url: str, body: dict | None = None) -> Any: + return requests.post( + url, + json=body, + headers=self.headers, + params={"api-version": "2025-04-30-preview"}, + timeout=self.REQUESTS_TIMEOUT, + ) + + def _put(self, url: str, body: dict | None = None) -> Any: + return requests.put( + url, + json=body, + headers=self.headers, + params={"api-version": "2025-04-30-preview"}, + timeout=self.REQUESTS_TIMEOUT, + ) + + def _create_element(self, url: str, stac_config: dict) -> bool: + response = self._post(url, body=stac_config) + log = logging.getLogger("planetary_computer.geocatalog") + if response.status_code not in {200, 201, 202}: + log.error( + "POST to '%s' failed: %s - %s", url, response.status_code, response.text + ) + return False + location = response.headers["location"] + log.info("Creating '%s'...", stac_config["id"]) + start = perf_counter() + while True: + if (perf_counter() - start) > self.CREATION_TIMEOUT: + log.error("Creation of '%s' timed out", stac_config["id"]) + return False + response = self._get(location) + if response.status_code not in {200, 201, 202}: + log.warning( + "Polling '%s' returned %s, retrying...", + location, + response.status_code, + ) + _time_module.sleep(5) + continue + try: + status = response.json()["status"] + except (ValueError, KeyError) as exc: + log.warning("Unexpected polling response: %s", exc) + _time_module.sleep(5) + continue + log.info(status) + if status not in {"Pending", "Running"}: + break + _time_module.sleep(5) + if status == "Succeeded": + log.info("Successfully created '%s'", stac_config["id"]) + return True + else: + log.error("Failed to create '%s': %s", stac_config["id"], response.text) + return False + + def _get_collection_json(self, collection_id: str | None) -> dict: + path = self._config_dir / f"template-collection-{self._workflow_name}.json" + with open(path) as f: + stac_config = json.load(f) + if collection_id is None: + stac_config["id"] = stac_config["id"].format(uuid=uuid4()) + else: + stac_config["id"] = collection_id + return stac_config + + def _get_feature_json( + self, + start_time: datetime, + end_time: datetime, + blob_url: str, + ) -> dict: + path = self._config_dir / f"template-feature-{self._workflow_name}.json" + with open(path) as f: + stac_config = json.load(f) + iso_start = start_time.isoformat() + iso_end = end_time.isoformat() + stac_config["id"] = stac_config["id"].format( + start_time=iso_start[:13], uuid=uuid4() + ) + stac_config["properties"]["datetime"] = iso_start + stac_config["properties"]["start_datetime"] = iso_start + stac_config["properties"]["end_datetime"] = iso_end + stac_config["assets"]["data"]["href"] = blob_url + stac_config["assets"]["data"]["description"] = stac_config["assets"]["data"][ + "description" + ].format(start_time=iso_start, end_time=iso_end) + return stac_config + + def _update_tile_settings(self, geocatalog_url: str, collection_id: str) -> None: + path = self._config_dir / f"tile-settings-{self._workflow_name}.json" + with open(path) as f: + tile_settings = json.load(f) + response = self._put( + f"{geocatalog_url}/stac/collections/{collection_id}/configurations/tile-settings", + body=tile_settings, + ) + if response.status_code not in {200, 201}: + log = logging.getLogger("planetary_computer.geocatalog") + log.error( + "Could not update tile settings: Error %s - %s", + response.status_code, + response.text, + ) + + def _update_render_options(self, geocatalog_url: str, collection_id: str) -> None: + path = self._config_dir / f"render-options-{self._workflow_name}.json" + with open(path) as f: + render_params = json.load(f) + log = logging.getLogger("planetary_computer.geocatalog") + for params in render_params: + render_option = { + "id": f"auto-{params['id']}", + "name": params["id"], + "type": "raster-tile", + "options": ( + f"assets=data&subdataset_name={params['id']}" + "&sel=time=2100-01-01&sel=ensemble=0&sel_method=nearest" + f"&rescale={params['scale'][0]},{params['scale'][1]}" + f"&colormap_name={params['cmap']}" + ), + "minZoom": 0, + } + response = self._post( + f"{geocatalog_url}/stac/collections/{collection_id}/configurations/render-options", + body=render_option, + ) + if response.status_code not in {200, 201}: + log.error( + "Could not update render options: Error %s - %s", + response.status_code, + response.text, + ) + + def _create_collection( + self, + geocatalog_url: str, + collection_id: str | None, + ) -> str: + stac_config = self._get_collection_json(collection_id) + success = self._create_element( + url=f"{geocatalog_url}/stac/collections", + stac_config=stac_config, + ) + if not success: + raise RuntimeError(f"Failed to create collection '{stac_config['id']}'") + self._update_tile_settings(geocatalog_url, stac_config["id"]) + self._update_render_options(geocatalog_url, stac_config["id"]) + return stac_config["id"] + + def _ensure_collection_exists(self, geocatalog_url: str, collection_id: str) -> str: + response = self._get(f"{geocatalog_url}/stac/collections/{collection_id}") + if response.status_code == 200: + return collection_id + if response.status_code != 404: + raise RuntimeError( + f"Failed to retrieve collection: Error {response.status_code} - {response.text}" + ) + return self._create_collection(geocatalog_url, collection_id) + + def _resolve_start_time(self, parameters: dict) -> datetime: + key = self._parameters["start_time_parameter_key"] + raw = parameters.get(key) + if raw is None: + raise ValueError( + f"Missing {key!r} in parameters for workflow {self._workflow_name!r}" + ) + if isinstance(raw, str): + normalized = raw.replace("Z", "+00:00") + return datetime.fromisoformat(normalized) + if isinstance(raw, datetime): + return raw + if hasattr(raw, "isoformat"): + return datetime.fromisoformat(raw.isoformat()) + raise TypeError(f"start time must be str or datetime, got {type(raw)}") + + def create_feature( + self, + geocatalog_url: str, + collection_id: str | None, + parameters: dict, + blob_url: str, + ) -> tuple[str, str]: + """Ingest a new STAC feature into the collection.""" + self.update_headers() + if collection_id is None: + collection_id = self._create_collection(geocatalog_url, None) + else: + self._ensure_collection_exists(geocatalog_url, collection_id) + start_time = self._resolve_start_time(parameters) + step_hours = self._parameters["step_size_hours"] + end_time = start_time + timedelta(hours=step_hours) + stac_config = self._get_feature_json(start_time, end_time, blob_url) + success = self._create_element( + url=f"{geocatalog_url}/stac/collections/{collection_id}/items", + stac_config=stac_config, + ) + if not success: + raise RuntimeError(f"Failed to create feature '{stac_config['id']}'") + return collection_id, stac_config["id"] diff --git a/earth2studio/serve/client/e2client.py b/earth2studio/serve/client/e2client.py index 59a17aebf..d50019735 100644 --- a/earth2studio/serve/client/e2client.py +++ b/earth2studio/serve/client/e2client.py @@ -192,6 +192,10 @@ def as_dataset(self) -> xr.Dataset: zarr_path = "/".join(result_path.split("/")[1:]) mapper = fsspec_utils.get_mapper(request_result, zarr_path) ds = xr.open_zarr(mapper, consolidated=True, **self.workflow.xr_args) + + elif request_result.storage_type == StorageType.AZURE: + mapper = fsspec_utils.get_mapper(request_result, "results.zarr") + ds = xr.open_zarr(mapper, consolidated=True, **self.workflow.xr_args) elif request_result.storage_type == StorageType.SERVER: result_url = urljoin( self.workflow.base_url + "/", diff --git a/earth2studio/serve/client/fsspec_utils.py b/earth2studio/serve/client/fsspec_utils.py index c121e7108..9649fe8e6 100644 --- a/earth2studio/serve/client/fsspec_utils.py +++ b/earth2studio/serve/client/fsspec_utils.py @@ -269,6 +269,192 @@ def create_cloudfront_mapper(signed_url: str, zarr_path: str = "") -> Any: return mapper +# Create a custom filesystem that replaces * with the filename +class AzureSignedURLFileSystem(fsspec.AbstractFileSystem): + """Wrapper that replaces wildcard * with filename and appends SAS token.""" + + def __init__( + self, base_fs: Any, query_params: dict[str, str], base_url: str + ) -> None: + super().__init__() + self._fs = base_fs + self._query_params = query_params + self._base_url = base_url + self._query_string = urlencode(query_params, safe="~") + + def _make_signed_path(self, path: str) -> str: + """Replace wildcard * with the actual path and append SAS token.""" + if path.startswith("http"): + full_url = path + else: + clean_path = path.lstrip("/") + # Replace * in base_url with the actual path + if "*" in self._base_url: + # Replace the * wildcard with the cleaned path + full_url = self._base_url.replace("*", clean_path) + else: + full_url = ( + f"{self._base_url}/{clean_path}" if clean_path else self._base_url + ) + separator = "&" if "?" in full_url else "?" + return f"{full_url}{separator}{self._query_string}" + + def _handle_403(self, e: BaseException, path: str) -> "NoReturn": + """Convert 403 errors to FileNotFoundError.""" + error_str = str(e).lower() + if "403" in str(e) or "forbidden" in error_str: + raise FileNotFoundError(f"File not found: {path}") from None + raise e + + def _open(self, path: str, mode: str = "rb", **kwargs: Any) -> Any: + try: + return self._fs._open(self._make_signed_path(path), mode=mode, **kwargs) + except Exception as e: + self._handle_403(e, path) + + def cat_file( + self, path: str, start: int | None = None, end: int | None = None, **kwargs: Any + ) -> Any: + """ + Read file contents with signed URL; 403 is converted to FileNotFoundError. + + Parameters + ---------- + path : str + Path or URL to read. + start : int, optional + Start byte offset. + end : int, optional + End byte offset. + **kwargs : Any + Passed to the underlying filesystem. + + Returns + ------- + Any + File contents (typically bytes). + + Raises + ------ + FileNotFoundError + If the server returns 403. + """ + try: + return self._fs.cat_file( + self._make_signed_path(path), start=start, end=end, **kwargs + ) + except Exception as e: + self._handle_403(e, path) + + def _cat_file( + self, path: str, start: int | None = None, end: int | None = None, **kwargs: Any + ) -> Any: + """Async version used by zarr.""" + return self.cat_file(path, start=start, end=end, **kwargs) + + def info(self, path: str, **kwargs: Any) -> Any: + """ + Return metadata for path with signed URL; 403 becomes FileNotFoundError. + + Parameters + ---------- + path : str + Path or URL. + **kwargs : Any + Passed to the underlying filesystem. + + Returns + ------- + Any + Metadata dict for the path. + + Raises + ------ + FileNotFoundError + If the server returns 403. + """ + try: + return self._fs.info(self._make_signed_path(path), **kwargs) + except Exception as e: + self._handle_403(e, path) + + def exists(self, path: str, **kwargs: Any) -> bool: + """ + Return True if path exists; 403 is treated as not found. + + Parameters + ---------- + path : str + Path or URL to check. + **kwargs : Any + Passed to the underlying filesystem. + + Returns + ------- + bool + True if path exists, False if not found or 403. + """ + try: + return self._fs.exists(self._make_signed_path(path), **kwargs) + except Exception as e: + try: + self._handle_403(e, path) + except FileNotFoundError: + return False + + +def create_azure_mapper(signed_url: str, zarr_path: str = "") -> Any: + """ + Create an fsspec mapper for an Azure Blob Storage signed URL with wildcard. + + The Azure signed URL contains a wildcard (*) that needs to be replaced with + the actual filename when accessing files. + + Parameters + ---------- + signed_url : str + Azure signed URL with wildcard (*) and SAS token query params. + zarr_path : str, optional + Path to the zarr store within the signed URL prefix. + + Returns + ------- + mapper : fsspec.mapping.FSMap + A mapper suitable for use with xarray.open_zarr() + """ + # Parse the URL + parsed = urlparse(signed_url) + + # Extract query parameters (SAS token) + query_params = {k: v[0] for k, v in parse_qs(parsed.query).items()} + + # Get base path - keep the * wildcard if present + base_path = parsed.path + + # If zarr_path is provided, append it before the wildcard + if zarr_path != "": + # Replace * with zarr_path if * is at the end, otherwise append + if base_path.endswith("/*"): + base_path = base_path[:-2] + f"/{zarr_path}/*" + elif base_path.endswith("*"): + base_path = base_path[:-1] + f"{zarr_path}/*" + else: + base_path = f"{base_path}/{zarr_path}/*" + elif not base_path.endswith("*"): + # Ensure there's a wildcard at the end if no zarr_path + base_path = base_path.rstrip("/") + "/*" + + # Reconstruct base URL with wildcard + base_url = f"{parsed.scheme}://{parsed.netloc}{base_path}" + + # Create HTTP filesystem + fs = fsspec.filesystem("https") + signed_fs = AzureSignedURLFileSystem(fs, query_params, base_url) + mapper = fsspec.mapping.FSMap(root="", fs=signed_fs, check=False, create=False) + + return mapper + + def get_mapper( request_result: InferenceRequestResults, zarr_path: str = "" ) -> Any | None: @@ -296,6 +482,10 @@ def get_mapper( if request_result.signed_url is None: raise ValueError("S3 storage type requires a signed URL") return create_cloudfront_mapper(request_result.signed_url, zarr_path) + elif request_result.storage_type == StorageType.AZURE: + if request_result.signed_url is None: + raise ValueError("Azure storage type requires a signed URL") + return create_azure_mapper(request_result.signed_url, zarr_path) elif request_result.storage_type == StorageType.SERVER: return None else: diff --git a/earth2studio/serve/client/models.py b/earth2studio/serve/client/models.py index 60e3f9234..bf7bbb618 100644 --- a/earth2studio/serve/client/models.py +++ b/earth2studio/serve/client/models.py @@ -29,6 +29,7 @@ class StorageType(str, Enum): SERVER = "server" S3 = "s3" + AZURE = "azure" class RequestStatus(Enum): diff --git a/earth2studio/serve/server/config.py b/earth2studio/serve/server/config.py index 707fd044e..3cb49737d 100644 --- a/earth2studio/serve/server/config.py +++ b/earth2studio/serve/server/config.py @@ -49,6 +49,7 @@ class QueueConfig: name: str = "inference" result_zip_queue_name: str = "result_zip" object_storage_queue_name: str = "object_storage" + geocatalog_ingestion_queue_name: str = "geocatalog_ingestion" finalize_metadata_queue_name: str = "finalize_metadata" max_size: int = 10 default_timeout: str = "1h" @@ -103,9 +104,10 @@ class CORSConfig: @dataclass class ObjectStorageConfig: - """Object storage configuration for S3/CloudFront""" + """Object storage configuration for S3/CloudFront and Azure Blob Storage""" enabled: bool = False + storage_type: Literal["s3", "azure"] = "s3" # Storage provider type # S3 configuration bucket: str | None = None region: str = "us-east-1" @@ -125,6 +127,31 @@ class ObjectStorageConfig: cloudfront_private_key: str | None = None # PEM private key content # Signed URL settings signed_url_expires_in: int = 86400 # Default 24 hours + # Azure Blob Storage configuration + azure_connection_string: str | None = None # Azure connection string + azure_account_name: str | None = None # Azure storage account name + azure_account_key: str | None = ( + None # Azure storage account key (for SAS token generation) + ) + azure_container_name: str | None = ( + None # Azure container name (falls back to bucket if not set) + ) + # Azure Planetary Computer / GeoCatalog ingestion (optional) + azure_geocatalog_url: str | None = ( + None # When set, triggers PC ingestion after upload + ) + + +@dataclass +class WorkflowExposureConfig: + """Configuration for controlling which workflows are exposed via API endpoints""" + + exposed_workflows: list[str] = field( + default_factory=lambda: [] + ) # Empty list means all workflows are exposed + warmup_workflows: list[str] = field( + default_factory=lambda: ["example_user_workflow"] + ) # Workflows accessible for warmup even if not exposed @dataclass @@ -138,6 +165,9 @@ class AppConfig: server: ServerConfig = field(default_factory=ServerConfig) cors: CORSConfig = field(default_factory=CORSConfig) object_storage: ObjectStorageConfig = field(default_factory=ObjectStorageConfig) + workflow_exposure: WorkflowExposureConfig = field( + default_factory=WorkflowExposureConfig + ) class ConfigManager: @@ -220,6 +250,9 @@ def _dict_to_config(self, cfg_dict: dict) -> AppConfig: server=ServerConfig(**cfg_dict.get("server", {})), cors=CORSConfig(**cfg_dict.get("cors", {})), object_storage=ObjectStorageConfig(**cfg_dict.get("object_storage", {})), + workflow_exposure=WorkflowExposureConfig( + **cfg_dict.get("workflow_exposure", {}) + ), ) def _create_default_config_object(self) -> AppConfig: @@ -232,6 +265,7 @@ def _create_default_config_object(self) -> AppConfig: server=ServerConfig(), cors=CORSConfig(), object_storage=ObjectStorageConfig(), + workflow_exposure=WorkflowExposureConfig(), ) def _apply_env_overrides(self) -> None: @@ -276,6 +310,12 @@ def _apply_env_overrides(self) -> None: self._config.paths.results_zip_dir = os.getenv( "RESULTS_ZIP_DIR", default=self._config.paths.results_zip_dir ) + if os.getenv("OUTPUT_FORMAT"): + output_format = os.getenv("OUTPUT_FORMAT", "").lower() + if output_format in ["zarr", "netcdf4"]: + self._config.paths.output_format = cast( + Literal["zarr", "netcdf4"], output_format + ) # Logging overrides if os.getenv("LOG_LEVEL"): @@ -307,6 +347,13 @@ def _apply_env_overrides(self) -> None: self._config.object_storage.enabled = ( os.getenv("OBJECT_STORAGE_ENABLED", "").lower() == "true" ) + if os.getenv("OBJECT_STORAGE_TYPE"): + storage_type = os.getenv("OBJECT_STORAGE_TYPE", "").lower() + if storage_type in ["s3", "azure"]: + self._config.object_storage.storage_type = cast( + Literal["s3", "azure"], storage_type + ) + if os.getenv("OBJECT_STORAGE_BUCKET"): self._config.object_storage.bucket = os.getenv("OBJECT_STORAGE_BUCKET") if os.getenv("OBJECT_STORAGE_REGION"): @@ -375,6 +422,39 @@ def _apply_env_overrides(self) -> None: ) ) + # Azure Blob Storage overrides + if os.getenv("AZURE_CONNECTION_STRING"): + self._config.object_storage.azure_connection_string = os.getenv( + "AZURE_CONNECTION_STRING" + ) + if os.getenv("AZURE_STORAGE_ACCOUNT_NAME"): + self._config.object_storage.azure_account_name = os.getenv( + "AZURE_STORAGE_ACCOUNT_NAME" + ) + if os.getenv("AZURE_STORAGE_ACCOUNT_KEY"): + self._config.object_storage.azure_account_key = os.getenv( + "AZURE_STORAGE_ACCOUNT_KEY" + ) + if os.getenv("AZURE_CONTAINER_NAME"): + self._config.object_storage.azure_container_name = os.getenv( + "AZURE_CONTAINER_NAME" + ) + # Support AZURE_ENDPOINT_URL for managed identity scenarios + if os.getenv("AZURE_ENDPOINT_URL"): + self._config.object_storage.endpoint_url = os.getenv("AZURE_ENDPOINT_URL") + if os.getenv("AZURE_GEOCATALOG_URL"): + self._config.object_storage.azure_geocatalog_url = os.getenv( + "AZURE_GEOCATALOG_URL" + ) + + # Workflow exposure overrides + if os.getenv("EXPOSED_WORKFLOWS"): + # Parse comma-separated list of workflow names + exposed_workflows_str = os.getenv("EXPOSED_WORKFLOWS", "") + self._config.workflow_exposure.exposed_workflows = [ + w.strip() for w in exposed_workflows_str.split(",") if w.strip() + ] + logger.debug("Environment variable overrides applied") def _ensure_paths_exist(self) -> None: diff --git a/earth2studio/serve/server/cpu_worker.py b/earth2studio/serve/server/cpu_worker.py index 8a5ea0f27..0b2297081 100644 --- a/earth2studio/serve/server/cpu_worker.py +++ b/earth2studio/serve/server/cpu_worker.py @@ -16,6 +16,7 @@ import json import logging +import os import zipfile from dataclasses import dataclass, field from datetime import datetime, timezone @@ -539,7 +540,11 @@ def process_object_storage_upload( # Upload to object storage if enabled - if config.object_storage.enabled and config.object_storage.bucket: + # Check if object storage is enabled and properly configured + if config.object_storage.enabled and ( + config.object_storage.bucket + or config.object_storage.storage_type == "azure" + ): from earth2studio.serve.server.object_storage import ( MSCObjectStorage, ObjectStorageError, @@ -553,44 +558,90 @@ def process_object_storage_upload( f"Output path does not exist: {output_path}", ) - # Create S3 storage instance + # Validate Azure container name is configured + if config.object_storage.storage_type == "azure": + if ( + not config.object_storage.azure_container_name + and not config.object_storage.bucket + ): + return fail_workflow( + workflow_name, + execution_id, + "Azure storage is enabled but neither 'azure_container_name' nor 'bucket' is configured", + ) + + # Create storage instance storage_kwargs: dict[str, Any] = { - "bucket": config.object_storage.bucket, - "region": config.object_storage.region, - "use_transfer_acceleration": config.object_storage.use_transfer_acceleration, + "bucket": config.object_storage.bucket + or config.object_storage.azure_container_name + or "", + "storage_type": config.object_storage.storage_type, "max_concurrency": config.object_storage.max_concurrency, "multipart_chunksize": config.object_storage.multipart_chunksize, "use_rust_client": config.object_storage.use_rust_client, } - # Add optional credentials - if ( - config.object_storage.access_key_id - and config.object_storage.secret_access_key - ): - storage_kwargs["access_key_id"] = config.object_storage.access_key_id - storage_kwargs["secret_access_key"] = ( - config.object_storage.secret_access_key - ) - if config.object_storage.session_token: - storage_kwargs["session_token"] = config.object_storage.session_token - if config.object_storage.endpoint_url: - storage_kwargs["endpoint_url"] = config.object_storage.endpoint_url - - # Add CloudFront configuration for signed URLs - if config.object_storage.cloudfront_domain: - storage_kwargs["cloudfront_domain"] = ( - config.object_storage.cloudfront_domain - ) - if config.object_storage.cloudfront_key_pair_id: - storage_kwargs["cloudfront_key_pair_id"] = ( - config.object_storage.cloudfront_key_pair_id - ) - if config.object_storage.cloudfront_private_key: - storage_kwargs["cloudfront_private_key"] = ( - config.object_storage.cloudfront_private_key + # Add storage-type-specific configuration + if config.object_storage.storage_type == "s3": + # S3-specific parameters + storage_kwargs["region"] = config.object_storage.region + storage_kwargs["use_transfer_acceleration"] = ( + config.object_storage.use_transfer_acceleration ) + # Add optional S3 credentials + if ( + config.object_storage.access_key_id + and config.object_storage.secret_access_key + ): + storage_kwargs["access_key_id"] = ( + config.object_storage.access_key_id + ) + storage_kwargs["secret_access_key"] = ( + config.object_storage.secret_access_key + ) + if config.object_storage.session_token: + storage_kwargs["session_token"] = ( + config.object_storage.session_token + ) + if config.object_storage.endpoint_url: + storage_kwargs["endpoint_url"] = config.object_storage.endpoint_url + + # Add CloudFront configuration for signed URLs + if config.object_storage.cloudfront_domain: + storage_kwargs["cloudfront_domain"] = ( + config.object_storage.cloudfront_domain + ) + if config.object_storage.cloudfront_key_pair_id: + storage_kwargs["cloudfront_key_pair_id"] = ( + config.object_storage.cloudfront_key_pair_id + ) + if config.object_storage.cloudfront_private_key: + storage_kwargs["cloudfront_private_key"] = ( + config.object_storage.cloudfront_private_key + ) + elif config.object_storage.storage_type == "azure": + # Azure-specific parameters + if config.object_storage.azure_connection_string: + storage_kwargs["azure_connection_string"] = ( + config.object_storage.azure_connection_string + ) + if config.object_storage.azure_account_name: + storage_kwargs["azure_account_name"] = ( + config.object_storage.azure_account_name + ) + if config.object_storage.azure_account_key: + storage_kwargs["azure_account_key"] = ( + config.object_storage.azure_account_key + ) + if config.object_storage.azure_container_name: + storage_kwargs["azure_container_name"] = ( + config.object_storage.azure_container_name + ) + # Support endpoint_url for Azure (useful for managed identity) + if config.object_storage.endpoint_url: + storage_kwargs["endpoint_url"] = config.object_storage.endpoint_url + try: storage = MSCObjectStorage(**storage_kwargs) except Exception as e: @@ -606,8 +657,13 @@ def process_object_storage_upload( ) # Upload the output directory + storage_location = ( + f"s3://{config.object_storage.bucket}" + if config.object_storage.storage_type == "s3" + else f"azure://{config.object_storage.azure_container_name or config.object_storage.bucket}" + ) logger.info( - f"Uploading {output_path} to s3://{config.object_storage.bucket}/{remote_prefix}" + f"Uploading {output_path} to {storage_location}/{remote_prefix}" ) try: @@ -631,23 +687,36 @@ def process_object_storage_upload( f"Failed to upload to object storage: {upload_result.errors}", ) - storage_type = "s3" + storage_type = config.object_storage.storage_type logger.info( f"Successfully uploaded {upload_result.files_uploaded} files " f"({upload_result.total_bytes} bytes) to {upload_result.destination}" ) - # Generate signed URL if CloudFront is configured - cloudfront_configured = all( - [ - config.object_storage.cloudfront_domain, - config.object_storage.cloudfront_key_pair_id, - config.object_storage.cloudfront_private_key, - ] - ) + # Generate signed URL if configured + # For S3: requires CloudFront configuration + # For Azure: requires account_name and account_key + can_generate_signed_url = False + if storage_type == "s3": + can_generate_signed_url = all( + [ + config.object_storage.cloudfront_domain, + config.object_storage.cloudfront_key_pair_id, + config.object_storage.cloudfront_private_key, + ] + ) + elif storage_type == "azure": + can_generate_signed_url = all( + [ + config.object_storage.azure_account_name, + config.object_storage.azure_account_key, + ] + ) - if not cloudfront_configured: - logger.info("CloudFront not configured, skipping signed URL generation") + if not can_generate_signed_url: + logger.info( + f"Signed URL generation not configured for {storage_type}, skipping" + ) else: try: signed_url_path = f"{remote_prefix}/*" @@ -677,10 +746,36 @@ def process_object_storage_upload( storage_info = { "storage_type": storage_type, } - if storage_type == "s3" and remote_prefix: - storage_info["remote_path"] = ( - f"s3://{config.object_storage.bucket}/{remote_prefix}" - ) + if remote_prefix: + if storage_type == "s3": + storage_info["remote_path"] = ( + f"s3://{config.object_storage.bucket}/{remote_prefix}" + ) + elif storage_type == "azure": + container_name = ( + config.object_storage.azure_container_name + or config.object_storage.bucket + ) + storage_info["remote_path"] = ( + f"azure://{container_name}/{remote_prefix}" + ) + # Build HTTPS blob URL for primary netcdf file (for GeoCatalog ingestion) + if ( + config.object_storage.azure_account_name + and config.object_storage.azure_geocatalog_url + ): + primary_nc = None + if output_path.is_file() and output_path.suffix.lower() == ".nc": + primary_nc = output_path.name + elif output_path.is_dir(): + nc_files = sorted(output_path.rglob("*.nc")) + if nc_files: + primary_nc = nc_files[0].relative_to(output_path).as_posix() + if primary_nc: + storage_info["blob_url"] = ( + f"https://{config.object_storage.azure_account_name}.blob.core.windows.net/" + f"{container_name}/{remote_prefix}/{primary_nc}" + ) if signed_url: storage_info["signed_url"] = signed_url @@ -718,7 +813,7 @@ def process_object_storage_upload( "signed_url": signed_url, } - if upload_result and storage_type == "s3": + if upload_result: result["files_uploaded"] = upload_result.files_uploaded result["total_bytes"] = upload_result.total_bytes result["destination"] = upload_result.destination @@ -736,6 +831,189 @@ def process_object_storage_upload( ) +def process_geocatalog_ingestion( + workflow_name: str, + execution_id: str, +) -> dict[str, Any] | None: + """ + RQ Worker function to trigger ingestion of uploaded inference results into + Azure Planetary Computer / GeoCatalog when AZURE_GEOCATALOG_URL is configured. + + This function is intended to be executed by the CPU worker from the + geocatalog_ingestion_queue, after process_object_storage_upload. It reads + storage info and parameters from Redis and calls the Planetary Computer + client to create a STAC feature for the uploaded netcdf blob. + + Args: + workflow_name: Name of the workflow + execution_id: Execution ID of the workflow + + Returns: + Dict containing result info, None on critical failure + """ + request_id = f"{workflow_name}:{execution_id}" + logger.info(f"Processing geocatalog ingestion for {request_id}") + + try: + geocatalog_url = config.object_storage.azure_geocatalog_url + if not geocatalog_url: + logger.warning( + f"AZURE_GEOCATALOG_URL not set, skipping geocatalog ingestion for {request_id}" + ) + job_id = queue_next_stage( + redis_client=redis_client, + current_stage="geocatalog_ingestion", + workflow_name=workflow_name, + execution_id=execution_id, + output_path_str="", + ) + if not job_id: + return fail_workflow( + workflow_name, + execution_id, + f"Failed to queue finalize_metadata for {request_id}", + ) + return { + "success": True, + "skipped": True, + "reason": "AZURE_GEOCATALOG_URL not set", + } + + storage_info_key = f"inference_request:{request_id}:storage_info" + metadata_key = get_inference_request_metadata_key(request_id) + storage_info_json = redis_client.get(storage_info_key) + pending_metadata_json = redis_client.get(metadata_key) + + if not storage_info_json or not pending_metadata_json: + logger.warning( + f"Storage info or pending metadata missing for {request_id}, skipping geocatalog ingestion" + ) + job_id = queue_next_stage( + redis_client=redis_client, + current_stage="geocatalog_ingestion", + workflow_name=workflow_name, + execution_id=execution_id, + output_path_str="", + ) + if not job_id: + return fail_workflow( + workflow_name, + execution_id, + f"Failed to queue finalize_metadata for {request_id}", + ) + return { + "success": True, + "skipped": True, + "reason": "missing storage/metadata", + } + + storage_info = json.loads(storage_info_json) + metadata_dict = json.loads(pending_metadata_json) + blob_url = storage_info.get("blob_url") + parameters = metadata_dict.get("parameters") or {} + + if not blob_url: + logger.warning( + f"No blob_url in storage info for {request_id} (e.g. not Azure or no .nc file), skipping geocatalog ingestion" + ) + job_id = queue_next_stage( + redis_client=redis_client, + current_stage="geocatalog_ingestion", + workflow_name=workflow_name, + execution_id=execution_id, + output_path_str="", + ) + if not job_id: + return fail_workflow( + workflow_name, + execution_id, + f"Failed to queue finalize_metadata for {request_id}", + ) + return {"success": True, "skipped": True, "reason": "no blob_url"} + + logger.info(f"Blob URL: {blob_url}") + # Only trigger PC ingestion for workflows supported by GeoCatalogClient + _PC_WORKFLOW_SUFFIX: dict[str, str] = { + "foundry_fcn3_workflow": "fcn3", + "foundry_fcn3_stormscope_goes_workflow": "fcn3-stormscope-goes", + } + if workflow_name not in _PC_WORKFLOW_SUFFIX: + logger.info( + f"Workflow {workflow_name} not supported by Planetary Computer client, skipping ingestion for {request_id}" + ) + job_id = queue_next_stage( + redis_client=redis_client, + current_stage="geocatalog_ingestion", + workflow_name=workflow_name, + execution_id=execution_id, + output_path_str="", + ) + if not job_id: + return fail_workflow( + workflow_name, + execution_id, + f"Failed to queue finalize_metadata for {request_id}", + ) + return { + "success": True, + "skipped": True, + "reason": "workflow not supported", + } + + pc_config_dir = os.environ.get( + "EARTH2STUDIO_PC_TEMPLATE_DIR", + "/workspace/earth2studio-project/serve/server/planetary_computer", + ) + pc_workflow_name = _PC_WORKFLOW_SUFFIX[workflow_name] + + try: + from earth2studio.data.planetary_computer import GeoCatalogClient + + pc_client = GeoCatalogClient( + workflow_name=pc_workflow_name, + config_dir=pc_config_dir, + ) + collection_id = parameters.get("collection_id") + pc_client.create_feature( + geocatalog_url=geocatalog_url, + collection_id=collection_id, + parameters=parameters, + blob_url=blob_url, + ) + logger.info(f"GeoCatalog ingestion completed for {request_id}") + except Exception as e: + # Log but do not fail the pipeline; finalize_metadata should still run + logger.exception( + f"GeoCatalog ingestion failed for {request_id}: {e}. Queuing finalize_metadata anyway." + ) + + job_id = queue_next_stage( + redis_client=redis_client, + current_stage="geocatalog_ingestion", + workflow_name=workflow_name, + execution_id=execution_id, + output_path_str="", + ) + if not job_id: + return fail_workflow( + workflow_name, + execution_id, + f"Failed to queue next pipeline stage for {request_id}", + ) + logger.info( + f"Queued finalize_metadata for {workflow_name}:{execution_id} with RQ job ID: {job_id}" + ) + return {"success": True} + + except Exception as e: + logger.exception(f"Failed in geocatalog ingestion for {request_id}") + return fail_workflow( + workflow_name, + execution_id, + f"Geocatalog ingestion failed for {request_id}: {str(e)}", + ) + + def process_finalize_metadata( workflow_name: str, execution_id: str, @@ -768,6 +1046,7 @@ def process_finalize_metadata( storage_info_json = redis_client.get(storage_info_key) if not pending_metadata_json or not results_zip_dir_str: + logger.error(f"Pending metadata not found in Redis for {request_id}") return fail_workflow( workflow_name, execution_id, diff --git a/earth2studio/serve/server/e2workflow.py b/earth2studio/serve/server/e2workflow.py index 26e00cb9b..e5de472f2 100644 --- a/earth2studio/serve/server/e2workflow.py +++ b/earth2studio/serve/server/e2workflow.py @@ -102,6 +102,7 @@ class Earth2Workflow(Workflow, metaclass=AutoParameters): def __init__(self) -> None: super().__init__() + self.execution_id: str | None = None @abstractmethod def __call__(self, io: IOBackend) -> None: @@ -126,6 +127,9 @@ def run( ) -> dict[str, Any]: """Run custom workflow""" + # Store execution_id for use in update_progress + self.execution_id = execution_id + # Validate and convert parameters parameters = self.validate_parameters(parameters) @@ -188,6 +192,36 @@ def run( self.update_execution_data(execution_id, progress) raise + def update_progress(self, progress: WorkflowProgress) -> None: + """ + Update workflow execution progress. + + This method is intended for child workflows to update progress + information during execution. It uses the execution_id stored + during the run() method. + + If execution_id is not set (e.g., when running outside the API server), + this method is a no-op and silently returns without updating progress. + + Parameters + ---------- + progress : WorkflowProgress + WorkflowProgress object containing progress information to update. + + Examples + -------- + >>> progress = WorkflowProgress( + ... progress="Processing data...", + ... current_step=5, + ... total_steps=10 + ... ) + >>> self.update_progress(progress) + """ + if self.execution_id is None: + # No-op when running outside API server context + return + self.update_execution_data(self.execution_id, progress) + logger = logging.getLogger(__name__) @@ -198,7 +232,7 @@ class BackendProgress: def __init__( self, io: IOBackend, - workflow: Workflow, + workflow: Earth2Workflow, execution_id: str, progress_dim: str = "lead_time", ) -> None: @@ -229,7 +263,7 @@ def add_array( progress = WorkflowProgress( current_step=0, total_steps=len(self.progress_coords) ) - self.workflow.update_execution_data(self.execution_id, progress) + self.workflow.update_progress(progress) def write( self, @@ -246,8 +280,12 @@ def write( step_index = self.progress_coords.index(current_coord) # Update progress using WorkflowProgress progress = WorkflowProgress(current_step=step_index + 1) - self.workflow.update_execution_data(self.execution_id, progress) + self.workflow.update_progress(progress) def __getattr__(self, name: str) -> Any: """Allow passthrough of unwrapped attributes.""" return getattr(self.io, name) + + def __getitem__(self, key: str) -> Any: + """Allow subscripting to access underlying io object.""" + return self.io[key] diff --git a/earth2studio/serve/server/main.py b/earth2studio/serve/server/main.py index e436601bd..0e21ecffa 100644 --- a/earth2studio/serve/server/main.py +++ b/earth2studio/serve/server/main.py @@ -37,7 +37,7 @@ import redis as redis_sync # type: ignore[import-untyped] # For RQ (synchronous) import redis.asyncio as redis # type: ignore[import-untyped] import uvicorn -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response, StreamingResponse from prometheus_client import CONTENT_TYPE_LATEST, generate_latest @@ -47,8 +47,10 @@ # Import configuration from earth2studio.serve.server.config import get_config, get_config_manager from earth2studio.serve.server.utils import ( + create_file_stream, get_inference_request_output_path_key, get_inference_request_zip_key, + parse_range_header, ) # Import workflow registry @@ -95,6 +97,7 @@ def check_admission_control() -> None: config.queue.name, config.queue.result_zip_queue_name, config.queue.object_storage_queue_name, + config.queue.geocatalog_ingestion_queue_name, config.queue.finalize_metadata_queue_name, ] for queue_name in queue_names: @@ -385,7 +388,7 @@ async def list_workflows() -> dict[str, dict[str, str]]: dict Single key ``workflows`` mapping workflow name to description. """ - workflows = workflow_registry.list_workflows() + workflows = workflow_registry.list_workflows(exposed_only=True) return {"workflows": workflows} @@ -414,12 +417,16 @@ async def get_workflow_schema(workflow_name: str) -> dict[str, Any]: HTTPException 404 if workflow not found; 500 if schema generation fails. """ - # Check if workflow exists + # Check if workflow exists and is exposed workflow_class = workflow_registry.get_workflow_class(workflow_name) if not workflow_class: raise HTTPException( status_code=404, detail=f"Workflow '{workflow_name}' not found" ) + if not workflow_registry.is_workflow_exposed(workflow_name): + raise HTTPException( + status_code=404, detail=f"Workflow '{workflow_name}' is not exposed" + ) try: # Get the Parameters class from the workflow @@ -522,12 +529,16 @@ async def execute_workflow( 404 if workflow not found; 422 if parameters invalid; 429 if queues full; 503 if Redis/queue not initialized; 500 on enqueue failure. """ - # Check if workflow exists and get the workflow class for validation + # Check if workflow exists and is exposed custom_workflow_class = workflow_registry.get_workflow_class(workflow_name) if not custom_workflow_class: raise HTTPException( status_code=404, detail=f"Workflow '{workflow_name}' not found" ) + if not workflow_registry.is_workflow_exposed(workflow_name): + raise HTTPException( + status_code=404, detail=f"Workflow '{workflow_name}' is not exposed" + ) # Validate parameters early to provide immediate feedback using classmethod try: @@ -650,12 +661,16 @@ async def get_workflow_status(workflow_name: str, execution_id: str) -> Workflow # Create logger adapter with execution_id log = logging.LoggerAdapter(logger, {"execution_id": execution_id}) - # Check if workflow exists + # Check if workflow exists and is exposed custom_workflow_class = workflow_registry.get_workflow_class(workflow_name) if not custom_workflow_class: raise HTTPException( status_code=404, detail=f"Custom workflow '{workflow_name}' not found" ) + if not workflow_registry.is_workflow_exposed(workflow_name): + raise HTTPException( + status_code=404, detail=f"Workflow '{workflow_name}' is not exposed" + ) try: result = custom_workflow_class._get_execution_data( @@ -718,12 +733,16 @@ async def get_workflow_results( # Create logger adapter with execution_id log = logging.LoggerAdapter(logger, {"execution_id": execution_id}) - # Check if workflow exists + # Check if workflow exists and is exposed custom_workflow_class = workflow_registry.get_workflow_class(workflow_name) if not custom_workflow_class: raise HTTPException( status_code=404, detail=f"Custom workflow '{workflow_name}' not found" ) + if not workflow_registry.is_workflow_exposed(workflow_name): + raise HTTPException( + status_code=404, detail=f"Workflow '{workflow_name}' is not exposed" + ) # Check workflow status first try: @@ -815,7 +834,10 @@ async def get_workflow_results( @app.get("/v1/infer/{workflow_name}/{execution_id}/results/{filepath:path}") async def get_workflow_result_file( - workflow_name: str, execution_id: str, filepath: str + workflow_name: str, + execution_id: str, + filepath: str, + request: Request, ) -> StreamingResponse: """ Stream a specific file from the workflow execution results. @@ -844,12 +866,16 @@ async def get_workflow_result_file( 403 on path traversal attempt; 404 if workflow, execution, file, or zip not found or results not completed; 503 if Redis not initialized; 500 on error. """ - # Check if workflow exists + # Check if workflow exists and is exposed custom_workflow_class = workflow_registry.get_workflow_class(workflow_name) if not custom_workflow_class: raise HTTPException( status_code=404, detail=f"Custom workflow '{workflow_name}' not found" ) + if not workflow_registry.is_workflow_exposed(workflow_name): + raise HTTPException( + status_code=404, detail=f"Workflow '{workflow_name}' is not exposed" + ) # Check workflow status first try: @@ -902,29 +928,32 @@ async def get_workflow_result_file( }, ) - # Stream the zip file - async def stream_zip_file() -> AsyncGenerator[bytes, None]: - """Stream the zip file contents""" - try: - chunk_size = 8192 - async with aiofiles.open(zip_file_path, "rb") as f: - while True: - chunk = await f.read(chunk_size) - if not chunk: - break - yield chunk - await asyncio.sleep(0) - except Exception: - logger.exception(f"Error streaming zip file {zip_file_path}") - raise + # Get file size and parse range header + zip_file_size = zip_file_path.stat().st_size + range_header = request.headers.get("Range") + start, end, content_length, status_code = parse_range_header( + range_header, zip_file_size + ) + + # Create streaming response + stream_generator = create_file_stream( + zip_file_path, start, content_length, "zip file" + ) headers = { "Content-Disposition": f'attachment; filename="{zip_filename}"', - "Content-Length": str(zip_file_path.stat().st_size), + "Content-Length": str(content_length), + "Accept-Ranges": "bytes", } + if range_header: + headers["Content-Range"] = f"bytes {start}-{end}/{zip_file_size}" + return StreamingResponse( - stream_zip_file(), media_type="application/zip", headers=headers + stream_generator, + media_type="application/zip", + headers=headers, + status_code=status_code, ) # Regular case: get file from output directory @@ -1005,29 +1034,28 @@ async def stream_zip_file() -> AsyncGenerator[bytes, None]: if media_type is None: media_type = "application/octet-stream" - # Stream the file - async def stream_file() -> AsyncGenerator[bytes, None]: - """Stream the file contents""" - try: - chunk_size = 8192 - async with aiofiles.open(requested_path, "rb") as f: - while True: - chunk = await f.read(chunk_size) - if not chunk: - break - yield chunk - await asyncio.sleep(0) - except Exception: - logger.exception(f"Error streaming file {requested_path}") - raise - - # Set appropriate headers + # Get file size and parse range header + file_size = requested_path.stat().st_size + range_header = request.headers.get("Range") + start, end, content_length, status_code = parse_range_header( + range_header, file_size + ) + stream_generator = create_file_stream( + requested_path, start, content_length, "file" + ) headers = { "Content-Disposition": f'attachment; filename="{requested_path.name}"', - "Content-Length": str(requested_path.stat().st_size), + "Content-Length": str(content_length), + "Accept-Ranges": "bytes", } - - return StreamingResponse(stream_file(), media_type=media_type, headers=headers) + if range_header: + headers["Content-Range"] = f"bytes {start}-{end}/{file_size}" + return StreamingResponse( + stream_generator, + media_type=media_type, + headers=headers, + status_code=status_code, + ) except HTTPException: raise diff --git a/earth2studio/serve/server/object_storage.py b/earth2studio/serve/server/object_storage.py index 8b318e372..260863cea 100644 --- a/earth2studio/serve/server/object_storage.py +++ b/earth2studio/serve/server/object_storage.py @@ -22,7 +22,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, Literal logger = logging.getLogger(__name__) @@ -198,25 +198,31 @@ class ObjectStorageError(Exception): class MSCObjectStorage(ObjectStorage): """ - Object storage using NVIDIA Multi-Storage Client (MSC) with Rust backend. + Object storage using NVIDIA Multi-Storage Client (MSC) with Rust backend for AWS S3 and Azure Blob Storage. MSC provides optimized parallel transfers; the Rust client bypasses Python's GIL for improved I/O performance (up to 12x faster). Uses sync_from for efficient directory uploads with parallel transfers. - Credentials are read from environment variables: AWS_ACCESS_KEY_ID, + Supports both AWS S3 and Azure Blob Storage. + + For S3, credentials are read from environment variables: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN (optional), AWS_DEFAULT_REGION. + For Azure, credentials can be provided via: + - Connection string (AzureCredentials): Provide azure_connection_string + - Managed Identity (DefaultAzureCredentials): Omit azure_connection_string and provide endpoint_url or azure_account_name + References ---------- https://nvidia.github.io/multi-storage-client/user_guide/rust.html - Initialize MSCObjectStorage with AWS credentials and configuration. - Parameters ---------- bucket : str - S3 bucket name. + S3 bucket name or Azure container name. + storage_type : str, optional + Storage provider type, either "s3" or "azure". Default is "s3". region : str, optional AWS region (e.g. 'us-east-1'). access_key_id : str, optional @@ -226,7 +232,7 @@ class MSCObjectStorage(ObjectStorage): session_token : str, optional AWS session token for temporary credentials. endpoint_url : str, optional - Custom endpoint URL for S3-compatible services. + Custom endpoint URL for S3-compatible services or Azure Blob Storage. use_transfer_acceleration : bool, optional Enable S3 Transfer Acceleration (bucket must support it). Default is False. max_concurrency : int, optional @@ -236,18 +242,27 @@ class MSCObjectStorage(ObjectStorage): use_rust_client : bool, optional Use the high-performance Rust client. Default is False. profile_name : str, optional - Name for the MSC profile. Default is 'e2studio-s3'. + Name for the MSC profile. Default is 'e2studio-s3' for S3, 'e2studio-azure' for Azure. cloudfront_domain : str, optional CloudFront distribution domain for signed URLs. cloudfront_key_pair_id : str, optional CloudFront key pair ID for signed URLs. cloudfront_private_key : str, optional PEM private key content as string for signed URLs. + azure_connection_string : str, optional + Azure connection string (optional if using managed identity). + azure_account_name : str, optional + Azure storage account name (required if using managed identity without endpoint_url). + azure_account_key : str, optional + Azure storage account key (optional, for SAS token generation). + azure_container_name : str, optional + Azure container name. """ def __init__( self, bucket: str, + storage_type: Literal["s3", "azure"] = "s3", region: str | None = None, access_key_id: str | None = None, secret_access_key: str | None = None, @@ -257,40 +272,100 @@ def __init__( max_concurrency: int = 16, multipart_chunksize: int = 8 * 1024 * 1024, # 8 MB use_rust_client: bool = False, - profile_name: str = "e2studio-s3", + profile_name: str | None = None, cloudfront_domain: str | None = None, cloudfront_key_pair_id: str | None = None, cloudfront_private_key: str | None = None, + # Azure-specific parameters + azure_connection_string: str | None = None, + azure_account_name: str | None = None, + azure_account_key: str | None = None, + azure_container_name: str | None = None, ): + self.storage_type = storage_type self.bucket = bucket - self.region = region or os.environ.get("AWS_DEFAULT_REGION", "us-east-1") self.max_concurrency = max_concurrency self.multipart_chunksize = multipart_chunksize self.use_rust_client = use_rust_client - self.profile_name = profile_name - self.use_transfer_acceleration = use_transfer_acceleration - - # Use S3 Transfer Acceleration endpoint if enabled (and no custom endpoint provided) - if use_transfer_acceleration and not endpoint_url: - self.endpoint_url = f"https://{bucket}.s3-accelerate.amazonaws.com" - logger.info(f"S3 Transfer Acceleration enabled: {self.endpoint_url}") + self.profile_name = profile_name or ( + "e2studio-s3" if storage_type == "s3" else "e2studio-azure" + ) + # Initialize endpoint_url as None to allow str | None type + self.endpoint_url: str | None = None + + # S3-specific configuration + if storage_type == "s3": + self.region = region or os.environ.get("AWS_DEFAULT_REGION", "us-east-1") + self.use_transfer_acceleration = use_transfer_acceleration + + # Use S3 Transfer Acceleration endpoint if enabled (and no custom endpoint provided) + if use_transfer_acceleration and not endpoint_url: + self.endpoint_url = f"https://{bucket}.s3-accelerate.amazonaws.com" + logger.info(f"S3 Transfer Acceleration enabled: {self.endpoint_url}") + else: + self.endpoint_url = endpoint_url + + # CloudFront configuration for signed URLs + self.cloudfront_domain = cloudfront_domain + self.cloudfront_key_pair_id = cloudfront_key_pair_id + self.cloudfront_private_key = cloudfront_private_key + + # Set credentials as environment variables - MSC picks these up automatically + if access_key_id: + os.environ["AWS_ACCESS_KEY_ID"] = access_key_id + if secret_access_key: + os.environ["AWS_SECRET_ACCESS_KEY"] = secret_access_key + if session_token: + os.environ["AWS_SESSION_TOKEN"] = session_token + if region: + os.environ["AWS_DEFAULT_REGION"] = region + + # Azure-specific configuration + elif storage_type == "azure": + self.azure_container_name = azure_container_name or bucket + self.azure_account_key = azure_account_key + + # Determine if using managed identity (DefaultAzureCredentials) + # Managed identity is used when connection string is not provided + self.use_managed_identity = azure_connection_string is None + + if self.use_managed_identity: + # Using managed identity - connection string not needed + logger.info( + "Using Azure Managed Identity (DefaultAzureCredentials). " + f"Account: {azure_account_name or 'will be determined from endpoint'}, " + f"Container: {self.azure_container_name}" + ) + self.azure_account_name = azure_account_name + else: + # Using connection string authentication — azure_connection_string is + # guaranteed to be not None here (see use_managed_identity above). + if azure_connection_string is None: # pragma: no cover + raise ObjectStorageError( + "Connection string required for non-managed identity mode" + ) + # Set Azure connection string + os.environ["AZURE_CONNECTION_STRING"] = azure_connection_string + + # Extract account name from connection string if not provided directly + if not azure_account_name: + self.azure_account_name = None + # Connection string format: DefaultEndpointsProtocol=https;AccountName=...;AccountKey=... + for part in azure_connection_string.split(";"): + if part.startswith("AccountName="): + self.azure_account_name = part.split("=", 1)[1] + break + if not self.azure_account_name: + raise ObjectStorageError( + "Could not extract account name from connection string. " + "Please provide azure_account_name directly or ensure connection string contains AccountName." + ) + else: + self.azure_account_name = azure_account_name else: - self.endpoint_url = endpoint_url or "" - - # CloudFront configuration for signed URLs - self.cloudfront_domain = cloudfront_domain - self.cloudfront_key_pair_id = cloudfront_key_pair_id - self.cloudfront_private_key = cloudfront_private_key - - # Set credentials as environment variables - MSC picks these up automatically - if access_key_id: - os.environ["AWS_ACCESS_KEY_ID"] = access_key_id - if secret_access_key: - os.environ["AWS_SECRET_ACCESS_KEY"] = secret_access_key - if session_token: - os.environ["AWS_SESSION_TOKEN"] = session_token - if region: - os.environ["AWS_DEFAULT_REGION"] = region + raise ValueError( + f"Unsupported storage_type: {storage_type}. Must be 's3' or 'azure'." + ) # Import multi-storage-client try: @@ -303,44 +378,124 @@ def __init__( self._msc = msc - # Build the S3 storage provider options - s3_storage_provider_options: dict[str, Any] = { - "base_path": bucket, - "region_name": self.region, - "multipart_threshold": multipart_chunksize, - "multipart_chunksize": multipart_chunksize, - "max_concurrency": max_concurrency, - } - - # Add endpoint URL if provided (for S3-compatible services) - if endpoint_url: - s3_storage_provider_options["endpoint_url"] = endpoint_url - - # Enable Rust client for high-performance I/O - if use_rust_client: - s3_storage_provider_options["rust_client"] = { + # Build storage provider profile config based on storage_type + if storage_type == "s3": + # Build the S3 storage provider options + s3_storage_provider_options: dict[str, Any] = { + "base_path": bucket, + "region_name": self.region, + "multipart_threshold": multipart_chunksize, "multipart_chunksize": multipart_chunksize, "max_concurrency": max_concurrency, } - # Build the S3 profile config - s3_profile_config = { - "profiles": { - profile_name: { - "storage_provider": { - "type": "s3", - "options": s3_storage_provider_options, + # Add endpoint URL if provided (for S3-compatible services) + if self.endpoint_url: + s3_storage_provider_options["endpoint_url"] = self.endpoint_url + + # Enable Rust client for high-performance I/O + if use_rust_client: + s3_storage_provider_options["rust_client"] = { + "multipart_chunksize": multipart_chunksize, + "max_concurrency": max_concurrency, + } + + # Build the S3 profile config + profile_config = { + "profiles": { + self.profile_name: { + "storage_provider": { + "type": "s3", + "options": s3_storage_provider_options, + } } } } - } + elif storage_type == "azure": + # Build the Azure storage provider options + # Derive endpoint URL from endpoint_url parameter, connection string, or account name + azure_endpoint_url = None + + # First, check if endpoint_url was provided directly (for managed identity or custom endpoints) + if endpoint_url: + azure_endpoint_url = endpoint_url.rstrip("/") + # Then, try to extract BlobEndpoint directly from connection string + elif azure_connection_string: + for part in azure_connection_string.split(";"): + if part.startswith("BlobEndpoint="): + azure_endpoint_url = part.split("=", 1)[1].rstrip("/") + break + + # If not found, construct from AccountName and EndpointSuffix + if not azure_endpoint_url: + account_name = None + endpoint_suffix = "core.windows.net" # Default suffix + + # Extract from connection string + if azure_connection_string: + for part in azure_connection_string.split(";"): + if part.startswith("AccountName="): + account_name = part.split("=", 1)[1] + elif part.startswith("EndpointSuffix="): + endpoint_suffix = part.split("=", 1)[1] + + # Fall back to provided account_name if not in connection string + if not account_name: + account_name = self.azure_account_name + + if not account_name: + raise ObjectStorageError( + "Azure endpoint_url cannot be determined. " + "Please provide endpoint_url, azure_connection_string (with AccountName or BlobEndpoint), " + "or azure_account_name." + ) + + azure_endpoint_url = f"https://{account_name}.blob.{endpoint_suffix}" + logger.info( + f"Constructed Azure endpoint URL from account name: {azure_endpoint_url}" + ) + + azure_storage_provider_options = { + "base_path": self.azure_container_name, + "endpoint_url": azure_endpoint_url, + } + + # Build the Azure profile config with credentials provider + profile_config = { + "profiles": { + self.profile_name: { + "storage_provider": { + "type": "azure", + "options": azure_storage_provider_options, + } + } + } + } + + # Add Azure credentials provider + if self.use_managed_identity: + # Use DefaultAzureCredentials for managed identity + profile_config["profiles"][self.profile_name][ + "credentials_provider" + ] = { + "type": "DefaultAzureCredentials", + "options": {}, + } + elif azure_connection_string: + # Use AzureCredentials with connection string + profile_config["profiles"][self.profile_name][ + "credentials_provider" + ] = { + "type": "AzureCredentials", + "options": {"connection": "${AZURE_CONNECTION_STRING}"}, + } - # Initialize the S3 StorageClient (target for uploads) - s3_client_config = msc.StorageClientConfig.from_dict( - config_dict=s3_profile_config, - profile=profile_name, + # Initialize the StorageClient (target for uploads) + storage_client_config = msc.StorageClientConfig.from_dict( + config_dict=profile_config, + profile=self.profile_name, ) - self._s3_client = msc.StorageClient(config=s3_client_config) + self._storage_client = msc.StorageClient(config=storage_client_config) # Initialize the local filesystem StorageClient (source for uploads) local_profile_config = { @@ -362,12 +517,18 @@ def __init__( self._local_client = msc.StorageClient(config=local_client_config) rust_status = "enabled" if use_rust_client else "disabled" - accel_status = "enabled" if use_transfer_acceleration else "disabled" - logger.info( - f"MSCObjectStorage initialized: bucket={bucket}, region={self.region}, " - f"max_concurrency={max_concurrency}, rust_client={rust_status}, " - f"transfer_acceleration={accel_status}" - ) + if storage_type == "s3": + accel_status = "enabled" if use_transfer_acceleration else "disabled" + logger.info( + f"MSCObjectStorage initialized (S3): bucket={bucket}, region={self.region}, " + f"max_concurrency={max_concurrency}, rust_client={rust_status}, " + f"transfer_acceleration={accel_status}" + ) + else: + logger.info( + f"MSCObjectStorage initialized (Azure): container={self.azure_container_name}, " + f"max_concurrency={max_concurrency}, rust_client={rust_status}" + ) def upload_directory( self, @@ -416,9 +577,14 @@ def upload_directory( total_bytes = sum(f.stat().st_size for f in files) + storage_prefix = ( + f"s3://{self.bucket}" + if self.storage_type == "s3" + else f"azure://{self.azure_container_name if self.storage_type == 'azure' else self.bucket}" + ) logger.info( f"[MSC] Syncing {len(files)} files ({total_bytes / (1024 * 1024):.2f} MB) " - f"from {local_directory} to s3://{self.bucket}/{remote_prefix}" + f"from {local_directory} to {storage_prefix}/{remote_prefix}" ) errors: list[str] = [] @@ -426,7 +592,7 @@ def upload_directory( try: # Use sync_from for efficient parallel directory upload - result = self._s3_client.sync_from( + result = self._storage_client.sync_from( source_client=self._local_client, source_path=source_path, target_path=f"/{remote_prefix}" if remote_prefix else "/", @@ -443,7 +609,15 @@ def upload_directory( elapsed_time = time.time() - start_time success = len(errors) == 0 - destination = f"s3://{self.bucket}/{remote_prefix}" + if self.storage_type == "s3": + destination = f"s3://{self.bucket}/{remote_prefix}" + else: + container = ( + self.azure_container_name + if self.storage_type == "azure" + else self.bucket + ) + destination = f"azure://{container}/{remote_prefix}" result = UploadResult( success=success, @@ -500,7 +674,7 @@ def upload_file( try: remote_key = f"/{remote_key.lstrip('/')}" - self._s3_client.upload_file(remote_key, str(local_path)) + self._storage_client.upload_file(remote_key, str(local_path)) return True except Exception as e: @@ -523,7 +697,7 @@ def file_exists(self, remote_key: str) -> bool: """ try: remote_path = f"/{remote_key.lstrip('/')}" - self._s3_client.info(remote_path) + self._storage_client.info(remote_path) return True except FileNotFoundError: return False @@ -544,7 +718,7 @@ def delete_file(self, remote_key: str) -> bool: """ try: remote_path = f"/{remote_key.lstrip('/')}" - self._s3_client.delete(remote_path) + self._storage_client.delete(remote_path) return True except FileNotFoundError: logger.warning(f"File not found for deletion: {remote_key}") @@ -610,25 +784,37 @@ def _url_safe_b64(data: bytes) -> str: def generate_signed_url(self, remote_key: str, expires_in: int = 86400) -> str: """ - Generate a CloudFront signed URL for accessing a file. + Generate a signed URL for accessing a file. + + For S3, generates a CloudFront signed URL. + For Azure, generates a SAS (Shared Access Signature) token URL. Parameters ---------- remote_key : str - S3 key/path to the file. Can include wildcards. + Storage key/path to the file. Can include wildcards for S3. expires_in : int, optional Number of seconds until the URL expires. Default is 86400. Returns ------- str - Signed CloudFront URL string. + Signed URL string. Raises ------ ObjectStorageError - If CloudFront configuration is missing. + If required configuration is missing. """ + if self.storage_type == "s3": + return self._generate_cloudfront_signed_url(remote_key, expires_in) + elif self.storage_type == "azure": + return self._generate_azure_sas_url(remote_key, expires_in) + else: + raise ObjectStorageError(f"Unsupported storage_type: {self.storage_type}") + + def _generate_cloudfront_signed_url(self, remote_key: str, expires_in: int) -> str: + """Generate a CloudFront signed URL for S3.""" if not all( [ self.cloudfront_domain, @@ -678,5 +864,58 @@ def generate_signed_url(self, remote_key: str, expires_in: int = 86400) -> str: f"&Key-Pair-Id={self.cloudfront_key_pair_id}" ) - logger.debug(f"Generated signed URL for {remote_key}, expires in {expires_in}s") + logger.debug( + f"Generated CloudFront signed URL for {remote_key}, expires in {expires_in}s" + ) + return signed_url + + def _generate_azure_sas_url(self, remote_key: str, expires_in: int) -> str: + """Generate an Azure SAS (Shared Access Signature) URL.""" + if not self.azure_account_name or not self.azure_account_key: + raise ObjectStorageError( + "Azure account name and account key are required for signed URLs. " + "Please provide azure_account_name and azure_account_key." + ) + + try: + from azure.storage.blob import ( + ContainerSasPermissions, + generate_container_sas, + ) + except ImportError as e: + raise ImportError( + "azure-storage-blob is required for Azure signed URLs. " + "Install with: pip install azure-storage-blob" + ) from e + + container_name = ( + self.azure_container_name if self.storage_type == "azure" else self.bucket + ) + + # Define permissions (Read + List) + permissions = ContainerSasPermissions(read=True, list=True) + + # Set the duration + start_time = datetime.datetime.now(datetime.timezone.utc) + expiry_time = start_time + datetime.timedelta(seconds=expires_in) + + # Generate the SAS token + sas_token = generate_container_sas( + account_name=self.azure_account_name, + account_key=self.azure_account_key, + container_name=container_name, + permission=permissions, + expiry=expiry_time, + start=start_time, + ) + + # Construct the full URL with the prefix + # Remove leading slash from remote_key if present + prefix = remote_key.lstrip("/") + base_url = f"https://{self.azure_account_name}.blob.core.windows.net/{container_name}/{prefix}" + signed_url = f"{base_url}?{sas_token}" + + logger.debug( + f"Generated Azure SAS URL for {remote_key}, expires in {expires_in}s" + ) return signed_url diff --git a/earth2studio/serve/server/utils.py b/earth2studio/serve/server/utils.py index 619cfe7c2..922f9310a 100644 --- a/earth2studio/serve/server/utils.py +++ b/earth2studio/serve/server/utils.py @@ -14,10 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging +from collections.abc import AsyncGenerator +from pathlib import Path from typing import Any, Literal +import aiofiles # type: ignore[import-untyped] import redis # type: ignore[import-untyped] +from fastapi import HTTPException from rq import Queue from earth2studio.serve.server.config import get_config @@ -55,14 +60,151 @@ def get_signed_url_key(request_id: str) -> str: return f"inference_request:{request_id}:signed_url" +# ============================================================================= +# File Streaming Utilities +# ============================================================================= + + +def parse_range_header( + range_header: str | None, file_size: int +) -> tuple[int, int, int, int]: + """ + Parse Range header and return start, end, content_length, and status_code. + + Args: + range_header: The Range header value from the request, or None + file_size: Total size of the file in bytes + + Returns: + Tuple of (start, end, content_length, status_code) + + Raises: + HTTPException: If the range is invalid (416 status) + """ + start = 0 + end = file_size - 1 + status_code = 200 + + if not range_header: + return (start, end, file_size, status_code) + + # Parse Range header: "bytes=start-end" or "bytes=start-" or "bytes=-suffix" + if not range_header.startswith("bytes="): + raise HTTPException( + status_code=416, + detail={ + "error": "Range Not Satisfiable", + "details": "Only byte ranges are supported", + }, + ) + + range_spec = range_header[6:] # Remove "bytes=" prefix + ranges = range_spec.split(",") + + # For now, only handle single range (first range) + if len(ranges) > 1: + logger.warning(f"Multiple ranges requested, using first: {range_header}") + + range_part = ranges[0].strip() + if "-" not in range_part: + raise HTTPException( + status_code=416, + detail={ + "error": "Range Not Satisfiable", + "details": "Invalid range format", + }, + ) + + start_str, end_str = range_part.split("-", 1) + + try: + if start_str: + start = int(start_str) + if end_str: + end = int(end_str) + else: + end = file_size - 1 + else: + # Suffix range: "-suffix" means last N bytes + suffix = int(end_str) + start = max(0, file_size - suffix) + end = file_size - 1 + except ValueError: + raise HTTPException( + status_code=416, + detail={ + "error": "Range Not Satisfiable", + "details": f"Invalid range values in: {range_header}", + }, + ) + + # Validate range - clamp end to file_size - 1 per RFC 9110 §14.1.2 + if start < 0 or start >= file_size or end < start: + raise HTTPException( + status_code=416, + headers={ + "Content-Range": f"bytes */{file_size}", + }, + detail={ + "error": "Range Not Satisfiable", + "details": f"Requested range {start}-{end} is invalid for file size {file_size}", + }, + ) + # Clamp end to the last valid byte + end = min(end, file_size - 1) + + status_code = 206 # Partial Content + content_length = end - start + 1 + return (start, end, content_length, status_code) + + +async def create_file_stream( + file_path: Path, start: int, content_length: int, file_description: str = "file" +) -> AsyncGenerator[bytes, None]: + """ + Create an async generator that streams a file with optional range support. + + Args: + file_path: Path to the file to stream + start: Starting byte position (0 for full file, or range start) + content_length: Number of bytes to stream + file_description: Description for error logging + + Yields: + Bytes chunks from the file + """ + try: + chunk_size = 1048576 # 1MB chunks for better performance + async with aiofiles.open(file_path, "rb") as f: + # Seek to start position if range request + if start > 0: + await f.seek(start) + + remaining = content_length + while remaining > 0: + read_size = min(chunk_size, remaining) + chunk = await f.read(read_size) + if not chunk: + break + yield chunk + remaining -= len(chunk) + await asyncio.sleep(0) + except Exception: + logger.exception(f"Error streaming {file_description} {file_path}") + raise + + # ============================================================================= # Pipeline Stage Utilities # ============================================================================= +Stage = Literal["inference", "result_zip", "object_storage", "geocatalog_ingestion"] + + def queue_next_stage( redis_client: redis.Redis, - current_stage: Literal["inference", "result_zip", "object_storage"], + current_stage: Stage, workflow_name: str, execution_id: str, output_path_str: str, @@ -72,12 +214,12 @@ def queue_next_stage( Queue the next pipeline stage based on configuration. Pipeline flow: - - If result_zip_enabled: inference -> result_zip -> object_storage (if enabled) -> finalize - - If not result_zip_enabled: inference -> object_storage (if enabled) -> finalize + - If result_zip_enabled: inference -> result_zip -> object_storage (if enabled) -> [geocatalog_ingestion (if AZURE_GEOCATALOG_URL)] -> finalize + - If not result_zip_enabled: inference -> object_storage (if enabled) -> [geocatalog_ingestion (if AZURE_GEOCATALOG_URL)] -> finalize Args: redis_client: Redis client for queue connection - current_stage: The stage that just completed ("inference", "result_zip", "object_storage") + current_stage: The stage that just completed ("inference", "result_zip", "object_storage", "geocatalog_ingestion") workflow_name: Name of the workflow execution_id: Execution ID of the workflow output_path_str: Path to the output files @@ -120,6 +262,18 @@ def queue_next_stage( args = (workflow_name, execution_id) elif current_stage == "object_storage": + if config.object_storage.azure_geocatalog_url: + next_queue = "geocatalog_ingestion" + next_func = ( + "earth2studio.serve.server.cpu_worker.process_geocatalog_ingestion" + ) + args = (workflow_name, execution_id) + else: + next_queue = "finalize_metadata" + next_func = "earth2studio.serve.server.cpu_worker.process_finalize_metadata" + args = (workflow_name, execution_id) + + elif current_stage == "geocatalog_ingestion": next_queue = "finalize_metadata" next_func = "earth2studio.serve.server.cpu_worker.process_finalize_metadata" args = (workflow_name, execution_id) diff --git a/earth2studio/serve/server/workflow.py b/earth2studio/serve/server/workflow.py index e1191bff2..a613f80b7 100644 --- a/earth2studio/serve/server/workflow.py +++ b/earth2studio/serve/server/workflow.py @@ -738,12 +738,79 @@ def get( return instance - def list_workflows(self) -> dict[str, str]: - """List all registered workflows.""" - return { - name: workflow_class.description - for name, workflow_class in self._workflows.items() - } + def is_workflow_exposed(self, workflow_name: str) -> bool: + """ + Check if a workflow is exposed via API. + + A workflow is exposed if: + - The exposed_workflows list is empty (all workflows exposed by default), OR + - The workflow name is in the exposed_workflows list, OR + - The workflow name is in the warmup_workflows list (accessible for warmup) + + Parameters + ---------- + workflow_name : str + Name of the workflow to check + + Returns + ------- + bool + True if workflow should be exposed, False otherwise + """ + from earth2studio.serve.server.config import get_config + + config = get_config() + exposed_workflows = config.workflow_exposure.exposed_workflows + warmup_workflows = config.workflow_exposure.warmup_workflows + + # Empty list means all workflows are exposed + if not exposed_workflows: + return True + + # Check if in exposed list or warmup list + return workflow_name in exposed_workflows or workflow_name in warmup_workflows + + def list_workflows(self, exposed_only: bool = True) -> dict[str, str]: + """ + List registered workflows. + + Parameters + ---------- + exposed_only : bool, optional + If True, only return workflows that are in exposed_workflows + (warmup-only workflows are excluded from public listing). + If False, return all registered workflows. + + Returns + ------- + dict + Dictionary mapping workflow names to descriptions + """ + if exposed_only: + from earth2studio.serve.server.config import get_config + + config = get_config() + exposed_workflows = config.workflow_exposure.exposed_workflows + + # Empty list means all workflows are exposed (including warmup) + if not exposed_workflows: + return { + name: workflow_class.description + for name, workflow_class in self._workflows.items() + } + + # Only return workflows in the exposed_workflows list + # (warmup-only workflows are excluded from public listing) + return { + name: workflow_class.description + for name, workflow_class in self._workflows.items() + if name in exposed_workflows + } + else: + return { + name: workflow_class.description + for name, workflow_class in self._workflows.items() + } def discover_and_register_from_directories( self, workflow_dirs: list, include_builtin: bool = True diff --git a/pyproject.toml b/pyproject.toml index cda6434b8..40e8d7db1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,11 @@ serve = [ "python-multipart>=0.0.6", "requests>=2.25.0", "urllib3>=1.26.0", + # Object storage related + "cryptography>=41.0.0", + "multi-storage-client>=0.44.0", + "azure-storage-blob>=12.19.0", + "azure-identity>=1.15.0", ] # PX Models ace2 = [ diff --git a/serve/server/README_object_storage.md b/serve/server/README_object_storage.md index b455137d1..acf41b088 100644 --- a/serve/server/README_object_storage.md +++ b/serve/server/README_object_storage.md @@ -1,18 +1,25 @@ # Object Storage Support -This document describes how to configure and use object storage (AWS S3 with CloudFront) for storing -workflow results in the Earth2Studio Inference Server. +This document describes how to configure and use object storage (AWS S3 with CloudFront or +Azure Blob Storage) for storing workflow results in the Earth2Studio Inference Server. ## Overview By default, workflow results are stored locally on the inference server. When object storage is -enabled, results are automatically uploaded to S3 and served via CloudFront signed URLs. This -provides: +enabled, results are automatically uploaded to your chosen cloud storage provider (AWS S3 or +Azure Blob Storage) and served via signed URLs. This provides: - **Scalability**: Offload storage from the inference server -- **Performance**: CloudFront CDN for fast global access +- **Performance**: Fast global access via CDN (CloudFront for S3) or direct Azure Blob Storage access - **Security**: Time-limited signed URLs for secure access -- **Seamless Client Experience**: The Python client SDK automatically handles both storage types +- **Seamless Client Experience**: The Python client SDK automatically handles all storage types + +## Storage Provider Options + +The inference server supports two storage providers: + +- **AWS S3**: With optional CloudFront CDN for enhanced performance +- **Azure Blob Storage**: Direct access with SAS (Shared Access Signature) URLs ## AWS Prerequisites @@ -20,70 +27,54 @@ Before enabling object storage, you need to set up the following AWS resources: ### 1. S3 Bucket -Create an S3 bucket to store workflow results: +Create an S3 bucket to store workflow results. +**Must for performance**: Enable S3 Transfer Acceleration for faster uploads: -```bash -aws s3 mb s3://your-bucket-name --region us-east-1 -``` +### 2. CloudFront Distribution -**Must for performance**: Enable S3 Transfer Acceleration for faster uploads: +Create a CloudFront distribution to serve content from your S3 bucket. -```bash -aws s3api put-bucket-accelerate-configuration \ - --bucket your-bucket-name \ - --accelerate-configuration Status=Enabled -``` +### 3. CloudFront Key Pair for Signed URLs -### 2. CloudFront Distribution +To generate signed URLs, you need a CloudFront key pair. -Create a CloudFront distribution to serve content from your S3 bucket: +### 4. IAM Credentials -1. Go to AWS CloudFront Console → Create Distribution -2. Set Origin Domain to your S3 bucket (`your-bucket-name.s3.amazonaws.com`) -3. Set Origin Access to "Origin access control settings (recommended)" -4. Create a new Origin Access Control (OAC) -5. Update the S3 bucket policy to allow CloudFront access (AWS will provide the policy) +Create IAM credentials with permissions to upload to S3. -### 3. CloudFront Key Pair for Signed URLs +## Azure Prerequisites -To generate signed URLs, you need a CloudFront key pair: +Before enabling Azure Blob Storage, you need to set up the following Azure resources: -1. Go to AWS CloudFront Console → Key Management → Public Keys -2. Create a new public key by uploading a public key you generated: +### 1. Azure Storage Account -```bash -# Generate a private key -openssl genrsa -out cloudfront-private-key.pem 2048 +Create an Azure Storage Account. -# Extract the public key -openssl rsa -in cloudfront-private-key.pem -pubout -out cloudfront-public-key.pem -``` +### 2. Storage Container -Then: +Create a blob container in your storage account. -1. Upload `cloudfront-public-key.pem` to CloudFront -2. Create a Key Group containing your public key -3. Associate the Key Group with your CloudFront distribution's behavior settings (Restrict Viewer -Access → Yes, Trusted Key Groups) -4. Note the **Key Pair ID** (e.g., `KUCQGLNFR6UH1`) -5. Keep `cloudfront-private-key.pem` secure - this is used by the server to sign URLs +### 3. Connection String, Account Key -### 4. IAM Credentials +You will need a connection string to write into the container. +The account key is required for generating SAS (Shared Access Signature) signed URLs. + +### 4. Permissions -Create IAM credentials with permissions to upload to S3. See [Creating IAM -users](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_users_create.html) for -detailed instructions. The credentials need `s3:PutObject`, `s3:GetObject`, -`s3:DeleteObject`, and `s3:ListBucket` permissions on your bucket. +Ensure your Azure credentials have the permissions to write/read into the container. ## Server Configuration ### Environment Variables -Configure object storage using environment variables: +Configure object storage using environment variables. Choose either AWS S3 or Azure Blob Storage: + +#### AWS S3 Configuration ```bash # Enable object storage export OBJECT_STORAGE_ENABLED=true +export OBJECT_STORAGE_TYPE=s3 # S3 Configuration export OBJECT_STORAGE_BUCKET=your-bucket-name @@ -111,13 +102,46 @@ export CLOUDFRONT_PRIVATE_KEY_PATH=/path/to/cloudfront-private-key.pem export OBJECT_STORAGE_SIGNED_URL_EXPIRES_IN=3600 # URL expiration in seconds ``` +#### Azure Blob Storage Configuration + +```bash +# Enable object storage +export OBJECT_STORAGE_ENABLED=true +export OBJECT_STORAGE_TYPE=azure + +# Azure Configuration +export OBJECT_STORAGE_BUCKET=your-container-name # Container name (used as bucket equivalent) +export OBJECT_STORAGE_PREFIX=outputs # Optional: prefix for uploaded files + +# Azure Credentials (Connection String - Recommended) +export AZURE_CONNECTION_STRING="DefaultEndpointsProtocol=https;AccountName=mystorageaccount;AccountKey=...;EndpointSuffix=core.windows.net" + +# OR Azure Credentials (Account Name and Key - Alternative) +export AZURE_STORAGE_ACCOUNT_NAME=mystorageaccount +export AZURE_STORAGE_ACCOUNT_KEY=... # Required for SAS URL generation + +# Optional: Container name (defaults to OBJECT_STORAGE_BUCKET if not set) +export AZURE_CONTAINER_NAME=workflow-results + +# Transfer Configuration +export OBJECT_STORAGE_MAX_CONCURRENCY=16 # Concurrent upload threads +export OBJECT_STORAGE_MULTIPART_CHUNKSIZE=8388608 # 8MB chunk size +export OBJECT_STORAGE_USE_RUST_CLIENT=true # High-performance Rust client + +# Signed URL Configuration +export OBJECT_STORAGE_SIGNED_URL_EXPIRES_IN=3600 # SAS URL expiration in seconds +``` + ### YAML Configuration Alternatively, configure via `config.yaml`: +#### AWS S3 YAML Configuration + ```yaml object_storage: enabled: true + storage_type: s3 bucket: your-bucket-name region: us-east-1 prefix: outputs @@ -131,26 +155,49 @@ object_storage: signed_url_expires_in: 3600 ``` +#### Azure Blob Storage YAML Configuration + +```yaml +object_storage: + enabled: true + storage_type: azure + bucket: your-container-name # Container name (used as bucket equivalent) + prefix: outputs + max_concurrency: 16 + multipart_chunksize: 8388608 + use_rust_client: true + azure_connection_string: "DefaultEndpointsProtocol=https;AccountName=mystorageaccount;AccountKey=...;EndpointSuffix=core.windows.net" + azure_account_name: mystorageaccount # Optional if in connection string + azure_account_key: ... # Required for SAS URL generation + azure_container_name: workflow-results # Optional, defaults to bucket + signed_url_expires_in: 3600 +``` + ### Configuration Parameters Reference | Parameter | Environment Variable | Default | Description | |-----------|---------------------|---------|-------------| | `enabled` | `OBJECT_STORAGE_ENABLED` | `false` | Enable object storage | -| `bucket` | `OBJECT_STORAGE_BUCKET` | `null` | S3 bucket name | -| `region` | `OBJECT_STORAGE_REGION` | `us-east-1` | AWS region | +| `storage_type` | `OBJECT_STORAGE_TYPE` | `s3` | Storage provider: `s3` or `azure` | +| `bucket` | `OBJECT_STORAGE_BUCKET` | `null` | S3 bucket name or Azure container name | +| `region` | `OBJECT_STORAGE_REGION` | `us-east-1` | AWS region (S3 only) | | `prefix` | `OBJECT_STORAGE_PREFIX` | `outputs` | Remote prefix for files | -| `access_key_id` | `OBJECT_STORAGE_ACCESS_KEY_ID` | `null` | AWS access key ID | -| `secret_access_key` | `OBJECT_STORAGE_SECRET_ACCESS_KEY` | `null` | AWS secret access key | -| `session_token` | `OBJECT_STORAGE_SESSION_TOKEN` | `null` | AWS session token | +| `access_key_id` | `OBJECT_STORAGE_ACCESS_KEY_ID` | `null` | AWS access key ID (S3 only) | +| `secret_access_key` | `OBJECT_STORAGE_SECRET_ACCESS_KEY` | `null` | AWS secret access key (S3 only) | +| `session_token` | `OBJECT_STORAGE_SESSION_TOKEN` | `null` | AWS session token (S3 only) | | `endpoint_url` | `OBJECT_STORAGE_ENDPOINT_URL` | `null` | Custom endpoint (S3-compatible) | -| `use_transfer_acceleration` | `OBJECT_STORAGE_TRANSFER_ACCELERATION` | `true` | Enable S3 Transfer Acceleration | +| `use_transfer_acceleration` | `OBJECT_STORAGE_TRANSFER_ACCELERATION` | `true` | Enable S3 Transfer Acceleration (S3 only) | | `max_concurrency` | `OBJECT_STORAGE_MAX_CONCURRENCY` | `16` | Max concurrent transfers | | `multipart_chunksize` | `OBJECT_STORAGE_MULTIPART_CHUNKSIZE` | `8388608` | Multipart chunk size (bytes) | | `use_rust_client` | `OBJECT_STORAGE_USE_RUST_CLIENT` | `true` | Use high-performance Rust client | -| `cloudfront_domain` | `CLOUDFRONT_DOMAIN` | `null` | CloudFront distribution domain | -| `cloudfront_key_pair_id` | `CLOUDFRONT_KEY_PAIR_ID` | `null` | CloudFront key pair ID | -| `cloudfront_private_key_path` | `CLOUDFRONT_PRIVATE_KEY_PATH` | `null` | Path to private key PEM file | +| `cloudfront_domain` | `CLOUDFRONT_DOMAIN` | `null` | CloudFront distribution domain (S3 only) | +| `cloudfront_key_pair_id` | `CLOUDFRONT_KEY_PAIR_ID` | `null` | CloudFront key pair ID (S3 only) | +| `cloudfront_private_key_path` | `CLOUDFRONT_PRIVATE_KEY_PATH` | `null` | Path to private key PEM file (S3 only) | +| `azure_connection_string` | `AZURE_CONNECTION_STRING` | `null` | Azure connection string (Azure only, recommended) | +| `azure_account_name` | `AZURE_STORAGE_ACCOUNT_NAME` | `null` | Azure storage account name (Azure only) | +| `azure_account_key` | `AZURE_STORAGE_ACCOUNT_KEY` | `null` | Azure account key for SAS URLs (Azure only) | +| `azure_container_name` | `AZURE_CONTAINER_NAME` | `null` | Azure container name (Azure only, defaults to bucket) | | `signed_url_expires_in` | `OBJECT_STORAGE_SIGNED_URL_EXPIRES_IN` | `3600` | Signed URL expiration (seconds) | @@ -158,6 +205,8 @@ object_storage: When object storage is enabled, the workflow result metadata includes additional fields: +### AWS S3 Example + ```json { "request_id": "exec_1769560728_10ed9d3c", @@ -174,12 +223,30 @@ When object storage is enabled, the workflow result metadata includes additional } ``` +### Azure Blob Storage Example + +```json +{ + "request_id": "exec_1769560728_10ed9d3c", + "status": "completed", + "storage_type": "azure", + "signed_url": + "https://mystorageaccount.blob.core.windows.net/workflow-results/outputs/exec_1769560728_10ed9d3c?sv=2021-06-08&ss=b&srt=co&sp=rl&se=2025-01-01T00:00:00Z&st=2024-12-31T00:00:00Z&spr=https&sig=...", + "remote_path": "outputs/exec_1769560728_10ed9d3c", + "output_files": [ + {"path": "exec_1769560728_10ed9d3c/results.zarr/.zarray", "size": 123}, + {"path": "exec_1769560728_10ed9d3c/results.zarr/t2m/0.0.0", "size": 4567890} + ] +} +``` + ### Storage Type Values | Value | Description | |-------|-------------| | `server` | Results stored locally on the inference server | | `s3` | Results stored in S3, accessible via CloudFront signed URL | +| `azure` | Results stored in Azure Blob Storage, accessible via SAS signed URL | ## Client Usage @@ -209,7 +276,7 @@ request_result = client.run_inference_sync( InferenceRequest(parameters={"start_time": [datetime(2025, 8, 21, 6)]}) ) -# Automatically downloads from S3 if storage_type is "s3" +# Automatically downloads from S3 or Azure if storage_type is "s3" or "azure" for file in request_result.output_files[:5]: content = client.download_result(request_result, file.path) print(f"Downloaded {file.path}: {len(content.getvalue())} bytes") @@ -217,7 +284,10 @@ for file in request_result.output_files[:5]: ### Using Signed URLs Directly -For advanced use cases, you can use the signed URL directly: +For advanced use cases, you can use the signed URL directly. The format differs between +S3/CloudFront and Azure: + +#### Using CloudFront Signed URLs ```python import requests @@ -237,6 +307,22 @@ file_url = f"{base_url}/{file_path}?{query_params}" response = requests.get(file_url) ``` +#### Using Azure SAS URLs + +```python +import requests + +# Get the signed URL from the result +signed_url = request_result.signed_url +# Example: +# https://mystorageaccount.blob.core.windows.net/workflow-results/outputs/exec_123?sv=...&sig=... + +# Azure SAS URLs already include the full path - append the file path +file_path = "results.zarr/.zarray" +file_url = f"{signed_url}/{file_path}" if not signed_url.endswith("/") else f"{signed_url}{file_path}" +response = requests.get(file_url) +``` + ### Using with Xarray and Zarr The client provides an fsspec mapper for opening Zarr stores directly: @@ -252,63 +338,3 @@ mapper = create_cloudfront_mapper(request_result.signed_url, zarr_path="results. ds = xr.open_zarr(mapper, consolidated=True) print(ds) ``` - -## Signed URL Format - -CloudFront signed URLs contain three query parameters: - -```text -https://d30anq61ot046p.cloudfront.net/outputs/exec_123/*?Policy=eyJTdGF0ZW1lbnQiOl...\ -&Signature=ABC123...&Key-Pair-Id=KUCQGLNFR6UH1 -``` - -| Parameter | Description | -|-----------|-------------| -| `Policy` | Base64-encoded JSON policy specifying resource and expiration | -| `Signature` | RSA signature of the policy using the private key | -| `Key-Pair-Id` | CloudFront key pair ID used to verify the signature | - -The wildcard (`*`) in the URL path allows access to all files under that prefix. - -## Testing - -Run object storage integration tests: - -```bash -# Set required environment variables -export TEST_S3_BUCKET=my-test-bucket -export AWS_ACCESS_KEY_ID=AKIA... -export AWS_SECRET_ACCESS_KEY=... - -# Run S3 upload tests -pytest test/integration/test_object_storage.py -v - -# Run CloudFront signed URL tests (requires additional config) -export TEST_CLOUDFRONT_DOMAIN=https://d30anq61ot046p.cloudfront.net -export TEST_CLOUDFRONT_KEY_PAIR_ID=KUCQGLNFR6UH1 -export TEST_CLOUDFRONT_PRIVATE_KEY_PATH=/path/to/private.pem -pytest test/integration/test_object_storage.py::TestCloudFrontSignedUrl -v -``` - -## Troubleshooting - -### Common Issues - -1. **403 Forbidden from CloudFront** - - Verify the S3 bucket policy allows CloudFront OAC access - - Check that the CloudFront distribution is configured with the correct origin - - Ensure the key pair is in a Key Group associated with the distribution - -2. **Signed URL expired** - - Increase `signed_url_expires_in` configuration - - Request fresh results from the API (URLs are regenerated) - -3. **Upload failures** - - Verify IAM credentials have `s3:PutObject` permission - - Check bucket name and region are correct - - If using Transfer Acceleration, ensure it's enabled on the bucket - -4. **Slow uploads** - - Enable `use_rust_client` for better performance - - Increase `max_concurrency` for more parallel uploads - - Enable `use_transfer_acceleration` if uploading from distant regions diff --git a/serve/server/conf/config.yaml b/serve/server/conf/config.yaml index f654305da..bab0788dc 100644 --- a/serve/server/conf/config.yaml +++ b/serve/server/conf/config.yaml @@ -36,11 +36,13 @@ worker: num_workers: 1 # The number of RQ inference workers to create by default zip_num_workers: 1 # The number of RQ workers for result_zip queue objstore_num_workers: 1 # The number of RQ workers for object_storage queue + geocatalog_num_workers: 1 # The number of RQ workers for geocatalog_ingestion queue (used when AZURE_GEOCATALOG_URL is set) finalize_num_workers: 1 # The number of RQ workers for finalize_metadata queue paths: default_output_dir: /outputs results_zip_dir: /workspace/earth2studio-project/examples/outputs + output_format: zarr # Output format: "zarr" or "netcdf4" result_zip_enabled: false logging: @@ -88,3 +90,18 @@ object_storage: cloudfront_private_key: null # PEM private key content # Signed URL settings signed_url_expires_in: 86400 # 24 hours + # Azure Blob Storage configuration + azure_connection_string: null + azure_account_name: null + azure_account_key: null + azure_container_name: null + azure_geocatalog_url: null # When set, triggers Planetary Computer ingestion after Azure upload + +workflow_exposure: + # List of workflow names to expose via API endpoints + # Empty list means all workflows are exposed + exposed_workflows: [] + # Workflows accessible for warmup even if not in exposed_workflows + # These workflows can be called via API for warmup purposes + warmup_workflows: + - example_user_workflow diff --git a/serve/server/example_workflows/foundry_fcn3.py b/serve/server/example_workflows/foundry_fcn3.py new file mode 100644 index 000000000..52e6af331 --- /dev/null +++ b/serve/server/example_workflows/foundry_fcn3.py @@ -0,0 +1,203 @@ +import logging +from collections.abc import Sequence +from datetime import datetime + +import numpy as np +import torch +import zarr + +from earth2studio.data import PlanetaryComputerECMWFOpenDataIFS, fetch_data +from earth2studio.io import IOBackend, NetCDF4Backend, ZarrBackend +from earth2studio.models.px import FCN3 +from earth2studio.serve.server import ( + Earth2Workflow, + WorkflowProgress, + workflow_registry, +) +from earth2studio.utils.coords import CoordSystem, map_coords, split_coords +from earth2studio.utils.time import to_time_array + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("foundry_fcn3_workflow") + + +@workflow_registry.register +class FoundryFCN3Workflow(Earth2Workflow): + name = "foundry_fcn3_workflow" + description = "FCN3 ensemble workflow for Foundry" + + def __init__( + self, + device: str = "cuda", + init_seed: int | None = None, + ): + super().__init__() + + self.device = torch.device(device) + + self.fcn3 = self.load_fcn3() + self.rng = np.random.default_rng(init_seed) + + self.data = PlanetaryComputerECMWFOpenDataIFS(verbose=False, cache=False) + + def load_fcn3(self) -> FCN3: + logger.info("Loading FCN3") + package = FCN3.load_default_package() + fcn3 = FCN3.load_model(package) + fcn3.to(self.device) + fcn3.eval() + return fcn3 + + def get_seeds(self, n_seeds: int) -> list[int]: + seeds = self.rng.choice(2**32, size=n_seeds, replace=False) + return [int(s) for s in seeds] + + def validate_start_time(self, start_time: datetime) -> None: + if (start_time - datetime(1900, 1, 1)).total_seconds() % 21600 != 0: + raise ValueError(f"Start time needs to be 6-hour interval: {start_time}") + + def validate_samples( + self, n_samples: int, seeds: Sequence[int] | None + ) -> list[int]: + if seeds is None: + seeds = self.get_seeds(n_samples) + elif len(seeds) != n_samples: + logger.warning( + "Ignoring requested number of samples because it does not match number of seeds" + ) + return list(seeds) + + def validate_variables(self, variables: Sequence[str] | None) -> np.ndarray: + if variables is None: + variables = self.fcn3.variables + else: + unknown_variables = set(variables) - set(self.fcn3.variables) + if len(unknown_variables): + raise ValueError(f"Unknown variable(s) {', '.join(unknown_variables)}") + variables = np.array(variables) + return variables + + def setup_io( + self, io: IOBackend, output_coords: CoordSystem, seeds: Sequence[int] + ) -> None: + io.add_array( + {k: v for k, v in output_coords.items() if k != "variable"}, + output_coords["variable"], + ) + + # Storing seeds separately makes it easier to filter with Titiler + e_coords = {"ensemble": output_coords["ensemble"]} + io.add_array(e_coords, "seed", data=torch.tensor(seeds)) + + # Add CRS definition + io.add_array({}, "crs") + io.root["crs"].grid_mapping_name = "latitude_longitude" + io.root["crs"].longitude_of_prime_meridian = 0.0 + io.root["crs"].semi_major_axis = 6378137.0 + io.root["crs"].inverse_flattening = 298.257223563 + + for var in output_coords["variable"]: + io.root[var].grid_mapping = "crs" + + # Set attributes for automatic parsing of dimensions + io.root["ensemble"].standard_name = "realization" + io.root["time"].standard_name = "time" + io.root["time"].axis = "T" + io.root["lat"].standard_name = "latitude" + io.root["lat"].units = "degrees_north" + io.root["lat"].axis = "Y" + io.root["lon"].standard_name = "longitude" + io.root["lon"].units = "degrees_east" + io.root["lon"].axis = "X" + + if isinstance(io, ZarrBackend): + zarr.consolidate_metadata(io.store) + + if isinstance(io, NetCDF4Backend): + # Planetary Computer does not like the original time format + ref_time = np.datetime_as_string(output_coords["time"][0], unit="s") + io["time"].units = f"hours since {ref_time.replace('T', ' ')}" + io["time"][:] = np.arange(len(io["time"])) * 6 + + return io + + def get_fcn3_input(self, time: datetime) -> tuple[torch.Tensor, CoordSystem]: + x, coords = fetch_data( + self.data, + time=to_time_array([time]), + variable=self.fcn3.input_coords()["variable"], + device=self.device, + ) + return x, coords + + def __call__( + self, + io: IOBackend, + start_time: datetime = datetime(2025, 1, 1), + n_steps: int = 20, + n_samples: int = 16, + seeds: Sequence[int] | None = None, + variables: Sequence[str] | None = None, + collection_id: str | None = None, + ) -> None: + self.validate_start_time(start_time) + lead_times = np.array([np.timedelta64(i * 6, "h") for i in range(n_steps + 1)]) + seeds = self.validate_samples(n_samples, seeds) + variables = self.validate_variables(variables) + + x_ori, coords_ori = self.get_fcn3_input(start_time) + + output_coords = CoordSystem( + { + "ensemble": np.arange(len(seeds)), + # Combine 'time' and 'lead_time' into single dimension + "time": to_time_array([start_time]) + lead_times, + "variable": variables, + "lat": np.linspace(90.0, -90.0, 721), + "lon": np.linspace(-180, 180, 1440, endpoint=False), + } + ) + self.setup_io(io, output_coords, seeds) + + logger.info("Starting inference") + total_samples = len(seeds) + n_steps += 1 # add 1 for step 0 (initial conditions) + for sample, seed in enumerate(seeds): + + self.fcn3.set_rng(seed=seed) + iterator = self.fcn3.create_iterator(x_ori.clone(), coords_ori.copy()) + for step, (x, coords) in enumerate(iterator): + # Update progress for step within sample + msg = ( + f"Processing sample {sample + 1}/{total_samples} " + f"(seed={seed}), step {step + 1}/{len(lead_times)}" + ) + progress = WorkflowProgress( + progress=msg, + current_step=step + 1, + total_steps=n_steps, + ) + self.update_progress(progress) + logger.info(msg) + + # Select variables + x_out, coords_out = map_coords( + x, coords, CoordSystem({"variable": output_coords["variable"]}) + ) + # Roll longitudes (for raster visualization) + x_out = torch.roll(x_out, 720, dims=-1) + coords_out["lon"] = np.linspace(-180, 180, 1440, endpoint=False) + # Add ensemble dimension + x_out = x_out.unsqueeze(0) + coords_out["ensemble"] = np.array([sample]) + coords_out.move_to_end("ensemble", last=False) + # Combine time and lead_time + lead_time_dim = list(coords_out).index("lead_time") + x_out = x_out.squeeze(lead_time_dim) + coords_out["time"] = coords_out["time"] + coords_out["lead_time"] + del coords_out["lead_time"] + # Write to disk + io.write(*split_coords(x_out, coords_out)) + + if step == (n_steps - 1): + break diff --git a/serve/server/example_workflows/foundry_fcn3_stormscope_goes.py b/serve/server/example_workflows/foundry_fcn3_stormscope_goes.py new file mode 100644 index 000000000..793ad016d --- /dev/null +++ b/serve/server/example_workflows/foundry_fcn3_stormscope_goes.py @@ -0,0 +1,456 @@ +import logging +from collections import OrderedDict +from collections.abc import Sequence +from datetime import datetime, timedelta + +import numpy as np +import torch +import xarray as xr +import zarr + +from earth2studio.data import ( + GOES, + InferenceOutputSource, + PlanetaryComputerECMWFOpenDataIFS, + PlanetaryComputerGOES, + fetch_data, +) +from earth2studio.io import IOBackend, NetCDF4Backend, XarrayBackend, ZarrBackend +from earth2studio.models.dx import DerivedSurfacePressure +from earth2studio.models.px import FCN3, DiagnosticWrapper, InterpModAFNO +from earth2studio.models.px.stormscope import ( + StormScopeBase, + StormScopeGOES, +) +from earth2studio.serve.server import ( + Earth2Workflow, + WorkflowProgress, + workflow_registry, +) +from earth2studio.utils.coords import CoordSystem, map_coords, split_coords +from earth2studio.utils.time import to_time_array + +logger = logging.getLogger("foundry_fcn3_stormscope_goes_workflow") +logger.setLevel(logging.INFO) + +GOES_MODEL_NAME = "6km_60min_natten_cos_zenith_input_eoe_v2" + + +@workflow_registry.register +class FoundryFCN3StormScopeGOESWorkflow(Earth2Workflow): + name = "foundry_fcn3_stormscope_goes_workflow" + description = "FCN3+StormScopeGOES ensemble workflow for Foundry" + + def __init__( + self, + device: str = "cuda", + init_seed: int = 1234, + ): + super().__init__() + + self.device = torch.device(device) + + self.fcn3_interp = self.load_fcn3_interp() + self.stormscope = self.load_stormscope() + self.rng = np.random.default_rng(init_seed) + + self.data_fcn3 = PlanetaryComputerECMWFOpenDataIFS(verbose=False, cache=False) + + scan_mode = "C" + self.data_stormscope = { + satellite: PlanetaryComputerGOES( + satellite=satellite, scan_mode=scan_mode, verbose=False, cache=False + ) + for satellite in ["goes16", "goes19"] + } + + # GOES-16 and GOES19 have the same grid + goes_lat, goes_lon = GOES.grid(satellite="goes16", scan_mode=scan_mode) + coords_out = self.fcn3_interp.output_coords(self.fcn3_interp.input_coords()) + self.stormscope.build_input_interpolator(goes_lat, goes_lon) + self.stormscope.build_conditioning_interpolator( + coords_out["lat"], coords_out["lon"] + ) + + def load_fcn3_interp(self) -> InterpModAFNO: + logger.info("Loading FCN3") + package = FCN3.load_default_package() + fcn3 = FCN3.load_model(package) + + # Surface pressure interpolation + orography_fn = package.resolve("orography.nc") + with xr.open_dataset(orography_fn) as ds: + z_surface = torch.as_tensor(ds["Z"][0].values) + z_surf_coords = OrderedDict({d: fcn3.input_coords()[d] for d in ["lat", "lon"]}) + sp_model = DerivedSurfacePressure( + p_levels=[50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000], + surface_geopotential=z_surface, + surface_geopotential_coords=z_surf_coords, + ) + + # Bundle surface pressure with FCN3 + fcn3_sp = DiagnosticWrapper(px_model=fcn3, dx_model=sp_model) + + # Add temporal interpolation to 1 hour + fcn3_interp = InterpModAFNO.from_pretrained() + fcn3_interp.px_model = fcn3_sp + fcn3_interp.to(device=self.device) + fcn3_interp.eval() + return fcn3_interp + + def load_stormscope(self) -> StormScopeGOES: + logger.info("Loading StormScope") + package = StormScopeBase.load_default_package() + stormscope = StormScopeGOES.load_model( + package=package, + conditioning_data_source=None, # set later + model_name=GOES_MODEL_NAME, + ) + stormscope.to(self.device) + stormscope.eval() + return stormscope + + def get_seeds(self, n_seeds: int) -> list[int]: + seeds = self.rng.choice(2**32, size=n_seeds, replace=False) + return [int(s) for s in seeds] + + def validate_start_times( + self, time_stormscope: datetime, time_fcn3: datetime + ) -> None: + ref = datetime(1900, 1, 1) + if (time_stormscope - ref).total_seconds() % (1 * 60 * 60) != 0: + raise ValueError( + f"Start time for StormScope must be 1-hour interval: {time_stormscope}" + ) + if (time_fcn3 - ref).total_seconds() % (6 * 60 * 60) != 0: + raise ValueError( + f"Start time for FCN3 must be 6-hour interval: {time_fcn3}" + ) + if time_stormscope < time_fcn3: + raise ValueError( + "Start time for StormScope cannot preceed start time for FCN3" + ) + if time_stormscope - time_fcn3 > timedelta(hours=12): + logger.warning( + "Start times for StormScope and FCN3 should not be more than 12 hours apart but got '%s' and '%s'", + time_stormscope, + time_fcn3, + ) + + def validate_samples( + self, n_samples: int, seeds: Sequence[int] | None + ) -> list[int]: + if not seeds: + return self.get_seeds(n_samples) + if len(seeds) != n_samples: + logger.warning( + "Ignoring requested number of samples because it does not match number of seeds" + ) + return list(seeds) + + def validate_variables(self, variables: Sequence[str] | None) -> np.ndarray: + if variables is None: + variables = self.stormscope.variables + else: + unknown_variables = set(variables) - set(self.stormscope.variables) + if len(unknown_variables): + raise ValueError(f"Unknown variable(s) {', '.join(unknown_variables)}") + variables = np.array(variables) + return variables + + def setup_io( + self, + io: IOBackend, + output_coords: CoordSystem, + seeds_fcn3: Sequence[int], + seeds_stormscope: Sequence[int], + ) -> None: + io.add_array( + {k: v for k, v in output_coords.items() if k != "variable"}, + output_coords["variable"], + ) + + # Storing seeds separately makes it easier to filter with Titiler + e_coords = {"ensemble": output_coords["ensemble"]} + n_stormscope_per_fcn3 = len(seeds_stormscope) // len(seeds_fcn3) + tiled_seeds_fcn3 = np.repeat(seeds_fcn3, n_stormscope_per_fcn3) + io.add_array(e_coords, "seed_fcn3", data=torch.tensor(tiled_seeds_fcn3)) + io.add_array(e_coords, "seed_stormscope", data=torch.tensor(seeds_stormscope)) + + # Add CRS definition + io.add_array({}, "crs") + io.root["crs"].grid_mapping_name = "lambert_conformal_conic" + io.root["crs"].standard_parallel = 38.5 + io.root["crs"].longitude_of_central_meridian = 262.5 + io.root["crs"].latitude_of_projection_origin = 38.5 + io.root["crs"].semi_major_axis = 6371229 + io.root["crs"].semi_minor_axis = 6371229 + + for var in output_coords["variable"]: + io.root[var].grid_mapping = "crs" + + # Set attributes for automatic parsing of dimensions + io.root["ensemble"].standard_name = "realization" + io.root["time"].standard_name = "time" + io.root["time"].axis = "T" + io.root["y"].standard_name = "projection_y_coordinate" + io.root["y"].units = "m" + io.root["y"].axis = "Y" + io.root["x"].standard_name = "projection_x_coordinate" + io.root["x"].units = "m" + io.root["x"].axis = "X" + + if isinstance(io, ZarrBackend): + zarr.consolidate_metadata(io.store) + + if isinstance(io, NetCDF4Backend): + # Planetary Computer does not like the original time format + ref_time = np.datetime_as_string(output_coords["time"][0], unit="s") + io["time"].units = f"hours since {ref_time.replace('T', ' ')}" + io["time"][:] = np.arange(len(io["time"])) + + return io + + def get_fcn3_input(self, time: datetime) -> tuple[torch.Tensor, CoordSystem]: + x, coords = fetch_data( + self.data_fcn3, + time=to_time_array([time]), + variable=self.fcn3_interp.input_coords()["variable"], + device=self.device, + ) + return x, coords + + def get_stormscope_input(self, time: datetime) -> tuple[torch.Tensor, CoordSystem]: + coords_in = self.stormscope.input_coords() + if time < datetime(2025, 4, 7): + data = self.data_stormscope["goes16"] + else: + data = self.data_stormscope["goes19"] + x, coords = fetch_data( + data, + time=to_time_array([time]), + variable=coords_in["variable"], + lead_time=coords_in["lead_time"], + device=self.device, + ) + + batch_size = 1 + if x.dim() == 5: + x = x.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1, 1) + coords["batch"] = np.arange(batch_size) + coords.move_to_end("batch", last=False) + + x, coords = self.stormscope.prep_input(x, coords) + x = torch.where(self.stormscope.valid_mask, x, torch.nan) + + return x, coords + + def run_fcn3( + self, + io: IOBackend, + x: torch.Tensor, + coords_x: CoordSystem, + seed_fcn3: int, + start_time_stormscope: datetime, + lead_times: np.ndarray, + sample: int, + total_samples: int, + ) -> None: + # Create z500 conditioning with FCN3 + coords_in = self.stormscope.input_coords() + start_time_stormscope = to_time_array([start_time_stormscope]) + variables = self.stormscope.conditioning_variables + # Start time and lead times are shifted to StormScope start time + output_coords = { + "time": start_time_stormscope, + "lead_time": lead_times, + "variable": variables, + "y": coords_in["y"], + "x": coords_in["x"], + } + io.add_array( + {k: v for k, v in output_coords.items() if k != "variable"}, variables + ) + + model_gap = int( + (start_time_stormscope - coords_x["time"]) / np.timedelta64(1, "h") + ) + + self.fcn3_interp.px_model.px_model.set_rng(seed=seed_fcn3) + iterator = self.fcn3_interp.create_iterator(x.clone(), coords_x.copy()) + + n_steps = model_gap + len(lead_times) + for step, (x, coords_x) in enumerate(iterator): + # Update progress for FCN3 step + msg = ( + f"Processing FCN3 for sample {sample + 1}/{total_samples} " + f"(seed_fcn3={seed_fcn3}) " + f"step {step + 1}/{n_steps}" + ) + progress = WorkflowProgress( + progress=msg, + current_step=step + 1, + total_steps=n_steps, + ) + self.update_progress(progress) + logger.info(msg) + + if step < model_gap: + # Skip initial steps leading up to StormScope start time + continue + + x, coords_x = map_coords(x, coords_x, OrderedDict({"variable": variables})) + x, coords_x = self.stormscope.prep_input(x, coords_x, conditioning=True) + coords_x["time"] = start_time_stormscope + coords_x["lead_time"] = coords_x["lead_time"] - np.timedelta64( + model_gap, "h" + ) + io.write(*split_coords(x, coords_x)) + + if step == (n_steps - 1): + break + + def run_stormscope( + self, + io: IOBackend, + y: torch.Tensor, + coords_y: CoordSystem, + seed_fcn3: int, + seed_stormscope: int, + lead_times: np.ndarray, + variables: np.ndarray, + sample: int, + total_samples: int, + ) -> None: + n_steps = len(lead_times) + + def log_progress(step: int) -> None: + msg = ( + f"Processing sample {sample + 1}/{total_samples} " + f"(seed_fcn3={seed_fcn3}, seed_stormscope={seed_stormscope}), " + f"step {step + 1}/{n_steps}" + ) + progress = WorkflowProgress( + progress=msg, + current_step=step + 1, + total_steps=n_steps, + ) + self.update_progress(progress) + logger.info(msg) + + def prep_output( + y_pred: torch.Tensor, coords_pred: CoordSystem + ) -> tuple[torch.Tensor, CoordSystem]: + y_out, coords_out = map_coords( + y_pred, coords_pred, CoordSystem({"variable": variables}) + ) + del coords_out["batch"] + # Reuse batch dimension as ensemble dimension (squeeze/unsqueeze) + coords_out["ensemble"] = np.array([sample]) + coords_out.move_to_end("ensemble", last=False) + # Combine time and lead_time + lead_time_dim = list(coords_out).index("lead_time") + y_out = y_out.squeeze(lead_time_dim) + coords_out["time"] = coords_out["time"] + coords_out["lead_time"] + del coords_out["lead_time"] + return y_out, coords_out + + # Update progress for step within sample + log_progress(0) + + # Store initial GOES data (identical across seeds) + y_out, coords_out = prep_output(y, coords_y) + io.write(*split_coords(y_out, coords_out)) + + # Cannot use seeded Generator before torch==2.10 + # Use self.stormscope.sampler_args["randn_like"] once updated + torch.manual_seed(seed_stormscope) + + for step in range(1, n_steps): + y_pred, coords_pred = self.stormscope(y, coords_y) + + # Update progress for step within sample + log_progress(step) + + y_out, coords_out = prep_output(y_pred, coords_pred) + io.write(*split_coords(y_out, coords_out)) + + if step == (n_steps - 1): + break + + y, coords_y = self.stormscope.next_input(y_pred, coords_pred, y, coords_y) + + def __call__( + self, + io: IOBackend, + start_time_fcn3: datetime = datetime(2025, 1, 1, 18), + start_time_stormscope: datetime = datetime(2025, 1, 1, 18), + n_steps: int = 12, + n_samples_fcn3: int = 4, + n_samples_stormscope: int = 16, + seeds_fcn3: Sequence[int] | None = None, + seeds_stormscope: Sequence[int] | None = None, + variables: Sequence[str] | None = None, + collection_id: str | None = None, + ) -> None: + self.validate_start_times(start_time_stormscope, start_time_fcn3) + lead_times = np.array([np.timedelta64(i, "h") for i in range(n_steps + 1)]) + # Different StormScope seed for every trajectory + if n_samples_stormscope % n_samples_fcn3 != 0: + raise ValueError( + "'n_samples_stormscope' must be divisible by 'n_samples_fcn3'" + ) + seeds_fcn3 = self.validate_samples(n_samples_fcn3, seeds_fcn3) + seeds_stormscope = self.validate_samples(n_samples_stormscope, seeds_stormscope) + n_stormscope_per_fcn3 = len(seeds_stormscope) // len(seeds_fcn3) + variables = self.validate_variables(variables) + + x_ori, coords_x_ori = self.get_fcn3_input(start_time_fcn3) + y_ori, coords_y_ori = self.get_stormscope_input(start_time_stormscope) + + coords_out = self.stormscope.output_coords(self.stormscope.input_coords()) + output_coords = { + "ensemble": np.arange(len(seeds_stormscope)), + # Planetary Computer does not like separate 'lead_time' + "time": to_time_array([start_time_stormscope]) + lead_times, + "variable": variables, + "y": coords_out["y"], + "x": coords_out["x"], + } + self.setup_io(io, output_coords, seeds_fcn3, seeds_stormscope) + + total_samples = len(seeds_stormscope) + sample = 0 + for seed_fcn3 in seeds_fcn3: + # Generate FCN3 conditioning (z500) + logger.info("Starting FCN3 inference") + io_fcn3 = XarrayBackend() + self.run_fcn3( + io=io_fcn3, + x=x_ori.clone(), + coords_x=coords_x_ori.copy(), + seed_fcn3=seed_fcn3, + start_time_stormscope=start_time_stormscope, + lead_times=lead_times, + sample=sample, + total_samples=total_samples, + ) + self.stormscope.conditioning_data_source = InferenceOutputSource( + io_fcn3.root + ) + + # Run StormScope forecast conditioned on FCN3 + logger.info("Starting StormScope inference") + for _ in range(n_stormscope_per_fcn3): + self.run_stormscope( + io=io, + y=y_ori.clone(), + coords_y=coords_y_ori.copy(), + seed_fcn3=seed_fcn3, + seed_stormscope=seeds_stormscope[sample], + lead_times=lead_times, + variables=variables, + sample=sample, + total_samples=total_samples, + ) + sample += 1 diff --git a/serve/server/planetary_computer/parameters-fcn3-stormscope-goes.json b/serve/server/planetary_computer/parameters-fcn3-stormscope-goes.json new file mode 100644 index 000000000..041a37947 --- /dev/null +++ b/serve/server/planetary_computer/parameters-fcn3-stormscope-goes.json @@ -0,0 +1,4 @@ +{ + "step_size_hours": 1, + "start_time_parameter_key": "start_time_stormscope" +} diff --git a/serve/server/planetary_computer/parameters-fcn3.json b/serve/server/planetary_computer/parameters-fcn3.json new file mode 100644 index 000000000..00d7cb222 --- /dev/null +++ b/serve/server/planetary_computer/parameters-fcn3.json @@ -0,0 +1,4 @@ +{ + "step_size_hours": 6, + "start_time_parameter_key": "start_time" +} diff --git a/serve/server/planetary_computer/render-options-fcn3-stormscope-goes.json b/serve/server/planetary_computer/render-options-fcn3-stormscope-goes.json new file mode 100644 index 000000000..550ecbad5 --- /dev/null +++ b/serve/server/planetary_computer/render-options-fcn3-stormscope-goes.json @@ -0,0 +1,10 @@ +[ + {"id": "abi01c", "scale": [0, 1], "cmap": "plasma"}, + {"id": "abi02c", "scale": [0, 1], "cmap": "plasma"}, + {"id": "abi03c", "scale": [0, 1], "cmap": "plasma"}, + {"id": "abi07c", "scale": [0, 1], "cmap": "plasma"}, + {"id": "abi08c", "scale": [0, 1], "cmap": "plasma"}, + {"id": "abi09c", "scale": [0, 1], "cmap": "plasma"}, + {"id": "abi10c", "scale": [0, 1], "cmap": "plasma"}, + {"id": "abi13c", "scale": [0, 1], "cmap": "plasma"} +] diff --git a/serve/server/planetary_computer/render-options-fcn3.json b/serve/server/planetary_computer/render-options-fcn3.json new file mode 100644 index 000000000..0e2dee421 --- /dev/null +++ b/serve/server/planetary_computer/render-options-fcn3.json @@ -0,0 +1,7 @@ +[ + {"id": "t2m", "scale": [263, 313], "cmap": "balance"}, + {"id": "t850", "scale": [263, 313], "cmap": "balance"}, + {"id": "u10m", "scale": [-20, 20], "cmap": "prgn"}, + {"id": "v10m", "scale": [-20, 20], "cmap": "prgn"}, + {"id": "z500", "scale": [45000, 60000], "cmap": "viridis"} +] diff --git a/serve/server/planetary_computer/template-collection-fcn3-stormscope-goes.json b/serve/server/planetary_computer/template-collection-fcn3-stormscope-goes.json new file mode 100644 index 000000000..62d5205b3 --- /dev/null +++ b/serve/server/planetary_computer/template-collection-fcn3-stormscope-goes.json @@ -0,0 +1,42 @@ +{ + "type": "Collection", + "stac_version": "1.0.0", + "id": "earth-2-fcn3-stormscope-goes-{uuid}", + "title": "Earth-2 StormScope-GOES conditioned on FourCastNet3", + "description": "Forecasts generated with the FourCastNet3-conditioned StormScope-GOES workflow on Microsoft Foundry.", + "links": [], + "stac_extensions": [], + "item_assets": { + "data": { + "type": "application/x-netcdf", + "roles": [] + } + }, + "extent": { + "spatial": { + "bbox": [ + [ + -135, + 20, + -60, + 53 + ] + ] + }, + "temporal": { + "interval": [ + [ + null, + null + ] + ] + } + }, + "license": "other", + "keywords": [ + "CONUS", + "Forecast", + "Earth-2" + ], + "providers": [] +} diff --git a/serve/server/planetary_computer/template-collection-fcn3.json b/serve/server/planetary_computer/template-collection-fcn3.json new file mode 100644 index 000000000..21fc63b56 --- /dev/null +++ b/serve/server/planetary_computer/template-collection-fcn3.json @@ -0,0 +1,42 @@ +{ + "type": "Collection", + "stac_version": "1.0.0", + "id": "earth-2-fcn3-{uuid}", + "title": "Earth-2 FourCastNet3", + "description": "Forecasts generated with the FourCastNet3 workflow on Microsoft Foundry.", + "links": [], + "stac_extensions": [], + "item_assets": { + "data": { + "type": "application/x-netcdf", + "roles": [] + } + }, + "extent": { + "spatial": { + "bbox": [ + [ + -180, + -90, + 180, + 90 + ] + ] + }, + "temporal": { + "interval": [ + [ + null, + null + ] + ] + } + }, + "license": "other", + "keywords": [ + "Global", + "Forecast", + "Earth-2" + ], + "providers": [] +} diff --git a/serve/server/planetary_computer/template-feature-fcn3-stormscope-goes.json b/serve/server/planetary_computer/template-feature-fcn3-stormscope-goes.json new file mode 100644 index 000000000..3a8d389f9 --- /dev/null +++ b/serve/server/planetary_computer/template-feature-fcn3-stormscope-goes.json @@ -0,0 +1,54 @@ +{ + "type": "Feature", + "stac_version": "1.1.0", + "id": "fcn3-stormscope-goes-{start_time}-{uuid}", + "stac_extensions": [], + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [ + -60, + 20 + ], + [ + -60, + 53 + ], + [ + -135, + 53 + ], + [ + -135, + 20 + ], + [ + -60, + 20 + ] + ] + ] + }, + "bbox": [ + -135, + 20, + -60, + 53 + ], + "properties": { + "datetime": "{start_time}", + "start_datetime": "{start_time}", + "end_datetime": "{end_time}" + }, + "links": [], + "assets": { + "data": { + "href": "{blob_url}", + "type": "application/x-netcdf", + "title": "StormScope-GOES forecast conditioned on FourCastNet3", + "description": "FourCastNet3-conditioned StormScope-GOES forecast from {start_time} to {end_time}.", + "roles": [] + } + } +} diff --git a/serve/server/planetary_computer/template-feature-fcn3.json b/serve/server/planetary_computer/template-feature-fcn3.json new file mode 100644 index 000000000..38fb9a801 --- /dev/null +++ b/serve/server/planetary_computer/template-feature-fcn3.json @@ -0,0 +1,54 @@ +{ + "type": "Feature", + "stac_version": "1.1.0", + "id": "fcn3-{start_time}-{uuid}", + "stac_extensions": [], + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [ + 180, + -90 + ], + [ + 180, + 90 + ], + [ + -180, + 90 + ], + [ + -180, + -90 + ], + [ + 180, + -90 + ] + ] + ] + }, + "bbox": [ + -180, + -90, + 180, + 90 + ], + "properties": { + "datetime": "{start_time}", + "start_datetime": "{start_time}", + "end_datetime": "{end_time}" + }, + "links": [], + "assets": { + "data": { + "href": "{blob_url}", + "type": "application/x-netcdf", + "title": "FourCastNet3 forecast", + "description": "FourCastNet3 forecast from {start_time} to {end_time}.", + "roles": [] + } + } +} diff --git a/serve/server/planetary_computer/tile-settings-fcn3-stormscope-goes.json b/serve/server/planetary_computer/tile-settings-fcn3-stormscope-goes.json new file mode 100644 index 000000000..b68afa14f --- /dev/null +++ b/serve/server/planetary_computer/tile-settings-fcn3-stormscope-goes.json @@ -0,0 +1,4 @@ +{ + "minZoom": 0, + "maxItemsPerTile": 35 +} diff --git a/serve/server/planetary_computer/tile-settings-fcn3.json b/serve/server/planetary_computer/tile-settings-fcn3.json new file mode 100644 index 000000000..b68afa14f --- /dev/null +++ b/serve/server/planetary_computer/tile-settings-fcn3.json @@ -0,0 +1,4 @@ +{ + "minZoom": 0, + "maxItemsPerTile": 35 +} diff --git a/serve/server/requirements.txt b/serve/server/requirements.txt index 9003236f8..69b3a0404 100644 --- a/serve/server/requirements.txt +++ b/serve/server/requirements.txt @@ -32,6 +32,13 @@ hydra-core>=1.3.0 # Cryptography for CloudFront signed URLs cryptography>=41.0.0 +# Multi-Storage Client with Azure support +multi-storage-client>=0.44.0 +# Azure Blob Storage for SAS token generation +azure-storage-blob>=12.19.0 +# Azure Identity for managed identity authentication +azure-identity>=1.15.0 + # Development dependencies pytest>=7.0.0 pytest-asyncio>=0.21.0 diff --git a/serve/server/scripts/start_api_server.sh b/serve/server/scripts/start_api_server.sh index 7bcf20fcf..67533721a 100755 --- a/serve/server/scripts/start_api_server.sh +++ b/serve/server/scripts/start_api_server.sh @@ -56,6 +56,7 @@ if [ -f "$CONFIG_FILE" ]; then CONFIG_RQ_NUM_WORKERS=$(read_config "worker.num_workers") CONFIG_ZIP_NUM_WORKERS=$(read_config "worker.zip_num_workers") CONFIG_OBJSTORE_NUM_WORKERS=$(read_config "worker.objstore_num_workers") + CONFIG_GEOCATALOG_NUM_WORKERS=$(read_config "worker.geocatalog_num_workers") CONFIG_FINALIZE_NUM_WORKERS=$(read_config "worker.finalize_num_workers") CONFIG_PERSISTENT_WORKER=$(read_config "worker.persistent") fi @@ -66,10 +67,11 @@ REDIS_HOST=${3:-${CONFIG_REDIS_HOST:-localhost}} # Default Redis host NUM_RQ_WORKERS=${4:-${CONFIG_RQ_NUM_WORKERS:-1}} # Default to 1 RQ workers NUM_ZIP_WORKERS=${5:-${CONFIG_ZIP_NUM_WORKERS:-1}} # Default to 1 workers for result_zip queue NUM_OBJSTORE_WORKERS=${CONFIG_OBJSTORE_NUM_WORKERS:-1} # Default to 1 object storage worker +NUM_GEOCATALOG_WORKERS=${CONFIG_GEOCATALOG_NUM_WORKERS:-1} # Default to 1 geocatalog ingestion worker NUM_FINALIZE_WORKERS=${CONFIG_FINALIZE_NUM_WORKERS:-1} # Default to 1 finalize metadata worker PERSISTENT_WORKER=${CONFIG_PERSISTENT_WORKER:-false} -echo "Starting Earth2Studio with $NUM_WORKERS API workers, $NUM_RQ_WORKERS RQ workers, $NUM_ZIP_WORKERS zip workers, $NUM_OBJSTORE_WORKERS object storage workers, and $NUM_FINALIZE_WORKERS finalize workers on port $API_PORT..." +echo "Starting Earth2Studio with $NUM_WORKERS API workers, $NUM_RQ_WORKERS RQ workers, $NUM_ZIP_WORKERS zip workers, $NUM_OBJSTORE_WORKERS object storage workers, $NUM_GEOCATALOG_WORKERS geocatalog workers, and $NUM_FINALIZE_WORKERS finalize workers on port $API_PORT..." echo "Configuration: Redis=$REDIS_HOST, Persistent Worker=$PERSISTENT_WORKER" # Function to cleanup on exit @@ -106,6 +108,12 @@ cleanup() { pkill -f "rq.*worker.*object_storage" fi + # Stop all geocatalog ingestion workers + if pgrep -f "rq.*worker.*geocatalog_ingestion" > /dev/null; then + echo "Stopping geocatalog ingestion workers..." + pkill -f "rq.*worker.*geocatalog_ingestion" + fi + # Stop all finalize metadata workers if pgrep -f "rq.*worker.*finalize_metadata" > /dev/null; then echo "Stopping finalize metadata workers..." @@ -122,7 +130,7 @@ trap cleanup SIGINT SIGTERM export EARTH2STUDIO_API_ACTIVE=1 # Start multiple workers using uvicorn with extended timeouts for large file downloads -uvicorn earth2studio.serve.server.main:app --host 0.0.0.0 --port $API_PORT --workers $NUM_WORKERS --loop asyncio --timeout-keep-alive 300 --timeout-graceful-shutdown 30 & +CUDA_VISIBLE_DEVICES="" uvicorn earth2studio.serve.server.main:app --host 0.0.0.0 --port $API_PORT --workers $NUM_WORKERS --loop asyncio --timeout-keep-alive 300 --timeout-graceful-shutdown 30 & UVICORN_PID=$! # Start RQ workers @@ -146,7 +154,7 @@ done echo "Starting $NUM_ZIP_WORKERS zip workers for result_zip queue..." ZIP_WORKER_PIDS=() for i in $(seq 1 $NUM_ZIP_WORKERS); do - rq worker -w rq.worker.SimpleWorker result_zip & + CUDA_VISIBLE_DEVICES="" rq worker -w rq.worker.SimpleWorker result_zip & ZIP_WORKER_PIDS+=($!) echo "Started zip worker $i for result_zip queue (PID: $!)" done @@ -155,22 +163,31 @@ done echo "Starting $NUM_OBJSTORE_WORKERS object storage workers..." OBJSTORE_WORKER_PIDS=() for i in $(seq 1 $NUM_OBJSTORE_WORKERS); do - rq worker -w rq.worker.SimpleWorker object_storage & + CUDA_VISIBLE_DEVICES="" rq worker -w rq.worker.SimpleWorker object_storage & OBJSTORE_WORKER_PIDS+=($!) echo "Started object storage worker $i (PID: $!)" done +# Start geocatalog ingestion workers (used when AZURE_GEOCATALOG_URL is set) +echo "Starting $NUM_GEOCATALOG_WORKERS geocatalog ingestion workers..." +GEOCATALOG_WORKER_PIDS=() +for i in $(seq 1 $NUM_GEOCATALOG_WORKERS); do + CUDA_VISIBLE_DEVICES="" rq worker -w rq.worker.SimpleWorker geocatalog_ingestion & + GEOCATALOG_WORKER_PIDS+=($!) + echo "Started geocatalog ingestion worker $i (PID: $!)" +done + # Start finalize metadata workers echo "Starting $NUM_FINALIZE_WORKERS finalize metadata workers..." FINALIZE_WORKER_PIDS=() for i in $(seq 1 $NUM_FINALIZE_WORKERS); do - rq worker -w rq.worker.SimpleWorker finalize_metadata & + CUDA_VISIBLE_DEVICES="" rq worker -w rq.worker.SimpleWorker finalize_metadata & FINALIZE_WORKER_PIDS+=($!) echo "Started finalize metadata worker $i (PID: $!)" done # Start cleanup daemon -python -m earth2studio.serve.server.cleanup_daemon & +CUDA_VISIBLE_DEVICES="" python -m earth2studio.serve.server.cleanup_daemon & CLEANUP_DAEMON_PID=$! echo "Started cleanup daemon (PID: $CLEANUP_DAEMON_PID)" @@ -205,6 +222,13 @@ if [ "$OBJSTORE_WORKER_COUNT" -eq 0 ]; then exit 1 fi +# Check if geocatalog ingestion workers are running +GEOCATALOG_WORKER_COUNT=$(pgrep -f "rq.*worker.*geocatalog_ingestion" | wc -l) +if [ "$GEOCATALOG_WORKER_COUNT" -eq 0 ]; then + echo "Failed to start geocatalog ingestion workers..." + exit 1 +fi + # Check if finalize metadata workers are running FINALIZE_WORKER_COUNT=$(pgrep -f "rq.*worker.*finalize_metadata" | wc -l) if [ "$FINALIZE_WORKER_COUNT" -eq 0 ]; then @@ -224,12 +248,14 @@ echo "Uvicorn PID: $UVICORN_PID" echo "RQ Worker PIDs: ${RQ_WORKER_PIDS[*]}" echo "Zip Worker PIDs: ${ZIP_WORKER_PIDS[*]}" echo "Object Storage Worker PIDs: ${OBJSTORE_WORKER_PIDS[*]}" +echo "Geocatalog Ingestion Worker PIDs: ${GEOCATALOG_WORKER_PIDS[*]}" echo "Finalize Metadata Worker PIDs: ${FINALIZE_WORKER_PIDS[*]}" echo "Cleanup Daemon PID: $CLEANUP_DAEMON_PID" echo "Active API workers: $API_WORKER_COUNT" echo "Active RQ inference workers: $RQ_WORKER_COUNT" echo "Active zip workers: $ZIP_WORKER_COUNT" echo "Active object storage workers: $OBJSTORE_WORKER_COUNT" +echo "Active geocatalog ingestion workers: $GEOCATALOG_WORKER_COUNT" echo "Active finalize metadata workers: $FINALIZE_WORKER_COUNT" echo "API available at http://localhost:$API_PORT" echo "API docs at http://localhost:$API_PORT/docs" @@ -238,7 +264,7 @@ echo "API docs at http://localhost:$API_PORT/docs" # Wait for health check to pass before invoking warmup workflow echo "" echo "Waiting for health check to pass..." -MAX_HEALTH_RETRIES=30 +MAX_HEALTH_RETRIES=60 HEALTH_RETRY_INTERVAL=2 for i in $(seq 1 $MAX_HEALTH_RETRIES); do if curl -s "http://localhost:$API_PORT/health" | grep -q '"status":"healthy"'; then diff --git a/serve/server/scripts/startup.sh b/serve/server/scripts/startup.sh index b5739fe07..a24ab6b56 100755 --- a/serve/server/scripts/startup.sh +++ b/serve/server/scripts/startup.sh @@ -17,6 +17,15 @@ set -euo pipefail +# Set EARTH2STUDIO_MODEL_CACHE to use AZUREML_MODEL_DIR if available +if [ -n "${AZUREML_MODEL_DIR:-}" ]; then + echo "AZUREML_MODEL_DIR: $AZUREML_MODEL_DIR" + export EARTH2STUDIO_MODEL_CACHE="$AZUREML_MODEL_DIR/${EARTH2STUDIO_MODEL_SUBPATH:-e2s_fcn3_stormscope}" + echo "--------------------------------" + echo "EARTH2STUDIO_MODEL_CACHE: $EARTH2STUDIO_MODEL_CACHE" + ls -la $EARTH2STUDIO_MODEL_CACHE && echo "--------------------------------" +fi + # Use CONFIG_DIR/SCRIPT_DIR from env if set (e.g. in Docker); else resolve from script location SCRIPT_DIR="${SCRIPT_DIR:-$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)}" SERVE_SERVER_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" diff --git a/serve/server/scripts/status.sh b/serve/server/scripts/status.sh index aadb6b94f..2dcdaadf4 100755 --- a/serve/server/scripts/status.sh +++ b/serve/server/scripts/status.sh @@ -65,7 +65,7 @@ fi echo "" # Check RQ workers status per queue -RQ_QUEUES=("inference" "result_zip" "object_storage" "finalize_metadata") +RQ_QUEUES=("inference" "result_zip" "object_storage" "geocatalog_ingestion" "finalize_metadata") RQ_ALL_OK=1 echo "RQ Workers:" diff --git a/serve/server/scripts/stop_api_server.sh b/serve/server/scripts/stop_api_server.sh index d5769013d..6f00e0830 100755 --- a/serve/server/scripts/stop_api_server.sh +++ b/serve/server/scripts/stop_api_server.sh @@ -33,6 +33,10 @@ pkill -f "rq.*worker.*result_zip" echo "Stopping object storage workers..." pkill -f "rq.*worker.*object_storage" +# Stop geocatalog ingestion workers +echo "Stopping geocatalog ingestion workers..." +pkill -f "rq.*worker.*geocatalog_ingestion" + # Stop finalize metadata workers echo "Stopping finalize metadata workers..." pkill -f "rq.*worker.*finalize_metadata" diff --git a/test/data/test_planetary_computer.py b/test/data/test_planetary_computer.py index d29f60d83..3de85deea 100644 --- a/test/data/test_planetary_computer.py +++ b/test/data/test_planetary_computer.py @@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import pathlib import shutil from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock, patch import numpy as np import pytest @@ -28,7 +30,10 @@ PlanetaryComputerOISST, PlanetaryComputerSentinel3AOD, ) -from earth2studio.data.planetary_computer import _PlanetaryComputerData +from earth2studio.data.planetary_computer import ( + GeoCatalogClient, + _PlanetaryComputerData, +) from earth2studio.lexicon.planetary_computer import PlanetaryComputerOISSTLexicon @@ -75,6 +80,216 @@ class DummyLexicon(PlanetaryComputerOISSTLexicon): ds.extract_variable_numpy(None, None, datetime.now(timezone.utc)) +@pytest.fixture +def geocatalog_config_dir(tmp_path): + workflow = "fcn3" + (tmp_path / f"parameters-{workflow}.json").write_text( + json.dumps({"step_size_hours": 6, "start_time_parameter_key": "start_time"}) + ) + (tmp_path / f"template-collection-{workflow}.json").write_text( + json.dumps({"type": "Collection", "id": "earth-2-fcn3-{uuid}", "title": "Test"}) + ) + (tmp_path / f"template-feature-{workflow}.json").write_text( + json.dumps( + { + "type": "Feature", + "id": "fcn3-{start_time}-{uuid}", + "properties": { + "datetime": "{start_time}", + "start_datetime": "{start_time}", + "end_datetime": "{end_time}", + }, + "assets": { + "data": { + "href": "PLACEHOLDER", + "description": "From {start_time} to {end_time}.", + } + }, + } + ) + ) + (tmp_path / f"tile-settings-{workflow}.json").write_text( + json.dumps({"minZoom": 0, "maxItemsPerTile": 35}) + ) + (tmp_path / f"render-options-{workflow}.json").write_text( + json.dumps([{"id": "t2m", "scale": [263, 313], "cmap": "balance"}]) + ) + return tmp_path + + +@pytest.fixture +def mock_azure_credential(): + with patch("azure.identity.DefaultAzureCredential") as m: + cred = MagicMock() + token = MagicMock() + token.token = "mock-token" # noqa: S105 + cred.get_token.return_value = token + m.return_value = cred + yield m + + +@pytest.fixture +def mock_requests(): + with patch("earth2studio.data.planetary_computer.requests") as m: + get_resp = MagicMock() + get_resp.status_code = 200 + get_resp.json.return_value = {"status": "Succeeded"} + get_resp.headers = {} + m.get.return_value = get_resp + + post_resp = MagicMock() + post_resp.status_code = 201 + post_resp.headers = {"location": "https://geocatalog.example/status/123"} + m.post.return_value = post_resp + + put_resp = MagicMock() + put_resp.status_code = 200 + m.put.return_value = put_resp + + yield m + + +def test_geocatalog_client_init_success(geocatalog_config_dir, mock_azure_credential): + client = GeoCatalogClient("fcn3", geocatalog_config_dir) + assert client._workflow_name == "fcn3" + assert client._parameters["step_size_hours"] == 6 + assert client._parameters["start_time_parameter_key"] == "start_time" + + +def test_geocatalog_client_init_missing_parameters_file( + tmp_path, mock_azure_credential +): + with pytest.raises(FileNotFoundError, match="parameters-other_workflow.json"): + GeoCatalogClient("other_workflow", tmp_path) + + +def test_geocatalog_client_resolve_start_time_iso_string( + geocatalog_config_dir, mock_azure_credential +): + client = GeoCatalogClient("fcn3", geocatalog_config_dir) + params = {"start_time": "2025-01-15T12:00:00Z"} + result = client._resolve_start_time(params) + assert result == datetime(2025, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + + +def test_geocatalog_client_resolve_start_time_iso_string_with_offset( + geocatalog_config_dir, mock_azure_credential +): + client = GeoCatalogClient("fcn3", geocatalog_config_dir) + params = {"start_time": "2025-01-15T12:00:00+00:00"} + result = client._resolve_start_time(params) + assert result == datetime(2025, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + + +def test_geocatalog_client_resolve_start_time_datetime( + geocatalog_config_dir, mock_azure_credential +): + client = GeoCatalogClient("fcn3", geocatalog_config_dir) + dt = datetime(2025, 2, 1, 6, 0, 0, tzinfo=timezone.utc) + params = {"start_time": dt} + result = client._resolve_start_time(params) + assert result == dt + + +def test_geocatalog_client_resolve_start_time_missing_key_raises( + geocatalog_config_dir, mock_azure_credential +): + client = GeoCatalogClient("fcn3", geocatalog_config_dir) + with pytest.raises(ValueError, match="Missing 'start_time'"): + client._resolve_start_time({}) + + +def test_geocatalog_client_resolve_start_time_invalid_type_raises( + geocatalog_config_dir, mock_azure_credential +): + client = GeoCatalogClient("fcn3", geocatalog_config_dir) + with pytest.raises(TypeError, match="start time must be str or datetime"): + client._resolve_start_time({"start_time": 12345}) + + +def test_geocatalog_client_get_collection_json_generates_uuid_when_collection_id_none( + geocatalog_config_dir, mock_azure_credential +): + client = GeoCatalogClient("fcn3", geocatalog_config_dir) + out = client._get_collection_json(None) + assert out["id"].startswith("earth-2-fcn3-") + assert len(out["id"]) > len("earth-2-fcn3-") + + +def test_geocatalog_client_get_collection_json_uses_id_when_collection_id_provided( + geocatalog_config_dir, mock_azure_credential +): + client = GeoCatalogClient("fcn3", geocatalog_config_dir) + out = client._get_collection_json("my-collection-id") + assert out["id"] == "my-collection-id" + + +def test_geocatalog_client_get_feature_json_formats_times_and_blob_url( + geocatalog_config_dir, mock_azure_credential +): + client = GeoCatalogClient("fcn3", geocatalog_config_dir) + start = datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + end = start + timedelta(hours=6) + out = client._get_feature_json( + start_time=start, + end_time=end, + blob_url="https://storage.example/container/blob.nc", + ) + assert out["properties"]["datetime"] == start.isoformat() + assert out["properties"]["start_datetime"] == start.isoformat() + assert out["properties"]["end_datetime"] == end.isoformat() + assert out["assets"]["data"]["href"] == "https://storage.example/container/blob.nc" + assert "2025-01-01" in out["assets"]["data"]["description"] + assert out["id"].startswith("fcn3-") + + +def test_geocatalog_client_create_feature_returns_collection_and_feature_id( + geocatalog_config_dir, mock_azure_credential, mock_requests +): + mock_requests.get.return_value.json.return_value = {"status": "Succeeded"} + client = GeoCatalogClient("fcn3", geocatalog_config_dir) + collection_id, feature_id = client.create_feature( + geocatalog_url="https://geocatalog.example/", + collection_id=None, + parameters={"start_time": "2025-01-01T00:00:00Z"}, + blob_url="https://storage.example/blob.nc", + ) + assert collection_id is not None + assert feature_id is not None + assert feature_id.startswith("fcn3-") + + +def test_geocatalog_client_create_feature_uses_existing_collection_when_provided( + geocatalog_config_dir, mock_azure_credential, mock_requests +): + mock_requests.get.return_value.status_code = 200 + mock_requests.get.return_value.json.return_value = {"status": "Succeeded"} + client = GeoCatalogClient("fcn3", geocatalog_config_dir) + collection_id, feature_id = client.create_feature( + geocatalog_url="https://geocatalog.example/", + collection_id="existing-collection", + parameters={"start_time": "2025-01-01T00:00:00Z"}, + blob_url="https://storage.example/blob.nc", + ) + assert collection_id == "existing-collection" + assert feature_id is not None + post_urls = [c[0][0] for c in mock_requests.post.call_args_list] + assert not any("/stac/collections" in u and "items" not in u for u in post_urls) + + +def test_geocatalog_client_create_feature_missing_start_time_raises( + geocatalog_config_dir, mock_azure_credential, mock_requests +): + client = GeoCatalogClient("fcn3", geocatalog_config_dir) + with pytest.raises(ValueError, match="Missing 'start_time'"): + client.create_feature( + geocatalog_url="https://geocatalog.example/", + collection_id="existing-collection", + parameters={}, + blob_url="https://storage.example/blob.nc", + ) + + @pytest.mark.slow @pytest.mark.xfail() @pytest.mark.timeout(60) diff --git a/test/serve/client/test_e2client.py b/test/serve/client/test_e2client.py index b1a2fc368..efe02c1ca 100644 --- a/test/serve/client/test_e2client.py +++ b/test/serve/client/test_e2client.py @@ -37,6 +37,7 @@ InferenceRequestResults, OutputFile, RequestStatus, + StorageType, ) @@ -244,6 +245,90 @@ def test_as_dataset_formats(self) -> None: mock_client.download_result.assert_called_once() mock_open_dataset.assert_called_once() + def test_as_dataset_azure_zarr(self) -> None: + """Test as_dataset for Azure storage zarr uses get_mapper with 'results.zarr'""" + with patch( + "earth2studio.serve.client.e2client.Earth2StudioClient" + ) as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + workflow = RemoteEarth2Workflow( + base_url="http://localhost:8000", + workflow_name="test_workflow", + ) + + mock_inference_result = InferenceRequestResults( + request_id="exec_azure", + status=RequestStatus.COMPLETED, + output_files=[OutputFile(path="results.zarr", size=2048)], + completion_time=datetime.now(), + storage_type=StorageType.AZURE, + signed_url="https://account.blob.core.windows.net/container/*?sig=abc", + ) + mock_client.wait_for_completion.return_value = mock_inference_result + + mock_ds = MagicMock() + with ( + patch( + "earth2studio.serve.client.e2client.fsspec_utils.get_mapper" + ) as mock_get_mapper, + patch("xarray.open_zarr", return_value=mock_ds) as mock_open_zarr, + ): + mock_mapper = Mock() + mock_get_mapper.return_value = mock_mapper + + result = RemoteEarth2WorkflowResult(workflow, "exec_azure") + ds = result.as_dataset() + + # get_mapper should be called with "results.zarr" for Azure + mock_get_mapper.assert_called_once_with( + mock_inference_result, "results.zarr" + ) + mock_open_zarr.assert_called_once_with(mock_mapper, consolidated=True) + assert ds is mock_ds + + def test_as_dataset_s3_zarr(self) -> None: + """Test as_dataset for S3 storage zarr uses get_mapper with execution_id-stripped path""" + with patch( + "earth2studio.serve.client.e2client.Earth2StudioClient" + ) as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + workflow = RemoteEarth2Workflow( + base_url="http://localhost:8000", + workflow_name="test_workflow", + ) + + mock_inference_result = InferenceRequestResults( + request_id="exec_s3", + status=RequestStatus.COMPLETED, + output_files=[OutputFile(path="exec_s3/results.zarr", size=2048)], + completion_time=datetime.now(), + storage_type=StorageType.S3, + signed_url="https://cdn.example.com/bucket?Policy=p&Signature=s&Key-Pair-Id=k", + ) + mock_client.wait_for_completion.return_value = mock_inference_result + + mock_ds = MagicMock() + with ( + patch( + "earth2studio.serve.client.e2client.fsspec_utils.get_mapper" + ) as mock_get_mapper, + patch("xarray.open_zarr", return_value=mock_ds), + ): + mock_get_mapper.return_value = Mock() + + result = RemoteEarth2WorkflowResult(workflow, "exec_s3") + ds = result.as_dataset() + + # S3 strips the first path component (execution_id prefix) + mock_get_mapper.assert_called_once_with( + mock_inference_result, "results.zarr" + ) + assert ds is mock_ds + def test_as_dataset_no_outputs(self) -> None: """Test as_dataset when no output files are available""" with patch( diff --git a/test/serve/client/test_fsspec_utils.py b/test/serve/client/test_fsspec_utils.py index ec0936d86..a1466ec36 100644 --- a/test/serve/client/test_fsspec_utils.py +++ b/test/serve/client/test_fsspec_utils.py @@ -21,7 +21,9 @@ import pytest from earth2studio.serve.client.fsspec_utils import ( + AzureSignedURLFileSystem, SignedURLFileSystem, + create_azure_mapper, create_cloudfront_mapper, get_mapper, ) @@ -291,6 +293,336 @@ def test_create_cloudfront_mapper_strips_trailing_wildcard(self) -> None: assert fs_arg._base_url.rstrip("/").endswith("bucket") +class TestAzureSignedURLFileSystemInit: + """Test AzureSignedURLFileSystem initialization.""" + + def test_init_stores_params(self) -> None: + """Init stores base_fs, query_params, base_url and builds query_string.""" + base_fs = Mock() + query_params = {"sv": "2021-08-06", "sig": "abc123", "se": "2024-01-01"} + base_url = "https://account.blob.core.windows.net/container/*" + fs = AzureSignedURLFileSystem(base_fs, query_params, base_url) + assert fs._fs is base_fs + assert fs._query_params == query_params + assert fs._base_url == base_url + assert "sv=2021-08-06" in fs._query_string + assert "sig=abc123" in fs._query_string + + +class TestAzureSignedURLFileSystemMakeSignedPath: + """Test AzureSignedURLFileSystem._make_signed_path.""" + + def test_wildcard_replaced_with_path(self) -> None: + """* in base_url is replaced with the actual path.""" + base_fs = Mock() + fs = AzureSignedURLFileSystem( + base_fs, + {"sig": "abc"}, + "https://account.blob.core.windows.net/container/*", + ) + out = fs._make_signed_path("results.zarr/.zmetadata") + assert "results.zarr/.zmetadata" in out + assert "*" not in out + assert "sig=abc" in out + + def test_no_wildcard_appends_path_to_base_url(self) -> None: + """Without * in base_url, path is appended normally.""" + base_fs = Mock() + fs = AzureSignedURLFileSystem( + base_fs, + {"sig": "abc"}, + "https://account.blob.core.windows.net/container/store", + ) + out = fs._make_signed_path("subdir/file") + assert out.startswith( + "https://account.blob.core.windows.net/container/store/subdir/file" + ) + assert "sig=abc" in out + + def test_path_starting_with_http_unchanged_with_query_appended(self) -> None: + """Path starting with http is used as full URL; query params appended.""" + base_fs = Mock() + fs = AzureSignedURLFileSystem( + base_fs, {"sig": "abc"}, "https://account.blob.core.windows.net/*" + ) + out = fs._make_signed_path("https://other.example.com/path") + assert out.startswith("https://other.example.com/path") + assert "sig=abc" in out + + def test_url_with_existing_query_uses_ampersand(self) -> None: + """If the URL already has ?, separator is &.""" + base_fs = Mock() + fs = AzureSignedURLFileSystem( + base_fs, {"sig": "abc"}, "https://account.blob.core.windows.net/*?foo=1" + ) + out = fs._make_signed_path("file") + assert "&sig=abc" in out + + def test_empty_path_uses_base_url(self) -> None: + """Empty path leaves base_url as-is (wildcard stays or replaced with empty).""" + base_fs = Mock() + base_url = "https://account.blob.core.windows.net/container/store" + fs = AzureSignedURLFileSystem(base_fs, {"sig": "x"}, base_url) + out = fs._make_signed_path("") + assert out.startswith(base_url) + assert "sig=x" in out + + +class TestAzureSignedURLFileSystemHandle403: + """Test AzureSignedURLFileSystem._handle_403 via method calls.""" + + def test_open_403_raises_file_not_found(self) -> None: + """Exception containing 403 raised in _open leads to FileNotFoundError.""" + base_fs = Mock() + base_fs._open.side_effect = Exception("HTTP 403 Forbidden") + fs = AzureSignedURLFileSystem( + base_fs, {}, "https://account.blob.core.windows.net/*" + ) + with pytest.raises(FileNotFoundError, match="File not found"): + fs._open("results.zarr/.zmetadata") + + def test_open_forbidden_text_raises_file_not_found(self) -> None: + """Exception containing 'forbidden' raised in _open leads to FileNotFoundError.""" + base_fs = Mock() + base_fs._open.side_effect = Exception("Access forbidden") + fs = AzureSignedURLFileSystem( + base_fs, {}, "https://account.blob.core.windows.net/*" + ) + with pytest.raises(FileNotFoundError): + fs._open("file") + + def test_open_other_exception_reraises(self) -> None: + """Non-403 exception raised in _open is reraised.""" + base_fs = Mock() + base_fs._open.side_effect = ValueError("network error") + fs = AzureSignedURLFileSystem( + base_fs, {}, "https://account.blob.core.windows.net/*" + ) + with pytest.raises(ValueError, match="network error"): + fs._open("file") + + +class TestAzureSignedURLFileSystemOpen: + """Test AzureSignedURLFileSystem._open.""" + + def test_open_success_returns_file_like(self) -> None: + """Successful _open returns result from base fs _open.""" + base_fs = Mock() + base_fs._open.return_value = MagicMock() + fs = AzureSignedURLFileSystem( + base_fs, {"sig": "abc"}, "https://account.blob.core.windows.net/c/*" + ) + result = fs._open("results.zarr/.zmetadata", mode="rb") + assert result is base_fs._open.return_value + base_fs._open.assert_called_once() + assert base_fs._open.call_args[1]["mode"] == "rb" + + +class TestAzureSignedURLFileSystemCatFile: + """Test AzureSignedURLFileSystem.cat_file.""" + + def test_cat_file_success(self) -> None: + """cat_file success returns content from base fs.""" + base_fs = Mock() + base_fs.cat_file.return_value = b"content" + fs = AzureSignedURLFileSystem( + base_fs, {"sig": "abc"}, "https://account.blob.core.windows.net/c/*" + ) + result = fs.cat_file("results.zarr/.zmetadata") + assert result == b"content" + + def test_cat_file_with_start_end(self) -> None: + """cat_file passes start/end to base fs.""" + base_fs = Mock() + base_fs.cat_file.return_value = b"xx" + fs = AzureSignedURLFileSystem( + base_fs, {}, "https://account.blob.core.windows.net/c/*" + ) + fs.cat_file("f", start=0, end=10) + base_fs.cat_file.assert_called_once() + assert base_fs.cat_file.call_args[1]["start"] == 0 + assert base_fs.cat_file.call_args[1]["end"] == 10 + + def test_cat_file_403_raises_file_not_found(self) -> None: + """cat_file 403 raises FileNotFoundError.""" + base_fs = Mock() + base_fs.cat_file.side_effect = Exception("403") + fs = AzureSignedURLFileSystem( + base_fs, {}, "https://account.blob.core.windows.net/c/*" + ) + with pytest.raises(FileNotFoundError): + fs.cat_file("x") + + def test_cat_file_other_exception_reraises(self) -> None: + """Non-403 exception from cat_file is reraised.""" + base_fs = Mock() + base_fs.cat_file.side_effect = IOError("read error") + fs = AzureSignedURLFileSystem( + base_fs, {}, "https://account.blob.core.windows.net/c/*" + ) + with pytest.raises(IOError, match="read error"): + fs.cat_file("x") + + +class TestAzureSignedURLFileSystemCatFileAsync: + """Test AzureSignedURLFileSystem._cat_file (delegates to cat_file).""" + + def test_cat_file_delegates_to_cat_file(self) -> None: + """_cat_file calls cat_file with same args.""" + base_fs = Mock() + base_fs.cat_file.return_value = b"data" + fs = AzureSignedURLFileSystem( + base_fs, {}, "https://account.blob.core.windows.net/c/*" + ) + result = fs._cat_file("a", start=1, end=2) + assert result == b"data" + base_fs.cat_file.assert_called_once() + assert base_fs.cat_file.call_args[1]["start"] == 1 + assert base_fs.cat_file.call_args[1]["end"] == 2 + + +class TestAzureSignedURLFileSystemInfo: + """Test AzureSignedURLFileSystem.info.""" + + def test_info_success(self) -> None: + """info returns metadata from base fs.""" + base_fs = Mock() + base_fs.info.return_value = {"size": 512, "type": "file"} + fs = AzureSignedURLFileSystem( + base_fs, {}, "https://account.blob.core.windows.net/c/*" + ) + result = fs.info("path") + assert result == {"size": 512, "type": "file"} + + def test_info_403_raises_file_not_found(self) -> None: + """info 403 raises FileNotFoundError.""" + base_fs = Mock() + base_fs.info.side_effect = Exception("Forbidden") + fs = AzureSignedURLFileSystem( + base_fs, {}, "https://account.blob.core.windows.net/c/*" + ) + with pytest.raises(FileNotFoundError): + fs.info("x") + + +class TestAzureSignedURLFileSystemExists: + """Test AzureSignedURLFileSystem.exists.""" + + def test_exists_true(self) -> None: + """exists returns True when base fs returns True.""" + base_fs = Mock() + base_fs.exists.return_value = True + fs = AzureSignedURLFileSystem( + base_fs, {}, "https://account.blob.core.windows.net/c/*" + ) + assert fs.exists("path") is True + + def test_exists_403_returns_false(self) -> None: + """exists returns False when base raises 403.""" + base_fs = Mock() + base_fs.exists.side_effect = Exception("403 Forbidden") + fs = AzureSignedURLFileSystem( + base_fs, {}, "https://account.blob.core.windows.net/c/*" + ) + assert fs.exists("path") is False + + def test_exists_false_when_base_returns_false(self) -> None: + """exists returns False when base fs returns False.""" + base_fs = Mock() + base_fs.exists.return_value = False + fs = AzureSignedURLFileSystem( + base_fs, {}, "https://account.blob.core.windows.net/c/*" + ) + assert fs.exists("path") is False + + def test_exists_other_exception_reraises(self) -> None: + """exists re-raises non-403 exceptions.""" + base_fs = Mock() + base_fs.exists.side_effect = RuntimeError("network error") + fs = AzureSignedURLFileSystem( + base_fs, {}, "https://account.blob.core.windows.net/c/*" + ) + with pytest.raises(RuntimeError, match="network error"): + fs.exists("path") + + +class TestCreateAzureMapper: + """Test create_azure_mapper.""" + + def test_builds_mapper_with_azure_filesystem(self) -> None: + """create_azure_mapper returns FSMap with AzureSignedURLFileSystem.""" + signed_url = ( + "https://account.blob.core.windows.net/container/*?sv=2021&sig=abc&se=2024" + ) + with patch("earth2studio.serve.client.fsspec_utils.fsspec") as mock_fsspec: + mock_fs = Mock() + mock_fsspec.filesystem.return_value = mock_fs + mock_map = Mock() + mock_fsspec.mapping.FSMap = mock_map + result = create_azure_mapper(signed_url) + mock_fsspec.filesystem.assert_called_once_with("https") + mock_map.assert_called_once() + call_kwargs = mock_map.call_args[1] + assert call_kwargs["root"] == "" + assert call_kwargs["check"] is False + assert call_kwargs["create"] is False + assert isinstance(call_kwargs["fs"], AzureSignedURLFileSystem) + assert result is mock_map.return_value + + def test_with_zarr_path_inserts_before_wildcard(self) -> None: + """zarr_path is inserted before the wildcard in base_url.""" + signed_url = "https://account.blob.core.windows.net/container/*?sig=abc" + with patch("earth2studio.serve.client.fsspec_utils.fsspec") as mock_fsspec: + mock_fs = Mock() + mock_fsspec.filesystem.return_value = mock_fs + mock_map = Mock() + mock_fsspec.mapping.FSMap = mock_map + create_azure_mapper(signed_url, zarr_path="results.zarr") + fs_arg = mock_map.call_args[1]["fs"] + assert isinstance(fs_arg, AzureSignedURLFileSystem) + assert "results.zarr" in fs_arg._base_url + assert fs_arg._base_url.endswith("/*") + + def test_with_star_only_suffix_and_zarr_path(self) -> None: + """URL ending with * (no slash) and zarr_path handled correctly.""" + signed_url = "https://account.blob.core.windows.net/container*?sig=abc" + with patch("earth2studio.serve.client.fsspec_utils.fsspec") as mock_fsspec: + mock_fsspec.filesystem.return_value = Mock() + mock_map = Mock() + mock_fsspec.mapping.FSMap = mock_map + create_azure_mapper(signed_url, zarr_path="results.zarr") + fs_arg = mock_map.call_args[1]["fs"] + assert "results.zarr" in fs_arg._base_url + assert fs_arg._base_url.endswith("/*") + + def test_without_wildcard_adds_wildcard_suffix(self) -> None: + """URL without * has /* appended when no zarr_path.""" + signed_url = "https://account.blob.core.windows.net/container/prefix?sig=abc" + with patch("earth2studio.serve.client.fsspec_utils.fsspec") as mock_fsspec: + mock_fsspec.filesystem.return_value = Mock() + mock_map = Mock() + mock_fsspec.mapping.FSMap = mock_map + create_azure_mapper(signed_url) + fs_arg = mock_map.call_args[1]["fs"] + assert fs_arg._base_url.endswith("/*") + + def test_sas_token_params_extracted(self) -> None: + """SAS token query params are extracted from URL and stored in filesystem.""" + signed_url = ( + "https://account.blob.core.windows.net/container/*" + "?sv=2021-08-06&sig=mysig&se=2024-01-01T00%3A00%3A00Z" + ) + with patch("earth2studio.serve.client.fsspec_utils.fsspec") as mock_fsspec: + mock_fsspec.filesystem.return_value = Mock() + mock_map = Mock() + mock_fsspec.mapping.FSMap = mock_map + create_azure_mapper(signed_url) + fs_arg = mock_map.call_args[1]["fs"] + assert "sv" in fs_arg._query_params + assert "sig" in fs_arg._query_params + assert fs_arg._query_params["sig"] == "mysig" + + class TestGetMapper: """Test get_mapper.""" @@ -353,26 +685,59 @@ def test_get_mapper_s3_without_signed_url_raises(self) -> None: with pytest.raises(ValueError, match="S3 storage type requires a signed URL"): get_mapper(result) - def test_get_mapper_unsupported_storage_raises(self) -> None: - """get_mapper with unsupported storage type raises ValueError.""" + def test_get_mapper_azure_with_signed_url_returns_mapper(self) -> None: + """get_mapper with AZURE and signed_url returns mapper from create_azure_mapper.""" result = InferenceRequestResults( request_id="r1", status=RequestStatus.COMPLETED, output_files=[], completion_time=datetime.now(), - storage_type=StorageType.SERVER, # will override via monkeypatch for unsupported - ) - # Use a storage type that doesn't exist in enum by creating a mock with invalid type - with patch.object(result, "storage_type", "invalid"): - # StorageType.SERVER is enum; "invalid" would be if we had another enum value. - # Instead simulate by patching the comparison: get_mapper checks - # request_result.storage_type == StorageType.S3 and == StorageType.SERVER. - # So unsupported = something that is neither. We need a value that is not - # StorageType.S3 and not StorageType.SERVER. StorageType only has SERVER and S3. - # So we need to pass something that fails both. E.g. a mock with storage_type - # that returns False for both. Or add a fake enum value. Easiest: patch - # StorageType to add a third value temporarily, or pass a mock object. - mock_result = Mock() - mock_result.storage_type = "unsupported_type" - with pytest.raises(ValueError, match="Unsupported storage type"): - get_mapper(mock_result) + storage_type=StorageType.AZURE, + signed_url="https://account.blob.core.windows.net/container/*?sig=abc", + ) + with patch( + "earth2studio.serve.client.fsspec_utils.create_azure_mapper" + ) as mock_create: + mock_create.return_value = Mock() + out = get_mapper(result) + assert out is mock_create.return_value + mock_create.assert_called_once_with(result.signed_url, "") + + def test_get_mapper_azure_with_zarr_path_passes_through(self) -> None: + """get_mapper passes zarr_path to create_azure_mapper for AZURE storage.""" + result = InferenceRequestResults( + request_id="r1", + status=RequestStatus.COMPLETED, + output_files=[], + completion_time=datetime.now(), + storage_type=StorageType.AZURE, + signed_url="https://account.blob.core.windows.net/container/*?sig=abc", + ) + with patch( + "earth2studio.serve.client.fsspec_utils.create_azure_mapper" + ) as mock_create: + mock_create.return_value = Mock() + get_mapper(result, zarr_path="results.zarr") + mock_create.assert_called_once_with(result.signed_url, "results.zarr") + + def test_get_mapper_azure_without_signed_url_raises(self) -> None: + """get_mapper with AZURE and no signed_url raises ValueError.""" + result = InferenceRequestResults( + request_id="r1", + status=RequestStatus.COMPLETED, + output_files=[], + completion_time=datetime.now(), + storage_type=StorageType.AZURE, + signed_url=None, + ) + with pytest.raises( + ValueError, match="Azure storage type requires a signed URL" + ): + get_mapper(result) + + def test_get_mapper_unsupported_storage_raises(self) -> None: + """get_mapper with unsupported storage type raises ValueError.""" + mock_result = Mock() + mock_result.storage_type = "unsupported_type" + with pytest.raises(ValueError, match="Unsupported storage type"): + get_mapper(mock_result) diff --git a/test/serve/server/test_config.py b/test/serve/server/test_config.py index 854b1c248..19d9aca50 100644 --- a/test/serve/server/test_config.py +++ b/test/serve/server/test_config.py @@ -584,3 +584,294 @@ def test_dict_to_config_handles_missing_keys(self) -> None: assert config.redis.host == "test_host" # Other configs should have defaults assert config.queue.name == "inference" # Default value + + +class TestObjectStorageEnvOverrides: + """Test object storage and Azure environment variable overrides""" + + def setup_method(self) -> None: + self._vars = [ + "OBJECT_STORAGE_TYPE", + "OBJECT_STORAGE_BUCKET", + "OBJECT_STORAGE_REGION", + "OBJECT_STORAGE_PREFIX", + "OBJECT_STORAGE_ACCESS_KEY_ID", + "OBJECT_STORAGE_SECRET_ACCESS_KEY", + "OBJECT_STORAGE_SESSION_TOKEN", + "OBJECT_STORAGE_ENDPOINT_URL", + "OBJECT_STORAGE_TRANSFER_ACCELERATION", + "OBJECT_STORAGE_MAX_CONCURRENCY", + "OBJECT_STORAGE_MULTIPART_CHUNKSIZE", + "OBJECT_STORAGE_USE_RUST_CLIENT", + "CLOUDFRONT_DOMAIN", + "CLOUDFRONT_KEY_PAIR_ID", + "CLOUDFRONT_PRIVATE_KEY", + "SIGNED_URL_EXPIRES_IN", + "AZURE_CONNECTION_STRING", + "AZURE_STORAGE_ACCOUNT_NAME", + "AZURE_STORAGE_ACCOUNT_KEY", + "AZURE_CONTAINER_NAME", + "AZURE_ENDPOINT_URL", + "AZURE_GEOCATALOG_URL", + "EXPOSED_WORKFLOWS", + "OUTPUT_FORMAT", + "CONFIG_DIR", + ] + for v in self._vars: + os.environ.pop(v, None) + + def teardown_method(self) -> None: + for v in self._vars: + os.environ.pop(v, None) + reset_config() + + def _get_manager(self) -> "ConfigManager": + reset_config() + return ConfigManager() + + def test_object_storage_type_s3(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_TYPE"] = "s3" + manager._apply_env_overrides() + assert manager.config.object_storage.storage_type == "s3" + + def test_object_storage_type_azure(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_TYPE"] = "azure" + manager._apply_env_overrides() + assert manager.config.object_storage.storage_type == "azure" + + def test_object_storage_type_invalid_ignored(self) -> None: + manager = self._get_manager() + original = manager.config.object_storage.storage_type + os.environ["OBJECT_STORAGE_TYPE"] = "gcs" + manager._apply_env_overrides() + assert manager.config.object_storage.storage_type == original + + def test_object_storage_bucket(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_BUCKET"] = "my-bucket" + manager._apply_env_overrides() + assert manager.config.object_storage.bucket == "my-bucket" + + def test_object_storage_region(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_REGION"] = "eu-west-1" + manager._apply_env_overrides() + assert manager.config.object_storage.region == "eu-west-1" + + def test_object_storage_prefix(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_PREFIX"] = "custom/prefix" + manager._apply_env_overrides() + assert manager.config.object_storage.prefix == "custom/prefix" + + def test_object_storage_access_key_id(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_ACCESS_KEY_ID"] = "AKID" + manager._apply_env_overrides() + assert manager.config.object_storage.access_key_id == "AKID" + + def test_object_storage_secret_access_key(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_SECRET_ACCESS_KEY"] = "SECRET" # noqa: S105 + manager._apply_env_overrides() + assert manager.config.object_storage.secret_access_key == "SECRET" # noqa: S105 + + def test_object_storage_session_token(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_SESSION_TOKEN"] = "TOKEN" # noqa: S105 + manager._apply_env_overrides() + assert manager.config.object_storage.session_token == "TOKEN" # noqa: S105 + + def test_object_storage_endpoint_url(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_ENDPOINT_URL"] = "https://s3.local" + manager._apply_env_overrides() + assert manager.config.object_storage.endpoint_url == "https://s3.local" + + def test_object_storage_transfer_acceleration_true(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_TRANSFER_ACCELERATION"] = "true" + manager._apply_env_overrides() + assert manager.config.object_storage.use_transfer_acceleration is True + + def test_object_storage_transfer_acceleration_false(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_TRANSFER_ACCELERATION"] = "false" + manager._apply_env_overrides() + assert manager.config.object_storage.use_transfer_acceleration is False + + def test_object_storage_max_concurrency(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_MAX_CONCURRENCY"] = "32" + manager._apply_env_overrides() + assert manager.config.object_storage.max_concurrency == 32 + + def test_object_storage_multipart_chunksize(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_MULTIPART_CHUNKSIZE"] = "8388608" + manager._apply_env_overrides() + assert manager.config.object_storage.multipart_chunksize == 8388608 + + def test_object_storage_use_rust_client_true(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_USE_RUST_CLIENT"] = "true" + manager._apply_env_overrides() + assert manager.config.object_storage.use_rust_client is True + + def test_cloudfront_domain(self) -> None: + manager = self._get_manager() + os.environ["CLOUDFRONT_DOMAIN"] = "cdn.example.com" + manager._apply_env_overrides() + assert manager.config.object_storage.cloudfront_domain == "cdn.example.com" + + def test_cloudfront_key_pair_id(self) -> None: + manager = self._get_manager() + os.environ["CLOUDFRONT_KEY_PAIR_ID"] = "KID123" + manager._apply_env_overrides() + assert manager.config.object_storage.cloudfront_key_pair_id == "KID123" + + def test_cloudfront_private_key(self) -> None: + manager = self._get_manager() + os.environ["CLOUDFRONT_PRIVATE_KEY"] = "-----BEGIN RSA PRIVATE KEY-----" + manager._apply_env_overrides() + assert ( + manager.config.object_storage.cloudfront_private_key + == "-----BEGIN RSA PRIVATE KEY-----" + ) + + def test_signed_url_expires_in(self) -> None: + manager = self._get_manager() + os.environ["SIGNED_URL_EXPIRES_IN"] = "3600" + manager._apply_env_overrides() + assert manager.config.object_storage.signed_url_expires_in == 3600 + + def test_azure_connection_string(self) -> None: + manager = self._get_manager() + os.environ["AZURE_CONNECTION_STRING"] = "DefaultEndpointsProtocol=https;..." + manager._apply_env_overrides() + assert ( + manager.config.object_storage.azure_connection_string + == "DefaultEndpointsProtocol=https;..." + ) + + def test_azure_storage_account_name(self) -> None: + manager = self._get_manager() + os.environ["AZURE_STORAGE_ACCOUNT_NAME"] = "myaccount" + manager._apply_env_overrides() + assert manager.config.object_storage.azure_account_name == "myaccount" + + def test_azure_storage_account_key(self) -> None: + manager = self._get_manager() + os.environ["AZURE_STORAGE_ACCOUNT_KEY"] = "base64key==" # noqa: S105 + manager._apply_env_overrides() + assert manager.config.object_storage.azure_account_key == "base64key==" + + def test_azure_container_name(self) -> None: + manager = self._get_manager() + os.environ["AZURE_CONTAINER_NAME"] = "mycontainer" + manager._apply_env_overrides() + assert manager.config.object_storage.azure_container_name == "mycontainer" + + def test_azure_endpoint_url(self) -> None: + manager = self._get_manager() + os.environ["AZURE_ENDPOINT_URL"] = "https://myaccount.blob.core.windows.net" + manager._apply_env_overrides() + assert ( + manager.config.object_storage.endpoint_url + == "https://myaccount.blob.core.windows.net" + ) + + def test_azure_geocatalog_url(self) -> None: + manager = self._get_manager() + os.environ["AZURE_GEOCATALOG_URL"] = "https://geocatalog.example.com" + manager._apply_env_overrides() + assert ( + manager.config.object_storage.azure_geocatalog_url + == "https://geocatalog.example.com" + ) + + def test_exposed_workflows_parses_comma_separated(self) -> None: + manager = self._get_manager() + os.environ["EXPOSED_WORKFLOWS"] = "workflow_a, workflow_b, workflow_c" + manager._apply_env_overrides() + assert manager.config.workflow_exposure.exposed_workflows == [ + "workflow_a", + "workflow_b", + "workflow_c", + ] + + def test_exposed_workflows_single(self) -> None: + manager = self._get_manager() + os.environ["EXPOSED_WORKFLOWS"] = "only_workflow" + manager._apply_env_overrides() + assert manager.config.workflow_exposure.exposed_workflows == ["only_workflow"] + + def test_output_format_zarr(self) -> None: + manager = self._get_manager() + os.environ["OUTPUT_FORMAT"] = "zarr" + manager._apply_env_overrides() + assert manager.config.paths.output_format == "zarr" + + def test_output_format_netcdf4(self) -> None: + manager = self._get_manager() + os.environ["OUTPUT_FORMAT"] = "netcdf4" + manager._apply_env_overrides() + assert manager.config.paths.output_format == "netcdf4" + + def test_output_format_invalid_ignored(self) -> None: + manager = self._get_manager() + original = manager.config.paths.output_format + os.environ["OUTPUT_FORMAT"] = "csv" + manager._apply_env_overrides() + assert manager.config.paths.output_format == original + + def test_apply_env_overrides_no_op_when_config_none(self) -> None: + """_apply_env_overrides returns early when _config is None""" + reset_config() + manager = ConfigManager() + manager._config = None + # Should not raise + manager._apply_env_overrides() + + def test_ensure_paths_exist_no_op_when_config_none(self) -> None: + """_ensure_paths_exist returns early when _config is None""" + reset_config() + manager = ConfigManager() + manager._config = None + # Should not raise + manager._ensure_paths_exist() + + def test_config_property_reinitializes_when_none(self) -> None: + """config property calls _initialize_config when _config is None""" + reset_config() + manager = ConfigManager() + manager._config = None + cfg = manager.config + assert isinstance(cfg, AppConfig) + + def test_workflow_config_property_reinitializes_when_none(self) -> None: + """workflow_config property calls _initialize_config when _workflow_config is None""" + reset_config() + manager = ConfigManager() + manager._workflow_config = None + expected = {"wf": {"param": 1}} + with patch.object( + manager, + "_initialize_config", + side_effect=lambda: setattr(manager, "_workflow_config", expected), + ) as mock_init: + wf_cfg = manager.workflow_config + mock_init.assert_called_once() + assert wf_cfg == expected + + def test_initialize_config_uses_config_dir_env_var(self) -> None: + """_initialize_config uses CONFIG_DIR env var when set (covers lines 203-204)""" + reset_config() + manager = ConfigManager() + manager._config = None + manager._workflow_config = None + os.environ["CONFIG_DIR"] = "/custom/conf" + manager._initialize_config() + assert isinstance(manager._config, AppConfig) diff --git a/test/serve/server/test_cpu_worker.py b/test/serve/server/test_cpu_worker.py index 45e2c9553..ab8291543 100644 --- a/test/serve/server/test_cpu_worker.py +++ b/test/serve/server/test_cpu_worker.py @@ -116,6 +116,15 @@ class MockObjectStorageConfig: cloudfront_key_pair_id: str | None = None cloudfront_private_key_path: str | None = None signed_url_expires_in: int = 3600 + azure_geocatalog_url: str | None = None + + +@dataclass +class MockWorkflowExposureConfig: + """Mock workflow exposure configuration""" + + exposed_workflows: list = field(default_factory=list) + warmup_workflows: list = field(default_factory=lambda: ["example_user_workflow"]) @dataclass @@ -131,6 +140,9 @@ class MockAppConfig: object_storage: MockObjectStorageConfig = field( default_factory=MockObjectStorageConfig ) + workflow_exposure: MockWorkflowExposureConfig = field( + default_factory=MockWorkflowExposureConfig + ) # Create a mock config module @@ -152,6 +164,7 @@ class MockAppConfig: create_results_zip, fail_workflow, process_finalize_metadata, + process_geocatalog_ingestion, process_object_storage_upload, process_result_zip, ) @@ -1399,5 +1412,1040 @@ def test_process_object_storage_upload_output_path_missing_returns_failure(self) assert "does not exist" in mock_fail.call_args[0][2] +class TestProcessGeocatalogIngestion: + """Tests for process_geocatalog_ingestion RQ worker function.""" + + def test_process_geocatalog_ingestion_url_not_set_skips_and_queues(self): + with patch("earth2studio.serve.server.cpu_worker.config") as mock_config: + with patch("earth2studio.serve.server.cpu_worker.redis_client") as _: + with patch( + "earth2studio.serve.server.cpu_worker.queue_next_stage" + ) as mock_queue_next: + mock_config.object_storage.azure_geocatalog_url = None + mock_queue_next.return_value = "job_1" + + result = process_geocatalog_ingestion( + workflow_name="foundry_fcn3_workflow", + execution_id="exec_1", + ) + + assert result is not None + assert result["success"] is True + assert result.get("skipped") is True + assert "AZURE_GEOCATALOG_URL" in result.get("reason", "") + mock_queue_next.assert_called_once() + assert mock_queue_next.call_args[1]["current_stage"] == "geocatalog_ingestion" + + def test_process_geocatalog_ingestion_url_not_set_queue_fails_returns_fail_workflow( + self, + ): + with patch("earth2studio.serve.server.cpu_worker.config") as mock_config: + with patch("earth2studio.serve.server.cpu_worker.redis_client") as _: + with patch( + "earth2studio.serve.server.cpu_worker.queue_next_stage" + ) as mock_queue_next: + with patch( + "earth2studio.serve.server.cpu_worker.fail_workflow" + ) as mock_fail: + mock_config.object_storage.azure_geocatalog_url = None + mock_queue_next.return_value = None + mock_fail.return_value = {"success": False} + + result = process_geocatalog_ingestion( + workflow_name="foundry_fcn3_workflow", + execution_id="exec_1", + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert "finalize_metadata" in mock_fail.call_args[0][2].lower() + + def test_process_geocatalog_ingestion_storage_or_metadata_missing_skips_and_queues( + self, + ): + request_id = "foundry_fcn3_workflow:exec_1" + storage_info_key = f"inference_request:{request_id}:storage_info" + metadata_key = f"inference_request:{request_id}:pending_metadata" + + with patch("earth2studio.serve.server.cpu_worker.config") as mock_config: + with patch( + "earth2studio.serve.server.cpu_worker.redis_client" + ) as mock_redis: + with patch( + "earth2studio.serve.server.cpu_worker.queue_next_stage" + ) as mock_queue_next: + mock_config.object_storage.azure_geocatalog_url = ( + "https://geocatalog.example/" + ) + mock_redis.get.side_effect = lambda k: { + storage_info_key: None, + metadata_key: None, + }.get(k) + mock_queue_next.return_value = "job_1" + + result = process_geocatalog_ingestion( + workflow_name="foundry_fcn3_workflow", + execution_id="exec_1", + ) + + assert result is not None + assert result["success"] is True + assert result.get("skipped") is True + assert ( + "missing" in result.get("reason", "").lower() + or "storage" in result.get("reason", "").lower() + ) + mock_queue_next.assert_called_once() + + def test_process_geocatalog_ingestion_no_blob_url_skips_and_queues(self): + request_id = "foundry_fcn3_workflow:exec_1" + storage_info_key = f"inference_request:{request_id}:storage_info" + metadata_key = f"inference_request:{request_id}:pending_metadata" + storage_info = {} + pending_metadata = {"parameters": {"start_time": "2025-01-01T00:00:00Z"}} + + with patch("earth2studio.serve.server.cpu_worker.config") as mock_config: + with patch( + "earth2studio.serve.server.cpu_worker.redis_client" + ) as mock_redis: + with patch( + "earth2studio.serve.server.cpu_worker.queue_next_stage" + ) as mock_queue_next: + mock_config.object_storage.azure_geocatalog_url = ( + "https://geocatalog.example/" + ) + mock_redis.get.side_effect = lambda k: { + storage_info_key: json.dumps(storage_info), + metadata_key: json.dumps(pending_metadata), + }.get(k) + mock_queue_next.return_value = "job_1" + + result = process_geocatalog_ingestion( + workflow_name="foundry_fcn3_workflow", + execution_id="exec_1", + ) + + assert result is not None + assert result["success"] is True + assert result.get("skipped") is True + assert "blob" in result.get("reason", "").lower() + mock_queue_next.assert_called_once() + + def test_process_geocatalog_ingestion_workflow_not_supported_skips_and_queues(self): + request_id = "other_workflow:exec_1" + storage_info_key = f"inference_request:{request_id}:storage_info" + metadata_key = f"inference_request:{request_id}:pending_metadata" + storage_info = {"blob_url": "https://storage.example/blob.nc"} + pending_metadata = {"parameters": {"start_time": "2025-01-01T00:00:00Z"}} + + with patch("earth2studio.serve.server.cpu_worker.config") as mock_config: + with patch( + "earth2studio.serve.server.cpu_worker.redis_client" + ) as mock_redis: + with patch( + "earth2studio.serve.server.cpu_worker.queue_next_stage" + ) as mock_queue_next: + mock_config.object_storage.azure_geocatalog_url = ( + "https://geocatalog.example/" + ) + mock_redis.get.side_effect = lambda k: { + storage_info_key: json.dumps(storage_info), + metadata_key: json.dumps(pending_metadata), + }.get(k) + mock_queue_next.return_value = "job_1" + + result = process_geocatalog_ingestion( + workflow_name="other_workflow", + execution_id="exec_1", + ) + + assert result is not None + assert result["success"] is True + assert result.get("skipped") is True + assert "not supported" in result.get("reason", "").lower() + mock_queue_next.assert_called_once() + + def test_process_geocatalog_ingestion_success_calls_create_feature_and_queues(self): + request_id = "foundry_fcn3_workflow:exec_1" + storage_info_key = f"inference_request:{request_id}:storage_info" + metadata_key = f"inference_request:{request_id}:pending_metadata" + storage_info = {"blob_url": "https://storage.example/container/out.nc"} + pending_metadata = { + "parameters": {"start_time": "2025-01-01T00:00:00Z"}, + } + + with patch("earth2studio.serve.server.cpu_worker.config") as mock_config: + with patch( + "earth2studio.serve.server.cpu_worker.redis_client" + ) as mock_redis: + with patch( + "earth2studio.serve.server.cpu_worker.queue_next_stage" + ) as mock_queue_next: + with patch( + "earth2studio.data.planetary_computer.GeoCatalogClient" + ) as mock_client_class: + mock_config.object_storage.azure_geocatalog_url = ( + "https://geocatalog.example/" + ) + mock_redis.get.side_effect = lambda k: { + storage_info_key: json.dumps(storage_info), + metadata_key: json.dumps(pending_metadata), + }.get(k) + mock_queue_next.return_value = "job_1" + mock_client = Mock() + mock_client_class.return_value = mock_client + + result = process_geocatalog_ingestion( + workflow_name="foundry_fcn3_workflow", + execution_id="exec_1", + ) + + assert result is not None + assert result["success"] is True + assert result.get("skipped") is not True + mock_client.create_feature.assert_called_once() + call_kw = mock_client.create_feature.call_args[1] + assert call_kw["geocatalog_url"] == "https://geocatalog.example/" + assert call_kw["blob_url"] == "https://storage.example/container/out.nc" + assert call_kw["parameters"] == {"start_time": "2025-01-01T00:00:00Z"} + mock_queue_next.assert_called_once() + + def test_process_geocatalog_ingestion_passes_collection_id_when_in_parameters(self): + request_id = "foundry_fcn3_workflow:exec_1" + storage_info_key = f"inference_request:{request_id}:storage_info" + metadata_key = f"inference_request:{request_id}:pending_metadata" + storage_info = {"blob_url": "https://storage.example/out.nc"} + pending_metadata = { + "parameters": { + "start_time": "2025-01-01T00:00:00Z", + "collection_id": "my-custom-collection", + }, + } + + with patch("earth2studio.serve.server.cpu_worker.config") as mock_config: + with patch( + "earth2studio.serve.server.cpu_worker.redis_client" + ) as mock_redis: + with patch( + "earth2studio.serve.server.cpu_worker.queue_next_stage" + ) as mock_queue_next: + with patch( + "earth2studio.data.planetary_computer.GeoCatalogClient" + ) as mock_client_class: + mock_config.object_storage.azure_geocatalog_url = ( + "https://geocatalog.example/" + ) + mock_redis.get.side_effect = lambda k: { + storage_info_key: json.dumps(storage_info), + metadata_key: json.dumps(pending_metadata), + }.get(k) + mock_queue_next.return_value = "job_1" + mock_client = Mock() + mock_client_class.return_value = mock_client + + process_geocatalog_ingestion( + workflow_name="foundry_fcn3_workflow", + execution_id="exec_1", + ) + + call_kw = mock_client_class.return_value.create_feature.call_args[1] + assert call_kw["collection_id"] == "my-custom-collection" + + def test_process_geocatalog_ingestion_create_feature_exception_still_queues(self): + request_id = "foundry_fcn3_workflow:exec_1" + storage_info_key = f"inference_request:{request_id}:storage_info" + metadata_key = f"inference_request:{request_id}:pending_metadata" + storage_info = {"blob_url": "https://storage.example/out.nc"} + pending_metadata = {"parameters": {"start_time": "2025-01-01T00:00:00Z"}} + + with patch("earth2studio.serve.server.cpu_worker.config") as mock_config: + with patch( + "earth2studio.serve.server.cpu_worker.redis_client" + ) as mock_redis: + with patch( + "earth2studio.serve.server.cpu_worker.queue_next_stage" + ) as mock_queue_next: + with patch( + "earth2studio.data.planetary_computer.GeoCatalogClient" + ) as mock_client_class: + mock_config.object_storage.azure_geocatalog_url = ( + "https://geocatalog.example/" + ) + mock_redis.get.side_effect = lambda k: { + storage_info_key: json.dumps(storage_info), + metadata_key: json.dumps(pending_metadata), + }.get(k) + mock_queue_next.return_value = "job_1" + mock_client_class.return_value.create_feature.side_effect = ( + ValueError("create_feature failed") + ) + + result = process_geocatalog_ingestion( + workflow_name="foundry_fcn3_workflow", + execution_id="exec_1", + ) + + assert result is not None + assert result["success"] is True + mock_queue_next.assert_called_once() + + def test_process_geocatalog_ingestion_queue_next_returns_none_returns_fail_workflow( + self, + ): + request_id = "foundry_fcn3_workflow:exec_1" + storage_info_key = f"inference_request:{request_id}:storage_info" + metadata_key = f"inference_request:{request_id}:pending_metadata" + storage_info = {"blob_url": "https://storage.example/out.nc"} + pending_metadata = {"parameters": {"start_time": "2025-01-01T00:00:00Z"}} + + with patch("earth2studio.serve.server.cpu_worker.config") as mock_config: + with patch( + "earth2studio.serve.server.cpu_worker.redis_client" + ) as mock_redis: + with patch( + "earth2studio.serve.server.cpu_worker.queue_next_stage" + ) as mock_queue_next: + with patch( + "earth2studio.data.planetary_computer.GeoCatalogClient" + ) as _: + with patch( + "earth2studio.serve.server.cpu_worker.fail_workflow" + ) as mock_fail: + mock_config.object_storage.azure_geocatalog_url = ( + "https://geocatalog.example/" + ) + mock_redis.get.side_effect = lambda k: { + storage_info_key: json.dumps(storage_info), + metadata_key: json.dumps(pending_metadata), + }.get(k) + mock_queue_next.return_value = None + mock_fail.return_value = {"success": False} + + result = process_geocatalog_ingestion( + workflow_name="foundry_fcn3_workflow", + execution_id="exec_1", + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert "next pipeline stage" in mock_fail.call_args[0][2].lower() + + +class TestProcessObjectStorageUploadEnabled: + """Tests for process_object_storage_upload when storage is enabled.""" + + def _make_mock_config(self, storage_type="s3", **kwargs): + """Return a Mock config with object_storage defaults for enabled storage.""" + mock_config = Mock() + os_cfg = mock_config.object_storage + os_cfg.enabled = True + os_cfg.storage_type = storage_type + os_cfg.bucket = "my-bucket" + os_cfg.prefix = "outputs" + os_cfg.region = "us-east-1" + os_cfg.max_concurrency = 10 + os_cfg.multipart_chunksize = 8388608 + os_cfg.use_rust_client = False + os_cfg.use_transfer_acceleration = False + os_cfg.access_key_id = None + os_cfg.secret_access_key = None + os_cfg.session_token = None + os_cfg.endpoint_url = None + os_cfg.cloudfront_domain = None + os_cfg.cloudfront_key_pair_id = None + os_cfg.cloudfront_private_key = None + os_cfg.azure_container_name = None + os_cfg.azure_connection_string = None + os_cfg.azure_account_name = None + os_cfg.azure_account_key = None + os_cfg.azure_geocatalog_url = None + os_cfg.signed_url_expires_in = 3600 + mock_config.redis.retention_ttl = 604800 + for k, v in kwargs.items(): + setattr(os_cfg, k, v) + return mock_config + + def _patch_all( + self, mock_config, mock_redis, mock_queue, mock_storage_cls, tmp_path + ): + """Context manager helper that patches everything for upload tests.""" + return ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ) + + def test_s3_upload_success_returns_result_with_files(self, tmp_path): + """S3 upload success path returns dict with files_uploaded, destination, remote_prefix.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + (output_dir / "result.nc").write_text("data") + + mock_config = self._make_mock_config(storage_type="s3") + mock_redis = Mock() + mock_queue = Mock(return_value="job_1") + mock_storage_cls = Mock() + mock_storage = mock_storage_cls.return_value + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 1 + mock_upload.total_bytes = 4 + mock_upload.destination = "s3://my-bucket/outputs/wf/exec_1" + mock_storage.upload_directory.return_value = mock_upload + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None + assert result["success"] is True + assert result["storage_type"] == "s3" + assert result["files_uploaded"] == 1 + assert result["total_bytes"] == 4 + assert result["remote_prefix"] == "outputs/wf/exec_1" + mock_queue.assert_called_once() + + def test_azure_missing_container_and_bucket_returns_failure(self, tmp_path): + """Azure storage enabled but no container name and no bucket returns fail_workflow.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config( + storage_type="azure", bucket=None, azure_container_name=None + ) + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.fail_workflow") as mock_fail, + ): + mock_fail.return_value = {"success": False, "error": "no container"} + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert ( + "azure_container_name" in mock_fail.call_args[0][2].lower() + or "container" in mock_fail.call_args[0][2].lower() + ) + + def test_msc_storage_creation_fails_returns_failure(self, tmp_path): + """When MSCObjectStorage constructor raises, returns fail_workflow dict.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config(storage_type="s3") + mock_redis = Mock() + mock_storage_cls = Mock(side_effect=RuntimeError("cannot connect")) + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.fail_workflow") as mock_fail, + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + mock_fail.return_value = {"success": False, "error": "msc failed"} + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert "msc storage client" in mock_fail.call_args[0][2].lower() + + def test_upload_directory_exception_returns_failure(self, tmp_path): + """When upload_directory raises, returns fail_workflow dict.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config(storage_type="s3") + mock_redis = Mock() + mock_storage_cls = Mock() + mock_storage_cls.return_value.upload_directory.side_effect = IOError( + "network error" + ) + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.fail_workflow") as mock_fail, + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + mock_fail.return_value = {"success": False, "error": "upload failed"} + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert "upload failed" in mock_fail.call_args[0][2].lower() + + def test_upload_result_not_success_returns_failure(self, tmp_path): + """When upload_result.success is False, returns fail_workflow dict.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config(storage_type="s3") + mock_redis = Mock() + mock_storage_cls = Mock() + mock_upload = Mock() + mock_upload.success = False + mock_upload.errors = ["checksum mismatch"] + mock_storage_cls.return_value.upload_directory.return_value = mock_upload + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.fail_workflow") as mock_fail, + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + mock_fail.return_value = {"success": False, "error": "upload result false"} + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert "object storage" in mock_fail.call_args[0][2].lower() + + def test_s3_with_cloudfront_generates_and_stores_signed_url(self, tmp_path): + """When CloudFront is configured, signed URL is generated and stored in storage_info.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config( + storage_type="s3", + cloudfront_domain="d123.cloudfront.net", + cloudfront_key_pair_id="KP123", + cloudfront_private_key="-----BEGIN RSA PRIVATE KEY-----\nfake\n-----END RSA PRIVATE KEY-----", + ) + mock_redis = Mock() + mock_queue = Mock(return_value="job_1") + mock_storage_cls = Mock() + mock_storage = mock_storage_cls.return_value + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 1 + mock_upload.total_bytes = 10 + mock_upload.destination = "s3://my-bucket/outputs/wf/exec_1" + mock_storage.upload_directory.return_value = mock_upload + mock_storage.generate_signed_url.return_value = ( + "https://d123.cloudfront.net/signed" + ) + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None + assert result["success"] is True + assert result["signed_url"] == "https://d123.cloudfront.net/signed" + mock_storage.generate_signed_url.assert_called_once() + # Verify signed URL is stored in Redis (setex called with signed_url_key) + setex_keys = [call[0][0] for call in mock_redis.setex.call_args_list] + assert any("signed_url" in k for k in setex_keys) + + def test_s3_signed_url_objectstorageerror_returns_failure(self, tmp_path): + """When generate_signed_url raises ObjectStorageError, returns fail_workflow.""" + from earth2studio.serve.server.object_storage import ObjectStorageError + + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config( + storage_type="s3", + cloudfront_domain="d123.cloudfront.net", + cloudfront_key_pair_id="KP123", + cloudfront_private_key="fake", + ) + mock_redis = Mock() + mock_storage_cls = Mock() + mock_storage = mock_storage_cls.return_value + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 1 + mock_upload.total_bytes = 10 + mock_upload.destination = "s3://my-bucket/prefix" + mock_storage.upload_directory.return_value = mock_upload + mock_storage.generate_signed_url.side_effect = ObjectStorageError( + "signing failed" + ) + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.fail_workflow") as mock_fail, + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + mock_fail.return_value = {"success": False, "error": "signed url failed"} + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert "signed url" in mock_fail.call_args[0][2].lower() + + def test_azure_upload_success_sets_remote_path_and_blob_url(self, tmp_path): + """Azure upload success: storage_info has azure remote_path and blob_url for .nc file.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + (output_dir / "result.nc").write_bytes(b"netcdf") + + mock_config = self._make_mock_config( + storage_type="azure", + bucket=None, + azure_container_name="my-container", + azure_account_name="myaccount", + azure_geocatalog_url="https://geocatalog.example/", + ) + mock_redis = Mock() + mock_queue = Mock(return_value="job_1") + mock_storage_cls = Mock() + mock_storage = mock_storage_cls.return_value + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 1 + mock_upload.total_bytes = 6 + mock_upload.destination = "azure://my-container/outputs/wf/exec_1" + mock_storage.upload_directory.return_value = mock_upload + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None + assert result["success"] is True + assert result["storage_type"] == "azure" + # Verify storage_info was written to Redis with azure remote_path and blob_url + storage_info_calls = [ + c for c in mock_redis.setex.call_args_list if "storage_info" in c[0][0] + ] + assert len(storage_info_calls) == 1 + stored_info = json.loads(storage_info_calls[0][0][2]) + assert stored_info["remote_path"].startswith("azure://my-container/") + assert "blob_url" in stored_info + assert "result.nc" in stored_info["blob_url"] + + def test_azure_upload_blob_url_from_nc_file_in_directory(self, tmp_path): + """Azure upload: blob_url is built from the first .nc file found in a directory.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + (output_dir / "forecast.nc").write_bytes(b"data") + (output_dir / "other.txt").write_text("text") + + mock_config = self._make_mock_config( + storage_type="azure", + bucket=None, + azure_container_name="container", + azure_account_name="myaccount", + azure_geocatalog_url="https://geocatalog.example/", + ) + mock_redis = Mock() + mock_queue = Mock(return_value="job_1") + mock_storage_cls = Mock() + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 2 + mock_upload.total_bytes = 100 + mock_upload.destination = "azure://container/outputs/wf/exec_1" + mock_storage_cls.return_value.upload_directory.return_value = mock_upload + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result["success"] is True + storage_info_calls = [ + c for c in mock_redis.setex.call_args_list if "storage_info" in c[0][0] + ] + stored_info = json.loads(storage_info_calls[0][0][2]) + assert "forecast.nc" in stored_info.get("blob_url", "") + + def test_queue_next_returns_none_after_upload_returns_failure(self, tmp_path): + """When queue_next_stage returns None after a successful upload, returns fail_workflow.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config(storage_type="s3") + mock_redis = Mock() + mock_queue = Mock(return_value=None) + mock_storage_cls = Mock() + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 0 + mock_upload.total_bytes = 0 + mock_upload.destination = "s3://my-bucket/prefix" + mock_storage_cls.return_value.upload_directory.return_value = mock_upload + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch("earth2studio.serve.server.cpu_worker.fail_workflow") as mock_fail, + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + mock_fail.return_value = {"success": False, "error": "no job"} + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert "pipeline stage" in mock_fail.call_args[0][2].lower() + + def test_unexpected_exception_returns_fail_workflow(self, tmp_path): + """An unexpected exception in the try block returns fail_workflow dict (lines 825-827).""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config(storage_type="s3") + # Make redis_client.setex raise to trigger the outer except after upload + mock_redis = Mock() + mock_queue = Mock(return_value="job_1") + mock_storage_cls = Mock() + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 1 + mock_upload.total_bytes = 1 + mock_upload.destination = "s3://my-bucket/prefix" + mock_storage_cls.return_value.upload_directory.return_value = mock_upload + mock_redis.setex.side_effect = RuntimeError("redis crashed") + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None + assert result.get("success") is False + + def test_azure_credentials_added_to_storage_kwargs(self, tmp_path): + """Azure-specific kwargs (connection_string, account_name, key, container) are passed.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config( + storage_type="azure", + bucket=None, + azure_container_name="my-container", + azure_connection_string="DefaultEndpointsProtocol=https;...", + azure_account_name="myaccount", + azure_account_key="mykey", + ) + mock_redis = Mock() + mock_queue = Mock(return_value="job_1") + mock_storage_cls = Mock() + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 0 + mock_upload.total_bytes = 0 + mock_upload.destination = "azure://my-container/prefix" + mock_storage_cls.return_value.upload_directory.return_value = mock_upload + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + call_kwargs = mock_storage_cls.call_args[1] + assert ( + call_kwargs.get("azure_connection_string") + == "DefaultEndpointsProtocol=https;..." + ) + assert call_kwargs.get("azure_account_name") == "myaccount" + assert call_kwargs.get("azure_account_key") == "mykey" + assert call_kwargs.get("azure_container_name") == "my-container" + + def test_s3_optional_credentials_added_when_set(self, tmp_path): + """S3 optional kwargs (access_key_id, session_token, endpoint_url) are passed when set.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config( # noqa: S106 + storage_type="s3", + access_key_id="AK123", + secret_access_key="SK456", # noqa: S106 + session_token="ST789", # noqa: S106 + endpoint_url="https://s3.custom.example.com", + ) + mock_redis = Mock() + mock_queue = Mock(return_value="job_1") + mock_storage_cls = Mock() + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 0 + mock_upload.total_bytes = 0 + mock_upload.destination = "s3://my-bucket/prefix" + mock_storage_cls.return_value.upload_directory.return_value = mock_upload + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + call_kwargs = mock_storage_cls.call_args[1] + assert call_kwargs.get("access_key_id") == "AK123" + assert call_kwargs.get("secret_access_key") == "SK456" + assert call_kwargs.get("session_token") == "ST789" + assert call_kwargs.get("endpoint_url") == "https://s3.custom.example.com" + + +class TestProcessGeocatalogIngestionEdgeCases: + """Additional geocatalog tests for remaining uncovered branches.""" + + def _redis_side_effect(self, request_id, storage_info, pending_metadata): + storage_info_key = f"inference_request:{request_id}:storage_info" + metadata_key = f"inference_request:{request_id}:pending_metadata" + return lambda k: { + storage_info_key: ( + json.dumps(storage_info) if storage_info is not None else None + ), + metadata_key: ( + json.dumps(pending_metadata) if pending_metadata is not None else None + ), + }.get(k) + + def test_missing_storage_queue_returns_none_triggers_fail_workflow(self): + """When storage/metadata missing and queue_next returns None, calls fail_workflow (line 899).""" + request_id = "foundry_fcn3_workflow:exec_1" + with ( + patch("earth2studio.serve.server.cpu_worker.config") as mock_config, + patch("earth2studio.serve.server.cpu_worker.redis_client") as mock_redis, + patch( + "earth2studio.serve.server.cpu_worker.queue_next_stage" + ) as mock_queue, + patch("earth2studio.serve.server.cpu_worker.fail_workflow") as mock_fail, + ): + mock_config.object_storage.azure_geocatalog_url = ( + "https://geocatalog.example/" + ) + mock_redis.get.side_effect = self._redis_side_effect( + request_id, storage_info=None, pending_metadata=None + ) + mock_queue.return_value = None + mock_fail.return_value = {"success": False} + + result = process_geocatalog_ingestion( + workflow_name="foundry_fcn3_workflow", execution_id="exec_1" + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert "finalize_metadata" in mock_fail.call_args[0][2].lower() + + def test_no_blob_url_queue_returns_none_triggers_fail_workflow(self): + """When no blob_url and queue_next returns None, calls fail_workflow (line 927).""" + request_id = "foundry_fcn3_workflow:exec_1" + with ( + patch("earth2studio.serve.server.cpu_worker.config") as mock_config, + patch("earth2studio.serve.server.cpu_worker.redis_client") as mock_redis, + patch( + "earth2studio.serve.server.cpu_worker.queue_next_stage" + ) as mock_queue, + patch("earth2studio.serve.server.cpu_worker.fail_workflow") as mock_fail, + ): + mock_config.object_storage.azure_geocatalog_url = ( + "https://geocatalog.example/" + ) + mock_redis.get.side_effect = self._redis_side_effect( + request_id, + storage_info={}, # no blob_url key + pending_metadata={"parameters": {}}, + ) + mock_queue.return_value = None + mock_fail.return_value = {"success": False} + + result = process_geocatalog_ingestion( + workflow_name="foundry_fcn3_workflow", execution_id="exec_1" + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert "finalize_metadata" in mock_fail.call_args[0][2].lower() + + def test_unsupported_workflow_queue_returns_none_triggers_fail_workflow(self): + """When workflow unsupported and queue_next returns None, calls fail_workflow (line 952).""" + request_id = "other_workflow:exec_1" + with ( + patch("earth2studio.serve.server.cpu_worker.config") as mock_config, + patch("earth2studio.serve.server.cpu_worker.redis_client") as mock_redis, + patch( + "earth2studio.serve.server.cpu_worker.queue_next_stage" + ) as mock_queue, + patch("earth2studio.serve.server.cpu_worker.fail_workflow") as mock_fail, + ): + mock_config.object_storage.azure_geocatalog_url = ( + "https://geocatalog.example/" + ) + mock_redis.get.side_effect = self._redis_side_effect( + request_id, + storage_info={"blob_url": "https://storage.example/blob.nc"}, + pending_metadata={"parameters": {}}, + ) + mock_queue.return_value = None + mock_fail.return_value = {"success": False} + + result = process_geocatalog_ingestion( + workflow_name="other_workflow", execution_id="exec_1" + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert "finalize_metadata" in mock_fail.call_args[0][2].lower() + + def test_outer_exception_returns_fail_workflow(self): + """An unexpected exception in the outer try returns fail_workflow (lines 1008-1010).""" + with ( + patch("earth2studio.serve.server.cpu_worker.config") as mock_config, + patch("earth2studio.serve.server.cpu_worker.redis_client") as mock_redis, + ): + mock_config.object_storage.azure_geocatalog_url = ( + "https://geocatalog.example/" + ) + mock_redis.get.side_effect = RuntimeError("redis unavailable") + + result = process_geocatalog_ingestion( + workflow_name="foundry_fcn3_workflow", execution_id="exec_1" + ) + + assert result is not None + assert result.get("success") is False + assert "geocatalog" in result.get("error", "").lower() + + +class TestProcessFinalizeMetadataEdgeCases: + """Tests for the exception handler in process_finalize_metadata (lines 1114-1116).""" + + def test_exception_in_try_block_returns_fail_workflow(self, tmp_path): + """When json.loads raises (corrupt metadata), the except block returns fail_workflow.""" + results_zip_dir = tmp_path / "results" + results_zip_dir.mkdir() + request_id = "my_wf:exec_1" + metadata_key = f"inference_request:{request_id}:pending_metadata" + results_zip_dir_key = f"inference_request:{request_id}:results_zip_dir" + + with patch("earth2studio.serve.server.cpu_worker.redis_client") as mock_redis: + # Return non-JSON for metadata to trigger json.loads to raise + mock_redis.get.side_effect = lambda k: { + metadata_key: "NOT_VALID_JSON{{{", + results_zip_dir_key: str(results_zip_dir), + }.get(k) + + result = process_finalize_metadata( + workflow_name="my_wf", execution_id="exec_1" + ) + + assert result is not None + assert result.get("success") is False + assert "metadata finalization" in result.get("error", "").lower() + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/serve/server/test_object_storage.py b/test/serve/server/test_object_storage.py index 907441cc2..1947c9731 100644 --- a/test/serve/server/test_object_storage.py +++ b/test/serve/server/test_object_storage.py @@ -140,7 +140,7 @@ def test_upload_directory_success_returns_upload_result(self, mock_msc): with tempfile.TemporaryDirectory() as tmpdir: (Path(tmpdir) / "f1.txt").write_text("hello") storage = MSCObjectStorage(bucket="b", region="us-east-1") - storage._s3_client.sync_from.return_value = None + storage._storage_client.sync_from.return_value = None result = storage.upload_directory( local_directory=tmpdir, @@ -159,7 +159,7 @@ def test_upload_directory_failure_appends_errors(self, mock_msc): with tempfile.TemporaryDirectory() as tmpdir: (Path(tmpdir) / "f1.txt").write_text("x") storage = MSCObjectStorage(bucket="b", region="us-east-1") - storage._s3_client.sync_from.side_effect = Exception("sync failed") + storage._storage_client.sync_from.side_effect = Exception("sync failed") result = storage.upload_directory( local_directory=tmpdir, @@ -196,7 +196,7 @@ def test_upload_file_returns_true_on_success(self, mock_msc): remote_key="key.txt", ) assert result is True - storage._s3_client.upload_file.assert_called_once() + storage._storage_client.upload_file.assert_called_once() finally: Path(path).unlink(missing_ok=True) @@ -206,7 +206,7 @@ def test_upload_file_returns_false_on_exception(self, mock_msc): path = f.name try: storage = MSCObjectStorage(bucket="b", region="us-east-1") - storage._s3_client.upload_file.side_effect = Exception("upload failed") + storage._storage_client.upload_file.side_effect = Exception("upload failed") result = storage.upload_file( local_file=path, @@ -219,14 +219,14 @@ def test_upload_file_returns_false_on_exception(self, mock_msc): def test_file_exists_returns_true_when_info_succeeds(self, mock_msc): """file_exists returns True when info() does not raise.""" storage = MSCObjectStorage(bucket="b", region="us-east-1") - storage._s3_client.info.return_value = None + storage._storage_client.info.return_value = None assert storage.file_exists("my/key") is True def test_file_exists_returns_false_when_file_not_found(self, mock_msc): """file_exists returns False when info() raises FileNotFoundError.""" storage = MSCObjectStorage(bucket="b", region="us-east-1") - storage._s3_client.info.side_effect = FileNotFoundError() + storage._storage_client.info.side_effect = FileNotFoundError() assert storage.file_exists("my/key") is False @@ -235,12 +235,12 @@ def test_delete_file_returns_true_on_success(self, mock_msc): storage = MSCObjectStorage(bucket="b", region="us-east-1") assert storage.delete_file("my/key") is True - storage._s3_client.delete.assert_called_once() + storage._storage_client.delete.assert_called_once() def test_delete_file_returns_false_when_file_not_found(self, mock_msc): """delete_file returns False when delete raises FileNotFoundError.""" storage = MSCObjectStorage(bucket="b", region="us-east-1") - storage._s3_client.delete.side_effect = FileNotFoundError() + storage._storage_client.delete.side_effect = FileNotFoundError() assert storage.delete_file("my/key") is False @@ -292,3 +292,330 @@ def test_generate_signed_url_returns_url_when_configured(self, mock_msc): assert "Policy=" in url assert "Signature=" in url assert "Key-Pair-Id=KP123" in url + + +class TestMSCObjectStorageS3Additional: + """Additional tests to cover S3 init branches and other uncovered S3 paths.""" + + @pytest.fixture + def mock_msc(self): + mock_module = MagicMock() + mock_module.StorageClientConfig.from_dict.side_effect = [ + MagicMock(), + MagicMock(), + ] + mock_module.StorageClient.return_value = MagicMock() + with patch.dict(sys.modules, {"multistorageclient": mock_module}): + yield mock_module + + def test_init_s3_transfer_acceleration(self, mock_msc): + """S3 init with use_transfer_acceleration sets accelerate endpoint URL.""" + storage = MSCObjectStorage(bucket="my-bucket", use_transfer_acceleration=True) + assert storage.endpoint_url == "https://my-bucket.s3-accelerate.amazonaws.com" + + def test_init_s3_with_credentials(self, mock_msc): + """S3 init with credentials sets AWS environment variables.""" + import os + + MSCObjectStorage( + bucket="b", + access_key_id="AKID", + secret_access_key="SECRET", # noqa: S106 + session_token="TOKEN", # noqa: S106 + ) + assert os.environ.get("AWS_ACCESS_KEY_ID") == "AKID" + assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "SECRET" + assert os.environ.get("AWS_SESSION_TOKEN") == "TOKEN" + + def test_init_s3_with_endpoint_url(self, mock_msc): + """S3 init with endpoint_url stores it and includes it in provider options.""" + storage = MSCObjectStorage(bucket="b", endpoint_url="http://minio:9000") + assert storage.endpoint_url == "http://minio:9000" + + def test_init_s3_with_rust_client(self, mock_msc): + """S3 init with use_rust_client=True adds rust_client section to config.""" + storage = MSCObjectStorage(bucket="b", use_rust_client=True) + assert storage.use_rust_client is True + + def test_init_unsupported_storage_type(self, mock_msc): + """__init__ raises ValueError for unsupported storage_type.""" + with pytest.raises(ValueError, match="Unsupported storage_type"): + MSCObjectStorage(bucket="b", storage_type="gcs") + + def test_upload_directory_non_recursive(self, mock_msc): + """upload_directory with recursive=False only counts top-level files.""" + with tempfile.TemporaryDirectory() as tmpdir: + (Path(tmpdir) / "top.txt").write_text("hello") + subdir = Path(tmpdir) / "sub" + subdir.mkdir() + (subdir / "deep.txt").write_text("world") + + storage = MSCObjectStorage(bucket="b", region="us-east-1") + storage._storage_client.sync_from.return_value = None + + result = storage.upload_directory( + local_directory=tmpdir, + remote_prefix="prefix", + recursive=False, + ) + assert result.success is True + assert result.files_uploaded == 1 # only top-level file + + def test_upload_file_path_is_directory(self, mock_msc): + """upload_file returns False when local_file path is a directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + storage = MSCObjectStorage(bucket="b", region="us-east-1") + result = storage.upload_file(local_file=tmpdir, remote_key="key.txt") + assert result is False + + def test_delete_file_generic_exception(self, mock_msc): + """delete_file returns False on an unexpected (non-FileNotFoundError) exception.""" + storage = MSCObjectStorage(bucket="b", region="us-east-1") + storage._storage_client.delete.side_effect = RuntimeError("unexpected") + assert storage.delete_file("my/key") is False + + def test_rsa_signer_no_key_raises(self, mock_msc): + """_rsa_signer raises ObjectStorageError when cloudfront_private_key is None.""" + storage = MSCObjectStorage(bucket="b", region="us-east-1") + storage.cloudfront_private_key = None + with pytest.raises(ObjectStorageError, match="No CloudFront private key"): + storage._rsa_signer(b"message") + + def test_rsa_signer_with_mocked_key(self, mock_msc): + """_rsa_signer signs message using the configured private key.""" + storage = MSCObjectStorage(bucket="b", region="us-east-1") + storage.cloudfront_private_key = ( + "-----BEGIN RSA PRIVATE KEY-----\nfake\n-----END RSA PRIVATE KEY-----" + ) + + mock_private_key = MagicMock() + mock_private_key.sign.return_value = b"fake_signature" + mock_serialization = MagicMock() + mock_serialization.load_pem_private_key.return_value = mock_private_key + + crypto_mocks = { + "cryptography": MagicMock(), + "cryptography.hazmat": MagicMock(), + "cryptography.hazmat.primitives": MagicMock(), + "cryptography.hazmat.primitives.hashes": MagicMock(), + "cryptography.hazmat.primitives.asymmetric": MagicMock(), + "cryptography.hazmat.primitives.asymmetric.padding": MagicMock(), + "cryptography.hazmat.primitives.serialization": mock_serialization, + } + with patch.dict(sys.modules, crypto_mocks): + result = storage._rsa_signer(b"message") + + assert result == b"fake_signature" + + def test_generate_signed_url_unsupported_type(self, mock_msc): + """generate_signed_url raises ObjectStorageError for unsupported storage_type.""" + storage = MSCObjectStorage(bucket="b", region="us-east-1") + storage.storage_type = "unsupported" + with pytest.raises(ObjectStorageError, match="Unsupported storage_type"): + storage.generate_signed_url("key.txt") + + +class TestMSCObjectStorageAzure: + """Tests for MSCObjectStorage with Azure storage type.""" + + @pytest.fixture + def mock_msc(self): + mock_module = MagicMock() + mock_module.StorageClientConfig.from_dict.side_effect = [ + MagicMock(), + MagicMock(), + ] + mock_module.StorageClient.return_value = MagicMock() + with patch.dict(sys.modules, {"multistorageclient": mock_module}): + yield mock_module + + def _make_azure_storage(self, mock_msc, **kwargs): + """Helper to reset mock side_effect for multiple instantiations.""" + mock_msc.StorageClientConfig.from_dict.side_effect = [ + MagicMock(), + MagicMock(), + ] + return MSCObjectStorage(**kwargs) + + def test_init_azure_managed_identity_with_account_name(self, mock_msc): + """Azure init with managed identity uses DefaultAzureCredentials.""" + storage = MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + azure_account_name="myaccount", + ) + assert storage.use_managed_identity is True + assert storage.azure_account_name == "myaccount" + assert storage.storage_type == "azure" + + def test_init_azure_managed_identity_with_endpoint_url(self, mock_msc): + """Azure init with managed identity and explicit endpoint_url succeeds.""" + storage = MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + endpoint_url="https://myaccount.blob.core.windows.net", + ) + assert storage.use_managed_identity is True + + def test_init_azure_no_account_name_no_endpoint_raises(self, mock_msc): + """Azure managed identity raises ObjectStorageError when neither account name nor endpoint_url is given.""" + with pytest.raises( + ObjectStorageError, match="Azure endpoint_url cannot be determined" + ): + MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + # no azure_account_name, no endpoint_url, no connection_string + ) + + def test_init_azure_connection_string_extracts_account_name(self, mock_msc): + """Azure init with connection string extracts AccountName from it.""" + conn_str = "DefaultEndpointsProtocol=https;AccountName=myaccount;AccountKey=key123;EndpointSuffix=core.windows.net" + storage = MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + azure_connection_string=conn_str, + ) + assert storage.use_managed_identity is False + assert storage.azure_account_name == "myaccount" + + def test_init_azure_connection_string_explicit_account_name(self, mock_msc): + """Azure init with connection string uses explicitly provided azure_account_name.""" + conn_str = "DefaultEndpointsProtocol=https;AccountKey=somekey" + storage = MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + azure_connection_string=conn_str, + azure_account_name="explicitaccount", + ) + assert storage.azure_account_name == "explicitaccount" + + def test_init_azure_connection_string_no_account_name_raises(self, mock_msc): + """Azure init raises ObjectStorageError when connection string has no AccountName and none provided.""" + conn_str = "DefaultEndpointsProtocol=https;AccountKey=somekey" + with pytest.raises(ObjectStorageError, match="Could not extract account name"): + MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + azure_connection_string=conn_str, + ) + + def test_init_azure_connection_string_with_blob_endpoint(self, mock_msc): + """Azure init uses BlobEndpoint directly from connection string.""" + conn_str = "BlobEndpoint=https://myaccount.blob.core.windows.net;AccountName=myaccount;AccountKey=key" + storage = MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + azure_connection_string=conn_str, + ) + assert storage.azure_account_name == "myaccount" + + def test_init_azure_connection_string_with_endpoint_suffix(self, mock_msc): + """Azure init constructs endpoint URL using EndpointSuffix from connection string.""" + conn_str = ( + "AccountName=myaccount;AccountKey=key;EndpointSuffix=core.chinacloudapi.cn" + ) + storage = MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + azure_connection_string=conn_str, + ) + assert storage.azure_account_name == "myaccount" + + def test_upload_directory_azure_destination(self, mock_msc): + """upload_directory for azure uses azure:// in the destination.""" + with tempfile.TemporaryDirectory() as tmpdir: + (Path(tmpdir) / "f.txt").write_text("hi") + storage = MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + azure_account_name="acct", + endpoint_url="https://acct.blob.core.windows.net", + ) + storage._storage_client.sync_from.return_value = None + + result = storage.upload_directory( + local_directory=tmpdir, + remote_prefix="prefix", + ) + assert result.success is True + assert "azure://" in result.destination + assert "mycontainer" in result.destination + + def test_generate_signed_url_azure_delegates_to_sas(self, mock_msc): + """generate_signed_url for azure delegates to _generate_azure_sas_url.""" + storage = MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + azure_account_name="acct", + endpoint_url="https://acct.blob.core.windows.net", + ) + storage._generate_azure_sas_url = MagicMock(return_value="https://sas_url") + + result = storage.generate_signed_url("key.txt") + + assert result == "https://sas_url" + storage._generate_azure_sas_url.assert_called_once_with("key.txt", 86400) + + def test_generate_azure_sas_url_missing_credentials_raises(self, mock_msc): + """_generate_azure_sas_url raises ObjectStorageError when account key is missing.""" + storage = MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + azure_account_name="acct", + endpoint_url="https://acct.blob.core.windows.net", + # azure_account_key not provided → None + ) + with pytest.raises( + ObjectStorageError, match="Azure account name and account key" + ): + storage._generate_azure_sas_url("key.txt", 3600) + + def test_generate_azure_sas_url_import_error(self, mock_msc): + """_generate_azure_sas_url raises ImportError when azure-storage-blob is not installed.""" + storage = MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + azure_account_name="acct", + azure_account_key="key123", + endpoint_url="https://acct.blob.core.windows.net", + ) + with patch.dict( + sys.modules, + { + "azure": MagicMock(), + "azure.storage": MagicMock(), + "azure.storage.blob": None, + }, + ): + with pytest.raises(ImportError, match="azure-storage-blob"): + storage._generate_azure_sas_url("key.txt", 3600) + + def test_generate_azure_sas_url_success(self, mock_msc): + """_generate_azure_sas_url returns a properly constructed SAS URL.""" + storage = MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + azure_account_name="myaccount", + azure_account_key="mykey", + endpoint_url="https://myaccount.blob.core.windows.net", + ) + + mock_azure_blob = MagicMock() + mock_azure_blob.ContainerSasPermissions.return_value = MagicMock() + mock_azure_blob.generate_container_sas.return_value = "sv=2020&sig=abc" + + with patch.dict( + sys.modules, + { + "azure": MagicMock(), + "azure.storage": MagicMock(), + "azure.storage.blob": mock_azure_blob, + }, + ): + url = storage._generate_azure_sas_url("path/to/file", 3600) + + assert "myaccount" in url + assert "mycontainer" in url + assert "path/to/file" in url + assert "sv=2020" in url diff --git a/test/serve/server/test_server_main.py b/test/serve/server/test_server_main.py index 6c9894ff0..f2ec4f294 100644 --- a/test/serve/server/test_server_main.py +++ b/test/serve/server/test_server_main.py @@ -1541,8 +1541,8 @@ def test_execute_workflow_inference_queue_none_503(self, client_exec): def test_execute_workflow_llen_raises_after_enqueue_500(self, client_exec): """When llen raises after enqueue (queue position lookup), returns 500.""" client, mock_redis, mock_queue = client_exec - # First 4 calls: admission (4 queues); 5th: position lookup. Make 5th raise. - mock_redis.llen.side_effect = [0, 0, 0, 0, RuntimeError("redis error")] + # 5 queue checks in admission control + 1 for queue position lookup + mock_redis.llen.side_effect = [0, 0, 0, 0, 0, RuntimeError("redis error")] with patch("earth2studio.serve.server.main.inference_queue", mock_queue): response = client.post( "/v1/infer/exec_wf", @@ -2113,6 +2113,537 @@ def test_get_workflow_result_file_stream_file_200(self, client_file): assert response.text == "hello world" +class TestLifespanBranches: + """Tests for lifespan startup exception branches (lines 236-242).""" + + def test_lifespan_workflow_registration_generic_exception_continues(self): + """When register_all_workflows raises non-ImportError, app still starts.""" + with ( + patch("redis.asyncio.Redis") as mock_async_redis, + patch("redis.Redis") as mock_sync_redis, + patch("rq.Queue"), + patch( + "earth2studio.serve.server.workflow.register_all_workflows", + side_effect=RuntimeError("unexpected registration error"), + ), + ): + mock_async_instance = MagicMock() + mock_async_instance.ping = AsyncMock(return_value=True) + mock_async_instance.close = AsyncMock() + mock_async_redis.return_value = mock_async_instance + mock_sync_instance = MagicMock() + mock_sync_instance.ping = MagicMock(return_value=True) + mock_sync_instance.close = MagicMock() + mock_sync_redis.return_value = mock_sync_instance + + from earth2studio.serve.server.main import app + + with TestClient(app, raise_server_exceptions=False) as c: + response = c.get("/liveness") + assert response.status_code == 200 + + def test_lifespan_redis_ping_failure_raises(self): + """When Redis ping fails, lifespan raises and app fails to start.""" + with ( + patch("redis.asyncio.Redis") as mock_async_redis, + patch("redis.Redis") as mock_sync_redis, + patch("rq.Queue"), + patch("earth2studio.serve.server.workflow.register_all_workflows"), + ): + mock_async_instance = MagicMock() + mock_async_instance.ping = AsyncMock( + side_effect=ConnectionError("Redis unavailable") + ) + mock_async_instance.close = AsyncMock() + mock_async_redis.return_value = mock_async_instance + mock_sync_instance = MagicMock() + mock_sync_instance.ping = MagicMock(return_value=True) + mock_sync_redis.return_value = mock_sync_instance + + from earth2studio.serve.server.main import app + + with pytest.raises(Exception): + with TestClient(app): + pass + + +class TestHealthCheckWithoutScriptDir: + """Tests health endpoint when SCRIPT_DIR env var is absent (lines 304-305).""" + + @pytest.fixture + def client_probes(self): + """Client for probe endpoints.""" + with ( + patch("redis.asyncio.Redis") as mock_async_redis, + patch("redis.Redis") as mock_sync_redis, + patch("rq.Queue"), + patch("earth2studio.serve.server.workflow.register_all_workflows"), + ): + mock_async_instance = MagicMock() + mock_async_instance.ping = AsyncMock(return_value=True) + mock_async_instance.close = AsyncMock() + mock_async_redis.return_value = mock_async_instance + mock_sync_instance = MagicMock() + mock_sync_instance.ping = MagicMock(return_value=True) + mock_sync_instance.close = MagicMock() + mock_sync_redis.return_value = mock_sync_instance + + from earth2studio.serve.server.main import app + + with TestClient(app, raise_server_exceptions=False) as c: + yield c + + def test_health_follows_repo_relative_path_when_script_dir_empty( + self, client_probes + ): + """Health check uses repo-relative script path when SCRIPT_DIR is empty/unset.""" + with patch.dict(os.environ, {"SCRIPT_DIR": ""}): + with patch( + "earth2studio.serve.server.main.asyncio.create_subprocess_exec" + ) as mock_exec: + mock_proc = MagicMock() + mock_proc.returncode = 0 + mock_proc.communicate = AsyncMock(return_value=(b"", b"")) + mock_exec.return_value = mock_proc + + response = client_probes.get("/health") + assert response.status_code == 200 + assert response.json()["status"] == "healthy" + + +class TestNotExposedWorkflowEndpoints: + """Tests for 404 when workflow is registered but not exposed (lines 427, 539, 671, 743, 876).""" + + @pytest.fixture + def client_with_workflow(self): + """Standard client with a registered workflow.""" + with ( + patch("redis.asyncio.Redis") as mock_async_redis, + patch("redis.Redis") as mock_sync_redis, + patch("rq.Queue"), + patch("earth2studio.serve.server.workflow.register_all_workflows"), + ): + mock_async_instance = MagicMock() + mock_async_instance.ping = AsyncMock(return_value=True) + mock_async_instance.close = AsyncMock() + mock_async_redis.return_value = mock_async_instance + mock_sync_instance = MagicMock() + mock_sync_instance.ping = MagicMock(return_value=True) + mock_sync_instance.close = MagicMock() + mock_sync_redis.return_value = mock_sync_instance + + from earth2studio.serve.server.main import app + from earth2studio.serve.server.workflow import ( + Workflow, + WorkflowParameters, + workflow_registry, + ) + + class TestEndpointsParams(WorkflowParameters): + x: int = Field(default=1) + + class TestEndpointsWf(Workflow): + name = "test_endpoints_wf" + description = "Test" + Parameters = TestEndpointsParams + + @classmethod + def validate_parameters(cls, parameters): + return TestEndpointsParams.validate(parameters) + + def run(self, parameters, execution_id): + return {"status": "ok"} + + workflow_registry._workflows["test_endpoints_wf"] = TestEndpointsWf + + with TestClient(app, raise_server_exceptions=False) as c: + yield c + + if "test_endpoints_wf" in workflow_registry._workflows: + del workflow_registry._workflows["test_endpoints_wf"] + + def test_schema_not_exposed_404(self, client_with_workflow): + """get_workflow_schema returns 404 when workflow is not exposed.""" + with patch( + "earth2studio.serve.server.main.workflow_registry.is_workflow_exposed", + return_value=False, + ): + response = client_with_workflow.get( + "/v1/workflows/test_endpoints_wf/schema" + ) + assert response.status_code == 404 + assert "not exposed" in response.json().get("detail", "").lower() + + def test_execute_not_exposed_404(self, client_with_workflow): + """execute_workflow returns 404 when workflow is not exposed.""" + with patch( + "earth2studio.serve.server.main.workflow_registry.is_workflow_exposed", + return_value=False, + ): + response = client_with_workflow.post( + "/v1/infer/test_endpoints_wf", json={"parameters": {}} + ) + assert response.status_code == 404 + + def test_get_status_workflow_not_found_404(self, client_with_workflow): + """get_workflow_status returns 404 when workflow not found.""" + response = client_with_workflow.get("/v1/infer/nonexistent_wf/exec_1/status") + assert response.status_code == 404 + assert "not found" in response.json().get("detail", "").lower() + + def test_get_status_not_exposed_404(self, client_with_workflow): + """get_workflow_status returns 404 when workflow is not exposed.""" + with patch( + "earth2studio.serve.server.main.workflow_registry.is_workflow_exposed", + return_value=False, + ): + response = client_with_workflow.get( + "/v1/infer/test_endpoints_wf/exec_1/status" + ) + assert response.status_code == 404 + + def test_get_results_workflow_not_found_404(self, client_with_workflow): + """get_workflow_results returns 404 when workflow not found.""" + response = client_with_workflow.get("/v1/infer/nonexistent_wf/exec_1/results") + assert response.status_code == 404 + + def test_get_results_not_exposed_404(self, client_with_workflow): + """get_workflow_results returns 404 when workflow is not exposed.""" + with patch( + "earth2studio.serve.server.main.workflow_registry.is_workflow_exposed", + return_value=False, + ): + response = client_with_workflow.get( + "/v1/infer/test_endpoints_wf/exec_1/results" + ) + assert response.status_code == 404 + + def test_get_result_file_not_exposed_404(self, client_with_workflow): + """get_workflow_result_file returns 404 when workflow is not exposed.""" + with patch( + "earth2studio.serve.server.main.workflow_registry.is_workflow_exposed", + return_value=False, + ): + response = client_with_workflow.get( + "/v1/infer/test_endpoints_wf/exec_1/results/file.nc" + ) + assert response.status_code == 404 + + +class TestExecuteWorkflowAdditionalBranches: + """Additional coverage for execute_workflow (lines 598, 605-609).""" + + @pytest.fixture + def client_exec2(self): + """Client with workflow for additional execute branch tests.""" + from earth2studio.serve.server.workflow import Workflow, WorkflowParameters + + with ( + patch("redis.asyncio.Redis") as mock_async_redis, + patch("redis.Redis") as mock_sync_redis, + patch("rq.Queue") as mock_queue_class, + patch("earth2studio.serve.server.workflow.register_all_workflows"), + ): + mock_async_instance = MagicMock() + mock_async_instance.ping = AsyncMock(return_value=True) + mock_async_instance.close = AsyncMock() + mock_async_instance.get = AsyncMock(return_value=None) + mock_async_redis.return_value = mock_async_instance + + mock_sync_instance = MagicMock() + mock_sync_instance.ping = MagicMock(return_value=True) + mock_sync_instance.close = MagicMock() + mock_sync_instance.setex = MagicMock() + mock_sync_instance.llen = MagicMock(return_value=0) + mock_sync_redis.return_value = mock_sync_instance + + mock_queue = MagicMock() + mock_job = MagicMock() + mock_job.id = "exec2_wf_exec_123" + mock_queue.enqueue = MagicMock(return_value=mock_job) + mock_queue_class.return_value = mock_queue + + class Exec2Params(WorkflowParameters): + x: int = Field(default=1) + + class Exec2Workflow(Workflow): + name = "exec2_wf" + description = "For additional execute tests" + Parameters = Exec2Params + + @classmethod + def validate_parameters(cls, parameters): + return Exec2Params.validate(parameters) + + def run(self, parameters, execution_id): + return {"status": "ok"} + + from earth2studio.serve.server.main import app + from earth2studio.serve.server.workflow import workflow_registry + + workflow_registry._workflows["exec2_wf"] = Exec2Workflow + with TestClient(app, raise_server_exceptions=False) as c: + yield c, mock_sync_instance, mock_queue, Exec2Workflow + if "exec2_wf" in workflow_registry._workflows: + del workflow_registry._workflows["exec2_wf"] + + def test_execute_llen_raises_after_enqueue_500(self, client_exec2): + """When llen raises during queue position lookup, returns 500.""" + client, mock_redis, mock_queue, _ = client_exec2 + # 5 queue checks in admission control + 1 for queue position lookup + mock_redis.llen.side_effect = [0, 0, 0, 0, 0, RuntimeError("redis error")] + with patch("earth2studio.serve.server.main.inference_queue", mock_queue): + response = client.post( + "/v1/infer/exec2_wf", + json={"parameters": {}}, + ) + assert response.status_code == 500 + + def test_execute_redis_none_during_queue_position_503(self, client_exec2): + """When redis_sync_client is None at queue position check, returns 503.""" + client, mock_redis, mock_queue, wf_class = client_exec2 + with patch("earth2studio.serve.server.main.check_admission_control"): + with patch.object(wf_class, "_save_execution_data"): + with patch("earth2studio.serve.server.main.redis_sync_client", None): + with patch( + "earth2studio.serve.server.main.inference_queue", mock_queue + ): + response = client.post( + "/v1/infer/exec2_wf", + json={"parameters": {}}, + ) + assert response.status_code == 503 + assert "Redis" in response.json().get("detail", "") + + +class TestGetWorkflowResultFileAdditionalBranches: + """Additional tests for get_workflow_result_file (lines 895, 903, 919-952, 986, 1035, 1052, 1062-1066).""" + + @pytest.fixture + def client_file2(self, tmp_path): + """Client and tmp dir for additional result file tests.""" + with ( + patch("redis.asyncio.Redis") as mock_async_redis, + patch("redis.Redis") as mock_sync_redis, + patch("rq.Queue") as mock_queue_class, + patch("earth2studio.serve.server.workflow.register_all_workflows"), + ): + mock_async_instance = MagicMock() + mock_async_instance.ping = AsyncMock(return_value=True) + mock_async_instance.close = AsyncMock() + mock_async_instance.get = AsyncMock(return_value=None) + mock_async_redis.return_value = mock_async_instance + mock_sync_instance = MagicMock() + mock_sync_instance.ping = MagicMock(return_value=True) + mock_sync_instance.close = MagicMock() + mock_sync_redis.return_value = mock_sync_instance + mock_queue_class.return_value = MagicMock() + + from earth2studio.serve.server.main import app + from earth2studio.serve.server.workflow import ( + Workflow, + WorkflowParameters, + WorkflowStatus, + workflow_registry, + ) + + class File2Params(WorkflowParameters): + x: int = Field(default=1) + + class File2Workflow(Workflow): + name = "file2_wf" + description = "File2 test" + Parameters = File2Params + + @classmethod + def validate_parameters(cls, parameters): + return File2Params.validate(parameters) + + def run(self, parameters, execution_id): + return {"status": "ok"} + + workflow_registry._workflows["file2_wf"] = File2Workflow + with TestClient(app, raise_server_exceptions=False) as c: + yield c, mock_async_instance, tmp_path, WorkflowStatus + if "file2_wf" in workflow_registry._workflows: + del workflow_registry._workflows["file2_wf"] + + def _completed_exec_data(self): + from earth2studio.serve.server.workflow import WorkflowResult, WorkflowStatus + + return WorkflowResult( + workflow_name="file2_wf", + execution_id="exec_1", + status=WorkflowStatus.COMPLETED, + start_time=datetime.now(timezone.utc).isoformat(), + end_time=datetime.now(timezone.utc).isoformat(), + ) + + def test_get_result_file_value_error_in_exec_data_404(self, client_file2): + """When _get_execution_data raises ValueError, returns 404.""" + client, *_ = client_file2 + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock( + side_effect=ValueError("Execution not found") + ) + mock_reg.get_workflow_class.return_value = mock_wf + response = client.get("/v1/infer/file2_wf/exec_1/results/file.nc") + assert response.status_code == 404 + assert "Execution not found" in response.json().get("detail", "") + + def test_get_result_file_redis_none_for_zip_path_503(self, client_file2): + """When filepath == request_id but redis_client is None, returns 503.""" + client, mock_async, tmp_path, _ = client_file2 + exec_data = self._completed_exec_data() + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + with patch("earth2studio.serve.server.main.redis_client", None): + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock(return_value=exec_data) + mock_reg.get_workflow_class.return_value = mock_wf + response = client.get( + "/v1/infer/file2_wf/exec_1/results/file2_wf:exec_1" + ) + assert response.status_code == 503 + + def test_get_result_file_zip_not_on_disk_404(self, client_file2): + """When zip key in Redis but zip file missing from disk, returns 404.""" + client, mock_async, tmp_path, _ = client_file2 + exec_data = self._completed_exec_data() + mock_async.get = AsyncMock(return_value="missing_zip.zip") + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + with patch("earth2studio.serve.server.main.redis_client", mock_async): + with patch("earth2studio.serve.server.main.RESULTS_ZIP_DIR", tmp_path): + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock(return_value=exec_data) + mock_reg.get_workflow_class.return_value = mock_wf + response = client.get( + "/v1/infer/file2_wf/exec_1/results/file2_wf:exec_1" + ) + assert response.status_code == 404 + assert "disk" in response.json().get("detail", {}).get("error", "").lower() + + def test_get_result_file_zip_stream_success(self, client_file2): + """When zip file exists on disk, returns 200 with streamed content.""" + client, mock_async, tmp_path, _ = client_file2 + exec_data = self._completed_exec_data() + zip_file = tmp_path / "results.zip" + zip_file.write_bytes(b"PK\x03\x04fake_zip_content") + mock_async.get = AsyncMock(return_value="results.zip") + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + with patch("earth2studio.serve.server.main.redis_client", mock_async): + with patch("earth2studio.serve.server.main.RESULTS_ZIP_DIR", tmp_path): + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock(return_value=exec_data) + mock_reg.get_workflow_class.return_value = mock_wf + response = client.get( + "/v1/infer/file2_wf/exec_1/results/file2_wf:exec_1" + ) + assert response.status_code == 200 + assert 'results.zip"' in response.headers.get("content-disposition", "") + + def test_get_result_file_zip_stream_with_range_header(self, client_file2): + """Zip streaming with Range header returns 206 and Content-Range.""" + client, mock_async, tmp_path, _ = client_file2 + exec_data = self._completed_exec_data() + zip_file = tmp_path / "ranged.zip" + zip_file.write_bytes(b"A" * 1000) + mock_async.get = AsyncMock(return_value="ranged.zip") + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + with patch("earth2studio.serve.server.main.redis_client", mock_async): + with patch("earth2studio.serve.server.main.RESULTS_ZIP_DIR", tmp_path): + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock(return_value=exec_data) + mock_reg.get_workflow_class.return_value = mock_wf + response = client.get( + "/v1/infer/file2_wf/exec_1/results/file2_wf:exec_1", + headers={"Range": "bytes=0-99"}, + ) + assert response.status_code == 206 + assert "Content-Range" in response.headers + + def test_get_result_file_filepath_with_output_dir_prefix(self, client_file2): + """When filepath starts with output dir name, the prefix is stripped.""" + client, mock_async, tmp_path, _ = client_file2 + exec_data = self._completed_exec_data() + output_dir = tmp_path / "exec_1" + output_dir.mkdir() + data_file = output_dir / "data.txt" + data_file.write_text("contents") + mock_async.get = AsyncMock(return_value=str(output_dir)) + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + with patch("earth2studio.serve.server.main.redis_client", mock_async): + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock(return_value=exec_data) + mock_reg.get_workflow_class.return_value = mock_wf + # filepath starts with output dir name (exec_1/data.txt) + response = client.get( + "/v1/infer/file2_wf/exec_1/results/exec_1/data.txt" + ) + assert response.status_code == 200 + assert response.text == "contents" + + def test_get_result_file_no_mime_type_uses_octet_stream(self, client_file2): + """Files with no recognized MIME type use application/octet-stream.""" + client, mock_async, tmp_path, _ = client_file2 + exec_data = self._completed_exec_data() + output_dir = tmp_path / "output" + output_dir.mkdir() + unknown_file = output_dir / "data.unknownext99999" + unknown_file.write_bytes(b"\x00\x01\x02\x03") + mock_async.get = AsyncMock(return_value=str(output_dir)) + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + with patch("earth2studio.serve.server.main.redis_client", mock_async): + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock(return_value=exec_data) + mock_reg.get_workflow_class.return_value = mock_wf + with patch("mimetypes.guess_type", return_value=(None, None)): + response = client.get( + "/v1/infer/file2_wf/exec_1/results/data.unknownext99999" + ) + assert response.status_code == 200 + assert "octet-stream" in response.headers.get("content-type", "") + + def test_get_result_file_with_range_header_206(self, client_file2): + """Serving regular file with Range header returns 206 and Content-Range.""" + client, mock_async, tmp_path, _ = client_file2 + exec_data = self._completed_exec_data() + output_dir = tmp_path / "output" + output_dir.mkdir() + data_file = output_dir / "large.txt" + data_file.write_bytes(b"X" * 500) + mock_async.get = AsyncMock(return_value=str(output_dir)) + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + with patch("earth2studio.serve.server.main.redis_client", mock_async): + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock(return_value=exec_data) + mock_reg.get_workflow_class.return_value = mock_wf + response = client.get( + "/v1/infer/file2_wf/exec_1/results/large.txt", + headers={"Range": "bytes=0-99"}, + ) + assert response.status_code == 206 + assert "Content-Range" in response.headers + + def test_get_result_file_generic_exception_500(self, client_file2): + """When an unexpected exception occurs in the file handler, returns 500.""" + client, mock_async, tmp_path, _ = client_file2 + exec_data = self._completed_exec_data() + mock_async.get = AsyncMock(side_effect=RuntimeError("unexpected redis error")) + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + with patch("earth2studio.serve.server.main.redis_client", mock_async): + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock(return_value=exec_data) + mock_reg.get_workflow_class.return_value = mock_wf + response = client.get( + "/v1/infer/file2_wf/exec_1/results/file2_wf:exec_1" + ) + assert response.status_code == 500 + assert "Failed to retrieve file" in response.json().get("detail", {}).get( + "error", "" + ) + + class TestMainEntrypoint: """Test main module entrypoint (covers line 1044).""" diff --git a/test/serve/server/test_utils.py b/test/serve/server/test_utils.py index 3e3b663ff..d99ed48c9 100644 --- a/test/serve/server/test_utils.py +++ b/test/serve/server/test_utils.py @@ -14,14 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock, patch +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException from earth2studio.serve.server.utils import ( + create_file_stream, get_inference_request_metadata_key, get_inference_request_output_path_key, get_inference_request_zip_key, get_results_zip_dir_key, get_signed_url_key, + parse_range_header, queue_next_stage, ) @@ -206,10 +212,45 @@ def test_result_zip_stage_object_storage_disabled_queues_finalize_metadata(self) assert "process_finalize_metadata" in mock_queue.enqueue.call_args[0][0] assert mock_queue.enqueue.call_args[0][1:3] == ("wf", "exec_1") + def test_object_storage_stage_with_geocatalog_url_queues_geocatalog(self): + """current_stage=object_storage with geocatalog URL enqueues process_geocatalog_ingestion.""" + mock_redis = MagicMock() + mock_config = MagicMock() + mock_config.object_storage.azure_geocatalog_url = ( + "https://geocatalog.example.com" + ) + mock_config.queue.geocatalog_ingestion_queue_name = "geocatalog_ingestion" + mock_config.queue.default_timeout = "1h" + mock_config.queue.job_timeout = "2h" + mock_job = MagicMock() + mock_job.id = "job_geo" + mock_queue = MagicMock() + mock_queue.enqueue.return_value = mock_job + + with ( + patch( + "earth2studio.serve.server.utils.get_config", return_value=mock_config + ), + patch("earth2studio.serve.server.utils.Queue", return_value=mock_queue), + ): + result = queue_next_stage( + redis_client=mock_redis, + current_stage="object_storage", + workflow_name="wf", + execution_id="exec_1", + output_path_str="/out", + ) + + assert result == "job_geo" + mock_queue.enqueue.assert_called_once() + assert "process_geocatalog_ingestion" in mock_queue.enqueue.call_args[0][0] + assert mock_queue.enqueue.call_args[0][1:3] == ("wf", "exec_1") + def test_object_storage_stage_queues_finalize_metadata(self): - """current_stage=object_storage enqueues process_finalize_metadata.""" + """current_stage=object_storage enqueues process_finalize_metadata when geocatalog is not configured.""" mock_redis = MagicMock() mock_config = MagicMock() + mock_config.object_storage.azure_geocatalog_url = None mock_config.queue.finalize_metadata_queue_name = "finalize_metadata" mock_config.queue.default_timeout = "1h" mock_config.queue.job_timeout = "2h" @@ -236,6 +277,37 @@ def test_object_storage_stage_queues_finalize_metadata(self): mock_queue.enqueue.assert_called_once() assert "process_finalize_metadata" in mock_queue.enqueue.call_args[0][0] + def test_geocatalog_ingestion_stage_queues_finalize_metadata(self): + """current_stage=geocatalog_ingestion enqueues process_finalize_metadata.""" + mock_redis = MagicMock() + mock_config = MagicMock() + mock_config.queue.finalize_metadata_queue_name = "finalize_metadata" + mock_config.queue.default_timeout = "1h" + mock_config.queue.job_timeout = "2h" + mock_job = MagicMock() + mock_job.id = "job_finalize" + mock_queue = MagicMock() + mock_queue.enqueue.return_value = mock_job + + with ( + patch( + "earth2studio.serve.server.utils.get_config", return_value=mock_config + ), + patch("earth2studio.serve.server.utils.Queue", return_value=mock_queue), + ): + result = queue_next_stage( + redis_client=mock_redis, + current_stage="geocatalog_ingestion", + workflow_name="wf", + execution_id="exec_1", + output_path_str="/out", + ) + + assert result == "job_finalize" + mock_queue.enqueue.assert_called_once() + assert "process_finalize_metadata" in mock_queue.enqueue.call_args[0][0] + assert mock_queue.enqueue.call_args[0][1:3] == ("wf", "exec_1") + def test_enqueue_exception_returns_none(self): """When Queue.enqueue raises, queue_next_stage returns None.""" mock_redis = MagicMock() @@ -264,3 +336,207 @@ def test_enqueue_exception_returns_none(self): ) assert result is None + + +class TestParseRangeHeader: + """Tests for parse_range_header.""" + + def test_no_range_header_returns_full_file(self): + """None range_header returns full file range with status 200.""" + start, end, content_length, status_code = parse_range_header(None, 1000) + assert start == 0 + assert end == 999 + assert content_length == 1000 + assert status_code == 200 + + def test_explicit_range_returns_partial_content(self): + """bytes=0-99 returns first 100 bytes with status 206.""" + start, end, content_length, status_code = parse_range_header("bytes=0-99", 1000) + assert start == 0 + assert end == 99 + assert content_length == 100 + assert status_code == 206 + + def test_open_ended_range(self): + """bytes=500- returns from byte 500 to end of file.""" + start, end, content_length, status_code = parse_range_header("bytes=500-", 1000) + assert start == 500 + assert end == 999 + assert content_length == 500 + assert status_code == 206 + + def test_suffix_range(self): + """bytes=-200 returns last 200 bytes of file.""" + start, end, content_length, status_code = parse_range_header("bytes=-200", 1000) + assert start == 800 + assert end == 999 + assert content_length == 200 + assert status_code == 206 + + def test_suffix_range_larger_than_file(self): + """bytes=-2000 on a 1000-byte file returns entire file from byte 0.""" + start, end, content_length, status_code = parse_range_header( + "bytes=-2000", 1000 + ) + assert start == 0 + assert end == 999 + assert content_length == 1000 + assert status_code == 206 + + def test_multiple_ranges_uses_first(self): + """Multiple ranges are accepted; only the first range is used.""" + start, end, content_length, status_code = parse_range_header( + "bytes=0-99,200-299", 1000 + ) + assert start == 0 + assert end == 99 + assert content_length == 100 + assert status_code == 206 + + def test_non_bytes_unit_raises_416(self): + """Range header not starting with 'bytes=' raises HTTPException 416.""" + with pytest.raises(HTTPException) as exc_info: + parse_range_header("items=0-99", 1000) + assert exc_info.value.status_code == 416 + + def test_missing_dash_raises_416(self): + """Range spec without a dash raises HTTPException 416.""" + with pytest.raises(HTTPException) as exc_info: + parse_range_header("bytes=100", 1000) + assert exc_info.value.status_code == 416 + + def test_start_beyond_file_size_raises_416(self): + """start >= file_size raises HTTPException 416.""" + with pytest.raises(HTTPException) as exc_info: + parse_range_header("bytes=1000-1099", 1000) + assert exc_info.value.status_code == 416 + + def test_end_beyond_file_size_clamped_to_last_byte(self): + """end >= file_size is clamped to file_size-1 per RFC 9110 §14.1.2, returns 206.""" + start, end, content_length, status_code = parse_range_header( + "bytes=0-1000", 1000 + ) + assert start == 0 + assert end == 999 + assert content_length == 1000 + assert status_code == 206 + + def test_end_before_start_raises_416(self): + """end < start raises HTTPException 416.""" + with pytest.raises(HTTPException) as exc_info: + parse_range_header("bytes=200-100", 1000) + assert exc_info.value.status_code == 416 + + def test_non_numeric_range_values_raises_416(self): + """Non-integer range values raise HTTPException 416.""" + with pytest.raises(HTTPException) as exc_info: + parse_range_header("bytes=abc-xyz", 1000) + assert exc_info.value.status_code == 416 + assert "Invalid range values" in exc_info.value.detail["details"] + + +class TestCreateFileStream: + """Tests for create_file_stream.""" + + @pytest.mark.asyncio + async def test_streams_full_file(self): + """Streams entire file content when start=0.""" + file_data = b"hello world" + mock_file = AsyncMock() + mock_file.read.side_effect = [file_data, b""] + mock_file.__aenter__ = AsyncMock(return_value=mock_file) + mock_file.__aexit__ = AsyncMock(return_value=False) + + with patch( + "earth2studio.serve.server.utils.aiofiles.open", return_value=mock_file + ): + chunks = [ + chunk + async for chunk in create_file_stream( + Path("/fake/file.bin"), 0, len(file_data) + ) + ] + + assert b"".join(chunks) == file_data + mock_file.seek.assert_not_called() + + @pytest.mark.asyncio + async def test_seeks_to_start_for_range_request(self): + """Seeks to the start offset for range requests (start > 0).""" + file_data = b"partial content" + mock_file = AsyncMock() + mock_file.read.side_effect = [file_data, b""] + mock_file.__aenter__ = AsyncMock(return_value=mock_file) + mock_file.__aexit__ = AsyncMock(return_value=False) + + with patch( + "earth2studio.serve.server.utils.aiofiles.open", return_value=mock_file + ): + chunks = [ + chunk + async for chunk in create_file_stream( + Path("/fake/file.bin"), 512, len(file_data) + ) + ] + + mock_file.seek.assert_called_once_with(512) + assert b"".join(chunks) == file_data + + @pytest.mark.asyncio + async def test_streams_in_multiple_chunks(self): + """Yields multiple chunks when file exceeds chunk size.""" + chunk_size = 1048576 # 1MB + chunk1 = b"A" * chunk_size + chunk2 = b"B" * 512 + mock_file = AsyncMock() + mock_file.read.side_effect = [chunk1, chunk2, b""] + mock_file.__aenter__ = AsyncMock(return_value=mock_file) + mock_file.__aexit__ = AsyncMock(return_value=False) + + content_length = chunk_size + 512 + with patch( + "earth2studio.serve.server.utils.aiofiles.open", return_value=mock_file + ): + chunks = [ + chunk + async for chunk in create_file_stream( + Path("/fake/file.bin"), 0, content_length + ) + ] + + assert len(chunks) == 2 + assert chunks[0] == chunk1 + assert chunks[1] == chunk2 + + @pytest.mark.asyncio + async def test_stops_when_chunk_is_empty(self): + """Stops streaming early if read returns empty bytes before content_length is reached.""" + mock_file = AsyncMock() + mock_file.read.return_value = b"" + mock_file.__aenter__ = AsyncMock(return_value=mock_file) + mock_file.__aexit__ = AsyncMock(return_value=False) + + with patch( + "earth2studio.serve.server.utils.aiofiles.open", return_value=mock_file + ): + chunks = [ + chunk + async for chunk in create_file_stream(Path("/fake/file.bin"), 0, 1000) + ] + + assert chunks == [] + + @pytest.mark.asyncio + async def test_reraises_exception_on_read_error(self): + """Re-raises exceptions encountered during file streaming.""" + mock_file = AsyncMock() + mock_file.read.side_effect = OSError("disk error") + mock_file.__aenter__ = AsyncMock(return_value=mock_file) + mock_file.__aexit__ = AsyncMock(return_value=False) + + with patch( + "earth2studio.serve.server.utils.aiofiles.open", return_value=mock_file + ): + with pytest.raises(OSError, match="disk error"): + async for _ in create_file_stream(Path("/fake/file.bin"), 0, 1000): + pass diff --git a/test/serve/server/test_workflow.py b/test/serve/server/test_workflow.py index 8050ef9d4..b3097922a 100644 --- a/test/serve/server/test_workflow.py +++ b/test/serve/server/test_workflow.py @@ -1583,6 +1583,114 @@ def test_auto_register_workflows_with_error(self): self.registry.auto_register_workflows(mock_redis) +class TestWorkflowRegistryExposure: + """Tests for WorkflowRegistry.is_workflow_exposed and list_workflows with exposure filtering.""" + + def setup_method(self): + self.registry = WorkflowRegistry() + self.registry.register(Workflow1) + self.registry.register(Workflow2) + self.registry.register(Workflow3) + + def _make_config(self, exposed=None, warmup=None): + mock_config = MagicMock() + mock_config.workflow_exposure.exposed_workflows = ( + exposed if exposed is not None else [] + ) + mock_config.workflow_exposure.warmup_workflows = ( + warmup if warmup is not None else [] + ) + return mock_config + + # --- is_workflow_exposed --- + + def test_is_workflow_exposed_empty_list_exposes_all(self): + """Empty exposed_workflows means all workflows are exposed.""" + mock_config = self._make_config(exposed=[], warmup=[]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + assert self.registry.is_workflow_exposed("workflow1") is True + assert self.registry.is_workflow_exposed("workflow2") is True + assert self.registry.is_workflow_exposed("unknown_wf") is True + + def test_is_workflow_exposed_in_exposed_list(self): + """Workflow in exposed_workflows is exposed.""" + mock_config = self._make_config(exposed=["workflow1", "workflow2"], warmup=[]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + assert self.registry.is_workflow_exposed("workflow1") is True + assert self.registry.is_workflow_exposed("workflow2") is True + + def test_is_workflow_exposed_in_warmup_list_only(self): + """Workflow in warmup_workflows (but not exposed_workflows) is still exposed.""" + mock_config = self._make_config(exposed=["workflow1"], warmup=["workflow2"]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + assert self.registry.is_workflow_exposed("workflow2") is True + + def test_is_workflow_exposed_not_in_any_list(self): + """Workflow not in exposed or warmup lists is not exposed.""" + mock_config = self._make_config(exposed=["workflow1"], warmup=["workflow2"]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + assert self.registry.is_workflow_exposed("workflow3") is False + + # --- list_workflows --- + + def test_list_workflows_exposed_only_empty_list_returns_all(self): + """Empty exposed_workflows with exposed_only=True returns all registered workflows.""" + mock_config = self._make_config(exposed=[], warmup=[]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + result = self.registry.list_workflows(exposed_only=True) + assert set(result.keys()) == {"workflow1", "workflow2", "workflow3"} + + def test_list_workflows_exposed_only_filters_to_exposed_list(self): + """exposed_only=True excludes warmup-only workflows from the listing.""" + mock_config = self._make_config( + exposed=["workflow1", "workflow2"], warmup=["workflow3"] + ) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + result = self.registry.list_workflows(exposed_only=True) + assert set(result.keys()) == {"workflow1", "workflow2"} + assert "workflow3" not in result + + def test_list_workflows_exposed_only_false_returns_all(self): + """exposed_only=False returns all registered workflows regardless of config.""" + mock_config = self._make_config(exposed=["workflow1"], warmup=[]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + result = self.registry.list_workflows(exposed_only=False) + assert set(result.keys()) == {"workflow1", "workflow2", "workflow3"} + + def test_list_workflows_default_is_exposed_only(self): + """list_workflows() with no args defaults to exposed_only=True.""" + mock_config = self._make_config(exposed=["workflow1"], warmup=[]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + result = self.registry.list_workflows() + assert set(result.keys()) == {"workflow1"} + + def test_list_workflows_returns_descriptions(self): + """list_workflows includes the description for each returned workflow.""" + mock_config = self._make_config(exposed=[], warmup=[]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + result = self.registry.list_workflows() + assert result["workflow1"] == "First workflow" + assert result["workflow2"] == "Second workflow" + + # Test helper functions class TestHelperFunctions: """Test helper functions"""