-
Notifications
You must be signed in to change notification settings - Fork 158
Description
Is this a new feature, an improvement, or a change to existing functionality?
Improvement
How would you describe the priority of this feature request
Medium
Please provide a clear description of problem you would like to solve.
Problem
The inference workflows in earth2studio/run.py (deterministic, diagnostic, ensemble) execute autoregressive rollouts using plain Python iteration:
# run.py, lines 140-149 (deterministic workflow)
for step, (x, coords) in enumerate(model):
x, coords = map_coords(x, coords, output_coords)
io.write(*split_coords(x, coords))
pbar.update(1)
if step == nsteps:
breakEach rollout step calls the same model forward pass with an identical computation graph, but PyTorch re-traces the graph on every iteration. For models like FCN3, SFNO, Pangu, and DLWP — which repeat the exact same operations at every autoregressive step — this creates unnecessary overhead from:
- Python overhead per iteration (interpreter, dynamic dispatch)
- Redundant CUDA kernel launches without graph capture
- No cross-step kernel fusion opportunities
Proposed Solution
Add optional torch.compile support to the prognostic model forward pass within the inference loop. torch.compile(mode="reduce-overhead") uses CUDA Graphs under the hood and can yield 1.5–3x speedups on autoregressive rollouts by:
- Capturing the computation graph once and replaying it
- Eliminating Python overhead per step
- Enabling cross-kernel optimizations via Triton
Implementation Approach
Add a compile parameter to deterministic(), diagnostic(), and ensemble() workflows in run.py:
def deterministic(
time, nsteps, prognostic, data, io,
output_coords=OrderedDict({}),
device=None,
verbose=True,
compile: bool = False, # NEW
) -> IOBackend:- Wrap the model iterator with torch.compile when enabled:
if compile:
# Compile the model's forward method for static graph replay
prognostic._forward = torch.compile(
prognostic._forward, mode="reduce-overhead"
)- Handle dynamic shapes: The coordinate system and batch dimensions are static during rollout, so torch.compile should work without dynamic=True. A warmup step handles the initial compilation.
- Graceful fallback: If compilation fails (e.g., unsupported ops in some models), catch the error and fall back to eager mode with a warning.
Key Considerations
- Not all models benefit equally: Models using custom CUDA kernels (e.g., CorrDiff with diffusion sampling) may not see gains. The feature should be opt-in.
- First step warmup: torch.compile incurs a one-time compilation cost on the first forward pass. For short rollouts (nsteps < 5), this overhead may negate benefits.
- Memory: CUDA Graphs require static memory allocation. Models near GPU memory limits may need the max_autotune mode disabled.
- Compatibility: Requires PyTorch ≥ 2.0. Earth2Studio already requires torch as a core dependency.