From 05f4ca6d1c9d84053fdd13e5dfafdbd8ec34d592 Mon Sep 17 00:00:00 2001 From: Nicholas Geneva Date: Wed, 18 Mar 2026 06:27:59 +0000 Subject: [PATCH 1/3] adding claude commands --- .claude/commands/docs-serve.md | 32 +++++++++++++++++++++ .claude/commands/docs.md | 29 +++++++++++++++++++ .claude/commands/format.md | 7 +++++ .claude/commands/license.md | 18 ++++++++++++ .claude/commands/lint.md | 8 ++++++ .claude/commands/test-full.md | 17 ++++++++++++ .claude/commands/test.md | 51 ++++++++++++++++++++++++++++++++++ .gitignore | 3 ++ CLAUDE.md | 46 ++++++++++++++++++++++++++++++ 9 files changed, 211 insertions(+) create mode 100644 .claude/commands/docs-serve.md create mode 100644 .claude/commands/docs.md create mode 100644 .claude/commands/format.md create mode 100644 .claude/commands/license.md create mode 100644 .claude/commands/lint.md create mode 100644 .claude/commands/test-full.md create mode 100644 .claude/commands/test.md create mode 100644 CLAUDE.md diff --git a/.claude/commands/docs-serve.md b/.claude/commands/docs-serve.md new file mode 100644 index 000000000..ae9bcab0f --- /dev/null +++ b/.claude/commands/docs-serve.md @@ -0,0 +1,32 @@ +Build the HTML documentation and then serve it locally so it can be viewed in a browser. + +Ask the user which build mode they want if not specified: + +**Step 1 — build (choose one):** + +Fast build (skips example gallery): + +```bash +make docs +``` + +Full build (includes all examples, slow): + +```bash +make docs-full +``` + +Targeted single-example build (ask the user for the filename stem if not provided): + +```bash +make docs-dev FILENAME= +``` + +**Step 2 — serve:** + +```bash +uv run python -m http.server 8000 --directory docs/_build/html +``` + +The docs will be available at . Inform the user of the URL and that they +can stop the server with Ctrl-C. diff --git a/.claude/commands/docs.md b/.claude/commands/docs.md new file mode 100644 index 000000000..e2331d22b --- /dev/null +++ b/.claude/commands/docs.md @@ -0,0 +1,29 @@ +Build the HTML documentation. Ask the user which build mode they want if not specified: + +**Fast build** — skips the example gallery, quickest iteration: + +```bash +make docs +``` + +**Full build** — includes the complete example gallery (slow, requires all extras): + +```bash +make docs-full +``` + +**Targeted example build** — builds a single example page by filename stem, useful when +developing or debugging a specific example. Ask the user for the example filename if not +provided: + +```bash +make docs-dev FILENAME= +``` + +For example, to build only `examples/run_inference.py`: + +```bash +make docs-dev FILENAME=run_inference +``` + +All builds output to `docs/_build/html/`. Report any Sphinx warnings or errors encountered. diff --git a/.claude/commands/format.md b/.claude/commands/format.md new file mode 100644 index 000000000..22787761b --- /dev/null +++ b/.claude/commands/format.md @@ -0,0 +1,7 @@ +Run the black formatter on the entire codebase: + +```bash +make format +``` + +Report whether any files were reformatted or if everything was already compliant. diff --git a/.claude/commands/license.md b/.claude/commands/license.md new file mode 100644 index 000000000..1aa0b97af --- /dev/null +++ b/.claude/commands/license.md @@ -0,0 +1,18 @@ +Check that every Python source file carries the required SPDX Apache-2.0 license header: + +```bash +make license +``` + +The expected header is: + +```python +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# ... +``` + +Report any files that are missing or have an incorrect header. diff --git a/.claude/commands/lint.md b/.claude/commands/lint.md new file mode 100644 index 000000000..a34caef83 --- /dev/null +++ b/.claude/commands/lint.md @@ -0,0 +1,8 @@ +Run all pre-commit lint checks on the codebase (large-file check, trailing whitespace, +end-of-file fixer, debug statements, markdownlint, test naming, pyupgrade, ruff, and mypy): + +```bash +make lint +``` + +Report every failure with its file path and a short explanation of what needs to be fixed. diff --git a/.claude/commands/test-full.md b/.claude/commands/test-full.md new file mode 100644 index 000000000..7b484d5aa --- /dev/null +++ b/.claude/commands/test-full.md @@ -0,0 +1,17 @@ +Run the full test suite across all tox environments with coverage enabled. + +This requires all optional extras to be available. tox handles syncing each environment's +extras automatically via `uv sync`. + +```bash +make pytest-full +``` + +After the run completes, print a coverage summary: + +```bash +make coverage +``` + +Report overall pass/fail counts and the final coverage percentage. Flag any modules below the +90% coverage threshold. diff --git a/.claude/commands/test.md b/.claude/commands/test.md new file mode 100644 index 000000000..115cca8d3 --- /dev/null +++ b/.claude/commands/test.md @@ -0,0 +1,51 @@ +Run the test suite for a specific tox environment. + +## Choosing the right environment + +Each tox environment installs a specific set of uv extras before running. When the user asks +to test a file or feature, check the top of the relevant source file for +`OptionalDependencyFailure("group-name")` calls — the group name tells you which extra (and +therefore which tox environment) is required. + +| TOX_ENV | uv extras | Test paths | +|---|---|---| +| `test` | _(none)_ | `test/io`, `test/run`, `test/utils/`, `test/models/test_auto_models.py` | +| `test-data` | `data`, `cbottle` | `test/data`, `test/lexicon` | +| `test-perturb` | `perturbation` | `test/perturbation` | +| `test-stats` | `statistics`, `utils` | `test/statistics`, `test/utils/test_interp.py` | +| `test-px-models` | `all` | `test/models/px` (except ACE2) | +| `test-dx-models` | `all` | `test/models/dx` | +| `test-da-models` | `all` | `test/models/da` | +| `test-px-models-ace2` | `ace2` | `test/models/px/test_ace2.py` | +| `test-serve` | `serve` | `test/serve` | + +> **Note**: The `test` environment also covers `test/models/test_batch.py` and +> `test/utils/test_imports.py`. Check `tox.ini` for the full authoritative list. + +If the user has not specified a `TOX_ENV`, infer it from the file or feature being tested using +the table above. + +## Running via tox (recommended) + +```bash +make pytest TOX_ENV= +``` + +tox handles the `uv sync --extra ` automatically before running pytest. + +## Running a single test directly (faster, no tox overhead) + +First sync the required extras manually: + +```bash +uv sync --extra # e.g. uv sync --extra data,cbottle +``` + +Then run pytest directly: + +```bash +uv run pytest test/path/to/test_file.py -v +uv run pytest test/path/to/test_file.py -k "test_name" -v +``` + +Report the test results, including any failures with their tracebacks. diff --git a/.gitignore b/.gitignore index afe1d7abd..b6c82f78d 100644 --- a/.gitignore +++ b/.gitignore @@ -102,3 +102,6 @@ docs/examples # Pytestmon .testmondata* + +.claude/agent-memory +.claude/agents diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..09caf8e51 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,46 @@ +# Earth2Studio – Claude Code Guide + +## Project Rules + +The authoritative coding rules live in `.cursor/rules/`. Read the relevant rule +file(s) before writing or reviewing code: + +| Rule file | Topic | +|---|---| +| `e2s-000-python-style-guide.mdc` | Style, formatting, type hints, license header | +| `e2s-001-dependency-management.mdc` | Adding / updating dependencies | +| `e2s-002-api-documentation.mdc` | Docstrings and public API docs | +| `e2s-003-unit-testing.mdc` | Writing and structuring tests | +| `e2s-004-data-sources.mdc` | Implementing `DataSource` classes | +| `e2s-005-forecast-sources.mdc` | Implementing `ForecastSource` classes | +| `e2s-006-dataframe-sources.mdc` | Implementing `DataFrameSource` classes | +| `e2s-007-forecast-frame-sources.mdc` | Implementing `ForecastFrameSource` classes | +| `e2s-008-lexicon-usage.mdc` | Variable lexicons and coordinate conventions | +| `e2s-009-prognostic-models.mdc` | Implementing prognostic models | +| `e2s-010-diagnostic-models.mdc` | Implementing diagnostic models | +| `e2s-011-examples.mdc` | Writing gallery examples | +| `e2s-012-time-tolerance.mdc` | Time tolerance patterns | +| `e2s-013-assimilation-models.mdc` | Implementing data assimilation models | + +## Key Conventions (quick reference) + +- **License header** — every `.py` file in `earth2studio/` must start with the SPDX Apache-2.0 + header. +- **Type hints** — all public functions must be fully typed; the codebase is mypy-clean. +- **Logging** — use `loguru.logger`, never `print()`, inside `earth2studio/`. +- **Formatting** — black; run `/format` or `make format`. +- **Linting** — ruff + mypy; run `/lint` or `make lint` before opening a PR. + +## Custom Commands + +Use these slash commands (defined in `.claude/commands/`) to run common tasks: + +| Command | Action | +|---|---| +| `/format` | Auto-format code with black | +| `/lint` | Run all linters (ruff, mypy, pre-commit checks) | +| `/license` | Check SPDX license headers on all Python files | +| `/test` | Run tests for a specific tox environment | +| `/test-full` | Run the full test suite with coverage | +| `/docs` | Build docs — fast, full (with examples), or targeted single example | +| `/docs-serve` | Build docs then serve locally at | From f58d9121c0b3882492faac50b8f36b5b75b7ff0f Mon Sep 17 00:00:00 2001 From: Nicholas Geneva Date: Wed, 18 Mar 2026 06:28:13 +0000 Subject: [PATCH 2/3] Adding assimilation cursor rule --- .cursor/rules/e2s-013-assimilation-models.mdc | 578 ++++++++++++++++++ 1 file changed, 578 insertions(+) create mode 100644 .cursor/rules/e2s-013-assimilation-models.mdc diff --git a/.cursor/rules/e2s-013-assimilation-models.mdc b/.cursor/rules/e2s-013-assimilation-models.mdc new file mode 100644 index 000000000..1e15eec11 --- /dev/null +++ b/.cursor/rules/e2s-013-assimilation-models.mdc @@ -0,0 +1,578 @@ +--- +globs: earth2studio/models/da/*.py +description: Earth2Studio data assimilation model implementation guidelines following the AssimilationModel protocol +--- + +# Earth2Studio Data Assimilation Models + +This rule enforces data assimilation (DA) model implementation standards for Earth2Studio. DA models ingest sparse observations (DataFrames) and/or gridded state arrays (DataArrays) and produce an analysis. See [base.py](mdc:earth2studio/models/da/base.py) for the Protocol definition. + +## AssimilationModel Protocol + +DA models must implement the `AssimilationModel` Protocol from [base.py](mdc:earth2studio/models/da/base.py): + +- `__call__(*args) -> tuple[pd.DataFrame | xr.DataArray, ...]` — Stateless single-step inference +- `create_generator(*args) -> Generator[...]` — Stateful generator using Python's send protocol +- `init_coords() -> tuple[FrameSchema | CoordSystem, ...] | None` — Initialization data schema (or `None` for stateless models) +- `input_coords() -> tuple[FrameSchema | CoordSystem, ...]` — Input observation schema(s) +- `output_coords(input_coords, ...) -> tuple[FrameSchema | CoordSystem, ...]` — Output coordinate system +- `to(device) -> AssimilationModel` — Move model to device + +**Key difference from prognostic/diagnostic models**: inputs and outputs are `xr.DataArray` or `pd.DataFrame` objects (not raw `torch.Tensor` + `CoordSystem` tuples), because observations are typically sparse and tabular. + +## Required Base Classes + +DA models must inherit from: + +1. **`torch.nn.Module`**: PyTorch module base class +2. **`AutoModelMixin`**: Provides automatic checkpoint loading (from [mixin.py](mdc:earth2studio/models/auto/mixin.py)) + +```python +from earth2studio.models.auto import AutoModelMixin, Package +from earth2studio.models.da.base import AssimilationModel + +@check_optional_dependencies() +class MyDAModel(torch.nn.Module, AutoModelMixin): + """My data assimilation model.""" + pass +``` + +**Note**: DA models do NOT use `PrognosticMixin` or `batch_func`/`batch_coords` decorators — the observation-centric API does not use the same batching conventions as px/dx models. + +## Optional Dependencies + +DA models typically require optional dependencies. Always use the standard pattern: + +```python +from earth2studio.utils.imports import ( + OptionalDependencyFailure, + check_optional_dependencies, +) + +try: + from some_package import SomeClass +except ImportError: + OptionalDependencyFailure("da-mymodel") + SomeClass = None + +@check_optional_dependencies() +class MyDAModel(torch.nn.Module, AutoModelMixin): + ... +``` + +**Key Points**: +- Decorate the class itself with `@check_optional_dependencies()` +- Also decorate `load_model` with `@check_optional_dependencies()` +- Use `OptionalDependencyFailure("da-")` — the group name must match the entry in `pyproject.toml` + +## Coordinate Types: FrameSchema vs CoordSystem + +DA models use two coordinate types (from [type.py](mdc:earth2studio/utils/type.py)): + +- **`FrameSchema`** — describes a DataFrame's expected columns; keys are column names, values are representative numpy arrays (dtype and allowed values). Use for observation DataFrames. +- **`CoordSystem`** — `OrderedDict` mapping dimension names to coordinate arrays. Use for gridded tensor inputs/outputs (xr.DataArray with known dims). + +```python +from earth2studio.utils.type import CoordSystem, FrameSchema + +# FrameSchema example (tabular observation input) +obs_schema = FrameSchema({ + "time": np.empty(0, dtype="datetime64[ns]"), + "lat": np.empty(0, dtype=np.float32), + "lon": np.empty(0, dtype=np.float32), + "observation": np.empty(0, dtype=np.float32), + "variable": np.array(["u10m", "v10m", "t2m"], dtype=str), +}) + +# CoordSystem example (gridded xr.DataArray output) +out_coords = CoordSystem(OrderedDict({ + "time": request_time, + "variable": np.array(OUTPUT_VARIABLES, dtype=str), + "lat": np.linspace(90, -90, 181), + "lon": np.linspace(0, 360, 360, endpoint=False), +})) +``` + +## init_coords + +Return `None` for stateless models (no initial state needed). Return a tuple of `FrameSchema | CoordSystem` for models that require initialization data (e.g., an initial high-resolution state field): + +```python +def init_coords(self) -> tuple[CoordSystem] | None: + """Initialization coordinate system. + + Returns None if the model is stateless, or a tuple of coordinate + systems / frame schemas describing the initial state expected by + create_generator. + + Returns + ------- + tuple[CoordSystem] | None + Initialization schemas, or None for stateless models + """ + return None # stateless model + + # or for stateful models: + return ( + OrderedDict({ + "time": np.empty(0), + "lead_time": np.array([np.timedelta64(0, "h")]), + "variable": np.array(self.variables), + "hrrr_y": self.hrrr_y, + "hrrr_x": self.hrrr_x, + }), + ) +``` + +## input_coords + +Return a **tuple** of `FrameSchema | CoordSystem` — one entry per positional argument that `__call__` accepts. Use `np.empty(0, dtype=...)` for unbounded dimensions (like individual observations) and `np.array([...])` for enumerated allowed values. + +```python +def input_coords(self) -> tuple[FrameSchema, FrameSchema]: + """Input coordinate system specifying required DataFrame fields. + + Returns + ------- + tuple[FrameSchema, FrameSchema] + (conventional_schema, satellite_schema) + """ + conv_schema = FrameSchema({ + "time": np.empty(0, dtype="datetime64[ns]"), + "lat": np.empty(0, dtype=np.float32), + "lon": np.empty(0, dtype=np.float32), + "observation": np.empty(0, dtype=np.float32), + "variable": np.array(["u", "v", "t"], dtype=str), + }) + return (conv_schema,) +``` + +## output_coords + +Accept the input coordinate tuple plus any model-specific kwargs (e.g., `request_time`). Return a tuple of output `CoordSystem | FrameSchema`: + +```python +def output_coords( + self, + input_coords: tuple[FrameSchema | CoordSystem, ...], + request_time: np.ndarray | None = None, + **kwargs: Any, +) -> tuple[CoordSystem]: + """Output coordinate system. + + Parameters + ---------- + input_coords : tuple[FrameSchema | CoordSystem, ...] + Input coordinate system(s) + request_time : np.ndarray | None, optional + Analysis valid time(s) + + Returns + ------- + tuple[CoordSystem] + Output coordinate system(s) + """ + if request_time is None: + request_time = np.array([np.datetime64("NaT")], dtype="datetime64[ns]") + + return ( + CoordSystem(OrderedDict({ + "time": request_time, + "variable": np.array(self.OUTPUT_VARIABLES, dtype=str), + "lat": np.linspace(90, -90, 181), + "lon": np.linspace(0, 360, 360, endpoint=False), + })), + ) +``` + +**Key Points**: +- Validate input_coords using `handshake_dim` / `handshake_size` / `handshake_coords` where the schema is a `CoordSystem` +- For `FrameSchema` inputs, validate required columns with `validate_observation_fields` from [utils.py](mdc:earth2studio/models/da/utils.py) + +## __call__ (stateless inference) + +Accepts typed arguments matching `input_coords`, runs inference, and returns a tuple of `xr.DataArray | pd.DataFrame`. Use `@torch.inference_mode()` unless gradients are needed (e.g., DPS guidance requires gradients through the denoiser — see `StormCastSDA`). + +```python +@torch.inference_mode() +def __call__( + self, + obs: pd.DataFrame | None = None, +) -> xr.DataArray: + """Run single-step assimilation. + + Parameters + ---------- + obs : pd.DataFrame | None, optional + Observation DataFrame + + Returns + ------- + xr.DataArray + Analysis output on the same device as the model + """ + if obs is None: + raise ValueError("obs must be provided") + + request_time = obs.attrs.get("request_time") + if request_time is None: + raise ValueError( + "Observation DataFrame must have 'request_time' in attrs. " + "This is typically set by earth2studio data sources." + ) + + # ... process observations, run model ... + prediction = self._forward(inputs) + + (out_coords,) = self.output_coords(self.input_coords(), request_time=request_time) + return self.build_output(prediction, out_coords) +``` + +**Key Points**: +- Expect `request_time` in `obs.attrs` — this is set by earth2studio data sources and must be validated early +- Output data should live on the same device as the model (use cupy for GPU, numpy for CPU) +- Use `@torch.inference_mode()` unless gradient flow is required (document the reason if omitted) + +## create_generator (stateful iteration) + +Implements Python's bidirectional generator protocol. The generator receives observation data via `.send()` and yields analysis results. Always prime the generator by yielding `None` (or the initial state) first: + +```python +def create_generator( + self, + x: xr.DataArray, # initial state, if required +) -> Generator[xr.DataArray, pd.DataFrame | None, None]: + """Creates a stateful generator for sequential data assimilation. + + Yields the current analysis state and receives the next observations + via generator.send(). + + Parameters + ---------- + x : xr.DataArray + Initial state DataArray + + Yields + ------ + xr.DataArray + Analysis at each step + + Receives + -------- + pd.DataFrame | None + Observations for the next step. Pass None for steps with no data. + + Example + ------- + >>> gen = model.create_generator(x0) + >>> state = next(gen) # yields initial state / primes generator + >>> state = gen.send(obs_df) # step 1 with observations + >>> state = gen.send(None) # step 2 without observations + >>> gen.close() # clean up + """ + # Yield initial state to prime the generator + obs = yield x + + try: + while True: + result = self.__call__(x, obs) + x = result + obs = yield result + except GeneratorExit: + logger.debug("MyDAModel generator clean up complete.") +``` + +**Key Points**: +- Always yield the initial state / `None` first so the caller can prime the generator with `next(gen)` before the first `.send()` +- Handle `GeneratorExit` for clean-up logic (e.g., releasing GPU resources) +- Stateless models can implement `create_generator` by delegating each step to `__call__` +- For stateless models that have no meaningful initial state, yield `None` initially: `inputs = yield None` + +## DA Utilities + +Use the helper functions from [utils.py](mdc:earth2studio/models/da/utils.py): + +```python +from earth2studio.models.da.utils import ( + dfseries_to_torch, + filter_time_range, + validate_observation_fields, +) + +# Validate required columns in a DataFrame +validate_observation_fields(obs_df, ["time", "lat", "lon", "observation", "variable"]) + +# Filter observations within a time tolerance window +filtered = filter_time_range( + obs_df, + request_time=request_time, + tolerance=(np.timedelta64(-3, "h"), np.timedelta64(3, "h")), + time_column="time", +) + +# Zero-copy conversion of a DataFrame column to a torch tensor (uses cudf dlpack if available) +tensor = dfseries_to_torch(obs_df["observation"], dtype=torch.float32, device=device) +``` + +## cudf / cupy Support + +DA models should support both pandas/numpy (CPU) and cudf/cupy (GPU) DataFrames where practical. Use the standard optional import pattern: + +```python +try: + import cudf +except ImportError: + cudf = None # type: ignore[assignment, misc] + +try: + import cupy as cp +except ImportError: + cp = None +``` + +When returning `xr.DataArray` output, use cupy arrays for GPU models: + +```python +if self.device.type == "cuda" and cp is not None: + data = cp.asarray(out_tensor) +else: + data = out_tensor.cpu().numpy() +``` + +## Device Management + +Register a `device_buffer` to track the current device without extra parameters: + +```python +def __init__(self, ...): + super().__init__() + self.register_buffer("device_buffer", torch.empty(0)) + +@property +def device(self) -> torch.device: + return self.device_buffer.device +``` + +All other tensors that are not `torch.nn.Parameter` but need to follow `.to(device)` should also be registered as buffers via `self.register_buffer(...)`. + +## AutoModelMixin Implementation + +### load_default_package + +```python +@classmethod +def load_default_package(cls) -> Package: + """Default pre-trained model package. + + Returns + ------- + Package + Model package with default checkpoint location + """ + return Package( + "hf://nvidia/my-da-model@", + cache_options={"same_names": True}, + ) +``` + +### load_model + +```python +@classmethod +@check_optional_dependencies() +def load_model( + cls, + package: Package, + time_tolerance: TimeTolerance = np.timedelta64(3, "h"), +) -> AssimilationModel: + """Load assimilation model from package. + + Parameters + ---------- + package : Package + Package containing model checkpoint and statistics + time_tolerance : TimeTolerance, optional + Observation time tolerance window + + Returns + ------- + AssimilationModel + Loaded assimilation model + """ + model = SomeModel.from_checkpoint(package.resolve("model.mdlus")) + model.eval() + + # Load normalization stats, static fields, etc. + stats = np.load(package.resolve("stats.npy")) + + return cls(model=model, stats=stats, time_tolerance=time_tolerance) +``` + +**Key Points**: +- Decorate `load_model` with `@check_optional_dependencies()` +- Use `package.resolve("filename")` to get cached file paths +- Call `.eval()` on loaded neural network modules +- Only expose essential parameters — do not over-populate the API + +## Complete Example Structure + +```python +from __future__ import annotations + +from collections import OrderedDict +from collections.abc import Generator +from typing import Any + +import numpy as np +import pandas as pd +import torch +import xarray as xr +from loguru import logger + +from earth2studio.models.auto import AutoModelMixin, Package +from earth2studio.models.da.base import AssimilationModel +from earth2studio.models.da.utils import filter_time_range, validate_observation_fields +from earth2studio.utils.imports import OptionalDependencyFailure, check_optional_dependencies +from earth2studio.utils.type import CoordSystem, FrameSchema, TimeTolerance + +try: + import cupy as cp +except ImportError: + cp = None + +try: + from some_package import CoreModel +except ImportError: + OptionalDependencyFailure("da-mymodel") + CoreModel = None + + +@check_optional_dependencies() +class MyDAModel(torch.nn.Module, AutoModelMixin): + """My data assimilation model.""" + + OUTPUT_VARIABLES = ["u10m", "v10m", "t2m"] + + def __init__( + self, + model: torch.nn.Module, + time_tolerance: TimeTolerance = np.timedelta64(3, "h"), + ) -> None: + super().__init__() + self._model = model + self.register_buffer("device_buffer", torch.empty(0)) + self._tolerance = ( + -abs(time_tolerance) if isinstance(time_tolerance, np.timedelta64) + else time_tolerance + ) + + @property + def device(self) -> torch.device: + return self.device_buffer.device + + def init_coords(self) -> None: + return None # stateless model + + def input_coords(self) -> tuple[FrameSchema]: + return ( + FrameSchema({ + "time": np.empty(0, dtype="datetime64[ns]"), + "lat": np.empty(0, dtype=np.float32), + "lon": np.empty(0, dtype=np.float32), + "observation": np.empty(0, dtype=np.float32), + "variable": np.array(self.OUTPUT_VARIABLES, dtype=str), + }), + ) + + def output_coords( + self, + input_coords: tuple[FrameSchema], + request_time: np.ndarray | None = None, + **kwargs: Any, + ) -> tuple[CoordSystem]: + if request_time is None: + request_time = np.array([np.datetime64("NaT")], dtype="datetime64[ns]") + return ( + CoordSystem(OrderedDict({ + "time": request_time, + "variable": np.array(self.OUTPUT_VARIABLES, dtype=str), + "lat": np.linspace(90, -90, 181), + "lon": np.linspace(0, 360, 360, endpoint=False), + })), + ) + + @torch.inference_mode() + def __call__(self, obs: pd.DataFrame | None = None) -> xr.DataArray: + if obs is None: + raise ValueError("obs must be provided") + validate_observation_fields(obs, ["time", "lat", "lon", "observation", "variable"]) + + request_time = obs.attrs.get("request_time") + if request_time is None: + raise ValueError("obs.attrs must contain 'request_time'") + + # ... process and run model ... + prediction = self._model(...) + + (out_coords,) = self.output_coords(self.input_coords(), request_time=request_time) + out = prediction.squeeze().cpu().numpy() + return xr.DataArray(data=out, dims=list(out_coords.keys()), coords=out_coords) + + def create_generator( + self, + ) -> Generator[xr.DataArray, pd.DataFrame | None, None]: + inputs = yield None # prime generator + try: + while True: + result = self.__call__(inputs) + inputs = yield result + except GeneratorExit: + logger.debug("MyDAModel generator clean up complete.") + + @classmethod + def load_default_package(cls) -> Package: + return Package( + "hf://nvidia/my-da-model@", + cache_options={"same_names": True}, + ) + + @classmethod + @check_optional_dependencies() + def load_model(cls, package: Package) -> AssimilationModel: + model = CoreModel.from_checkpoint(package.resolve("model.mdlus")) + model.eval() + return cls(model=model) +``` + +## Differences from Prognostic and Diagnostic Models + +| Aspect | Prognostic (px) | Diagnostic (dx) | Assimilation (da) | +|---|---|---|---| +| Primary input | `torch.Tensor` + `CoordSystem` | `torch.Tensor` + `CoordSystem` | `pd.DataFrame` / `xr.DataArray` | +| Primary output | `torch.Tensor` + `CoordSystem` | `torch.Tensor` + `CoordSystem` | `xr.DataArray` / `pd.DataFrame` | +| Time integration | Yes (`create_iterator`) | No | Optional (`create_generator`) | +| `@batch_func` / `@batch_coords` | Yes | Yes | **No** | +| `PrognosticMixin` | Yes | No | No | +| Init data | No | No | Optional (`init_coords`) | +| Stateful iteration | via `create_iterator` generator | N/A | via `create_generator` send protocol | + +## Reminders + +- **Always inherit from**: `torch.nn.Module`, `AutoModelMixin` +- **Always decorate the class** with `@check_optional_dependencies()` +- **Always decorate `load_model`** with `@check_optional_dependencies()` +- **`input_coords` and `output_coords` return tuples**, even for single inputs/outputs +- **Use `FrameSchema` for tabular DataFrame inputs** and `CoordSystem` for gridded inputs/outputs +- **`request_time` must be validated** in `__call__` — expect it in `obs.attrs` +- **Return output on the model's device** — use cupy for GPU, numpy for CPU +- **Use `validate_observation_fields`** to check required DataFrame columns early +- **Use `filter_time_range`** for time-window filtering of observations +- **Use `dfseries_to_torch`** for zero-copy cudf→torch column conversion +- **Prime `create_generator` with `yield None`** (or the initial state) before the loop +- **Handle `GeneratorExit`** in `create_generator` for clean-up +- **Register `device_buffer`** to track device via a `device` property +- **Do NOT use `@batch_func` or `@batch_coords`** — these are px/dx conventions only +- **Do NOT use `@torch.inference_mode()`** if the forward pass requires gradients (e.g., DPS guidance); document the reason if omitted +- **Documentation**: Add the DA model to [models.rst](mdc:docs/modules/models.rst) in the `earth2studio.models.da` section, maintaining alphabetical order +- DO NOT attempt to make a general base class with intent to reuse the wrapper +- DO NOT over-populate the `load_model()` API — only expose essential parameters From d01effc7747a1a65f75dfc2b7d09c4c9e049ca97 Mon Sep 17 00:00:00 2001 From: Nicholas Geneva <5533524+NickGeneva@users.noreply.github.com> Date: Wed, 18 Mar 2026 15:15:21 -0700 Subject: [PATCH 3/3] Refine reminders and guidelines in assimilation models --- .cursor/rules/e2s-013-assimilation-models.mdc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/.cursor/rules/e2s-013-assimilation-models.mdc b/.cursor/rules/e2s-013-assimilation-models.mdc index 1e15eec11..62c90618c 100644 --- a/.cursor/rules/e2s-013-assimilation-models.mdc +++ b/.cursor/rules/e2s-013-assimilation-models.mdc @@ -558,21 +558,20 @@ class MyDAModel(torch.nn.Module, AutoModelMixin): ## Reminders -- **Always inherit from**: `torch.nn.Module`, `AutoModelMixin` -- **Always decorate the class** with `@check_optional_dependencies()` -- **Always decorate `load_model`** with `@check_optional_dependencies()` +- **Always decorate the class** with `@check_optional_dependencies()` if extra dependency group is needed +- **Always decorate `load_model`** with `@check_optional_dependencies()` if extra dependency group is needed - **`input_coords` and `output_coords` return tuples**, even for single inputs/outputs - **Use `FrameSchema` for tabular DataFrame inputs** and `CoordSystem` for gridded inputs/outputs - **`request_time` must be validated** in `__call__` — expect it in `obs.attrs` -- **Return output on the model's device** — use cupy for GPU, numpy for CPU +- **Return output on the model's device** - If a dense array output use cupy for GPU, numpy for CPU - **Use `validate_observation_fields`** to check required DataFrame columns early - **Use `filter_time_range`** for time-window filtering of observations - **Use `dfseries_to_torch`** for zero-copy cudf→torch column conversion - **Prime `create_generator` with `yield None`** (or the initial state) before the loop - **Handle `GeneratorExit`** in `create_generator` for clean-up - **Register `device_buffer`** to track device via a `device` property -- **Do NOT use `@batch_func` or `@batch_coords`** — these are px/dx conventions only +- **Do NOT use `@batch_func` or `@batch_coords`** - these are px/dx conventions only - **Do NOT use `@torch.inference_mode()`** if the forward pass requires gradients (e.g., DPS guidance); document the reason if omitted - **Documentation**: Add the DA model to [models.rst](mdc:docs/modules/models.rst) in the `earth2studio.models.da` section, maintaining alphabetical order - DO NOT attempt to make a general base class with intent to reuse the wrapper -- DO NOT over-populate the `load_model()` API — only expose essential parameters +- DO NOT over-populate the `load_model()` API - only expose essential parameters