Skip to content

Improve dot lift rewrites #1471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 21 additions & 32 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3025,6 +3025,11 @@ def make_node(self, *inputs):
)

sx, sy = (input.type.shape for input in inputs)
if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]:
raise ValueError(
f"Incompatible shared dimension for dot product: {sx}, {sy}"
)

if len(sy) == 2:
sz = sx[:-1] + sy[-1:]
elif len(sy) == 1:
Expand Down Expand Up @@ -3916,23 +3921,7 @@ def logsumexp(x, axis=None, keepdims=False):
return log(sum(exp(x), axis=axis, keepdims=keepdims))


# Predefine all batched variations of Dot
_inner_prod = Blockwise(
_dot,
signature="(n),(n)->()",
)

_matrix_vec_prod = Blockwise(
_dot,
signature="(m,k),(k)->(m)",
)

_vec_matrix_prod = Blockwise(
_dot,
signature="(k),(k,n)->(n)",
)

_matrix_matrix_matmul = Blockwise(
_matmul = Blockwise(
_dot,
signature="(m,k),(k,n)->(m,n)",
gufunc_spec=("numpy.matmul", 2, 1),
Expand Down Expand Up @@ -3988,11 +3977,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
if x1.type.ndim == 1 and x2.type.ndim == 1:
out = _dot(x1, x2)
elif x1.type.ndim == 1:
out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2)
out = vecmat(x1, x2)
elif x2.type.ndim == 1:
out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1)
out = matvec(x1, x2)
else:
out = _matrix_matrix_matmul(x1, x2)
out = _matmul(x1, x2)

if dtype is not None:
out = out.astype(dtype)
Expand Down Expand Up @@ -4042,7 +4031,7 @@ def vecdot(
>>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,)
>>> # Equivalent to numpy.vecdot(x_batch, y_batch)
"""
out = _inner_prod(x1, x2)
out = matmul(x1[..., None, :], x2[..., :, None]).squeeze((-2, -1))

if dtype is not None:
out = out.astype(dtype)
Expand Down Expand Up @@ -4091,7 +4080,7 @@ def matvec(
>>> result = pt.matvec(batched_A, batched_v) # shape (2, 3)
>>> # Equivalent to numpy.matvec(batched_A, batched_v)
"""
out = _matrix_vec_prod(x1, x2)
out = matmul(x1, x2[..., None]).squeeze(-1)

if dtype is not None:
out = out.astype(dtype)
Expand Down Expand Up @@ -4129,18 +4118,18 @@ def vecmat(
--------
>>> import pytensor.tensor as pt
>>> # Vector-matrix product
>>> v = pt.vector("v", shape=(3,)) # shape (3,)
>>> A = pt.matrix("A", shape=(3, 4)) # shape (3, 4)
>>> v = pt.vector("v", shape=(3,))
>>> A = pt.matrix("A", shape=(3, 4))
>>> result = pt.vecmat(v, A) # shape (4,)
>>> # Equivalent to numpy.vecmat(v, A)
>>>
>>> # Batched vector-matrix product
>>> batched_v = pt.matrix("v", shape=(2, 3)) # shape (2, 3)
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) # shape (2, 3, 4)
>>> batched_v = pt.matrix("v", shape=(2, 3))
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4))
>>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4)
>>> # Equivalent to numpy.vecmat(batched_v, batched_A)
"""
out = _vec_matrix_prod(x1, x2)
out = matmul(x2.mT, x1[..., None]).squeeze(-1)

if dtype is not None:
out = out.astype(dtype)
Expand All @@ -4155,18 +4144,18 @@ def vectorize_node_dot(op, node, batched_x, batched_y):
old_y_ndim = old_y.type.ndim
match (old_x_ndim, old_y_ndim):
case (1, 1):
batch_op = _inner_prod
batch_fn = vecdot
case (2, 1):
batch_op = _matrix_vec_prod
batch_fn = matvec
case (1, 2):
batch_op = _vec_matrix_prod
batch_fn = vecmat
case (2, 2):
batch_op = _matrix_matrix_matmul
batch_fn = matmul
case _:
raise ValueError(
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
)
return batch_op(batched_x, batched_y).owner
return batch_fn(batched_x, batched_y).owner


def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
Expand Down
13 changes: 9 additions & 4 deletions pytensor/tensor/rewriting/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import (
Dot,
_matrix_matrix_matmul,
_matmul,
add,
mul,
neg,
Expand Down Expand Up @@ -758,7 +758,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
ignore_newtrees=False,
),
"fast_run",
position=15,
position=11,
)


Expand Down Expand Up @@ -903,19 +903,23 @@ def local_dot22_to_dot22scalar(fgraph, node):
"local_dot22_to_dot22scalar",
in2out(local_dot22_to_dot22scalar),
"fast_run",
position=11,
position=12,
)


@register_specialize
@node_rewriter([_matrix_matrix_matmul])
@node_rewriter([_matmul])
def specialize_matmul_to_batched_dot(fgraph, node):
"""Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot.

