55
66from pytensor import Variable
77from pytensor .compile import optdb
8- from pytensor .graph import Constant , FunctionGraph , node_rewriter
8+ from pytensor .graph import Constant , FunctionGraph , node_rewriter , vectorize_graph
99from pytensor .graph .rewriting .basic import NodeRewriter , copy_stack_trace
1010from pytensor .npy_2_compat import normalize_axis_index , normalize_axis_tuple
1111from pytensor .scalar import basic as ps
@@ -119,21 +119,43 @@ def local_subtensor_of_dot(fgraph, node):
119119 the remaining entries of ``idxs`` (if any), modified to skip the
120120 second-to-last dimension of ``B`` (because dot sums over this dimension).
121121 """
122- if not isinstance (node .op , Subtensor ):
123- return
124- if not (node .inputs [0 ].owner and isinstance (node .inputs [0 ].owner .op , Dot )):
122+ x , * idx_vars = node .inputs
123+ if not (
124+ x .owner is not None
125+ and (
126+ isinstance (x .owner .op , Dot )
127+ or (
128+ isinstance (x .owner .op , Blockwise )
129+ and isinstance (x .owner .op .core_op , Dot )
130+ )
131+ )
132+ ):
125133 return
126134 # If there is other node that use the outputs of the dot
127135 # We don't want to compute twice the sub part.
128- if len (fgraph .clients [node . inputs [ 0 ] ]) > 1 :
136+ if len (fgraph .clients [x ]) > 1 :
129137 return
130138
131- a = node .inputs [0 ].owner .inputs [0 ]
132- b = node .inputs [0 ].owner .inputs [1 ]
139+ a = x .owner .inputs [0 ]
140+ b = x .owner .inputs [1 ]
141+ idx_list = indices_from_subtensor (idx_vars , node .op .idx_list )
133142
134- idx_list = get_idx_list (node .inputs , node .op .idx_list )
143+ batch_ndim = (
144+ x .owner .op .batch_ndim (x .owner ) if isinstance (x .owner .op , Blockwise ) else 0
145+ )
146+
147+ if batch_ndim :
148+ batch_idx_list , idx_list = idx_list [:batch_ndim ], idx_list [batch_ndim :]
149+ if not idx_list :
150+ # Indexing only over batch dimensions of Blockwise, that can be handled by another rewrite
151+ return None
152+ # We perform the rest of the rewrite on dummy a, b that correspond to the core case
153+ a = a .type .clone (shape = a .type .shape [batch_ndim :])()
154+ b = b .type .clone (shape = b .type .shape [batch_ndim :])()
135155
136- num_a_indices = min (a .ndim - 1 , len (idx_list ))
156+ a_ndim = a .ndim
157+ b_ndim = b .ndim
158+ num_a_indices = min (a_ndim - 1 , len (idx_list ))
137159 a_indices = idx_list [:num_a_indices ]
138160 b_indices = idx_list [num_a_indices :]
139161
@@ -142,26 +164,22 @@ def local_subtensor_of_dot(fgraph, node):
142164 # This wasn't necessary for a, because we just omitted the last index.
143165 # We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
144166 # (dot also handles b.ndim < 2 as a special case)
145- if b . ndim > 1 and len (b_indices ) >= b . ndim - 1 :
167+ if b_ndim > 1 and len (b_indices ) >= b_ndim - 1 :
146168 b_indices = (
147- b_indices [: b . ndim - 2 ]
169+ b_indices [: b_ndim - 2 ]
148170 + (slice (None , None , None ),)
149- + b_indices [b . ndim - 2 :]
171+ + b_indices [b_ndim - 2 :]
150172 )
151173
152- a_sub = a .__getitem__ (tuple (a_indices ))
153- b_sub = b .__getitem__ (tuple (b_indices )) if b_indices else b
174+ a_sub = a [tuple (a_indices )]
175+ b_sub = b [tuple (b_indices )] if b_indices else b
176+ r = dot (a_sub , b_sub )
154177
155- # Copy over previous output stacktrace to a_sub and b_sub,
156- # because an error in the subtensor operation (e.g. an index error)
157- # on either a or b must correspond to an error in the
158- # subtensor operation on their dot product.
159- copy_stack_trace (node .outputs [0 ], [a_sub , b_sub ])
178+ if batch_ndim :
179+ # Replace dummy inputs by the original batch ones
180+ r = vectorize_graph (r , replace = {a : x .owner .inputs [0 ], b : x .owner .inputs [1 ]})
181+ r = r [tuple (batch_idx_list )]
160182
161- # Copy over previous output stacktrace and previous dot product stacktrace,
162- # because an error here may correspond to an either in either the original
163- # dot product, or in the dot product after the subtensor operation.
164- r = dot (a_sub , b_sub )
165183 copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], r )
166184
167185 return [r ]
0 commit comments