@@ -344,84 +344,66 @@ def connection_pattern(self, node):
344
344
345
345
return [[True for _ in node .outputs ] for _ in node .inputs ]
346
346
347
- def _bgrad (self , inputs , outputs , ograds ):
348
- # Grad, with respect to broadcasted versions of inputs
349
-
350
- def as_core (t , core_t ):
351
- # Inputs could be NullType or DisconnectedType
352
- if isinstance (t .type , NullType | DisconnectedType ):
353
- return t
354
- return core_t .type ()
347
+ def L_op (self , inputs , outputs , output_gradients ):
348
+ batch_ndim = self .batch_ndim (outputs [0 ].owner )
355
349
350
+ # Obtain core_op gradients
356
351
with config .change_flags (compute_test_value = "off" ):
357
- safe_inputs = [
352
+ core_inputs = [
358
353
tensor (
359
354
dtype = inp .type .dtype ,
360
- shape = inp .type .shape [inp . type . ndim - len ( sig ) :],
355
+ shape = inp .type .shape [batch_ndim :],
361
356
)
362
- for inp , sig in zip (inputs , self .inputs_sig , strict = True )
363
- ]
364
- core_node = self ._create_dummy_core_node (safe_inputs )
365
-
366
- core_inputs = [
367
- as_core (inp , core_inp )
368
- for inp , core_inp in zip (inputs , core_node .inputs , strict = True )
369
- ]
370
- core_ograds = [
371
- as_core (ograd , core_ograd )
372
- for ograd , core_ograd in zip (ograds , core_node .outputs , strict = True )
357
+ for inp in inputs
373
358
]
374
- # FIXME: These core_outputs do not depend on core_inputs, not pretty
375
- # It's not neccessarily a problem because if they are referenced by the gradient,
376
- # they get replaced later in vectorize. But if the Op was to make any decision
377
- # by introspecting the dependencies of output on inputs it would fail badly!
378
- core_outputs = core_node .outputs
379
-
380
- core_igrads = self .core_op .L_op (core_inputs , core_outputs , core_ograds )
381
-
382
- igrads = vectorize_graph (
383
- [core_igrad for core_igrad in core_igrads if core_igrad is not None ],
384
- replace = dict (
385
- zip (
386
- core_inputs + core_outputs + core_ograds ,
387
- inputs + outputs + ograds ,
388
- strict = True ,
359
+ core_outputs = self ._create_dummy_core_node (core_inputs ).outputs
360
+
361
+ # Define core output_gradients, but keep original disconnected/null output_gradients (if any)
362
+ core_output_gradients = [
363
+ output_grad
364
+ if isinstance (output_grad .type , NullType | DisconnectedType )
365
+ else core_output .type ()
366
+ for output_grad , core_output in zip (
367
+ output_gradients , core_outputs , strict = True
389
368
)
390
- ),
391
- )
392
-
393
- igrads_iter = iter (igrads )
394
- return [
395
- None if core_igrad is None else next (igrads_iter )
396
- for core_igrad in core_igrads
397
- ]
369
+ ]
398
370
399
- def L_op (self , inputs , outs , ograds ):
400
- from pytensor .tensor .math import sum as pt_sum
371
+ core_input_gradients = self .core_op .L_op (
372
+ core_inputs , core_outputs , core_output_gradients
373
+ )
401
374
402
- # Compute grad with respect to broadcasted input
403
- rval = self ._bgrad (inputs , outs , ograds )
375
+ # Vectorize core gradients to original inputs
376
+ input_gradients = list (
377
+ vectorize_graph (
378
+ core_input_gradients ,
379
+ replace = dict (
380
+ zip (
381
+ core_inputs + core_outputs + core_output_gradients ,
382
+ inputs + outputs + output_gradients ,
383
+ strict = True ,
384
+ )
385
+ ),
386
+ )
387
+ )
404
388
405
- # Sum out the broadcasted dimensions
406
- batch_ndims = self .batch_ndim (outs [0 ].owner )
407
- batch_shape = outs [0 ].type .shape [:batch_ndims ]
389
+ # Sum out the broadcasted batch dimensions
390
+ batch_shape = outputs [0 ].type .shape [:batch_ndim ]
408
391
for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig , strict = True )):
409
- if isinstance (rval [i ].type , NullType | DisconnectedType ):
392
+ if isinstance (input_gradients [i ].type , NullType | DisconnectedType ):
410
393
continue
411
394
412
- assert inp .type .ndim == batch_ndims + len (sig )
395
+ assert inp .type .ndim == batch_ndim + len (sig )
413
396
414
- to_sum = [
397
+ if to_sum : = [
415
398
j
416
399
for j , (inp_s , out_s ) in enumerate (
417
400
zip (inp .type .shape , batch_shape , strict = False )
418
401
)
419
402
if inp_s == 1 and out_s != 1
420
- ]
421
- if to_sum :
422
- rval [i ] = pt_sum (rval [i ], axis = to_sum , keepdims = True )
403
+ ]:
404
+ input_gradients [i ] = input_gradients [i ].sum (axis = to_sum , keepdims = True )
423
405
424
- return rval
406
+ return input_gradients
425
407
426
408
def _create_node_gufunc (self , node : Apply , impl ) -> Callable :
427
409
"""Define (or retrieve) the node gufunc used in `perform`.
0 commit comments