TODO: Do the same for Blockwise BatchedDot
"""
x, y = node.inputs

if x.type.ndim < 3:
# This doesn't actually have a batch dimension
return None

# BatchedDot does not allow implicit broadcasting of the batch dimensions
# We do not want to explicitly broadcast as it may result in huge arrays
if x.type.broadcastable[:-2] != y.type.broadcastable[:-2]:
Expand All @@ -926,6 +930,7 @@ def specialize_matmul_to_batched_dot(fgraph, node):
if len(x_shape) > 3:
# If we have more than one batch dim, ravel it
x = x.reshape((-1, x_shape[-2], x_shape[-1]))
if len(y_shape) > 3:
y = y.reshape((-1, y_shape[-2], y_shape[-1]))

new_out = _batched_dot(x, y)
Expand Down
95 changes: 33 additions & 62 deletions pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,17 @@
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import (
MakeVector,
alloc,
cast,
constant,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import add, exp, mul
from pytensor.tensor.rewriting.basic import (
alloc_like,
broadcasted_by,
register_canonicalize,
register_specialize,
register_stabilize,
)
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.variable import TensorConstant, TensorVariable


Expand Down Expand Up @@ -346,6 +342,7 @@ def is_dimshuffle_useless(new_order, input):


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([DimShuffle])
def local_dimshuffle_lift(fgraph, node):
Expand Down Expand Up @@ -434,66 +431,40 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):

"""
if len(node.outputs) > 1:
return
try:
shape_i = fgraph.shape_feature.shape_i
except AttributeError:
shape_i = None
if isinstance(node.op, Elemwise):
scalar_op = node.op.scalar_op
# print "aa", scalar_op.output_types_preference
if getattr(scalar_op, "output_types_preference", None) in (
ps.upgrade_to_float,
ps.upcast_out,
):
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
output_dtype = node.outputs[0].type.dtype
new_inputs = []
for i in node.inputs:
if i.type.dtype == output_dtype:
new_inputs.append(i)
else:
try:
cval_i = get_underlying_scalar_constant_value(
i, only_process_constants=True
)
if all(i.broadcastable):
new_inputs.append(
shape_padleft(cast(cval_i, output_dtype), i.ndim)
)
else:
if shape_i is None:
return
new_inputs.append(
alloc(
cast(cval_i, output_dtype),
*[shape_i(d)(i) for d in range(i.ndim)],
)
)
# print >> sys.stderr, "AAA",
# *[Shape_i(d)(i) for d in range(i.ndim)]
except NotScalarConstantError:
# for the case of a non-scalar
if isinstance(i, TensorConstant):
new_inputs.append(cast(i, output_dtype))
else:
new_inputs.append(i)
return None

if getattr(node.op.scalar_op, "output_types_preference", None) not in (
ps.upgrade_to_float,
ps.upcast_out,
):
return None

if new_inputs != node.inputs:
rval = [node.op(*new_inputs)]
if not node.outputs[0].type.is_super(rval[0].type):
# This can happen for example when floatX=float32
# and we do the true division between and int64
# and a constant that will get typed as int8.
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
[old_out] = node.outputs
output_dtype = old_out.type.dtype
new_inputs = list(node.inputs)
changed = False
for i, inp in enumerate(node.inputs):
if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant):
new_inputs[i] = constant(inp.data.astype(output_dtype))
changed = True

if not changed:
return None

# As this is just to allow merging more case, if
# the upcast don't work, we can just skip it.
return
rval = node.op(*new_inputs)
if not old_out.type.is_super(rval.type):
# This can happen for example when floatX=float32
# and we do the true division between and int64
# and a constant that will get typed as int8.
# As this is just to allow merging more case, if
# the upcast don't work, we can just skip it.
return None

# Copy over output stacktrace from before upcasting
copy_stack_trace(node.outputs[0], rval)
return rval
# Copy over output stacktrace from before upcasting
copy_stack_trace(old_out, rval)
return [rval]


@node_rewriter([add, mul])
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod
from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod
from pytensor.tensor.nlinalg import (
SVD,
KroneckerProduct,
Expand Down Expand Up @@ -284,7 +284,7 @@ def cholesky_ldotlt(fgraph, node):
# This rewrite only applies to matrix Dot
and A.owner.inputs[0].type.ndim == 2
)
or (A.owner.op == _matrix_matrix_matmul)
or (A.owner.op == _matmul)
)
):
return
Expand Down
Loading