Skip to content

Templated-alias substitution duplicates body across rate cells (no CSE) -- ~80x perf hit #111

@jc-macdonald

Description

@jc-macdonald

Summary

When an alias foo[ax1, ..., axN] is substituted into a transitions: rate
that has additional template axes beyond ax1..axN, the alias body is
re-emitted in full into every consuming cell of the rate. XLA does not
collapse the duplicates, so per-step compute scales with the product of
the extra template axes.

For the COVID19_USA SMH R19 base discrete config (axes age=4, vax=2, loc=51,
imm=11) this produces an ~80x per-draw runtime regression after wiring in
the contact-kernel FoI:

stage now before
warmup (compile) ~82 s seconds
per-draw engine.run dispatch ~4000 ms ~50 ms
block_until_ready ~1 ms similar

JAX_LOG_COMPILES=1 confirms compiles happen only on draw 0 — so it is the
emitted XLA program itself that is doing roughly 80x more arithmetic per
timestep, not a recompile-per-draw bug.

Reproducer

YAML shape (simplified):

aliases:
  foi[age]: "r0[time] * gamma * chi_omicron *
             apply_along(age=ap,
               contact_kernel[age, age=ap]
               * apply_along(vax=v, loc=l, I[age=ap, vax=v, loc=l])
               / apply_along(vax=v, loc=l, N[age=ap, vax=v, loc=l]))"
transitions:
  - {from: "X[age, vax, loc, imm]", to: "E[age, vax, loc]",
     rate: "foi[age] * theta[imm]"}

The X→E transition exists at every (age, vax, loc, imm) cell:
4 * 2 * 51 * 11 = 4488 cells. After PR #110 each cell substitutes the
full foi[age] body with its row's age bound; the inner contraction
(constant for a given age) is re-emitted 4488 times rather than computed
once per age and re-used.

For hierarchical configs with foi[age, loc]: the multiplier is 51x
larger (loc count is part of the alias LHS template).

Expected behavior

When an alias is substituted into a rate, the alias output should be
hoisted into a per-(LHS-row) temporary that is computed once and indexed
into each consuming cell of the rate -- equivalent to manual CSE on the
alias output.

In the example above we expect 4 evaluations of the foi[age] body
(one per age coord), not 4488.

Suggested fix sketch

In _expand_alias_templates (or wherever per-row substitutions happen
into transitions: rates):

  1. Detect that an alias template foo[ax1..axN] is consumed by a rate
    whose template axes include ax1..axN plus extras.
  2. Emit the per-row alias bodies as anonymous intermediates
    _foo__row<i> (shape: scalar; or a single shaped intermediate
    _foo[ax1..axN]).
  3. Replace the inlined body in each consuming cell with a reference to
    the intermediate, so XLA sees a single computation per row shared
    across all extra-axis cells.

The vectorize pass already supports shaped intermediates; the change is
in how the template-substitution emits them.

Acceptance criteria

  • A minimal regression test that asserts the inner alias body appears
    exactly once per row (not once per consuming cell) in the normalized
    RHS / vectorized graph.
  • COVID19_USA SMH R19 base discrete prior-predictive draw drops back to
    ~50 ms per draw on CPU (currently ~4000 ms).
  • All existing tests still pass.

Context

Found while wiring the POLYMOD-fitted contact_kernel[age, age]
posterior into all 14 SMH R19 configs after PR #108 / PR #110 made the
same-axis-twice contraction expressible. Correctness is fine; this is
purely a performance / codegen-duplication issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions