Skip to content

Commit 1401b84

Browse files
committed
Refactor Blockwise L_op
1 parent 61b2475 commit 1401b84

File tree

2 files changed

+42
-60
lines changed

2 files changed

+42
-60
lines changed

pytensor/tensor/blockwise.py

Lines changed: 40 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -344,84 +344,66 @@ def connection_pattern(self, node):
344344

345345
return [[True for _ in node.outputs] for _ in node.inputs]
346346

347-
def _bgrad(self, inputs, outputs, ograds):
348-
# Grad, with respect to broadcasted versions of inputs
349-
350-
def as_core(t, core_t):
351-
# Inputs could be NullType or DisconnectedType
352-
if isinstance(t.type, NullType | DisconnectedType):
353-
return t
354-
return core_t.type()
347+
def L_op(self, inputs, outputs, output_gradients):
348+
batch_ndim = self.batch_ndim(outputs[0].owner)
355349

350+
# Obtain core_op gradients
356351
with config.change_flags(compute_test_value="off"):
357-
safe_inputs = [
352+
core_inputs = [
358353
tensor(
359354
dtype=inp.type.dtype,
360-
shape=inp.type.shape[inp.type.ndim - len(sig) :],
355+
shape=inp.type.shape[batch_ndim:],
361356
)
362-
for inp, sig in zip(inputs, self.inputs_sig, strict=True)
363-
]
364-
core_node = self._create_dummy_core_node(safe_inputs)
365-
366-
core_inputs = [
367-
as_core(inp, core_inp)
368-
for inp, core_inp in zip(inputs, core_node.inputs, strict=True)
369-
]
370-
core_ograds = [
371-
as_core(ograd, core_ograd)
372-
for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True)
357+
for inp in inputs
373358
]
374-
# FIXME: These core_outputs do not depend on core_inputs, not pretty
375-
# It's not neccessarily a problem because if they are referenced by the gradient,
376-
# they get replaced later in vectorize. But if the Op was to make any decision
377-
# by introspecting the dependencies of output on inputs it would fail badly!
378-
core_outputs = core_node.outputs
379-
380-
core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)
381-
382-
igrads = vectorize_graph(
383-
[core_igrad for core_igrad in core_igrads if core_igrad is not None],
384-
replace=dict(
385-
zip(
386-
core_inputs + core_outputs + core_ograds,
387-
inputs + outputs + ograds,
388-
strict=True,
359+
core_outputs = self._create_dummy_core_node(core_inputs).outputs
360+
361+
# Define core output_gradients, but keep original disconnected/null output_gradients (if any)
362+
core_output_gradients = [
363+
output_grad
364+
if isinstance(output_grad.type, NullType | DisconnectedType)
365+
else core_output.type()
366+
for output_grad, core_output in zip(
367+
output_gradients, core_outputs, strict=True
389368
)
390-
),
391-
)
392-
393-
igrads_iter = iter(igrads)
394-
return [
395-
None if core_igrad is None else next(igrads_iter)
396-
for core_igrad in core_igrads
397-
]
369+
]
398370

399-
def L_op(self, inputs, outs, ograds):
400-
from pytensor.tensor.math import sum as pt_sum
371+
core_input_gradients = self.core_op.L_op(
372+
core_inputs, core_outputs, core_output_gradients
373+
)
401374

402-
# Compute grad with respect to broadcasted input
403-
rval = self._bgrad(inputs, outs, ograds)
375+
# Vectorize core gradients to original inputs
376+
input_gradients = list(
377+
vectorize_graph(
378+
core_input_gradients,
379+
replace=dict(
380+
zip(
381+
core_inputs + core_outputs + core_output_gradients,
382+
inputs + outputs + output_gradients,
383+
strict=True,
384+
)
385+
),
386+
)
387+
)
404388

405-
# Sum out the broadcasted dimensions
406-
batch_ndims = self.batch_ndim(outs[0].owner)
407-
batch_shape = outs[0].type.shape[:batch_ndims]
389+
# Sum out the broadcasted batch dimensions
390+
batch_shape = outputs[0].type.shape[:batch_ndim]
408391
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)):
409-
if isinstance(rval[i].type, NullType | DisconnectedType):
392+
if isinstance(input_gradients[i].type, NullType | DisconnectedType):
410393
continue
411394

412-
assert inp.type.ndim == batch_ndims + len(sig)
395+
assert inp.type.ndim == batch_ndim + len(sig)
413396

414-
to_sum = [
397+
if to_sum := [
415398
j
416399
for j, (inp_s, out_s) in enumerate(
417400
zip(inp.type.shape, batch_shape, strict=False)
418401
)
419402
if inp_s == 1 and out_s != 1
420-
]
421-
if to_sum:
422-
rval[i] = pt_sum(rval[i], axis=to_sum, keepdims=True)
403+
]:
404+
input_gradients[i] = input_gradients[i].sum(axis=to_sum, keepdims=True)
423405

424-
return rval
406+
return input_gradients
425407

426408
def _create_node_gufunc(self, node: Apply, impl) -> Callable:
427409
"""Define (or retrieve) the node gufunc used in `perform`.

tests/tensor/test_blockwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,14 +683,14 @@ def L_op(self, inputs, outputs, output_grads):
683683

684684
def test_scan_gradient_core_type():
685685
n_steps = 3
686-
seq = tensor("seq", shape=(n_steps, 1))
686+
seq = tensor("seq", shape=(n_steps, 1), dtype="float64")
687687
out, _ = scan(
688688
lambda s: s,
689689
sequences=[seq],
690690
n_steps=n_steps,
691691
)
692692

693-
vec_seq = tensor("vec_seq", shape=(None, n_steps, 1))
693+
vec_seq = tensor("vec_seq", shape=(None, n_steps, 1), dtype="float64")
694694
vec_out = vectorize_graph(out, replace={seq: vec_seq})
695695
grad_sit_sot0 = grad(vec_out.sum(), vec_seq)
696696

0 commit comments

Comments
 (0)