Skip to content

Commit 9f80bdc

Browse files
authored
Fix bug in gradient of Blockwise'd Scan (#1482)
* Avoid pytest warning for variable name * Respect core type shape in gradient of Blockwise * Refactor Blockwise L_op
1 parent b218ffe commit 9f80bdc

File tree

2 files changed

+99
-66
lines changed

2 files changed

+99
-66
lines changed

pytensor/tensor/blockwise.py

Lines changed: 42 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -344,81 +344,66 @@ def connection_pattern(self, node):
344344

345345
return [[True for _ in node.outputs] for _ in node.inputs]
346346

347-
def _bgrad(self, inputs, outputs, ograds):
348-
# Grad, with respect to broadcasted versions of inputs
349-
350-
def as_core(t, core_t):
351-
# Inputs could be NullType or DisconnectedType
352-
if isinstance(t.type, NullType | DisconnectedType):
353-
return t
354-
return core_t.type()
347+
def L_op(self, inputs, outputs, output_gradients):
348+
batch_ndim = self.batch_ndim(outputs[0].owner)
355349

350+
# Obtain core_op gradients
356351
with config.change_flags(compute_test_value="off"):
357-
safe_inputs = [
358-
tensor(dtype=inp.type.dtype, shape=(None,) * len(sig))
359-
for inp, sig in zip(inputs, self.inputs_sig, strict=True)
360-
]
361-
core_node = self._create_dummy_core_node(safe_inputs)
362-
363352
core_inputs = [
364-
as_core(inp, core_inp)
365-
for inp, core_inp in zip(inputs, core_node.inputs, strict=True)
366-
]
367-
core_ograds = [
368-
as_core(ograd, core_ograd)
369-
for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True)
353+
tensor(
354+
dtype=inp.type.dtype,
355+
shape=inp.type.shape[batch_ndim:],
356+
)
357+
for inp in inputs
370358
]
371-
# FIXME: These core_outputs do not depend on core_inputs, not pretty
372-
# It's not neccessarily a problem because if they are referenced by the gradient,
373-
# they get replaced later in vectorize. But if the Op was to make any decision
374-
# by introspecting the dependencies of output on inputs it would fail badly!
375-
core_outputs = core_node.outputs
376-
377-
core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)
378-
379-
igrads = vectorize_graph(
380-
[core_igrad for core_igrad in core_igrads if core_igrad is not None],
381-
replace=dict(
382-
zip(
383-
core_inputs + core_outputs + core_ograds,
384-
inputs + outputs + ograds,
385-
strict=True,
359+
core_outputs = self._create_dummy_core_node(core_inputs).outputs
360+
361+
# Define core output_gradients, but keep original disconnected/null output_gradients (if any)
362+
core_output_gradients = [
363+
output_grad
364+
if isinstance(output_grad.type, NullType | DisconnectedType)
365+
else core_output.type()
366+
for output_grad, core_output in zip(
367+
output_gradients, core_outputs, strict=True
386368
)
387-
),
388-
)
389-
390-
igrads_iter = iter(igrads)
391-
return [
392-
None if core_igrad is None else next(igrads_iter)
393-
for core_igrad in core_igrads
394-
]
369+
]
395370

396-
def L_op(self, inputs, outs, ograds):
397-
from pytensor.tensor.math import sum as pt_sum
371+
core_input_gradients = self.core_op.L_op(
372+
core_inputs, core_outputs, core_output_gradients
373+
)
398374

399-
# Compute grad with respect to broadcasted input
400-
rval = self._bgrad(inputs, outs, ograds)
375+
# Vectorize core gradients to original inputs
376+
input_gradients = list(
377+
vectorize_graph(
378+
core_input_gradients,
379+
replace=dict(
380+
zip(
381+
core_inputs + core_outputs + core_output_gradients,
382+
inputs + outputs + output_gradients,
383+
strict=True,
384+
)
385+
),
386+
)
387+
)
401388

402-
# Sum out the broadcasted dimensions
403-
batch_ndims = self.batch_ndim(outs[0].owner)
404-
batch_shape = outs[0].type.shape[:batch_ndims]
389+
# Sum out the broadcasted batch dimensions
390+
batch_shape = outputs[0].type.shape[:batch_ndim]
405391
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)):
406-
if isinstance(rval[i].type, NullType | DisconnectedType):
392+
if isinstance(input_gradients[i].type, NullType | DisconnectedType):
407393
continue
408394

409-
assert inp.type.ndim == batch_ndims + len(sig)
395+
assert inp.type.ndim == batch_ndim + len(sig)
410396

411-
to_sum = [
397+
if to_sum := [
412398
j
413399
for j, (inp_s, out_s) in enumerate(
414400
zip(inp.type.shape, batch_shape, strict=False)
415401
)
416402
if inp_s == 1 and out_s != 1
417-
]
418-
if to_sum:
419-
rval[i] = pt_sum(rval[i], axis=to_sum, keepdims=True)
403+
]:
404+
input_gradients[i] = input_gradients[i].sum(axis=to_sum, keepdims=True)
420405

421-
return rval
406+
return input_gradients
422407

