6
6
import scipy .linalg
7
7
8
8
import pytensor
9
- from pytensor import In , config , function
9
+ from pytensor import In , config , function , scan
10
10
from pytensor .compile import get_default_mode , get_mode
11
11
from pytensor .gradient import grad
12
12
from pytensor .graph import Apply , Op
13
- from pytensor .graph .replace import vectorize_node
13
+ from pytensor .graph .replace import vectorize_graph , vectorize_node
14
14
from pytensor .raise_op import assert_op
15
15
from pytensor .tensor import diagonal , dmatrix , log , ones_like , scalar , tensor , vector
16
16
from pytensor .tensor .blockwise import Blockwise , vectorize_node_fallback
@@ -162,13 +162,13 @@ def perform(self, *args, **kwargs):
162
162
raise NotImplementedError ("Test Op should not be present in final graph" )
163
163
164
164
165
- test_op = MyTestOp ()
165
+ my_test_op = MyTestOp ()
166
166
167
167
168
168
def test_vectorize_node_default_signature ():
169
169
vec = tensor (shape = (None ,))
170
170
mat = tensor (shape = (5 , None ))
171
- node = test_op .make_node (vec , mat )
171
+ node = my_test_op .make_node (vec , mat )
172
172
173
173
vect_node = vectorize_node (node , mat , mat )
174
174
assert isinstance (vect_node .op , Blockwise ) and isinstance (
@@ -179,9 +179,9 @@ def test_vectorize_node_default_signature():
179
179
with pytest .raises (
180
180
ValueError , match = "Signature not provided nor found in core_op MyTestOp"
181
181
):
182
- Blockwise (test_op )
182
+ Blockwise (my_test_op )
183
183
184
- vect_node = Blockwise (test_op , signature = "(m),(n)->(m),(n)" ).make_node (vec , mat )
184
+ vect_node = Blockwise (my_test_op , signature = "(m),(n)->(m),(n)" ).make_node (vec , mat )
185
185
assert vect_node .outputs [0 ].type .shape == (
186
186
5 ,
187
187
None ,
@@ -198,7 +198,7 @@ def test_blockwise_shape():
198
198
inp_test = np .zeros ((5 , 4 , 3 ), dtype = config .floatX )
199
199
200
200
# Shape can be inferred from inputs
201
- op = Blockwise (test_op , signature = "(m, n) -> (n, m)" )
201
+ op = Blockwise (my_test_op , signature = "(m, n) -> (n, m)" )
202
202
out = op (inp )
203
203
assert out .type .shape == (5 , None , None )
204
204
@@ -210,7 +210,7 @@ def test_blockwise_shape():
210
210
assert tuple (shape_fn (inp_test )) == (5 , 3 , 4 )
211
211
212
212
# Shape can only be partially inferred from inputs
213
- op = Blockwise (test_op , signature = "(m, n) -> (m, k)" )
213
+ op = Blockwise (my_test_op , signature = "(m, n) -> (m, k)" )
214
214
out = op (inp )
215
215
assert out .type .shape == (5 , None , None )
216
216
@@ -233,7 +233,7 @@ def test_blockwise_shape():
233
233
inp1_test = np .zeros ((7 , 1 , 4 , 3 ), dtype = config .floatX )
234
234
inp2_test = np .zeros ((1 , 5 , 4 , 3 ), dtype = config .floatX )
235
235
236
- op = Blockwise (test_op , signature = "(m, n), (m, n) -> (n, m), (m, k)" )
236
+ op = Blockwise (my_test_op , signature = "(m, n), (m, n) -> (n, m), (m, k)" )
237
237
outs = op (inp1 , inp2 )
238
238
assert outs [0 ].type .shape == (7 , 5 , None , None )
239
239
assert outs [1 ].type .shape == (7 , 5 , None , None )
@@ -650,3 +650,51 @@ def L_op(self, inputs, outputs, output_gradients):
650
650
np .ones (12 , dtype = config .floatX ),
651
651
strict = True ,
652
652
)
653
+
654
+
655
+ def test_blockwise_grad_core_type ():
656
+ class StrictCoreTypeOp (Op ):
657
+ def make_node (self , x ):
658
+ assert x .type .shape [- 1 ] == 2
659
+ return Apply (self , [x ], [x .type ()])
660
+
661
+ def perform (self , node , inputs , output_storage ):
662
+ output_storage [0 ][0 ] = inputs [0 ] + 1
663
+
664
+ def L_op (self , inputs , outputs , output_grads ):
665
+ [x ] = inputs
666
+ assert x .type .shape == (2 ,)
667
+ return [x .zeros_like ()]
668
+
669
+ strict_core_type_op = StrictCoreTypeOp ()
670
+ block_strict_core_type_op = Blockwise (strict_core_type_op , signature = "(a)->(a)" )
671
+
672
+ x = tensor ("x" , shape = (5 , 2 ), dtype = "float64" )
673
+ y = block_strict_core_type_op (x )
674
+ assert y .type .shape == (5 , 2 )
675
+
676
+ grad_y = grad (y .sum (), x )
677
+ assert grad_y .type .shape == (5 , 2 )
678
+ np .testing .assert_allclose (
679
+ grad_y .eval ({x : np .ones ((5 , 2 ))}),
680
+ np .zeros ((5 , 2 )),
681
+ )
682
+
683
+
684
+ def test_scan_gradient_core_type ():
685
+ n_steps = 3
686
+ seq = tensor ("seq" , shape = (n_steps , 1 ), dtype = "float64" )
687
+ out , _ = scan (
688
+ lambda s : s ,
689
+ sequences = [seq ],
690
+ n_steps = n_steps ,
691
+ )
692
+
693
+ vec_seq = tensor ("vec_seq" , shape = (None , n_steps , 1 ), dtype = "float64" )
694
+ vec_out = vectorize_graph (out , replace = {seq : vec_seq })
695
+ grad_sit_sot0 = grad (vec_out .sum (), vec_seq )
696
+
697
+ np .testing .assert_allclose (
698
+ grad_sit_sot0 .eval ({vec_seq : np .ones ((4 , n_steps , 1 ))}),
699
+ np .ones ((4 , n_steps , 1 )),
700
+ )
0 commit comments