@@ -344,84 +344,67 @@ def connection_pattern(self, node):
344
344
345
345
return [[True for _ in node .outputs ] for _ in node .inputs ]
346
346
347
- def _bgrad (self , inputs , outputs , ograds ):
348
- # Grad, with respect to broadcasted versions of inputs
349
-
350
- def as_core (t , core_t ):
351
- # Inputs could be NullType or DisconnectedType
352
- if isinstance (t .type , NullType | DisconnectedType ):
353
- return t
354
- return core_t .type ()
347
+ def L_op (self , inputs , outputs , output_gradients ):
348
+ batch_ndim = self .batch_ndim (outputs [0 ].owner )
355
349
350
+ # Obtain core_op gradients
356
351
with config .change_flags (compute_test_value = "off" ):
357
- safe_inputs = [
352
+ core_inputs = [
358
353
tensor (
359
354
dtype = inp .type .dtype ,
360
- shape = inp .type .shape [inp . type . ndim - len ( sig ) :],
355
+ shape = inp .type .shape [batch_ndim :],
361
356
)
362
- for inp , sig in zip (inputs , self .inputs_sig , strict = True )
363
- ]
364
- core_node = self ._create_dummy_core_node (safe_inputs )
365
-
366
- core_inputs = [
367
- as_core (inp , core_inp )
368
- for inp , core_inp in zip (inputs , core_node .inputs , strict = True )
369
- ]
370
- core_ograds = [
371
- as_core (ograd , core_ograd )
372
- for ograd , core_ograd in zip (ograds , core_node .outputs , strict = True )
357
+ for inp in inputs
373
358
]
374
- # FIXME: These core_outputs do not depend on core_inputs, not pretty
375
- # It's not neccessarily a problem because if they are referenced by the gradient,
376
- # they get replaced later in vectorize. But if the Op was to make any decision
377
- # by introspecting the dependencies of output on inputs it would fail badly!
359
+ core_node = self ._create_dummy_core_node (core_inputs )
378
360
core_outputs = core_node .outputs
379
361
380
- core_igrads = self .core_op .L_op (core_inputs , core_outputs , core_ograds )
381
-
382
- igrads = vectorize_graph (
383
- [core_igrad for core_igrad in core_igrads if core_igrad is not None ],
384
- replace = dict (
385
- zip (
386
- core_inputs + core_outputs + core_ograds ,
387
- inputs + outputs + ograds ,
388
- strict = True ,
362
+ # Define core output_gradients, but keep original disconnected/null output_gradients (if any)
363
+ core_output_gradients = [
364
+ output_grad
365
+ if isinstance (output_grad .type , NullType | DisconnectedType )
366
+ else core_output .type ()
367
+ for output_grad , core_output in zip (
368
+ output_gradients , core_outputs , strict = True
389
369
)
390
- ),
391
- )
392
-
393
- igrads_iter = iter (igrads )
394
- return [
395
- None if core_igrad is None else next (igrads_iter )
396
- for core_igrad in core_igrads
397
- ]
370
+ ]
398
371
399
- def L_op (self , inputs , outs , ograds ):
400
- from pytensor .tensor .math import sum as pt_sum
372
+ core_input_gradients = self .core_op .L_op (
373
+ core_inputs , core_outputs , core_output_gradients
374
+ )
401
375
402
- # Compute grad with respect to broadcasted input
403
- rval = self ._bgrad (inputs , outs , ograds )
376
+ # Vectorize gradients to batch inputs
377
+ input_gradients = list (
378
+ vectorize_graph (
379
+ core_input_gradients ,
380
+ replace = dict (
381
+ zip (
382
+ core_inputs + core_outputs + core_output_gradients ,
383
+ inputs + outputs + output_gradients ,
384
+ strict = True ,
385
+ )
386
+ ),
387
+ )
388
+ )
404
389
405
- # Sum out the broadcasted dimensions
406
- batch_ndims = self .batch_ndim (outs [0 ].owner )
407
- batch_shape = outs [0 ].type .shape [:batch_ndims ]
390
+ # Sum out the broadcasted batch dimensions
391
+ batch_shape = outputs [0 ].type .shape [:batch_ndim ]
408
392
for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig , strict = True )):
409
- if isinstance (rval [i ].type , NullType | DisconnectedType ):
393
+ if isinstance (input_gradients [i ].type , NullType | DisconnectedType ):
410
394
continue
411
395
412
- assert inp .type .ndim == batch_ndims + len (sig )
396
+ assert inp .type .ndim == batch_ndim + len (sig )
413
397
414
- to_sum = [
398
+ if to_sum : = [
415
399
j
416
400
for j , (inp_s , out_s ) in enumerate (
417
401
zip (inp .type .shape , batch_shape , strict = False )
418
402
)
419
403
if inp_s == 1 and out_s != 1
420
- ]
421
- if to_sum :
422
- rval [i ] = pt_sum (rval [i ], axis = to_sum , keepdims = True )
404
+ ]:
405
+ input_gradients [i ] = input_gradients [i ].sum (axis = to_sum , keepdims = True )
423
406
424
- return rval
407
+ return input_gradients
425
408
426
409
def _create_node_gufunc (self , node : Apply , impl ) -> Callable :
427
410
"""Define (or retrieve) the node gufunc used in `perform`.
0 commit comments