Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.3
hooks:
- id: ruff
- id: ruff-check
types_or: [python, jupyter]
args: ["--fix", "--show-fixes"]
- id: ruff-format
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ The format is based on Keep a Changelog, and the project follows Semantic Versio

## [Unreleased]

### Added

- `load_export(path)` reader API that loads any export produced by `export_batch(...)` — both combined (single file) and per-item (directory) layouts — and returns a structured `ExportResult`. Failed points are NaN-filled rather than dropped, partial model runs are surfaced via `status="partial"`, and `ExportResult.embedding(model)` provides a typed shortcut to the embedding array.

### Fixed

- `device="auto"` now correctly selects MPS on Apple Silicon instead of silently falling back to CPU. Follows the PyTorch-recommended priority (`cuda > mps > cpu`), giving an approximately 4x speedup on Apple M-series hardware for all API calls that use the default device.
Expand Down
3 changes: 2 additions & 1 deletion docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ If you want installation and first-run examples, start with [Quickstart](quickst

## Core Entry Points

Most users only need four public entry points: `get_embedding(...)`, `get_embeddings_batch(...)`, `export_batch(...)`, and `inspect_provider_patch(...)`.
Most users only need five public entry points: `get_embedding(...)`, `get_embeddings_batch(...)`, `export_batch(...)`, `load_export(...)`, and `inspect_provider_patch(...)`.

---

Expand All @@ -18,6 +18,7 @@ Most users only need four public entry points: `get_embedding(...)`, `get_embedd
| understand spatial/temporal/output specs | [API: Specs and Data Structures](api_specs.md) |
| get one embedding or batch embeddings | [API: Embedding](api_embedding.md) |
| build export pipelines and datasets | [API: Export](api_export.md) |
| read back a saved export | [API: Load](api_load.md) |
| inspect raw provider patches before inference | [API: Inspect](api_inspect.md) |

---
Expand Down
2 changes: 1 addition & 1 deletion docs/api_export.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This page covers dataset export APIs.

Related pages: [API: Specs and Data Structures](api_specs.md), [API: Embedding](api_embedding.md), and [API: Inspect](api_inspect.md).
Related pages: [API: Specs and Data Structures](api_specs.md), [API: Embedding](api_embedding.md), [API: Load](api_load.md), and [API: Inspect](api_inspect.md).

---

Expand Down
194 changes: 194 additions & 0 deletions docs/api_load.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# API: Load

This page covers the export reader API for loading files produced by [`export_batch`](api_export.md).

Related pages: [API: Specs and Data Structures](api_specs.md), [API: Embedding](api_embedding.md), and [API: Export](api_export.md).

---

## load_export (primary / recommended) { #load_export }

### Signature

```python
load_export(
path: Union[str, os.PathLike],
) -> ExportResult
```

Use `load_export(...)` to read any export produced by [`export_batch`](api_export.md) — both **combined** (single file) and **per-item** (directory) layouts are supported. The layout is detected automatically.

### Mental Model

`load_export(...)` answers one question: *where is the export?*

- Pass a **file** (`.npz`, `.nc`, or `.json`) to load a **combined** export.
- Pass a **directory** to load a **per-item** export.

Everything else — layout detection, key parsing, NaN-fill for partial failures — is handled automatically.

### Default Pattern

```python
from rs_embed import load_export

# Combined export (single file)
result = load_export("exports/run.npz")

# Per-item export (directory of p00000.npz, p00001.npz, ...)
result = load_export("exports/per_item_run/")
```

---

## Parameters

| Parameter | Meaning |
| --------- | --------------------------------------------------------------------------------- |
| `path` | Path to a `.npz`/`.nc`/`.json` file (combined) or a directory (per-item export). |

### Raises

| Exception | When |
| ----------------- | --------------------------------------------------------------------- |
| `FileNotFoundError` | Path does not exist. |
| `ValueError` | Path exists but cannot be interpreted as an rs-embed export. |
| `ImportError` | NetCDF export requested but `xarray` is not installed. |

---

## Return Value: ExportResult { #ExportResult }

`load_export(...)` always returns an `ExportResult`.

```python
@dataclass
class ExportResult:
layout: str # "combined" or "per_item"
spatials: list[dict] # one dict per spatial point
temporal: dict | None # temporal spec used at export time
n_items: int # number of spatial points
status: str # "ok" | "partial" | "failed"
models: dict[str, ModelResult] # keyed by model name
manifest: dict # raw manifest for advanced use
```

### Convenience Methods

```python
result.embedding("remoteclip") # → np.ndarray, shape (n_items, dim)
result.ok_models # → list[str] — models with status "ok"
result.failed_models # → list[str] — models with status "failed"
```

`embedding(model)` raises `KeyError` if the model was not part of the export and `ValueError` if the model failed for every point.

---

## ModelResult { #ModelResult }

Each entry in `result.models` is a `ModelResult`:

```python
@dataclass
class ModelResult:
name: str # canonical model identifier
status: str # "ok" | "partial" | "failed"
embeddings: np.ndarray | None # (n_items, dim) or (n_items, C, H, W)
inputs: np.ndarray | None # (n_items, C, H, W) — None if not saved
meta: list[dict] # per-point embedding metadata
error: str | None # error string for fully-failed models
```

**Status values:**

| Status | Meaning |
| --------- | ------------------------------------------------ |
| `"ok"` | All points succeeded. |
| `"partial"` | Some points succeeded; failed points are NaN-filled in `embeddings`. |
| `"failed"` | All points failed; `embeddings` is `None`. |

---

## Common Patterns

### Load and inspect a combined export

```python
from rs_embed import load_export

result = load_export("exports/combined_run.npz")

print(result.n_items) # number of spatial points
print(result.ok_models) # models that succeeded
print(result.temporal) # {'start': '2022-06-01', 'end': '2022-09-01'}

emb = result.embedding("remoteclip") # shape (n_items, dim)
```

### Access inputs when save_inputs=True

```python
result = load_export("exports/combined_run.npz")
mr = result.models["prithvi"]
if mr.inputs is not None:
print(mr.inputs.shape) # (n_items, C, H, W)
```

### Load a per-item export directory

```python
result = load_export("exports/per_item_run/")
print(result.layout) # "per_item"
print(result.n_items) # number of files found

emb = result.embedding("remoteclip") # (n_items, dim) — stacked from per-file arrays
```

### Handle partial failures

```python
result = load_export("exports/combined_run.npz")

if result.failed_models:
print("Failed:", result.failed_models)

for name in result.ok_models:
emb = result.embedding(name)
print(f"{name}: {emb.shape}")
```

### Load via the manifest JSON

Pass the `.json` manifest path if that is what you have — `load_export` finds the paired array file automatically:

```python
result = load_export("exports/combined_run.json")
```

---

## Relationship to export_batch

`load_export` is the read-side counterpart to `export_batch`. Every file produced by `export_batch` can be read back with `load_export` without manually parsing NPZ keys or manifest JSON.

```python
from rs_embed import export_batch, load_export, ExportTarget, ExportConfig, PointBuffer, TemporalSpec

# Write
export_batch(
spatials=[PointBuffer(121.5, 31.2, 2048)],
temporal=TemporalSpec.range("2022-06-01", "2022-09-01"),
models=["remoteclip"],
target=ExportTarget.combined("exports/run"),
config=ExportConfig(save_inputs=True),
)

# Read back
result = load_export("exports/run.npz")
emb = result.embedding("remoteclip") # shape (1, dim)
```

!!! tip "Simple rule"
Pass a file path for combined exports, a directory path for per-item exports.
Everything else is automatic.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ nav:
- Specs & Data Structures: api_specs.md
- Embedding API: api_embedding.md
- Export API: api_export.md
- Load API: api_load.md
- Inspect API: api_inspect.md
- Extending:
- Overview: extending.md
Expand Down
5 changes: 5 additions & 0 deletions src/rs_embed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from .export import export_npz
from .inspect import inspect_gee_patch, inspect_provider_patch
from .load import ExportResult, ModelResult, load_export
from .model import Model
from .pipelines.exporter import BatchExporter

Expand Down Expand Up @@ -64,6 +65,10 @@
# Export API
"export_batch",
"export_npz",
# Load API
"load_export",
"ExportResult",
"ModelResult",
# Inspection
"inspect_provider_patch",
# Backward-compatible alias for inspect_provider_patch
Expand Down
6 changes: 0 additions & 6 deletions src/rs_embed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,6 @@
from .tools.model_defaults import (
resolve_sensor_for_model as _resolve_sensor_for_model,
)
from .tools.normalization import (
# Re-exported so `from rs_embed.api import ...` in tests/downstream still works.
_default_provider_backend_for_api, # noqa: F401
_probe_model_describe, # noqa: F401
_resolve_embedding_api_backend, # noqa: F401
)
from .tools.normalization import (
normalize_backend_name as _normalize_backend_name,
)
Expand Down
Loading
Loading