9898from pytensor .printing import FunctionPrinter , pprint
9999from pytensor .scalar import bool as bool_t
100100from pytensor .tensor import basic as at
101+ from pytensor .tensor .basic import expand_dims
101102from pytensor .tensor .blas_headers import blas_header_text , blas_header_version
102103from pytensor .tensor .elemwise import DimShuffle
103104from pytensor .tensor .math import add , mul , neg , sub
104- from pytensor .tensor .shape import specify_broadcastable
105+ from pytensor .tensor .shape import shape_padright , specify_broadcastable
105106from pytensor .tensor .type import DenseTensorType , TensorType , integer_dtypes , tensor
106107from pytensor .utils import memoize
107108
@@ -1637,48 +1638,53 @@ def c_code_cache_version(self):
16371638
16381639class BatchedDot (COp ):
16391640 """
1640- Computes the batched dot product of two variables:
1641+ Computes a batch matrix-matrix dot with tensor3 variables
16411642
16421643 batched_dot(a, b)[i] = dot(a[i], b[i])
16431644 """
16441645
16451646 __props__ = ()
1647+ gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)"
16461648
1647- def make_node (self , * inputs ):
1648- inputs = list (map (at .as_tensor_variable , inputs ))
1649+ def make_node (self , x , y ):
1650+ x = at .as_tensor_variable (x )
1651+ y = at .as_tensor_variable (y )
16491652
1650- if any (not isinstance (i .type , DenseTensorType ) for i in inputs ):
1653+ if not (
1654+ isinstance (x .type , DenseTensorType ) and isinstance (y .type , DenseTensorType )
1655+ ):
16511656 raise NotImplementedError ("Only dense tensor types are supported" )
16521657
1653- if len (inputs ) != 2 :
1654- raise TypeError (f"Two arguments required, but { len (inputs )} given." )
1655- if inputs [0 ].ndim not in (2 , 3 ):
1658+ if not (x .type .ndim == 3 and y .type .ndim == 3 ):
16561659 raise TypeError (
1657- "Input 0 (0-indexed)"
1658- f" must have ndim of 2 or 3, { int (inputs [0 ].ndim )} given. Consider"
1659- " calling batched_dot instead."
1660- )
1661- if inputs [1 ].ndim not in (2 , 3 ):
1662- raise TypeError (
1663- "Input 1 (0-indexed)"
1664- f" must have ndim of 2 or 3, { int (inputs [1 ].ndim )} given. Consider"
1665- " calling batched_dot instead."
1660+ f"Inputs must have 3 ndim, but got { x .type .ndim } and { y .type .ndim } . "
1661+ "Consider calling batched_dot instead."
16661662 )
16671663
1668- dtype = pytensor .scalar .upcast (* [input .type .dtype for input in inputs ])
1669- # upcast inputs to common dtype if needed
1670- upcasted_inputs = [at .cast (input , dtype ) for input in inputs ]
1671- out_shape = (
1672- (
1673- 1
1674- if inputs [0 ].type .shape [0 ] == 1 or inputs [1 ].type .shape [0 ] == 1
1675- else None ,
1676- )
1677- + inputs [0 ].type .shape [1 :- 1 ]
1678- + inputs [1 ].type .shape [2 :]
1679- )
1680- out_shape = tuple (1 if s == 1 else None for s in out_shape )
1681- return Apply (self , upcasted_inputs , [tensor (dtype = dtype , shape = out_shape )])
1664+ def extract_static_dim (dim_x , dim_y ):
1665+ dims = {dim_x , dim_y } - {None }
1666+ if len (dims ) > 1 :
1667+ # BatchedDot doesn't allow broadcasting
1668+ raise ValueError (
1669+ f"Static dimensions of BatchedDot don't match, got { x .type .shape } and { y .type .shape } "
1670+ )
1671+ elif not dims :
1672+ return None
1673+ else :
1674+ return dims .pop ()
1675+
1676+ x_batch_dim , x_row_dim , x_sum_dim = x .type .shape
1677+ y_batch_dim , y_sum_dim , y_col_dim = y .type .shape
1678+ batch_dim = extract_static_dim (x_batch_dim , y_batch_dim )
1679+ # Raise if static sum dimensions do not match
1680+ _ = extract_static_dim (x_sum_dim , y_sum_dim )
1681+ out_shape = (batch_dim , x_row_dim , y_col_dim )
1682+
1683+ # Change dtype if needed
1684+ dtype = pytensor .scalar .upcast (x .type .dtype , y .type .dtype )
1685+ x , y = at .cast (x , dtype ), at .cast (y , dtype )
1686+ out = tensor (dtype = dtype , shape = out_shape )
1687+ return Apply (self , [x , y ], [out ])
16821688
16831689 def perform (self , node , inp , out ):
16841690 x , y = inp
@@ -1690,11 +1696,7 @@ def perform(self, node, inp, out):
16901696 f" same size in axis 0, but have sizes [{ ', ' .join ([str (i .shape [0 ]) for i in inp ])} ]."
16911697 )
16921698
1693- shape = self .infer_shape (None , node , [i .shape for i in inp ])[0 ]
1694- dtype = node .outputs [0 ].dtype
1695- z0 = z [0 ] = np .empty (shape , dtype = dtype )
1696- for i in range (z0 .shape [0 ]):
1697- z0 [i ] = np .dot (x [i ], y [i ])
1699+ z [0 ] = np .matmul (x , y )
16981700
16991701 def c_support_code (self , ** kwargs ):
17001702 batch_gemm_defn = """
@@ -1792,14 +1794,6 @@ def c_lib_dirs(self, **kwargs):
17921794 def c_header_dirs (self , ** kwargs ):
17931795 return ldflags (libs = False , include_dir = True )
17941796
1795- def c_code_cleanup (self , node , name , inputs , outputs , sub ):
1796- return """
1797- // clean up views
1798- Py_XDECREF(xs); xs = 0;
1799- Py_XDECREF(ys); ys = 0;
1800- Py_XDECREF(zs); zs = 0;
1801- """
1802-
18031797 def c_code (self , node , name , inp , out , sub ):
18041798 _x , _y = inp
18051799 (_z ,) = out
@@ -1832,12 +1826,11 @@ def contiguous(var, ndim):
18321826 )
18331827
18341828 # generate code to allocate output based on runtime input shapes
1835- z_dims = [f"PyArray_DIMS({ _x } )[0]" ]
1836- if x_ndim == 3 :
1837- z_dims .append (f"PyArray_DIMS({ _x } )[1]" )
1838- if y_ndim == 3 :
1839- z_dims .append (f"PyArray_DIMS({ _y } )[2]" )
1840- assert len (z_dims ) == z_ndim
1829+ z_dims = [
1830+ f"PyArray_DIMS({ _x } )[0]" ,
1831+ f"PyArray_DIMS({ _x } )[1]" ,
1832+ f"PyArray_DIMS({ _y } )[2]" ,
1833+ ]
18411834
18421835 z_shape_correct = " && " .join (
18431836 "PyArray_DIMS(%s)[%i] == %s" % (_z , i , dim ) for i , dim in enumerate (z_dims )
@@ -1880,76 +1873,26 @@ def contiguous(var, ndim):
18801873 )
18811874 contiguate = "\n " .join (contiguate )
18821875
1883- def c_dimshuffle (newname , oldname , shape ):
1884- _fail = fail
1885- _shape = ", " .join (
1886- "1" if axis is None else "PyArray_DIMS(%s)[%i]" % (oldname , axis )
1887- for axis in shape
1888- )
1889- return (
1890- """{
1891- npy_intp dims[3] = {%(_shape)s};
1892- PyArray_Dims newshape = {dims, 3};
1893- %(newname)s = (PyArrayObject*)PyArray_Newshape(%(oldname)s, &newshape, NPY_ANYORDER);
1894- if (!%(newname)s)
1895- %(_fail)s
1896- // make sure we didn't accidentally copy
1897- assert(PyArray_DATA(%(oldname)s) == PyArray_DATA(%(newname)s));
1898- }"""
1899- % locals ()
1900- )
1901-
1902- # create tensor3 views for any of x, y, z that are not tensor3, so that
1903- # we only need to implement the tensor3-tensor3 batched dot product.
1904- # xs, ys and zs will point to these views, or to the original array if
1905- # it was already tensor3.
1906- # in the latter case, we artificially increase the reference count of
1907- # the original array so that the c_code_cleanup method can decref them
1908- # all indiscriminately.
1909- upcast = []
1910- if x_ndim == 3 :
1911- upcast .append ("xs = %(_x)s; Py_XINCREF(xs);" )
1912- elif x_ndim == 2 :
1913- upcast .append (c_dimshuffle ("xs" , _x , (0 , None , 1 )))
1914- if y_ndim == 3 :
1915- upcast .append ("ys = %(_y)s; Py_XINCREF(ys);" )
1916- elif y_ndim == 2 :
1917- upcast .append (c_dimshuffle ("ys" , _y , (0 , 1 , None )))
1918- if z_ndim == 3 :
1919- upcast .append ("zs = %(_z)s; Py_XINCREF(zs);" )
1920- else :
1921- upcast .append (
1922- c_dimshuffle (
1923- "zs" ,
1924- _z ,
1925- (0 , None if x_ndim == 2 else 1 , None if y_ndim == 2 else 1 ),
1926- )
1927- )
1928- upcast = "\n " .join (upcast ) % locals ()
1929-
19301876 return (
19311877 """
19321878 int type_num = PyArray_DESCR(%(_x)s)->type_num;
19331879 int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes
19341880
1935- // xs, ys, zs will point to views onto %(_x)s, %(_y)s, %(_z)s
1936- PyArrayObject *xs = 0, *ys = 0, *zs = 0;
1937-
1938- if (PyArray_NDIM(%(_x)s) != %(x_ndim)s) {
1881+ if (PyArray_NDIM(%(_x)s) != 3) {
19391882 PyErr_Format(PyExc_NotImplementedError,
1940- "rank(x) != %(x_ndim)s . rank(x) is %%d.",
1883+ "rank(x) != 3 . rank(x) is %%d.",
19411884 PyArray_NDIM(%(_x)s));
19421885 %(fail)s;
19431886 }
1944- if (PyArray_NDIM(%(_y)s) != %(y_ndim)s ) {
1887+ if (PyArray_NDIM(%(_y)s) != 3 ) {
19451888 PyErr_Format(PyExc_NotImplementedError,
1946- "rank(y) != %(y_ndim)s . rank(y) is %%d.",
1889+ "rank(y) != 3 . rank(y) is %%d.",
19471890 PyArray_NDIM(%(_y)s));
19481891 %(fail)s;
19491892 }
1950- if (%(_z)s && PyArray_NDIM(%(_z)s) != %(z_ndim)s ) {
1893+ if (%(_z)s && PyArray_NDIM(%(_z)s) != 3 ) {
19511894 PyErr_Format(PyExc_NotImplementedError,
1952- "rank(z) != %(z_ndim)s . rank(z) is %%d.",
1895+ "rank(z) != 3 . rank(z) is %%d.",
19531896 PyArray_NDIM(%(_z)s));
19541897 %(fail)s;
19551898 }
@@ -1958,36 +1901,32 @@ def c_dimshuffle(newname, oldname, shape):
19581901 %(allocate)s
19591902 // reallocate any noncontiguous arrays or arrays with invalid strides
19601903 %(contiguate)s
1961- // add dims to make sure everything is tensor3
1962- %(upcast)s
1963- // from here on, use xs, ys and zs as they are tensor3 and share memory
1964- // with the original %(_x)s, %(_y)s and %(_z)s arrays.
19651904
1966- if ((PyArray_DESCR(xs )->type_num != NPY_DOUBLE)
1967- && (PyArray_DESCR(xs )->type_num != NPY_FLOAT))
1905+ if ((PyArray_DESCR(%(_x)s )->type_num != NPY_DOUBLE)
1906+ && (PyArray_DESCR(%(_x)s )->type_num != NPY_FLOAT))
19681907 {PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
19691908
1970- if ((PyArray_DESCR(ys )->type_num != NPY_DOUBLE)
1971- && (PyArray_DESCR(ys )->type_num != NPY_FLOAT))
1909+ if ((PyArray_DESCR(%(_y)s )->type_num != NPY_DOUBLE)
1910+ && (PyArray_DESCR(%(_y)s )->type_num != NPY_FLOAT))
19721911 {PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
19731912
1974- if ((PyArray_DESCR(zs )->type_num != NPY_DOUBLE)
1975- && (PyArray_DESCR(zs )->type_num != NPY_FLOAT))
1913+ if ((PyArray_DESCR(%(_z)s )->type_num != NPY_DOUBLE)
1914+ && (PyArray_DESCR(%(_z)s )->type_num != NPY_FLOAT))
19761915 {PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
19771916
1978- if ((PyArray_DESCR(xs) ->type_num != PyArray_DESCR(ys )->type_num)
1979- ||(PyArray_DESCR(xs) ->type_num != PyArray_DESCR(zs )->type_num))
1917+ if ((PyArray_DESCR(%(_x)s) ->type_num != PyArray_DESCR(%(_y)s )->type_num)
1918+ ||(PyArray_DESCR(%(_x)s) ->type_num != PyArray_DESCR(%(_z)s )->type_num))
19801919 { PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; }
19811920
19821921 switch (type_num)
19831922 {
19841923 case NPY_FLOAT:
1985- if (batch_gemm<float>(sgemm_, type_size, xs, ys, zs )) {
1924+ if (batch_gemm<float>(sgemm_, type_size, %(_x)s, %(_y)s, %(_z)s )) {
19861925 %(fail)s;
19871926 }
19881927 break;
19891928 case NPY_DOUBLE:
1990- if (batch_gemm<double>(dgemm_, type_size, xs, ys, zs )) {
1929+ if (batch_gemm<double>(dgemm_, type_size, %(_x)s, %(_y)s, %(_z)s )) {
19911930 %(fail)s;
19921931 }
19931932 break;
@@ -1999,32 +1938,14 @@ def c_dimshuffle(newname, oldname, shape):
19991938 def c_code_cache_version (self ):
20001939 from pytensor .tensor .blas_headers import blas_header_version
20011940
2002- return (4 , blas_header_version ())
1941+ return (5 , blas_header_version ())
20031942
20041943 def grad (self , inp , grads ):
20051944 x , y = inp
20061945 (gz ,) = grads
2007- xdim , ydim , gdim = x .type .ndim , y .type .ndim , gz .type .ndim
20081946
2009- # grad is a vector, so x is a matrix and y is a matrix
2010- if gdim == 1 :
2011- xgrad = gz .dimshuffle (0 , "x" ) * y
2012- ygrad = gz .dimshuffle (0 , "x" ) * x
2013-
2014- # x is a matrix, y is a tensor3, grad is a matrix
2015- elif xdim == 2 and ydim == 3 :
2016- xgrad = batched_dot (gz , y .dimshuffle (0 , 2 , 1 ))
2017- ygrad = x .dimshuffle (0 , 1 , "x" ) * gz .dimshuffle (0 , "x" , 1 )
2018-
2019- # x is a tensor3, y is a matrix, grad is a matrix
2020- elif xdim == 3 and ydim == 2 :
2021- xgrad = gz .dimshuffle (0 , 1 , "x" ) * y .dimshuffle (0 , "x" , 1 )
2022- ygrad = batched_dot (x .dimshuffle (0 , 2 , 1 ), gz )
2023-
2024- # x is a tensor3, y is a tensor3, grad is a tensor3
2025- elif xdim == ydim == 3 :
2026- xgrad = batched_dot (gz , y .dimshuffle (0 , 2 , 1 ))
2027- ygrad = batched_dot (x .dimshuffle (0 , 2 , 1 ), gz )
1947+ xgrad = batched_dot (gz , y .dimshuffle (0 , 2 , 1 ))
1948+ ygrad = batched_dot (x .dimshuffle (0 , 2 , 1 ), gz )
20281949
20291950 # If x or y contain broadcastable dimensions but only one of
20301951 # them know that a matching dimensions is broadcastable, the
@@ -2105,6 +2026,7 @@ def R_op(self, inputs, eval_points):
21052026 + " to BatchedDot.R_op should have the same shape, but "
21062027 f"their shapes are { input_values [i ].shape } and { eval_point_values [i ].shape } , respectively"
21072028 )
2029+
21082030 if eval_points [0 ]:
21092031 t1 = self (eval_points [0 ], inputs [1 ])
21102032 if eval_points [1 ]:
@@ -2118,9 +2040,6 @@ def R_op(self, inputs, eval_points):
21182040 return [t2 ]
21192041
21202042 def infer_shape (self , fgraph , node , shapes ):
2121- for shape_ in shapes :
2122- if len (shape_ ) not in (2 , 3 ):
2123- raise NotImplementedError ()
21242043 xshp , yshp = shapes
21252044 return [xshp [:- 1 ] + yshp [2 :]]
21262045
@@ -2157,14 +2076,24 @@ def batched_dot(a, b):
21572076 elif b .ndim == 0 :
21582077 raise TypeError ("b must have at least one (batch) axis" )
21592078 elif a .ndim == 1 :
2160- return a . dimshuffle ( * ([ 0 ] + [ "x" ] * (b .ndim - 1 ) )) * b
2079+ return shape_padright ( a , (b .ndim - 1 )) * b
21612080 elif b .ndim == 1 :
2162- return a * b . dimshuffle ( * ([ 0 ] + [ "x" ] * (a .ndim - 1 ) ))
2081+ return a * shape_padright ( b , (a .ndim - 1 ))
21632082 elif a .ndim > 3 or b .ndim > 3 :
21642083 return batched_tensordot (a , b , [[a .ndim - 1 ], [np .maximum (1 , b .ndim - 2 )]])
21652084 else :
2166- # avoid circular import
2167- return _batched_dot (a , b )
2085+ # If either a or b is a batched vector, expand dims and later squeeze them
2086+ expanded_axis = []
2087+ if a .ndim == 2 :
2088+ a = expand_dims (a , axis = 1 )
2089+ expanded_axis .append (1 )
2090+ if b .ndim == 2 :
2091+ b = expand_dims (b , axis = 2 )
2092+ expanded_axis .append (2 )
2093+ out = _batched_dot (a , b )
2094+ if expanded_axis :
2095+ out = out .squeeze (axis = expanded_axis )
2096+ return out
21682097
21692098
21702099def batched_tensordot (x , y , axes = 2 ):
0 commit comments