-
Notifications
You must be signed in to change notification settings - Fork 135
Fix bug in gradient of Blockwise'd Scan #1482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bug in gradient of Blockwise'd Scan #1482
Conversation
with config.change_flags(compute_test_value="off"): | ||
safe_inputs = [ | ||
tensor(dtype=inp.type.dtype, shape=(None,) * len(sig)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line was the problematic one: shape=(None,) * len(sig)
# FIXME: These core_outputs do not depend on core_inputs, not pretty | ||
# It's not neccessarily a problem because if they are referenced by the gradient, | ||
# they get replaced later in vectorize. But if the Op was to make any decision | ||
# by introspecting the dependencies of output on inputs it would fail badly! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also fixed
7f6d58f
to
61c0bf6
Compare
61c0bf6
to
1401b84
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes a gradient bug in Blockwise
when batching scans by refactoring the L_op
implementation and adds targeted tests for core‐type gradients and scan gradients.
- Refactored
Blockwise.L_op
to simplify and correct core gradient extraction and batching logic. - Renamed
test_op
tomy_test_op
in existing tests and added two new tests:test_blockwise_grad_core_type
test_scan_gradient_core_type
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
tests/tensor/test_blockwise.py | Renamed test_op, imported scan , and added two new gradient tests. |
pytensor/tensor/blockwise.py | Completely refactored the L_op method to remove the old helper and improve batching of core gradients. |
Comments suppressed due to low confidence (1)
pytensor/tensor/blockwise.py:353
- The new
core_inputs
comprehension no longer preservesNullType
orDisconnectedType
inputs as the old_bgrad
helper did viaas_core
. If an input is a null/disconnected gradient, it should be passed through unchanged, otherwise downstream gradient logic may break.
core_inputs = [
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1482 +/- ##
==========================================
- Coverage 82.01% 82.00% -0.01%
==========================================
Files 214 214
Lines 50426 50414 -12
Branches 8903 8902 -1
==========================================
- Hits 41355 41343 -12
Misses 6863 6863
Partials 2208 2208
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @ricardoV94 !! Definitely unblocks us over on pymc-extras 🤩
I'd say it was much more than a one-liner though 😉
The fix was a one liner, I just cleaned up stuff besides it. Check PRs commit by commit and you'll see it ;) |
@ricardoV94 , using this over on [(d__logp/dP0_diag_log__),
(d__logp/dinitial_trend),
(d__logp/dar_params_logodds__),
(d__logp/dsigma_trend_log__),
(d__logp/dsigma_ar_log__),
(d__logp/dsigma_obs_log__)] But now... The following example will trigger the error, just using a small dataset of 15 data points and a batch size of 5 (one per president)(agg.csv). I'm running pytensor main ( import numpy as np
import pandas as pd
import pymc as pm
import pymc_extras.statespace as pmss
import pytensor
import pytensor.tensor as pt
import xarray as xr
presidents = agg.president.unique()
mod = pmss.structural.LevelTrendComponent(order=2, innovations_order=[0, 1])
mod += pmss.structural.AutoregressiveComponent(order=1)
mod += pmss.structural.MeasurementError(name="obs")
ss_mod = mod.build(
name="president",
batch_coords={"president": presidents}, # this is gonna be leftmost dimension
)
ss_array = (
agg.set_index(["president", "month_id"])["approve_pr"].unstack("month_id").to_numpy()[..., None]
) # dims=(president, timesteps, obs_dim)
initial_trend_dims, sigma_trend_dims, ar_param_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords
with pm.Model(coords=coords | ss_mod.batch_coords) as model_1:
P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5, dims="president")
P0 = pm.Deterministic(
"P0", pt.eye(ss_mod.k_states)[None] * P0_diag[..., None, None], dims=("president", *P0_dims)
)
initial_trend = pm.Normal("initial_trend", dims=("president", *initial_trend_dims))
ar_params = pm.Beta("ar_params", alpha=3, beta=3, dims=("president", *ar_param_dims))
sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=50, dims=("president", *sigma_trend_dims))
sigma_ar = pm.Gamma("sigma_ar", alpha=2, beta=5, dims="president")
sigma_obs = pm.HalfNormal("sigma_obs", sigma=0.05, dims="president")
ss_mod.build_statespace_graph(ss_array)
idata = pm.sample() # nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"}) This (or the The Kernel crashed while executing code in the current cell or a previous cell.
Please review the code in the cell(s) to identify a possible cause of the failure. |
This PR fixes a bug found in the work to add batch dimensions to the Statespace module in pymc-extras. The gradient of a Blockwise'd scan with a specific broadcastable signature in the inner graph (i.e., shape=(1,)), was failing, because Blockwise was creating a dummy node with variables that didn't respect the core shapes (i.e., shape=(None,)).
The fix is a one liner (second commit).
The last commit refactors the L_op implementation since the helper function isn't used anywhere else, and some parts that were copied from Elemwise don't make sense (such as worrying that core.op L_op might return
None
).📚 Documentation preview 📚: https://pytensor--1482.org.readthedocs.build/en/1482/