Summary
When an alias foo[ax1, ..., axN] is substituted into a transitions: rate
that has additional template axes beyond ax1..axN, the alias body is
re-emitted in full into every consuming cell of the rate. XLA does not
collapse the duplicates, so per-step compute scales with the product of
the extra template axes.
For the COVID19_USA SMH R19 base discrete config (axes age=4, vax=2, loc=51,
imm=11) this produces an ~80x per-draw runtime regression after wiring in
the contact-kernel FoI:
| stage |
now |
before |
| warmup (compile) |
~82 s |
seconds |
per-draw engine.run dispatch |
~4000 ms |
~50 ms |
block_until_ready |
~1 ms |
similar |
JAX_LOG_COMPILES=1 confirms compiles happen only on draw 0 — so it is the
emitted XLA program itself that is doing roughly 80x more arithmetic per
timestep, not a recompile-per-draw bug.
Reproducer
YAML shape (simplified):
aliases:
foi[age]: "r0[time] * gamma * chi_omicron *
apply_along(age=ap,
contact_kernel[age, age=ap]
* apply_along(vax=v, loc=l, I[age=ap, vax=v, loc=l])
/ apply_along(vax=v, loc=l, N[age=ap, vax=v, loc=l]))"
transitions:
- {from: "X[age, vax, loc, imm]", to: "E[age, vax, loc]",
rate: "foi[age] * theta[imm]"}
The X→E transition exists at every (age, vax, loc, imm) cell:
4 * 2 * 51 * 11 = 4488 cells. After PR #110 each cell substitutes the
full foi[age] body with its row's age bound; the inner contraction
(constant for a given age) is re-emitted 4488 times rather than computed
once per age and re-used.
For hierarchical configs with foi[age, loc]: the multiplier is 51x
larger (loc count is part of the alias LHS template).
Expected behavior
When an alias is substituted into a rate, the alias output should be
hoisted into a per-(LHS-row) temporary that is computed once and indexed
into each consuming cell of the rate -- equivalent to manual CSE on the
alias output.
In the example above we expect 4 evaluations of the foi[age] body
(one per age coord), not 4488.
Suggested fix sketch
In _expand_alias_templates (or wherever per-row substitutions happen
into transitions: rates):
- Detect that an alias template
foo[ax1..axN] is consumed by a rate
whose template axes include ax1..axN plus extras.
- Emit the per-row alias bodies as anonymous intermediates
_foo__row<i> (shape: scalar; or a single shaped intermediate
_foo[ax1..axN]).
- Replace the inlined body in each consuming cell with a reference to
the intermediate, so XLA sees a single computation per row shared
across all extra-axis cells.
The vectorize pass already supports shaped intermediates; the change is
in how the template-substitution emits them.
Acceptance criteria
- A minimal regression test that asserts the inner alias body appears
exactly once per row (not once per consuming cell) in the normalized
RHS / vectorized graph.
- COVID19_USA SMH R19 base discrete prior-predictive draw drops back to
~50 ms per draw on CPU (currently ~4000 ms).
- All existing tests still pass.
Context
Found while wiring the POLYMOD-fitted contact_kernel[age, age]
posterior into all 14 SMH R19 configs after PR #108 / PR #110 made the
same-axis-twice contraction expressible. Correctness is fine; this is
purely a performance / codegen-duplication issue.
Summary
When an alias
foo[ax1, ..., axN]is substituted into atransitions:ratethat has additional template axes beyond
ax1..axN, the alias body isre-emitted in full into every consuming cell of the rate. XLA does not
collapse the duplicates, so per-step compute scales with the product of
the extra template axes.
For the COVID19_USA SMH R19 base discrete config (axes age=4, vax=2, loc=51,
imm=11) this produces an ~80x per-draw runtime regression after wiring in
the contact-kernel FoI:
engine.rundispatchblock_until_readyJAX_LOG_COMPILES=1confirms compiles happen only on draw 0 — so it is theemitted XLA program itself that is doing roughly 80x more arithmetic per
timestep, not a recompile-per-draw bug.
Reproducer
YAML shape (simplified):
The X→E transition exists at every
(age, vax, loc, imm)cell:4 * 2 * 51 * 11 = 4488cells. After PR #110 each cell substitutes thefull
foi[age]body with its row'sagebound; the inner contraction(constant for a given
age) is re-emitted 4488 times rather than computedonce per
ageand re-used.For hierarchical configs with
foi[age, loc]:the multiplier is 51xlarger (loc count is part of the alias LHS template).
Expected behavior
When an alias is substituted into a rate, the alias output should be
hoisted into a per-(LHS-row) temporary that is computed once and indexed
into each consuming cell of the rate -- equivalent to manual CSE on the
alias output.
In the example above we expect 4 evaluations of the
foi[age]body(one per age coord), not 4488.
Suggested fix sketch
In
_expand_alias_templates(or wherever per-row substitutions happeninto
transitions:rates):foo[ax1..axN]is consumed by a ratewhose template axes include
ax1..axNplus extras._foo__row<i>(shape: scalar; or a single shaped intermediate_foo[ax1..axN]).the intermediate, so XLA sees a single computation per row shared
across all extra-axis cells.
The vectorize pass already supports shaped intermediates; the change is
in how the template-substitution emits them.
Acceptance criteria
exactly once per row (not once per consuming cell) in the normalized
RHS / vectorized graph.
~50 ms per draw on CPU (currently ~4000 ms).
Context
Found while wiring the POLYMOD-fitted
contact_kernel[age, age]posterior into all 14 SMH R19 configs after PR #108 / PR #110 made the
same-axis-twice contraction expressible. Correctness is fine; this is
purely a performance / codegen-duplication issue.