Skip to content

Fix bug in gradient of Blockwise'd Scan #1482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 16, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jun 16, 2025

This PR fixes a bug found in the work to add batch dimensions to the Statespace module in pymc-extras. The gradient of a Blockwise'd scan with a specific broadcastable signature in the inner graph (i.e., shape=(1,)), was failing, because Blockwise was creating a dummy node with variables that didn't respect the core shapes (i.e., shape=(None,)).

The fix is a one liner (second commit).

The last commit refactors the L_op implementation since the helper function isn't used anywhere else, and some parts that were copied from Elemwise don't make sense (such as worrying that core.op L_op might return None).


📚 Documentation preview 📚: https://pytensor--1482.org.readthedocs.build/en/1482/

with config.change_flags(compute_test_value="off"):
safe_inputs = [
tensor(dtype=inp.type.dtype, shape=(None,) * len(sig))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line was the problematic one: shape=(None,) * len(sig)

@ricardoV94 ricardoV94 requested review from jessegrabowski and removed request for jessegrabowski June 16, 2025 09:28
Comment on lines -371 to -374
# FIXME: These core_outputs do not depend on core_inputs, not pretty
# It's not neccessarily a problem because if they are referenced by the gradient,
# they get replaced later in vectorize. But if the Op was to make any decision
# by introspecting the dependencies of output on inputs it would fail badly!
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also fixed

@ricardoV94 ricardoV94 force-pushed the fix_blockwise_scan_bug branch from 7f6d58f to 61c0bf6 Compare June 16, 2025 09:31
@ricardoV94 ricardoV94 force-pushed the fix_blockwise_scan_bug branch from 61c0bf6 to 1401b84 Compare June 16, 2025 10:27
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a gradient bug in Blockwise when batching scans by refactoring the L_op implementation and adds targeted tests for core‐type gradients and scan gradients.

  • Refactored Blockwise.L_op to simplify and correct core gradient extraction and batching logic.
  • Renamed test_op to my_test_op in existing tests and added two new tests:
    • test_blockwise_grad_core_type
    • test_scan_gradient_core_type

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
tests/tensor/test_blockwise.py Renamed test_op, imported scan, and added two new gradient tests.
pytensor/tensor/blockwise.py Completely refactored the L_op method to remove the old helper and improve batching of core gradients.
Comments suppressed due to low confidence (1)

pytensor/tensor/blockwise.py:353

  • The new core_inputs comprehension no longer preserves NullType or DisconnectedType inputs as the old _bgrad helper did via as_core. If an input is a null/disconnected gradient, it should be passed through unchanged, otherwise downstream gradient logic may break.
core_inputs = [

Copy link

codecov bot commented Jun 16, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.00%. Comparing base (b218ffe) to head (1401b84).
Report is 1 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1482      +/-   ##
==========================================
- Coverage   82.01%   82.00%   -0.01%     
==========================================
  Files         214      214              
  Lines       50426    50414      -12     
  Branches     8903     8902       -1     
==========================================
- Hits        41355    41343      -12     
  Misses       6863     6863              
  Partials     2208     2208              
Files with missing lines Coverage Δ
pytensor/tensor/blockwise.py 89.27% <100.00%> (-0.53%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@AlexAndorra AlexAndorra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @ricardoV94 !! Definitely unblocks us over on pymc-extras 🤩
I'd say it was much more than a one-liner though 😉

@AlexAndorra AlexAndorra merged commit 9f80bdc into pymc-devs:main Jun 16, 2025
73 checks passed
@ricardoV94
Copy link
Member Author

Thanks a lot @ricardoV94 !! Definitely unblocks us over on pymc-extras 🤩 I'd say it was much more than a one-liner though 😉

The fix was a one liner, I just cleaned up stuff besides it. Check PRs commit by commit and you'll see it ;)

@AlexAndorra
Copy link

@ricardoV94 , using this over on statespace, and the error we had before did disappear 🍾
pt.grad(model_1.logp(), model_1.value_vars) is now running fine, giving back:

[(d__logp/dP0_diag_log__),
 (d__logp/dinitial_trend),
 (d__logp/dar_params_logodds__),
 (d__logp/dsigma_trend_log__),
 (d__logp/dsigma_ar_log__),
 (d__logp/dsigma_obs_log__)]

But now... pm.sample just hangs for a long time, before giving up and crashing the kernel, which is weird for a fairly simple model with a batch size of only 5.
The problem is already here when doing model_1.logp_dlogp_function(ravel_inputs=True).
Could that be related to scan too? Because I don't think we're doing anything crazy over on statespace -- at least to the best of my knowledge.

The following example will trigger the error, just using a small dataset of 15 data points and a batch size of 5 (one per president)(agg.csv). I'm running pytensor main (2.31.3+18.gd3bbc20aa) and PyMC 5.23.0 from conda locally.

import numpy as np
import pandas as pd
import pymc as pm
import pymc_extras.statespace as pmss
import pytensor
import pytensor.tensor as pt
import xarray as xr

presidents = agg.president.unique()

mod = pmss.structural.LevelTrendComponent(order=2, innovations_order=[0, 1])
mod += pmss.structural.AutoregressiveComponent(order=1)
mod += pmss.structural.MeasurementError(name="obs")

ss_mod = mod.build(
    name="president",
    batch_coords={"president": presidents},  # this is gonna be leftmost dimension
)

ss_array = (
    agg.set_index(["president", "month_id"])["approve_pr"].unstack("month_id").to_numpy()[..., None]
)  # dims=(president, timesteps, obs_dim)

initial_trend_dims, sigma_trend_dims, ar_param_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords

with pm.Model(coords=coords | ss_mod.batch_coords) as model_1:
    P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5, dims="president")
    P0 = pm.Deterministic(
        "P0", pt.eye(ss_mod.k_states)[None] * P0_diag[..., None, None], dims=("president", *P0_dims)
    )

    initial_trend = pm.Normal("initial_trend", dims=("president", *initial_trend_dims))
    ar_params = pm.Beta("ar_params", alpha=3, beta=3, dims=("president", *ar_param_dims))

    sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=50, dims=("president", *sigma_trend_dims))
    sigma_ar = pm.Gamma("sigma_ar", alpha=2, beta=5, dims="president")
    sigma_obs = pm.HalfNormal("sigma_obs", sigma=0.05, dims="president")

    ss_mod.build_statespace_graph(ss_array)
    idata = pm.sample()  # nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"})

This (or the logp_dlogp_function call above) will trigger a kernel crash:

The Kernel crashed while executing code in the current cell or a previous cell. 
Please review the code in the cell(s) to identify a possible cause of the failure. 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants