Skip to content

🚀[FEA]: Add torch.compile support for autoregressive inference rollouts in run.py #721

@abhaygoudannavar

Description

@abhaygoudannavar

Is this a new feature, an improvement, or a change to existing functionality?

Improvement

How would you describe the priority of this feature request

Medium

Please provide a clear description of problem you would like to solve.

Problem

The inference workflows in earth2studio/run.py (deterministic, diagnostic, ensemble) execute autoregressive rollouts using plain Python iteration:

# run.py, lines 140-149 (deterministic workflow)
for step, (x, coords) in enumerate(model):
    x, coords = map_coords(x, coords, output_coords)
    io.write(*split_coords(x, coords))
    pbar.update(1)
    if step == nsteps:
        break

Each rollout step calls the same model forward pass with an identical computation graph, but PyTorch re-traces the graph on every iteration. For models like FCN3, SFNO, Pangu, and DLWP — which repeat the exact same operations at every autoregressive step — this creates unnecessary overhead from:

  • Python overhead per iteration (interpreter, dynamic dispatch)
  • Redundant CUDA kernel launches without graph capture
  • No cross-step kernel fusion opportunities

Proposed Solution

Add optional torch.compile support to the prognostic model forward pass within the inference loop. torch.compile(mode="reduce-overhead") uses CUDA Graphs under the hood and can yield 1.5–3x speedups on autoregressive rollouts by:

  • Capturing the computation graph once and replaying it
  • Eliminating Python overhead per step
  • Enabling cross-kernel optimizations via Triton

Implementation Approach

Add a compile parameter to deterministic(), diagnostic(), and ensemble() workflows in run.py:

def deterministic(
    time, nsteps, prognostic, data, io,
    output_coords=OrderedDict({}),
    device=None,
    verbose=True,
    compile: bool = False,  # NEW
) -> IOBackend:
  • Wrap the model iterator with torch.compile when enabled:
if compile:
    # Compile the model's forward method for static graph replay
    prognostic._forward = torch.compile(
        prognostic._forward, mode="reduce-overhead"
    )
  • Handle dynamic shapes: The coordinate system and batch dimensions are static during rollout, so torch.compile should work without dynamic=True. A warmup step handles the initial compilation.
  • Graceful fallback: If compilation fails (e.g., unsupported ops in some models), catch the error and fall back to eager mode with a warning.

Key Considerations

  • Not all models benefit equally: Models using custom CUDA kernels (e.g., CorrDiff with diffusion sampling) may not see gains. The feature should be opt-in.
  • First step warmup: torch.compile incurs a one-time compilation cost on the first forward pass. For short rollouts (nsteps < 5), this overhead may negate benefits.
  • Memory: CUDA Graphs require static memory allocation. Models near GPU memory limits may need the max_autotune mode disabled.
  • Compatibility: Requires PyTorch ≥ 2.0. Earth2Studio already requires torch as a core dependency.

Metadata

Metadata

Assignees

No one assigned

    Labels

    ? - Needs TriageNeed team to review and classifyenhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions