diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ddf7eda..fe7a3ab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.14.3 hooks: - - id: ruff + - id: ruff-check types_or: [python, jupyter] args: ["--fix", "--show-fixes"] - id: ruff-format diff --git a/CHANGELOG.md b/CHANGELOG.md index 2068a75..5359d51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ The format is based on Keep a Changelog, and the project follows Semantic Versio ## [Unreleased] +### Added + +- `load_export(path)` reader API that loads any export produced by `export_batch(...)` — both combined (single file) and per-item (directory) layouts — and returns a structured `ExportResult`. Failed points are NaN-filled rather than dropped, partial model runs are surfaced via `status="partial"`, and `ExportResult.embedding(model)` provides a typed shortcut to the embedding array. + ### Fixed - `device="auto"` now correctly selects MPS on Apple Silicon instead of silently falling back to CPU. Follows the PyTorch-recommended priority (`cuda > mps > cpu`), giving an approximately 4x speedup on Apple M-series hardware for all API calls that use the default device. diff --git a/docs/api.md b/docs/api.md index 0a5c47b..b06fba4 100644 --- a/docs/api.md +++ b/docs/api.md @@ -7,7 +7,7 @@ If you want installation and first-run examples, start with [Quickstart](quickst ## Core Entry Points -Most users only need four public entry points: `get_embedding(...)`, `get_embeddings_batch(...)`, `export_batch(...)`, and `inspect_provider_patch(...)`. +Most users only need five public entry points: `get_embedding(...)`, `get_embeddings_batch(...)`, `export_batch(...)`, `load_export(...)`, and `inspect_provider_patch(...)`. --- @@ -18,6 +18,7 @@ Most users only need four public entry points: `get_embedding(...)`, `get_embedd | understand spatial/temporal/output specs | [API: Specs and Data Structures](api_specs.md) | | get one embedding or batch embeddings | [API: Embedding](api_embedding.md) | | build export pipelines and datasets | [API: Export](api_export.md) | +| read back a saved export | [API: Load](api_load.md) | | inspect raw provider patches before inference | [API: Inspect](api_inspect.md) | --- diff --git a/docs/api_export.md b/docs/api_export.md index 6c7de2d..32876b5 100644 --- a/docs/api_export.md +++ b/docs/api_export.md @@ -2,7 +2,7 @@ This page covers dataset export APIs. -Related pages: [API: Specs and Data Structures](api_specs.md), [API: Embedding](api_embedding.md), and [API: Inspect](api_inspect.md). +Related pages: [API: Specs and Data Structures](api_specs.md), [API: Embedding](api_embedding.md), [API: Load](api_load.md), and [API: Inspect](api_inspect.md). --- diff --git a/docs/api_load.md b/docs/api_load.md new file mode 100644 index 0000000..37d195b --- /dev/null +++ b/docs/api_load.md @@ -0,0 +1,194 @@ +# API: Load + +This page covers the export reader API for loading files produced by [`export_batch`](api_export.md). + +Related pages: [API: Specs and Data Structures](api_specs.md), [API: Embedding](api_embedding.md), and [API: Export](api_export.md). + +--- + +## load_export (primary / recommended) { #load_export } + +### Signature + +```python +load_export( + path: Union[str, os.PathLike], +) -> ExportResult +``` + +Use `load_export(...)` to read any export produced by [`export_batch`](api_export.md) — both **combined** (single file) and **per-item** (directory) layouts are supported. The layout is detected automatically. + +### Mental Model + +`load_export(...)` answers one question: *where is the export?* + +- Pass a **file** (`.npz`, `.nc`, or `.json`) to load a **combined** export. +- Pass a **directory** to load a **per-item** export. + +Everything else — layout detection, key parsing, NaN-fill for partial failures — is handled automatically. + +### Default Pattern + +```python +from rs_embed import load_export + +# Combined export (single file) +result = load_export("exports/run.npz") + +# Per-item export (directory of p00000.npz, p00001.npz, ...) +result = load_export("exports/per_item_run/") +``` + +--- + +## Parameters + +| Parameter | Meaning | +| --------- | --------------------------------------------------------------------------------- | +| `path` | Path to a `.npz`/`.nc`/`.json` file (combined) or a directory (per-item export). | + +### Raises + +| Exception | When | +| ----------------- | --------------------------------------------------------------------- | +| `FileNotFoundError` | Path does not exist. | +| `ValueError` | Path exists but cannot be interpreted as an rs-embed export. | +| `ImportError` | NetCDF export requested but `xarray` is not installed. | + +--- + +## Return Value: ExportResult { #ExportResult } + +`load_export(...)` always returns an `ExportResult`. + +```python +@dataclass +class ExportResult: + layout: str # "combined" or "per_item" + spatials: list[dict] # one dict per spatial point + temporal: dict | None # temporal spec used at export time + n_items: int # number of spatial points + status: str # "ok" | "partial" | "failed" + models: dict[str, ModelResult] # keyed by model name + manifest: dict # raw manifest for advanced use +``` + +### Convenience Methods + +```python +result.embedding("remoteclip") # → np.ndarray, shape (n_items, dim) +result.ok_models # → list[str] — models with status "ok" +result.failed_models # → list[str] — models with status "failed" +``` + +`embedding(model)` raises `KeyError` if the model was not part of the export and `ValueError` if the model failed for every point. + +--- + +## ModelResult { #ModelResult } + +Each entry in `result.models` is a `ModelResult`: + +```python +@dataclass +class ModelResult: + name: str # canonical model identifier + status: str # "ok" | "partial" | "failed" + embeddings: np.ndarray | None # (n_items, dim) or (n_items, C, H, W) + inputs: np.ndarray | None # (n_items, C, H, W) — None if not saved + meta: list[dict] # per-point embedding metadata + error: str | None # error string for fully-failed models +``` + +**Status values:** + +| Status | Meaning | +| --------- | ------------------------------------------------ | +| `"ok"` | All points succeeded. | +| `"partial"` | Some points succeeded; failed points are NaN-filled in `embeddings`. | +| `"failed"` | All points failed; `embeddings` is `None`. | + +--- + +## Common Patterns + +### Load and inspect a combined export + +```python +from rs_embed import load_export + +result = load_export("exports/combined_run.npz") + +print(result.n_items) # number of spatial points +print(result.ok_models) # models that succeeded +print(result.temporal) # {'start': '2022-06-01', 'end': '2022-09-01'} + +emb = result.embedding("remoteclip") # shape (n_items, dim) +``` + +### Access inputs when save_inputs=True + +```python +result = load_export("exports/combined_run.npz") +mr = result.models["prithvi"] +if mr.inputs is not None: + print(mr.inputs.shape) # (n_items, C, H, W) +``` + +### Load a per-item export directory + +```python +result = load_export("exports/per_item_run/") +print(result.layout) # "per_item" +print(result.n_items) # number of files found + +emb = result.embedding("remoteclip") # (n_items, dim) — stacked from per-file arrays +``` + +### Handle partial failures + +```python +result = load_export("exports/combined_run.npz") + +if result.failed_models: + print("Failed:", result.failed_models) + +for name in result.ok_models: + emb = result.embedding(name) + print(f"{name}: {emb.shape}") +``` + +### Load via the manifest JSON + +Pass the `.json` manifest path if that is what you have — `load_export` finds the paired array file automatically: + +```python +result = load_export("exports/combined_run.json") +``` + +--- + +## Relationship to export_batch + +`load_export` is the read-side counterpart to `export_batch`. Every file produced by `export_batch` can be read back with `load_export` without manually parsing NPZ keys or manifest JSON. + +```python +from rs_embed import export_batch, load_export, ExportTarget, ExportConfig, PointBuffer, TemporalSpec + +# Write +export_batch( + spatials=[PointBuffer(121.5, 31.2, 2048)], + temporal=TemporalSpec.range("2022-06-01", "2022-09-01"), + models=["remoteclip"], + target=ExportTarget.combined("exports/run"), + config=ExportConfig(save_inputs=True), +) + +# Read back +result = load_export("exports/run.npz") +emb = result.embedding("remoteclip") # shape (1, dim) +``` + +!!! tip "Simple rule" + Pass a file path for combined exports, a directory path for per-item exports. + Everything else is automatic. diff --git a/mkdocs.yml b/mkdocs.yml index c93910e..af32a91 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -89,6 +89,7 @@ nav: - Specs & Data Structures: api_specs.md - Embedding API: api_embedding.md - Export API: api_export.md + - Load API: api_load.md - Inspect API: api_inspect.md - Extending: - Overview: extending.md diff --git a/src/rs_embed/__init__.py b/src/rs_embed/__init__.py index 6647d95..11f231a 100644 --- a/src/rs_embed/__init__.py +++ b/src/rs_embed/__init__.py @@ -34,6 +34,7 @@ ) from .export import export_npz from .inspect import inspect_gee_patch, inspect_provider_patch +from .load import ExportResult, ModelResult, load_export from .model import Model from .pipelines.exporter import BatchExporter @@ -64,6 +65,10 @@ # Export API "export_batch", "export_npz", + # Load API + "load_export", + "ExportResult", + "ModelResult", # Inspection "inspect_provider_patch", # Backward-compatible alias for inspect_provider_patch diff --git a/src/rs_embed/api.py b/src/rs_embed/api.py index 58138d2..0aa37c5 100644 --- a/src/rs_embed/api.py +++ b/src/rs_embed/api.py @@ -70,12 +70,6 @@ from .tools.model_defaults import ( resolve_sensor_for_model as _resolve_sensor_for_model, ) -from .tools.normalization import ( - # Re-exported so `from rs_embed.api import ...` in tests/downstream still works. - _default_provider_backend_for_api, # noqa: F401 - _probe_model_describe, # noqa: F401 - _resolve_embedding_api_backend, # noqa: F401 -) from .tools.normalization import ( normalize_backend_name as _normalize_backend_name, ) diff --git a/src/rs_embed/load.py b/src/rs_embed/load.py new file mode 100644 index 0000000..52d7dee --- /dev/null +++ b/src/rs_embed/load.py @@ -0,0 +1,512 @@ +"""Utilities for loading exports produced by :func:`rs_embed.export_batch`. + +Two export layouts are supported: + +* **combined** — a single ``.npz``/``.nc`` file (plus a ``.json`` manifest) + containing all spatial points and models in one array file. +* **per_item** — a directory of ``p.npz``/``.nc`` and ``.json`` files, + one pair per spatial point. + +The public entry point is :func:`load_export`, which auto-detects the layout +from the path argument and returns an :class:`ExportResult`. +""" + +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from typing import Any + +import numpy as np + +from .tools.serialization import sanitize_key + +# --------------------------------------------------------------------------- +# Public data types +# --------------------------------------------------------------------------- + +_EMBEDDING_KEY_SINGLE = "embedding__{model}" +_EMBEDDING_KEY_BATCH = "embeddings__{model}" +_INPUT_KEY_SINGLE = "input_chw__{model}" +_INPUT_KEY_BATCH = "inputs_bchw__{model}" + + +@dataclass +class ModelResult: + """Loaded embeddings and metadata for one model across all spatial points. + + Attributes + ---------- + name : str + Canonical model identifier (e.g. ``"remoteclip"``). + status : str + Aggregate status: ``"ok"``, ``"partial"``, or ``"failed"``. + embeddings : np.ndarray or None + Float32 array of shape ``(n_items, dim)`` for pooled output, or + ``(n_items, C, H, W)`` for grid output. ``None`` when the model + failed for every point. Individual failed points within a + partially-succeeded run are filled with ``NaN``. + inputs : np.ndarray or None + Raw input patches if they were saved during export + (``ExportConfig(save_inputs=True)``). Shape + ``(n_items, C, H, W)``. ``None`` when inputs were not saved. + meta : list[dict] + Per-point embedding metadata dicts (one entry per spatial point). + Empty dicts are used for failed points. + error : str or None + Error string if *every* point for this model failed, else ``None``. + """ + + name: str + status: str + embeddings: np.ndarray | None + inputs: np.ndarray | None + meta: list[dict[str, Any]] + error: str | None = None + + +@dataclass +class ExportResult: + """Loaded export from :func:`rs_embed.export_batch`. + + Attributes + ---------- + layout : str + ``"combined"`` or ``"per_item"``. + spatials : list[dict] + Spatial specs (one per point), as JSON-serializable dicts matching + the original :class:`~rs_embed.BBox` / :class:`~rs_embed.PointBuffer` + fields. + temporal : dict or None + Temporal spec used during export, or ``None`` if not specified. + n_items : int + Number of spatial points. + status : str + Overall export status: ``"ok"``, ``"partial"``, or ``"failed"``. + models : dict[str, ModelResult] + Loaded results keyed by model name. + manifest : dict + Raw manifest dict for advanced inspection. + """ + + layout: str + spatials: list[dict[str, Any]] + temporal: dict[str, Any] | None + n_items: int + status: str + models: dict[str, ModelResult] + manifest: dict[str, Any] + + # ------------------------------------------------------------------ + # Convenience accessors + # ------------------------------------------------------------------ + + def embedding(self, model: str) -> np.ndarray: + """Return the embedding array for *model*. + + Parameters + ---------- + model : str + Model name as it appears in :attr:`models`. + + Returns + ------- + np.ndarray + Shape ``(n_items, dim)`` for pooled, ``(n_items, C, H, W)`` + for grid. + + Raises + ------ + KeyError + If *model* was not part of this export. + ValueError + If *model* failed for every point (``status == "failed"``). + """ + if model not in self.models: + available = sorted(self.models) + raise KeyError(f"Model {model!r} not found in export. Available models: {available}") + result = self.models[model] + if result.embeddings is None: + raise ValueError( + f"Model {model!r} has no embeddings (status={result.status!r}). " + f"Error: {result.error}" + ) + return result.embeddings + + @property + def ok_models(self) -> list[str]: + """Model names whose status is ``"ok"``.""" + return [name for name, r in self.models.items() if r.status == "ok"] + + @property + def failed_models(self) -> list[str]: + """Model names whose status is ``"failed"``.""" + return [name for name, r in self.models.items() if r.status == "failed"] + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + + +def load_export(path: str | os.PathLike[str]) -> ExportResult: + """Load an export produced by :func:`rs_embed.export_batch`. + + Auto-detects the export layout: + + * Pass a **file path** (``.npz``, ``.nc``, or ``.json``) to load a + **combined** export (all spatial points in one file). + * Pass a **directory path** to load a **per-item** export (one file + pair per spatial point). + + Parameters + ---------- + path : str or path-like + Path to the export file or directory. + + Returns + ------- + ExportResult + Loaded embeddings, inputs, metadata, and spatial information. + + Raises + ------ + FileNotFoundError + If *path* does not exist. + ValueError + If the path exists but cannot be interpreted as an rs-embed export. + + Examples + -------- + Load a combined export: + + >>> result = load_export("embeddings.npz") + >>> result.ok_models + ['prithvi', 'remoteclip'] + >>> emb = result.embedding("prithvi") # shape (n_items, dim) + + Load a per-item export directory: + + >>> result = load_export("embeddings/") + >>> result.n_items + 50 + """ + path = os.fspath(path) + if not os.path.exists(path): + raise FileNotFoundError(f"Export path does not exist: {path!r}") + if os.path.isdir(path): + return _load_per_item(path) + return _load_combined(path) + + +# --------------------------------------------------------------------------- +# Combined layout +# --------------------------------------------------------------------------- + + +def _resolve_combined_paths(path: str) -> tuple[str, str, str]: + """Return ``(arrays_path, json_path, fmt)`` for a combined export path. + + Accepts a ``.npz``, ``.nc``, or ``.json`` path and finds the paired files. + """ + base, ext = os.path.splitext(path) + ext = ext.lower() + + if ext == ".json": + # Find the paired array file + for candidate_ext, fmt in ((".npz", "npz"), (".nc", "netcdf")): + arrays_path = base + candidate_ext + if os.path.exists(arrays_path): + return arrays_path, path, fmt + raise FileNotFoundError( + f"Found manifest {path!r} but no paired .npz or .nc file at {base!r}." + ) + + if ext == ".npz": + fmt = "npz" + elif ext in (".nc", ".netcdf"): + fmt = "netcdf" + else: + raise ValueError( + f"Unrecognised file extension {ext!r}. " + "Expected .npz, .nc, or .json for a combined export." + ) + + json_path = base + ".json" + if not os.path.exists(json_path): + raise FileNotFoundError( + f"Array file {path!r} found but no paired manifest at {json_path!r}." + ) + return path, json_path, fmt + + +def _load_arrays(arrays_path: str, fmt: str) -> dict[str, np.ndarray]: + """Load all arrays from an npz or netcdf file.""" + if fmt == "npz": + with np.load(arrays_path, allow_pickle=False) as payload: + return {str(k): np.asarray(payload[k]) for k in payload.files} + if fmt == "netcdf": + try: + import xarray as xr + except ImportError as exc: + raise ImportError( + "xarray is required to load NetCDF exports. Install with: pip install xarray" + ) from exc + ds = xr.open_dataset(arrays_path) + try: + return {str(k): np.asarray(ds[k].values) for k in ds.data_vars} + finally: + ds.close() + raise ValueError(f"Unknown format {fmt!r}.") # pragma: no cover + + +def _load_json(json_path: str) -> dict[str, Any]: + """Load and validate a JSON manifest.""" + try: + with open(json_path, encoding="utf-8") as f: + data = json.load(f) + except (OSError, json.JSONDecodeError) as exc: + raise ValueError(f"Failed to load manifest {json_path!r}: {exc}") from exc + if not isinstance(data, dict): + raise ValueError(f"Manifest {json_path!r} is not a JSON object.") + return data + + +def _build_combined_model_result( + entry: dict[str, Any], + arrays: dict[str, np.ndarray], + n_items: int, +) -> ModelResult: + """Build a ModelResult from one entry in a combined manifest's models list.""" + name = str(entry.get("model", "")) + status = str(entry.get("status", "failed")) + error = entry.get("error") or None + + key = sanitize_key(name) + + # Embeddings + emb_meta = entry.get("embeddings") or {} + emb_key = emb_meta.get("npz_key") or _EMBEDDING_KEY_BATCH.format(model=key) + embeddings = np.asarray(arrays[emb_key], dtype=np.float32) if emb_key in arrays else None + + # Inputs (optional) + inp_meta = entry.get("inputs") or {} + inp_key = inp_meta.get("npz_key") if isinstance(inp_meta, dict) else None + inp_key = inp_key or _INPUT_KEY_BATCH.format(model=key) + inputs = np.asarray(arrays[inp_key], dtype=np.float32) if inp_key in arrays else None + + # Per-point metadata list + raw_metas = entry.get("metas") or [] + meta: list[dict[str, Any]] = [] + for i in range(n_items): + m = raw_metas[i] if i < len(raw_metas) else {} + meta.append(m if isinstance(m, dict) else {}) + + return ModelResult( + name=name, + status=status, + embeddings=embeddings, + inputs=inputs, + meta=meta, + error=str(error) if error is not None else None, + ) + + +def _load_combined(path: str) -> ExportResult: + arrays_path, json_path, fmt = _resolve_combined_paths(path) + arrays = _load_arrays(arrays_path, fmt) + manifest = _load_json(json_path) + + n_items = int(manifest.get("n_items", 0)) + spatials: list[dict[str, Any]] = list(manifest.get("spatials") or []) + temporal = manifest.get("temporal") or None + status = str(manifest.get("status", "ok")) + + model_entries: list[dict[str, Any]] = list(manifest.get("models") or []) + models: dict[str, ModelResult] = {} + for entry in model_entries: + result = _build_combined_model_result(entry, arrays, n_items) + models[result.name] = result + + return ExportResult( + layout="combined", + spatials=spatials, + temporal=temporal, + n_items=n_items, + status=status, + models=models, + manifest=manifest, + ) + + +# --------------------------------------------------------------------------- +# Per-item layout +# --------------------------------------------------------------------------- + + +def _find_per_item_files(directory: str) -> list[tuple[str, str, str]]: + """Return sorted ``(arrays_path, json_path, fmt)`` tuples for a per-item dir. + + Files must match the pattern ``p.(npz|nc)`` with a paired ``.json``. + """ + entries: list[tuple[int, str, str, str]] = [] + for fname in os.listdir(directory): + base, ext = os.path.splitext(fname) + ext = ext.lower() + if not (base.startswith("p") and base[1:].isdigit()): + continue + if ext not in (".npz", ".nc", ".netcdf"): + continue + fmt = "npz" if ext == ".npz" else "netcdf" + json_path = os.path.join(directory, base + ".json") + if not os.path.exists(json_path): + continue + arrays_path = os.path.join(directory, fname) + index = int(base[1:]) + entries.append((index, arrays_path, json_path, fmt)) + + if not entries: + raise ValueError( + f"No per-item export files found in {directory!r}. " + "Expected files matching 'p.npz' or 'p.nc' " + "with paired '.json' manifests." + ) + + entries.sort(key=lambda t: t[0]) + return [(arrays_path, json_path, fmt) for _, arrays_path, json_path, fmt in entries] + + +def _stack_per_item_embeddings( + per_point: list[np.ndarray | None], +) -> np.ndarray | None: + """Stack per-point embedding arrays into ``(n_items, ...)``. + + Failed points (``None``) are filled with NaN using the shape inferred + from the first successful array. + """ + sample = next((a for a in per_point if a is not None), None) + if sample is None: + return None + fill = np.full(sample.shape, fill_value=np.nan, dtype=np.float32) + rows = [a.astype(np.float32) if a is not None else fill for a in per_point] + return np.stack(rows, axis=0) + + +def _load_per_item(directory: str) -> ExportResult: + files = _find_per_item_files(directory) + n_items = len(files) + + spatials: list[dict[str, Any]] = [] + temporal: dict[str, Any] | None = None + overall_status: str | None = None + + # Collect per-point data keyed by model name + per_model_embeddings: dict[str, list[np.ndarray | None]] = {} + per_model_inputs: dict[str, list[np.ndarray | None]] = {} + per_model_meta: dict[str, list[dict[str, Any]]] = {} + per_model_status: dict[str, list[str]] = {} + per_model_error: dict[str, str | None] = {} + + manifest_first: dict[str, Any] = {} + + for i, (arrays_path, json_path, fmt) in enumerate(files): + arrays = _load_arrays(arrays_path, fmt) + manifest = _load_json(json_path) + + if i == 0: + manifest_first = manifest + temporal = manifest.get("temporal") or None + + spatial = manifest.get("spatial") + spatials.append(spatial if isinstance(spatial, dict) else {}) + + for entry in manifest.get("models") or []: + name = str(entry.get("model", "")) + if not name: + continue + + key = sanitize_key(name) + entry_status = str(entry.get("status", "failed")) + + # Embeddings + emb_meta = entry.get("embedding") or {} + emb_key = ( + emb_meta.get("npz_key") if isinstance(emb_meta, dict) else None + ) or _EMBEDDING_KEY_SINGLE.format(model=key) + emb_arr = arrays.get(emb_key) + + # Inputs + inp_meta = entry.get("input") or {} + inp_key = ( + inp_meta.get("npz_key") if isinstance(inp_meta, dict) else None + ) or _INPUT_KEY_SINGLE.format(model=key) + inp_arr = arrays.get(inp_key) + + per_model_embeddings.setdefault(name, [None] * n_items) + per_model_inputs.setdefault(name, [None] * n_items) + per_model_meta.setdefault(name, [{}] * n_items) + per_model_status.setdefault(name, []) + per_model_error.setdefault(name, None) + + per_model_embeddings[name][i] = ( + np.asarray(emb_arr, dtype=np.float32) if emb_arr is not None else None + ) + per_model_inputs[name][i] = ( + np.asarray(inp_arr, dtype=np.float32) if inp_arr is not None else None + ) + raw_meta = entry.get("meta") + per_model_meta[name][i] = raw_meta if isinstance(raw_meta, dict) else {} + per_model_status[name].append(entry_status) + + if entry_status == "failed" and per_model_error[name] is None: + per_model_error[name] = entry.get("error") or None + + # Build per-model results + models: dict[str, ModelResult] = {} + for name in per_model_embeddings: + emb_arrays = per_model_embeddings[name] + inp_arrays = per_model_inputs[name] + statuses = per_model_status[name] + + embeddings = _stack_per_item_embeddings(emb_arrays) + inputs = _stack_per_item_embeddings(inp_arrays) + + n_ok = sum(1 for s in statuses if s == "ok") + n_failed = sum(1 for s in statuses if s == "failed") + if n_failed == 0: + agg_status = "ok" + elif n_ok == 0: + agg_status = "failed" + else: + agg_status = "partial" + + models[name] = ModelResult( + name=name, + status=agg_status, + embeddings=embeddings if agg_status != "failed" else None, + inputs=inputs if any(a is not None for a in inp_arrays) else None, + meta=per_model_meta[name], + error=str(per_model_error[name]) if agg_status == "failed" else None, + ) + + # Aggregate overall status + all_statuses = [r.status for r in models.values()] + n_ok_models = sum(1 for s in all_statuses if s == "ok") + n_failed_models = sum(1 for s in all_statuses if s == "failed") + if n_failed_models == 0: + overall_status = "ok" if all_statuses else "ok" + elif n_ok_models == 0: + overall_status = "failed" + else: + overall_status = "partial" + + return ExportResult( + layout="per_item", + spatials=spatials, + temporal=temporal, + n_items=n_items, + status=overall_status, + models=models, + manifest=manifest_first, + ) diff --git a/tests/test_backend_resolution.py b/tests/test_backend_resolution.py index c5ade98..379958a 100644 --- a/tests/test_backend_resolution.py +++ b/tests/test_backend_resolution.py @@ -8,7 +8,7 @@ from unittest.mock import patch -from rs_embed.api import ( +from rs_embed.tools.normalization import ( _default_provider_backend_for_api, _resolve_embedding_api_backend, ) diff --git a/tests/test_device_resolution.py b/tests/test_device_resolution.py index 487d7b3..8c76de5 100644 --- a/tests/test_device_resolution.py +++ b/tests/test_device_resolution.py @@ -12,7 +12,6 @@ from rs_embed.embedders.runtime_utils import resolve_device_auto_torch - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -121,9 +120,7 @@ def test_auto_does_not_return_cpu_on_accelerated_hardware(self): pytest.skip("No GPU/MPS available on this machine") result = resolve_device_auto_torch("auto") - assert result != "cpu", ( - f"Expected cuda or mps on accelerated hardware, got {result!r}" - ) + assert result != "cpu", f"Expected cuda or mps on accelerated hardware, got {result!r}" def test_auto_returns_mps_on_apple_silicon(self): """On Apple Silicon with MPS available, 'auto' must return 'mps' (not 'cpu').""" diff --git a/tests/test_load_export.py b/tests/test_load_export.py new file mode 100644 index 0000000..b269927 --- /dev/null +++ b/tests/test_load_export.py @@ -0,0 +1,651 @@ +"""Tests for rs_embed.load — load_export() reader API.""" + +from __future__ import annotations + +import json +import os + +import numpy as np +import pytest + +from rs_embed.load import ( + ModelResult, + _build_combined_model_result, + _find_per_item_files, + _load_arrays, + _resolve_combined_paths, + _stack_per_item_embeddings, + load_export, +) + +# ══════════════════════════════════════════════════════════════════════ +# Fixtures / helpers +# ══════════════════════════════════════════════════════════════════════ + +N_ITEMS = 3 +N_DIM = 16 +N_BANDS = 4 +IMG_SIZE = 8 + + +def _make_embeddings(n: int = N_ITEMS, dim: int = N_DIM) -> np.ndarray: + return np.random.rand(n, dim).astype(np.float32) + + +def _make_inputs(n: int = N_ITEMS, c: int = N_BANDS, h: int = IMG_SIZE) -> np.ndarray: + return np.random.rand(n, c, h, h).astype(np.float32) + + +def _combined_npz( + directory: str, + stem: str = "run", + *, + model: str = "remoteclip", + save_inputs: bool = False, + status: str = "ok", +) -> tuple[str, str]: + """Write a minimal combined .npz + .json and return (npz_path, json_path).""" + from rs_embed.tools.serialization import sanitize_key + + key = sanitize_key(model) + embs = _make_embeddings() + arrays: dict[str, np.ndarray] = {f"embeddings__{key}": embs} + if save_inputs: + arrays[f"inputs_bchw__{key}"] = _make_inputs() + + npz_path = os.path.join(directory, f"{stem}.npz") + np.savez(npz_path, **arrays) + + metas = [{"sha1": f"abc{i}"} for i in range(N_ITEMS)] + manifest = { + "n_items": N_ITEMS, + "status": status, + "spatials": [{"type": "PointBuffer", "lon": 121.5 + i, "lat": 31.2} for i in range(N_ITEMS)], + "temporal": {"start": "2022-06-01", "end": "2022-09-01"}, + "models": [ + { + "model": model, + "status": status, + "embeddings": {"npz_key": f"embeddings__{key}"}, + **({"inputs": {"npz_key": f"inputs_bchw__{key}"}} if save_inputs else {}), + "metas": metas, + } + ], + } + json_path = os.path.join(directory, f"{stem}.json") + with open(json_path, "w") as f: + json.dump(manifest, f) + + return npz_path, json_path + + +def _per_item_dir( + directory: str, + n: int = N_ITEMS, + model: str = "remoteclip", + save_inputs: bool = False, +) -> str: + """Write n per-item .npz + .json files and return the directory.""" + from rs_embed.tools.serialization import sanitize_key + + key = sanitize_key(model) + for i in range(n): + emb = np.random.rand(N_DIM).astype(np.float32) + arrays: dict[str, np.ndarray] = {f"embedding__{key}": emb} + if save_inputs: + arrays[f"input_chw__{key}"] = np.random.rand(N_BANDS, IMG_SIZE, IMG_SIZE).astype(np.float32) + + npz_path = os.path.join(directory, f"p{i:05d}.npz") + np.savez(npz_path, **arrays) + + manifest = { + "spatial": {"type": "PointBuffer", "lon": 121.5 + i, "lat": 31.2}, + "temporal": {"start": "2022-06-01", "end": "2022-09-01"}, + "models": [ + { + "model": model, + "status": "ok", + "embedding": {"npz_key": f"embedding__{key}"}, + **({"input": {"npz_key": f"input_chw__{key}"}} if save_inputs else {}), + "meta": {"sha1": f"abc{i}"}, + } + ], + } + json_path = os.path.join(directory, f"p{i:05d}.json") + with open(json_path, "w") as f: + json.dump(manifest, f) + + return directory + + +# ══════════════════════════════════════════════════════════════════════ +# _resolve_combined_paths +# ══════════════════════════════════════════════════════════════════════ + + +class TestResolveCombinedPaths: + def test_npz_resolves(self, tmp_path): + npz = tmp_path / "run.npz" + npz.touch() + js = tmp_path / "run.json" + js.touch() + arrays_p, json_p, fmt = _resolve_combined_paths(str(npz)) + assert arrays_p == str(npz) + assert json_p == str(js) + assert fmt == "npz" + + def test_nc_resolves(self, tmp_path): + nc = tmp_path / "run.nc" + nc.touch() + js = tmp_path / "run.json" + js.touch() + arrays_p, json_p, fmt = _resolve_combined_paths(str(nc)) + assert fmt == "netcdf" + + def test_json_finds_npz(self, tmp_path): + js = tmp_path / "run.json" + js.touch() + npz = tmp_path / "run.npz" + npz.touch() + arrays_p, json_p, fmt = _resolve_combined_paths(str(js)) + assert arrays_p == str(npz) + assert fmt == "npz" + + def test_json_finds_nc_when_no_npz(self, tmp_path): + js = tmp_path / "run.json" + js.touch() + nc = tmp_path / "run.nc" + nc.touch() + arrays_p, json_p, fmt = _resolve_combined_paths(str(js)) + assert fmt == "netcdf" + + def test_json_missing_array_raises(self, tmp_path): + js = tmp_path / "run.json" + js.touch() + with pytest.raises(FileNotFoundError, match="paired"): + _resolve_combined_paths(str(js)) + + def test_npz_missing_json_raises(self, tmp_path): + npz = tmp_path / "run.npz" + npz.touch() + with pytest.raises(FileNotFoundError, match="manifest"): + _resolve_combined_paths(str(npz)) + + def test_unknown_extension_raises(self, tmp_path): + f = tmp_path / "run.csv" + f.touch() + with pytest.raises(ValueError, match="extension"): + _resolve_combined_paths(str(f)) + + +# ══════════════════════════════════════════════════════════════════════ +# _load_arrays +# ══════════════════════════════════════════════════════════════════════ + + +class TestLoadArrays: + def test_npz_round_trip(self, tmp_path): + arr = np.array([1.0, 2.0], dtype=np.float32) + path = str(tmp_path / "data.npz") + np.savez(path, foo=arr) + loaded = _load_arrays(path, "npz") + assert "foo" in loaded + np.testing.assert_array_equal(loaded["foo"], arr) + + def test_unknown_format_raises(self, tmp_path): + with pytest.raises(ValueError, match="format"): + _load_arrays(str(tmp_path / "x.bin"), "bin") + + +# ══════════════════════════════════════════════════════════════════════ +# _stack_per_item_embeddings +# ══════════════════════════════════════════════════════════════════════ + + +class TestStackPerItemEmbeddings: + def test_all_present(self): + arrs = [np.ones((4,), dtype=np.float32) * i for i in range(3)] + stacked = _stack_per_item_embeddings(arrs) + assert stacked.shape == (3, 4) + np.testing.assert_array_equal(stacked[1], np.ones(4) * 1) + + def test_all_none_returns_none(self): + assert _stack_per_item_embeddings([None, None]) is None + + def test_nan_fill_for_failed_point(self): + arrs = [np.ones((4,), dtype=np.float32), None, np.ones((4,), dtype=np.float32) * 2] + stacked = _stack_per_item_embeddings(arrs) + assert stacked.shape == (3, 4) + assert np.all(np.isnan(stacked[1])) + np.testing.assert_array_almost_equal(stacked[2], np.full(4, 2.0)) + + def test_output_dtype_is_float32(self): + arrs = [np.ones((8,), dtype=np.float64)] + stacked = _stack_per_item_embeddings(arrs) + assert stacked.dtype == np.float32 + + +# ══════════════════════════════════════════════════════════════════════ +# _build_combined_model_result +# ══════════════════════════════════════════════════════════════════════ + + +class TestBuildCombinedModelResult: + def _make_entry_and_arrays(self, model: str = "remoteclip", status: str = "ok"): + from rs_embed.tools.serialization import sanitize_key + + key = sanitize_key(model) + embs = _make_embeddings() + arrays = {f"embeddings__{key}": embs} + entry = { + "model": model, + "status": status, + "embeddings": {"npz_key": f"embeddings__{key}"}, + "metas": [{"sha1": f"x{i}"} for i in range(N_ITEMS)], + } + return entry, arrays + + def test_basic(self): + entry, arrays = self._make_entry_and_arrays() + result = _build_combined_model_result(entry, arrays, N_ITEMS) + assert result.name == "remoteclip" + assert result.status == "ok" + assert result.embeddings is not None + assert result.embeddings.shape == (N_ITEMS, N_DIM) + assert result.inputs is None + assert len(result.meta) == N_ITEMS + + def test_inputs_loaded_when_present(self): + from rs_embed.tools.serialization import sanitize_key + + model = "remoteclip" + key = sanitize_key(model) + embs = _make_embeddings() + inps = _make_inputs() + arrays = {f"embeddings__{key}": embs, f"inputs_bchw__{key}": inps} + entry = { + "model": model, + "status": "ok", + "embeddings": {"npz_key": f"embeddings__{key}"}, + "inputs": {"npz_key": f"inputs_bchw__{key}"}, + "metas": [{} for _ in range(N_ITEMS)], + } + result = _build_combined_model_result(entry, arrays, N_ITEMS) + assert result.inputs is not None + assert result.inputs.shape == (N_ITEMS, N_BANDS, IMG_SIZE, IMG_SIZE) + + def test_failed_status_propagated(self): + entry, arrays = self._make_entry_and_arrays(status="failed") + # Remove embeddings to simulate failed run + result = _build_combined_model_result(entry, {}, N_ITEMS) + assert result.status == "failed" + assert result.embeddings is None + + def test_meta_padded_when_short(self): + from rs_embed.tools.serialization import sanitize_key + + model = "prithvi" + key = sanitize_key(model) + embs = _make_embeddings() + arrays = {f"embeddings__{key}": embs} + entry = { + "model": model, + "status": "ok", + "embeddings": {"npz_key": f"embeddings__{key}"}, + "metas": [{"sha1": "abc"}], # only 1, but n_items=3 + } + result = _build_combined_model_result(entry, arrays, N_ITEMS) + assert len(result.meta) == N_ITEMS + assert result.meta[1] == {} + + +# ══════════════════════════════════════════════════════════════════════ +# _find_per_item_files +# ══════════════════════════════════════════════════════════════════════ + + +class TestFindPerItemFiles: + def test_finds_and_sorts(self, tmp_path): + d = str(tmp_path) + for i in [2, 0, 1]: + (tmp_path / f"p{i:05d}.npz").touch() + (tmp_path / f"p{i:05d}.json").touch() + files = _find_per_item_files(d) + indices = [int(os.path.basename(a).lstrip("p").split(".")[0]) for a, _, _ in files] + assert indices == [0, 1, 2] + + def test_skips_orphaned_npz(self, tmp_path): + (tmp_path / "p00000.npz").touch() + # no p00000.json + with pytest.raises(ValueError, match="No per-item"): + _find_per_item_files(str(tmp_path)) + + def test_empty_directory_raises(self, tmp_path): + with pytest.raises(ValueError, match="No per-item"): + _find_per_item_files(str(tmp_path)) + + def test_ignores_unrelated_files(self, tmp_path): + (tmp_path / "p00000.npz").touch() + (tmp_path / "p00000.json").touch() + (tmp_path / "README.md").touch() + (tmp_path / "other.npz").touch() + files = _find_per_item_files(str(tmp_path)) + assert len(files) == 1 + + +# ══════════════════════════════════════════════════════════════════════ +# load_export — path routing +# ══════════════════════════════════════════════════════════════════════ + + +class TestLoadExportRouting: + def test_missing_path_raises(self, tmp_path): + with pytest.raises(FileNotFoundError): + load_export(str(tmp_path / "nonexistent.npz")) + + def test_directory_routes_to_per_item(self, tmp_path): + _per_item_dir(str(tmp_path)) + result = load_export(str(tmp_path)) + assert result.layout == "per_item" + + def test_npz_file_routes_to_combined(self, tmp_path): + npz_path, _ = _combined_npz(str(tmp_path)) + result = load_export(npz_path) + assert result.layout == "combined" + + def test_json_file_routes_to_combined(self, tmp_path): + _, json_path = _combined_npz(str(tmp_path)) + result = load_export(json_path) + assert result.layout == "combined" + + +# ══════════════════════════════════════════════════════════════════════ +# load_export — combined layout +# ══════════════════════════════════════════════════════════════════════ + + +class TestLoadCombined: + def test_basic_combined(self, tmp_path): + npz_path, _ = _combined_npz(str(tmp_path)) + result = load_export(npz_path) + assert result.layout == "combined" + assert result.n_items == N_ITEMS + assert result.status == "ok" + assert len(result.spatials) == N_ITEMS + assert result.temporal is not None + + def test_models_loaded(self, tmp_path): + npz_path, _ = _combined_npz(str(tmp_path)) + result = load_export(npz_path) + assert "remoteclip" in result.models + mr = result.models["remoteclip"] + assert mr.status == "ok" + assert mr.embeddings.shape == (N_ITEMS, N_DIM) + + def test_ok_models_property(self, tmp_path): + npz_path, _ = _combined_npz(str(tmp_path)) + result = load_export(npz_path) + assert result.ok_models == ["remoteclip"] + assert result.failed_models == [] + + def test_embedding_accessor(self, tmp_path): + npz_path, _ = _combined_npz(str(tmp_path)) + result = load_export(npz_path) + emb = result.embedding("remoteclip") + assert emb.shape == (N_ITEMS, N_DIM) + + def test_embedding_unknown_model_raises_keyerror(self, tmp_path): + npz_path, _ = _combined_npz(str(tmp_path)) + result = load_export(npz_path) + with pytest.raises(KeyError, match="prithvi"): + result.embedding("prithvi") + + def test_inputs_loaded_when_saved(self, tmp_path): + npz_path, _ = _combined_npz(str(tmp_path), save_inputs=True) + result = load_export(npz_path) + mr = result.models["remoteclip"] + assert mr.inputs is not None + assert mr.inputs.shape == (N_ITEMS, N_BANDS, IMG_SIZE, IMG_SIZE) + + def test_inputs_none_when_not_saved(self, tmp_path): + npz_path, _ = _combined_npz(str(tmp_path), save_inputs=False) + result = load_export(npz_path) + assert result.models["remoteclip"].inputs is None + + def test_manifest_preserved(self, tmp_path): + npz_path, _ = _combined_npz(str(tmp_path)) + result = load_export(npz_path) + assert "n_items" in result.manifest + + def test_two_models(self, tmp_path): + from rs_embed.tools.serialization import sanitize_key + + models = ["remoteclip", "prithvi"] + arrays: dict[str, np.ndarray] = {} + model_entries = [] + for m in models: + key = sanitize_key(m) + embs = _make_embeddings() + arrays[f"embeddings__{key}"] = embs + model_entries.append( + { + "model": m, + "status": "ok", + "embeddings": {"npz_key": f"embeddings__{key}"}, + "metas": [{} for _ in range(N_ITEMS)], + } + ) + + npz_path = str(tmp_path / "run.npz") + np.savez(npz_path, **arrays) + manifest = { + "n_items": N_ITEMS, + "status": "ok", + "spatials": [{} for _ in range(N_ITEMS)], + "temporal": None, + "models": model_entries, + } + with open(str(tmp_path / "run.json"), "w") as f: + json.dump(manifest, f) + + result = load_export(npz_path) + assert set(result.models) == {"remoteclip", "prithvi"} + assert sorted(result.ok_models) == ["prithvi", "remoteclip"] + + +# ══════════════════════════════════════════════════════════════════════ +# load_export — per-item layout +# ══════════════════════════════════════════════════════════════════════ + + +class TestLoadPerItem: + def test_basic_per_item(self, tmp_path): + _per_item_dir(str(tmp_path)) + result = load_export(str(tmp_path)) + assert result.layout == "per_item" + assert result.n_items == N_ITEMS + assert result.status == "ok" + + def test_embeddings_stacked(self, tmp_path): + _per_item_dir(str(tmp_path)) + result = load_export(str(tmp_path)) + mr = result.models["remoteclip"] + assert mr.embeddings.shape == (N_ITEMS, N_DIM) + + def test_inputs_loaded_when_saved(self, tmp_path): + _per_item_dir(str(tmp_path), save_inputs=True) + result = load_export(str(tmp_path)) + mr = result.models["remoteclip"] + assert mr.inputs is not None + assert mr.inputs.shape == (N_ITEMS, N_BANDS, IMG_SIZE, IMG_SIZE) + + def test_spatials_collected(self, tmp_path): + _per_item_dir(str(tmp_path)) + result = load_export(str(tmp_path)) + assert len(result.spatials) == N_ITEMS + assert result.spatials[0]["type"] == "PointBuffer" + + def test_temporal_from_first_item(self, tmp_path): + _per_item_dir(str(tmp_path)) + result = load_export(str(tmp_path)) + assert result.temporal is not None + assert result.temporal["start"] == "2022-06-01" + + def test_partial_failure_nan_fill(self, tmp_path): + """Failed middle point is NaN-filled in the stacked array.""" + from rs_embed.tools.serialization import sanitize_key + + model = "remoteclip" + key = sanitize_key(model) + d = str(tmp_path) + + for i in range(N_ITEMS): + is_failed = i == 1 + emb = np.ones(N_DIM, dtype=np.float32) * i + arrays: dict[str, np.ndarray] = {} + if not is_failed: + arrays[f"embedding__{key}"] = emb + npz_path = os.path.join(d, f"p{i:05d}.npz") + np.savez(npz_path, **arrays) + + manifest = { + "spatial": {"type": "PointBuffer", "lon": 121.5 + i, "lat": 31.2}, + "temporal": None, + "models": [ + { + "model": model, + "status": "failed" if is_failed else "ok", + "embedding": {"npz_key": f"embedding__{key}"}, + "meta": {}, + } + ], + } + json_path = os.path.join(d, f"p{i:05d}.json") + with open(json_path, "w") as f: + json.dump(manifest, f) + + result = load_export(d) + mr = result.models[model] + assert mr.status == "partial" + assert mr.embeddings is not None + assert np.all(np.isnan(mr.embeddings[1])) + np.testing.assert_array_equal(mr.embeddings[0], np.zeros(N_DIM)) + + def test_all_failed_model_has_none_embeddings(self, tmp_path): + from rs_embed.tools.serialization import sanitize_key + + model = "remoteclip" + key = sanitize_key(model) + d = str(tmp_path) + + for i in range(N_ITEMS): + npz_path = os.path.join(d, f"p{i:05d}.npz") + np.savez(npz_path) # no embedding arrays + manifest = { + "spatial": {}, + "temporal": None, + "models": [{"model": model, "status": "failed", "embedding": {}, "meta": {}}], + } + json_path = os.path.join(d, f"p{i:05d}.json") + with open(json_path, "w") as f: + json.dump(manifest, f) + + result = load_export(d) + mr = result.models[model] + assert mr.status == "failed" + assert mr.embeddings is None + + def test_ok_and_failed_models_properties(self, tmp_path): + from rs_embed.tools.serialization import sanitize_key + + d = str(tmp_path) + for model, ok in [("remoteclip", True), ("prithvi", False)]: + key = sanitize_key(model) + for i in range(N_ITEMS): + arrays: dict[str, np.ndarray] = {} + if ok: + arrays[f"embedding__{key}"] = np.ones(N_DIM, dtype=np.float32) + npz_path = os.path.join(d, f"p{i:05d}.npz") + # Need to merge arrays across models + if os.path.exists(npz_path): + with np.load(npz_path) as existing: + existing_arrays = dict(existing) + arrays.update(existing_arrays) + np.savez(npz_path, **arrays) + + json_path = os.path.join(d, f"p{i:05d}.json") + if os.path.exists(json_path): + with open(json_path) as f: + manifest = json.load(f) + manifest["models"].append( + { + "model": model, + "status": "ok" if ok else "failed", + "embedding": {"npz_key": f"embedding__{key}"}, + "meta": {}, + } + ) + else: + manifest = { + "spatial": {}, + "temporal": None, + "models": [ + { + "model": model, + "status": "ok" if ok else "failed", + "embedding": {"npz_key": f"embedding__{key}"}, + "meta": {}, + } + ], + } + with open(json_path, "w") as f: + json.dump(manifest, f) + + result = load_export(d) + assert result.ok_models == ["remoteclip"] + assert result.failed_models == ["prithvi"] + + +# ══════════════════════════════════════════════════════════════════════ +# ExportResult — error cases +# ══════════════════════════════════════════════════════════════════════ + + +class TestExportResultErrors: + def test_embedding_failed_model_raises_valueerror(self, tmp_path): + npz_path, _ = _combined_npz(str(tmp_path), status="failed") + result = load_export(npz_path) + # status=failed means no embeddings array + mr = result.models["remoteclip"] + mr_failed = ModelResult( + name="remoteclip", + status="failed", + embeddings=None, + inputs=None, + meta=[], + error="fetch error", + ) + result.models["remoteclip"] = mr_failed + with pytest.raises(ValueError, match="no embeddings"): + result.embedding("remoteclip") + + def test_embedding_unknown_model_raises_keyerror(self, tmp_path): + npz_path, _ = _combined_npz(str(tmp_path)) + result = load_export(npz_path) + with pytest.raises(KeyError): + result.embedding("no_such_model") + + +# ══════════════════════════════════════════════════════════════════════ +# Public API surface +# ══════════════════════════════════════════════════════════════════════ + + +def test_top_level_exports(): + """load_export, ExportResult, ModelResult are importable from rs_embed.""" + import rs_embed + + assert hasattr(rs_embed, "load_export") + assert hasattr(rs_embed, "ExportResult") + assert hasattr(rs_embed, "ModelResult") + assert rs_embed.load_export is load_export