423408
def _create_node_gufunc(self, node: Apply, impl) -> Callable:
424409
"""Define (or retrieve) the node gufunc used in `perform`.

tests/tensor/test_blockwise.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
import scipy.linalg
77

88
import pytensor
9-
from pytensor import In, config, function
9+
from pytensor import In, config, function, scan
1010
from pytensor.compile import get_default_mode, get_mode
1111
from pytensor.gradient import grad
1212
from pytensor.graph import Apply, Op
13-
from pytensor.graph.replace import vectorize_node
13+
from pytensor.graph.replace import vectorize_graph, vectorize_node
1414
from pytensor.raise_op import assert_op
1515
from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector
1616
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
@@ -162,13 +162,13 @@ def perform(self, *args, **kwargs):
162162
raise NotImplementedError("Test Op should not be present in final graph")
163163

164164

165-
test_op = MyTestOp()
165+
my_test_op = MyTestOp()
166166

167167

168168
def test_vectorize_node_default_signature():
169169
vec = tensor(shape=(None,))
170170
mat = tensor(shape=(5, None))
171-
node = test_op.make_node(vec, mat)
171+
node = my_test_op.make_node(vec, mat)
172172

173173
vect_node = vectorize_node(node, mat, mat)
174174
assert isinstance(vect_node.op, Blockwise) and isinstance(
@@ -179,9 +179,9 @@ def test_vectorize_node_default_signature():
179179
with pytest.raises(
180180
ValueError, match="Signature not provided nor found in core_op MyTestOp"
181181
):
182-
Blockwise(test_op)
182+
Blockwise(my_test_op)
183183

184-
vect_node = Blockwise(test_op, signature="(m),(n)->(m),(n)").make_node(vec, mat)
184+
vect_node = Blockwise(my_test_op, signature="(m),(n)->(m),(n)").make_node(vec, mat)
185185
assert vect_node.outputs[0].type.shape == (
186186
5,
187187
None,
@@ -198,7 +198,7 @@ def test_blockwise_shape():
198198
inp_test = np.zeros((5, 4, 3), dtype=config.floatX)
199199

200200
# Shape can be inferred from inputs
201-
op = Blockwise(test_op, signature="(m, n) -> (n, m)")
201+
op = Blockwise(my_test_op, signature="(m, n) -> (n, m)")
202202
out = op(inp)
203203
assert out.type.shape == (5, None, None)
204204

@@ -210,7 +210,7 @@ def test_blockwise_shape():
210210
assert tuple(shape_fn(inp_test)) == (5, 3, 4)
211211

212212
# Shape can only be partially inferred from inputs
213-
op = Blockwise(test_op, signature="(m, n) -> (m, k)")
213+
op = Blockwise(my_test_op, signature="(m, n) -> (m, k)")
214214
out = op(inp)
215215
assert out.type.shape == (5, None, None)
216216

@@ -233,7 +233,7 @@ def test_blockwise_shape():
233233
inp1_test = np.zeros((7, 1, 4, 3), dtype=config.floatX)
234234
inp2_test = np.zeros((1, 5, 4, 3), dtype=config.floatX)
235235

236-
op = Blockwise(test_op, signature="(m, n), (m, n) -> (n, m), (m, k)")
236+
op = Blockwise(my_test_op, signature="(m, n), (m, n) -> (n, m), (m, k)")
237237
outs = op(inp1, inp2)
238238
assert outs[0].type.shape == (7, 5, None, None)
239239
assert outs[1].type.shape == (7, 5, None, None)
@@ -650,3 +650,51 @@ def L_op(self, inputs, outputs, output_gradients):
650650
np.ones(12, dtype=config.floatX),
651651
strict=True,
652652
)
653+
654+
655+
def test_blockwise_grad_core_type():
656+
class StrictCoreTypeOp(Op):
657+
def make_node(self, x):
658+
assert x.type.shape[-1] == 2
659+
return Apply(self, [x], [x.type()])
660+
661+
def perform(self, node, inputs, output_storage):
662+
output_storage[0][0] = inputs[0] + 1
663+
664+
def L_op(self, inputs, outputs, output_grads):
665+
[x] = inputs
666+
assert x.type.shape == (2,)
667+
return [x.zeros_like()]
668+
669+
strict_core_type_op = StrictCoreTypeOp()
670+
block_strict_core_type_op = Blockwise(strict_core_type_op, signature="(a)->(a)")
671+
672+
x = tensor("x", shape=(5, 2), dtype="float64")
673+
y = block_strict_core_type_op(x)
674+
assert y.type.shape == (5, 2)
675+
676+
grad_y = grad(y.sum(), x)
677+
assert grad_y.type.shape == (5, 2)
678+
np.testing.assert_allclose(
679+
grad_y.eval({x: np.ones((5, 2))}),
680+
np.zeros((5, 2)),
681+
)
682+
683+
684+
def test_scan_gradient_core_type():
685+
n_steps = 3
686+
seq = tensor("seq", shape=(n_steps, 1), dtype="float64")
687+
out, _ = scan(
688+
lambda s: s,
689+
sequences=[seq],
690+
n_steps=n_steps,
691+
)
692+
693+
vec_seq = tensor("vec_seq", shape=(None, n_steps, 1), dtype="float64")
694+
vec_out = vectorize_graph(out, replace={seq: vec_seq})
695+
grad_sit_sot0 = grad(vec_out.sum(), vec_seq)
696+
697+
np.testing.assert_allclose(
698+
grad_sit_sot0.eval({vec_seq: np.ones((4, n_steps, 1))}),
699+
np.ones((4, n_steps, 1)),
700+
)

0 commit comments

Comments
 (0)