@@ -138,14 +138,12 @@ def generic_solve_to_solve_triangular(fgraph, node):
138138 ]
139139
140140
141- @register_stabilize
142141@register_specialize
143142@node_rewriter ([Blockwise ])
144143def batched_vector_b_solve_to_matrix_b_solve (fgraph , node ):
145144 """Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T
146145
147146 `a` must have no batched dimensions, while `b` can have arbitrary batched dimensions.
148- Only the last two dimensions of `b` and the output are swapped.
149147 """
150148 core_op = node .op .core_op
151149
@@ -175,8 +173,17 @@ def batched_vector_b_solve_to_matrix_b_solve(fgraph, node):
175173 new_core_op = type (core_op )(** props )
176174 matrix_b_solve = Blockwise (new_core_op )
177175
176+ # Ravel any batched dims
177+ original_b_shape = tuple (b .shape )
178+ if len (original_b_shape ) > 2 :
179+ b = b .reshape ((- 1 , original_b_shape [- 1 ]))
180+
178181 # Apply the rewrite
179- new_solve = _T (matrix_b_solve (a , _T (b )))
182+ new_solve = matrix_b_solve (a , b .T ).T
183+
184+ # Unravel any batched dims
185+ if len (original_b_shape ) > 2 :
186+ new_solve = new_solve .reshape (original_b_shape )
180187
181188 old_solve = node .outputs [0 ]
182189 copy_stack_trace (old_solve , new_solve )
0 commit comments