4040 get_normalized_batch_axes ,
4141 scalar_elemwise ,
4242)
43- from pytensor .tensor .shape import shape , specify_broadcastable
43+ from pytensor .tensor .shape import shape , specify_shape
4444from pytensor .tensor .type import (
4545 DenseTensorType ,
4646 complex_dtypes ,
4747 continuous_dtypes ,
4848 discrete_dtypes ,
49+ float_dtypes ,
4950 int_dtypes ,
5051 tensor ,
5152 uint_dtypes ,
@@ -2986,9 +2987,7 @@ def clip(x, min, max):
29862987
29872988class Dot (Op ):
29882989 """
2989- Computes the dot product of two variables. For two matrices, this is
2990- equivalent to matrix multiplication. For two vectors, this is the inner
2991- product.
2990+ Computes the dot product of two matrices variables
29922991
29932992 Notes
29942993 -----
@@ -3001,97 +3000,57 @@ class Dot(Op):
30013000
30023001 """
30033002
3003+ gufunc_signature = "(m,n),(n,p)->(m,p)"
3004+ gufunc_spec = ("matmul" , 2 , 1 )
30043005 __props__ = ()
30053006
3006- # the rationale for Dot22 is related to getting GEMM Ops into the
3007- # graph. See Dot22 in tensor.blas for details.
3008-
3009- def make_node (self , * inputs ):
3010- inputs = list (map (as_tensor_variable , inputs ))
3007+ def make_node (self , x , y ):
3008+ x = as_tensor_variable (x )
3009+ y = as_tensor_variable (y )
30113010
3012- if len (inputs ) != 2 :
3013- raise TypeError (f"Two arguments required, { len (inputs )} given " )
3014- if inputs [0 ].ndim not in (1 , 2 ):
3011+ if x .type .ndim != 2 :
30153012 raise TypeError (
3016- "Input 0 (0-indexed) must have ndim of "
3017- f"1 or 2, { int (inputs [0 ].ndim )} given. Consider calling "
3018- "pytensor.tensor.dot instead."
3013+ f"Dot Op expects a 2D tensor as input 0, got { x } with { x .type .ndim } dimensions"
30193014 )
3020- if inputs [ 1 ]. ndim not in ( 1 , 2 ) :
3015+ if y . type . ndim != 2 :
30213016 raise TypeError (
3022- "Input 1 (0-indexed) must have ndim of "
3023- f"1 or 2, { int (inputs [1 ].ndim )} given. Consider calling "
3024- "pytensor.tensor.dot instead."
3017+ f"Dot Op expects a 2D tensor as input 1, got { y } with { y .type .ndim } dimensions"
30253018 )
30263019
3027- sx , sy = ( input .type .shape for input in inputs )
3020+ sx , sy = x .type .shape , y . type . shape
30283021 if sx [- 1 ] is not None and sy [0 ] is not None and sx [- 1 ] != sy [0 ]:
30293022 raise ValueError (
30303023 f"Incompatible shared dimension for dot product: { sx } , { sy } "
30313024 )
3025+ sz = sx [:- 1 ] + sy [- 1 :]
3026+ outputs = [tensor (dtype = ps .upcast (x .type .dtype , y .type .dtype ), shape = sz )]
3027+ return Apply (self , [x , y ], outputs )
30323028
3033- if len (sy ) == 2 :
3034- sz = sx [:- 1 ] + sy [- 1 :]
3035- elif len (sy ) == 1 :
3036- sz = sx [:- 1 ]
3037-
3038- i_dtypes = [input .type .dtype for input in inputs ]
3039- outputs = [tensor (dtype = ps .upcast (* i_dtypes ), shape = sz )]
3040- return Apply (self , inputs , outputs )
3041-
3042- def perform (self , node , inp , out ):
3043- x , y = inp
3044- (z ,) = out
3045-
3046- # the asarray is here because dot between two vectors
3047- # gives a numpy float object but we need to return a 0d
3048- # ndarray
3049- z [0 ] = np .asarray (np .dot (x , y ))
3029+ def perform (self , node , inputs , output_storage ):
3030+ output_storage [0 ][0 ] = np .matmul (* inputs )
30503031
30513032 def grad (self , inp , grads ):
30523033 x , y = inp
30533034 (gz ,) = grads
3054- xdim , ydim , gdim = x .type .ndim , y .type .ndim , gz .type .ndim
3055-
3056- # grad is scalar, so x is vector and y is vector
3057- if gdim == 0 :
3058- xgrad = gz * y
3059- ygrad = gz * x
3060-
3061- # x is vector, y is matrix, grad is vector
3062- elif xdim == 1 and ydim == 2 :
3063- xgrad = dot (gz , y .T )
3064- ygrad = outer (x .T , gz )
30653035
3066- # x is matrix, y is vector, grad is vector
3067- elif xdim == 2 and ydim == 1 :
3068- xgrad = outer (gz , y .T )
3069- ygrad = dot (x .T , gz )
3070-
3071- # x is matrix, y is matrix, grad is matrix
3072- elif xdim == ydim == 2 :
3073- xgrad = dot (gz , y .T )
3074- ygrad = dot (x .T , gz )
3036+ xgrad = self (gz , y .T )
3037+ ygrad = self (x .T , gz )
30753038
30763039 # If x or y contain broadcastable dimensions but only one of
30773040 # them know that a matching dimensions is broadcastable, the
30783041 # above code don't always return the right broadcast pattern.
30793042 # This cause problem down the road. See gh-1461.
3080- if xgrad .broadcastable != x .broadcastable :
3081- xgrad = specify_broadcastable (
3082- xgrad , * (ax for (ax , b ) in enumerate (x .type .broadcastable ) if b )
3083- )
3084- if ygrad .broadcastable != y .broadcastable :
3085- ygrad = specify_broadcastable (
3086- ygrad , * (ax for (ax , b ) in enumerate (y .type .broadcastable ) if b )
3087- )
3043+ if xgrad .type .shape != x .type .shape :
3044+ xgrad = specify_shape (xgrad , x .type .shape )
3045+ if ygrad .type .shape != y .type .shape :
3046+ ygrad = specify_shape (ygrad , y .type .shape )
30883047
3089- rval = xgrad , ygrad
3048+ if xgrad .type .dtype not in float_dtypes :
3049+ raise TypeError ("Dot grad x output must be a float type" )
3050+ if ygrad .type .dtype not in float_dtypes :
3051+ raise TypeError ("Dot grad y output must be a float type" )
30903052
3091- for elem in rval :
3092- assert elem .dtype .find ("float" ) != - 1
3093-
3094- return rval
3053+ return xgrad , ygrad
30953054
30963055 def R_op (self , inputs , eval_points ):
30973056 # R_op for a \dot b evaluated at c for a and d for b is
@@ -3116,24 +3075,7 @@ def R_op(self, inputs, eval_points):
31163075
31173076 def infer_shape (self , fgraph , node , shapes ):
31183077 xshp , yshp = shapes
3119- x , y = node .inputs
3120-
3121- # vector / vector
3122- if x .ndim == 1 and y .ndim == 1 :
3123- return [()]
3124- # matrix / vector
3125- if x .ndim == 2 and y .ndim == 1 :
3126- return [xshp [:- 1 ]]
3127- # vector / matrix
3128- if x .ndim == 1 and y .ndim == 2 :
3129- return [yshp [- 1 :]]
3130- # matrix / matrix
3131- if x .ndim == 2 and y .ndim == 2 :
3132- return [xshp [:- 1 ] + yshp [- 1 :]]
3133- raise NotImplementedError ()
3134-
3135- def __str__ (self ):
3136- return "dot"
3078+ return [[xshp [0 ], yshp [1 ]]]
31373079
31383080
31393081_dot = Dot ()
@@ -3215,7 +3157,24 @@ def dense_dot(a, b):
32153157 elif a .ndim > 2 or b .ndim > 2 :
32163158 return tensordot (a , b , [[a .ndim - 1 ], [np .maximum (0 , b .ndim - 2 )]])
32173159 else :
3218- return _dot (a , b )
3160+ row_vector = a .ndim == 1
3161+ if row_vector :
3162+ # Promote to row matrix
3163+ a = a [None ]
3164+
3165+ col_vector = b .ndim == 1
3166+ if col_vector :
3167+ # Promote to column matrix
3168+ b = b [:, None ]
3169+
3170+ out = _dot (a , b )
3171+ if row_vector :
3172+ # If we promoted a to a row matrix, we need to squeeze the first dimension
3173+ out = out .squeeze (0 )
3174+ if col_vector :
3175+ # If we promoted b to a column matrix, we need to squeeze the last dimension
3176+ out = out .squeeze (- 1 )
3177+ return out
32193178
32203179
32213180def tensordot (
@@ -3921,11 +3880,7 @@ def logsumexp(x, axis=None, keepdims=False):
39213880 return log (sum (exp (x ), axis = axis , keepdims = keepdims ))
39223881
39233882
3924- _matmul = Blockwise (
3925- _dot ,
3926- signature = "(m,k),(k,n)->(m,n)" ,
3927- gufunc_spec = ("numpy.matmul" , 2 , 1 ),
3928- )
3883+ _matmul = Blockwise (_dot , name = "Matmul" )
39293884
39303885
39313886def matmul (x1 : "ArrayLike" , x2 : "ArrayLike" , dtype : Optional ["DTypeLike" ] = None ):
0 commit comments