diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 592a4ba27c..3124428016 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -79,10 +79,14 @@ import logging import os import shlex +import warnings from pathlib import Path import numpy as np +from pytensor.graph import vectorize_graph +from pytensor.npy_2_compat import normalize_axis_tuple + try: import numpy.__config__ @@ -99,10 +103,10 @@ from pytensor.link.c.params_type import ParamsType from pytensor.printing import FunctionPrinter, pprint from pytensor.scalar import bool as bool_t -from pytensor.tensor import basic as ptb -from pytensor.tensor.basic import expand_dims +from pytensor.tensor.basic import as_tensor_variable, cast from pytensor.tensor.blas_headers import blas_header_text, blas_header_version -from pytensor.tensor.shape import shape_padright, specify_broadcastable +from pytensor.tensor.math import dot, tensordot +from pytensor.tensor.shape import specify_broadcastable from pytensor.tensor.type import DenseTensorType, tensor @@ -153,11 +157,11 @@ def __str__(self): return f"{self.__class__.__name__}{{no_inplace}}" def make_node(self, y, alpha, A, x, beta): - y = ptb.as_tensor_variable(y) - x = ptb.as_tensor_variable(x) - A = ptb.as_tensor_variable(A) - alpha = ptb.as_tensor_variable(alpha) - beta = ptb.as_tensor_variable(beta) + y = as_tensor_variable(y) + x = as_tensor_variable(x) + A = as_tensor_variable(A) + alpha = as_tensor_variable(alpha) + beta = as_tensor_variable(beta) if y.dtype != A.dtype or y.dtype != x.dtype: raise TypeError( "Gemv requires matching dtypes", (y.dtype, A.dtype, x.dtype) @@ -253,10 +257,10 @@ def __str__(self): return f"{self.__class__.__name__}{{non-destructive}}" def make_node(self, A, alpha, x, y): - A = ptb.as_tensor_variable(A) - y = ptb.as_tensor_variable(y) - x = ptb.as_tensor_variable(x) - alpha = ptb.as_tensor_variable(alpha) + A = as_tensor_variable(A) + y = as_tensor_variable(y) + x = as_tensor_variable(x) + alpha = as_tensor_variable(alpha) if not (A.dtype == x.dtype == y.dtype == alpha.dtype): raise TypeError( "ger requires matching dtypes", (A.dtype, alpha.dtype, x.dtype, y.dtype) @@ -855,7 +859,7 @@ def __getstate__(self): return rval def make_node(self, *inputs): - inputs = list(map(ptb.as_tensor_variable, inputs)) + inputs = list(map(as_tensor_variable, inputs)) if any(not isinstance(i.type, DenseTensorType) for i in inputs): raise NotImplementedError("Only dense tensor types are supported") @@ -1125,8 +1129,8 @@ class Dot22(GemmRelated): check_input = False def make_node(self, x, y): - x = ptb.as_tensor_variable(x) - y = ptb.as_tensor_variable(y) + x = as_tensor_variable(x) + y = as_tensor_variable(y) if any(not isinstance(i.type, DenseTensorType) for i in (x, y)): raise NotImplementedError("Only dense tensor types are supported") @@ -1318,8 +1322,8 @@ class BatchedDot(COp): gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)" def make_node(self, x, y): - x = ptb.as_tensor_variable(x) - y = ptb.as_tensor_variable(y) + x = as_tensor_variable(x) + y = as_tensor_variable(y) if not ( isinstance(x.type, DenseTensorType) and isinstance(y.type, DenseTensorType) @@ -1353,7 +1357,7 @@ def extract_static_dim(dim_x, dim_y): # Change dtype if needed dtype = pytensor.scalar.upcast(x.type.dtype, y.type.dtype) - x, y = ptb.cast(x, dtype), ptb.cast(y, dtype) + x, y = cast(x, dtype), cast(y, dtype) out = tensor(dtype=dtype, shape=out_shape) return Apply(self, [x, y], [out]) @@ -1604,8 +1608,8 @@ def grad(self, inp, grads): x, y = inp (gz,) = grads - xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1)) - ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz) + xgrad = _batched_dot(gz, y.dimshuffle(0, 2, 1)) + ygrad = _batched_dot(x.dimshuffle(0, 2, 1), gz) # If x or y contain broadcastable dimensions but only one of # them know that a matching dimensions is broadcastable, the @@ -1729,31 +1733,22 @@ def batched_dot(a, b): dot products in terms of batched matrix-matrix dot products, so it may be possible to further optimize for performance. """ - a, b = ptb.as_tensor_variable(a), ptb.as_tensor_variable(b) + warnings.warn( + "batched_dot is deprecated. " + "Use `dot` in conjution with `tensor.vectorize` or `graph.replace.vectorize_graph`", + FutureWarning, + ) + a, b = as_tensor_variable(a), as_tensor_variable(b) if a.ndim == 0: raise TypeError("a must have at least one (batch) axis") elif b.ndim == 0: raise TypeError("b must have at least one (batch) axis") - elif a.ndim == 1: - return shape_padright(a, (b.ndim - 1)) * b - elif b.ndim == 1: - return a * shape_padright(b, (a.ndim - 1)) - elif a.ndim > 3 or b.ndim > 3: - return batched_tensordot(a, b, [[a.ndim - 1], [np.maximum(1, b.ndim - 2)]]) - else: - # If either a or b is a batched vector, expand dims and later squeeze them - expanded_axis = [] - if a.ndim == 2: - a = expand_dims(a, axis=1) - expanded_axis.append(1) - if b.ndim == 2: - b = expand_dims(b, axis=2) - expanded_axis.append(2) - out = _batched_dot(a, b) - if expanded_axis: - out = out.squeeze(axis=expanded_axis) - return out + + core_a = a[0].type() + core_b = b[0].type() + core_dot = dot(core_a, core_b) + return vectorize_graph(core_dot, replace={core_a: a, core_b: b}) def batched_tensordot(x, y, axes=2): @@ -1791,6 +1786,22 @@ def batched_tensordot(x, y, axes=2): reshapes to reduce the tensor dot product to a matrix or vector dot product. Finally, it calls batched_dot to compute the result. """ - from pytensor.tensor.math import _tensordot_as_dot + warnings.warn( + "batched_tensordot is deprecated. " + "Use `tensordot` in conjuction with `tensor.vectorize` or `graph.replace.vectorize_graph`", + FutureWarning, + ) + + if isinstance(axes, int): + core_axes = axes + else: + # Convert batched axes to core axes + core_axes_a = [a - 1 for a in normalize_axis_tuple(axes[0], x.type.ndim)] + core_axes = [a - 1 for a in normalize_axis_tuple(axes[1], y.type.ndim)] + core_axes = [core_axes_a, core_axes] + + core_x = x[0].type() + core_y = y[0].type() + core_tensordot = tensordot(core_x, core_y, axes=core_axes) - return _tensordot_as_dot(x, y, axes, dot=batched_dot, batched=True) + return vectorize_graph(core_tensordot, replace={core_x: x, core_y: y}) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 2aa6ad2381..5d0a6fdb2b 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -50,7 +50,7 @@ tensor, uint_dtypes, ) -from pytensor.tensor.utils import as_list, normalize_reduce_axis +from pytensor.tensor.utils import normalize_reduce_axis from pytensor.tensor.variable import ( TensorVariable, _tensor_py_operators, @@ -1927,133 +1927,6 @@ def dense_dot(a, b): return _dot(a, b) -def _tensordot_as_dot(a, b, axes, dot, batched): - """ - Reduces a tensor dot product to a matrix or vector dot product. Based - on code from Tijmen Tieleman's gnumpy - (http://www.cs.toronto.edu/~tijmen/gnumpy.html). - - Please see the documentation of tensordot for the meaning of the a, b - and axes arguments. - - :param dot: a function that accepts two symbolic variables and computes - the appropriate dot product (e.g. dot, batched_dot) - :type dot: function - - :param batched: whether to treat the first axis of a and b as a batch - axis. If so, this axis will be preserved in the output, - allowing this function to be used also for batched - tensor dot products. - :type batched: boolean - - :returns: a tensor with shape equal to the concatenation of a's shape - (less any dimensions that were summed over) and b's shape - (less the first dimension and any dimensions that were summed - over). - :rtype: symbolic tensor - """ - a, b = as_tensor_variable(a), as_tensor_variable(b) - - if not np.isscalar(axes) and len(axes) != 2: - raise ValueError( - "Axes should be an integer or a " - f"list/tuple of len 2 ({axes} was provided)" - ) - - # if 'axes' is a number of axes to multiply and sum over (trailing axes - # of a, leading axes of b), we can just reshape and use dot. - elif np.isscalar(axes): - axes = int(axes) - - for operand_name, operand in (("a", a), ("b", b)): - if axes > operand.ndim: - raise ValueError( - f"axes can not be larger than the dimension of {operand_name} " - f"({operand_name}.ndim={operand.ndim}, axes={axes})" - ) - if batched and axes == operand.ndim: - raise ValueError( - "axes to sum over must not include the batch axis " - f"of {operand_name} ({operand_name}.ndim={operand.ndim}, axes={axes})" - ) - - batch_axes = 1 if batched else 0 - a_outaxes = slice(0, a.ndim - axes) - b_outaxes = slice(batch_axes + axes, b.ndim) - outshape = concatenate([a.shape[a_outaxes], b.shape[b_outaxes]]) - outbcast = a.broadcastable[a_outaxes] + b.broadcastable[b_outaxes] - outndim = len(outbcast) - - a_shape = [1] * 2 - b_shape = [1] * 2 - - # compute total size of summed axes - for i in range(0, axes): - a_shape[1] *= a.shape[-(i + 1)] - b_shape[0] *= b.shape[batch_axes + i] - # compute total size of other axes - for i in range(0, a.ndim - axes - batch_axes): - a_shape[0] *= a.shape[batch_axes + i] - for i in range(0, b.ndim - axes - batch_axes): - b_shape[1] *= b.shape[-(i + 1)] - - if batched: - a_shape.insert(0, a.shape[0]) - b_shape.insert(0, b.shape[0]) - - a_reshaped = a.reshape(a_shape) - b_reshaped = b.reshape(b_shape) - - out_reshaped = dot(a_reshaped, b_reshaped) - out = out_reshaped.reshape(outshape, ndim=outndim) - # Make sure the broadcastable pattern of the result is correct, - # since some shape information can be lost in the reshapes. - if out.type.broadcastable != outbcast: - out = specify_broadcastable( - out, *(ax for (ax, b) in enumerate(outbcast) if b) - ) - return out - - # if 'axes' is a list, transpose a and b such that the summed axes of a - # are last and the summed axes of b are first. - else: - axes = [as_list(axes_) for axes_ in axes] - - if len(axes[0]) != len(axes[1]): - raise ValueError("Axes elements must have the same length.") - - for i, (operand_name, operand) in enumerate((("a", a), ("b", b))): - if len(axes[i]) > operand.ndim: - raise ValueError( - f"axes[{i}] should be array_like with length less than " - f"the dimensions of {operand_name} ({operand_name}.ndim={operand.ndim}, len(axes[0])={len(axes[i])})." - ) - if len(axes[i]) > 0 and np.max(axes[i]) >= operand.ndim: - raise ValueError( - f"axes[{i}] contains dimensions greater than or equal " - f"to {operand_name}.ndim ({operand_name}.ndim={operand.ndim}, max(axes[0])={np.max(np.array(axes[i]))})." - ) - if batched and 0 in axes[i]: - raise ValueError( - "axes to sum over must not contain the batch axis " - f"(axes[{i}]={axes[i]})" - ) - - batch_axes = [0] if batched else [] - other_axes = [ - [x for x in range(operand.ndim) if x not in axes[i] and x not in batch_axes] - for i, operand in enumerate((a, b)) - ] - - a_shuffled = a.dimshuffle(batch_axes + other_axes[0] + axes[0]) - b_shuffled = b.dimshuffle(batch_axes + axes[1] + other_axes[1]) - - # now that a and b are in the right order, recur with integer axes - return _tensordot_as_dot( - a_shuffled, b_shuffled, len(axes[0]), dot=dot, batched=batched - ) - - def tensordot( a: TensorLike, b: TensorLike, axes: int | Sequence[Sequence[int]] = 2 ) -> TensorVariable: diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index 0bf2733f10..b5c2564481 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -84,9 +84,9 @@ from pytensor.tensor import basic as ptb from pytensor.tensor.blas import ( Dot22, + _batched_dot, _dot22, _dot22scalar, - batched_dot, gemm_inplace, gemm_no_inplace, gemv_inplace, @@ -928,7 +928,7 @@ def specialize_matmul_to_batched_dot(fgraph, node): x = x.reshape((-1, x_shape[-2], x_shape[-1])) y = y.reshape((-1, y_shape[-2], y_shape[-1])) - new_out = batched_dot(x, y) + new_out = _batched_dot(x, y) if len(x_shape) > 3: # And then unravel it diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 9ce12296cd..0ebb2e5434 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -107,14 +107,6 @@ def shape_of_variables( return l -def as_list(x): - """Convert x to a list if it is an iterable; otherwise, wrap it in a list.""" - try: - return list(x) - except TypeError: - return [x] - - def import_func_from_string(func_string: str): # -> Optional[Callable]: func = getattr(np, func_string, None) if func is not None: diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py index 1e4afb8928..37e2c380b9 100644 --- a/tests/tensor/test_blas.py +++ b/tests/tensor/test_blas.py @@ -27,6 +27,7 @@ Gemm, Gemv, Ger, + _batched_dot, _dot22, _dot22scalar, batched_dot, @@ -2446,7 +2447,7 @@ def test_ger(self): rng = np.random.default_rng(unittest_tools.fetch_seed()) TestBatchedDot = makeTester( name="BatchedDotTester", - op=batched_dot, + op=_batched_dot, expected=( lambda xs, ys: np.asarray( [ @@ -2460,34 +2461,10 @@ def test_ger(self): grad=dict( correct1=(random(3, 5, 7, rng=rng), random(3, 7, 5, rng=rng)), correct2=(random(3, 5, 7, rng=rng), random(3, 7, 9, rng=rng)), - correct3=(random(3, 5, 7, rng=rng), random(3, 7, rng=rng)), - correct4=(random(3, 5), random(3, 5, 7, rng=rng)), - correct5=(random(3, rng=rng), random(3, 5, 7, rng=rng)), - correct6=(random(3, 5, rng=rng), random(3, rng=rng)), - correct7=(random(3, 5, rng=rng), random(3, 5, rng=rng)), - correct8=(random(3, rng=rng), random(3, rng=rng)), - correct9=(random(3, 5, 7, 11, rng=rng), random(3, rng=rng)), - correct10=(random(3, 2, 6, 5, rng=rng), random(3, 5, rng=rng)), - correct11=(random(3, 2, 6, 5, rng=rng), random(3, 5, 7, rng=rng)), - correct12=(random(3, 2, 6, 5, rng=rng), random(3, 7, 5, 8, rng=rng)), - mixed1=(random(3, 5, rng=rng).astype("float32"), random(3, 5, 7, rng=rng)), - mixed2=(random(3, 5, rng=rng).astype("float64"), random(3, 5, 7, rng=rng)), ), good=dict( correct1=(random(3, 5, 7, rng=rng), random(3, 7, 5, rng=rng)), correct2=(random(3, 5, 7, rng=rng), random(3, 7, 9, rng=rng)), - correct3=(random(3, 5, 7, rng=rng), random(3, 7, rng=rng)), - correct4=(random(3, 5, rng=rng), random(3, 5, 7, rng=rng)), - correct5=(random(3, rng=rng), random(3, 5, 7, rng=rng)), - correct6=(random(3, 5, rng=rng), random(3, rng=rng)), - correct7=(random(3, 5, rng=rng), random(3, 5, rng=rng)), - correct8=(random(3, rng=rng), random(3, rng=rng)), - correct9=(random(3, 5, 7, 11, rng=rng), random(3, rng=rng)), - correct10=(random(3, 7, 11, 5, rng=rng), random(3, 5, rng=rng)), - correct11=(random(3, 7, 11, 5, rng=rng), random(3, 5, 13, rng=rng)), - correct12=(random(3, 7, 11, 5, rng=rng), random(3, 13, 5, 17, rng=rng)), - mixed1=(random(3, 5, rng=rng).astype("float32"), random(3, 5, 7, rng=rng)), - mixed2=(random(3, 5, rng=rng).astype("float64"), random(3, 5, 7, rng=rng)), ), bad_build=dict( no_batch_axis2=(random(rng=rng), random(3, 5, rng=rng)), @@ -2496,13 +2473,8 @@ def test_ger(self): bad_runtime=dict( batch_dim_mismatch1=(random(2, 5, 7, rng=rng), random(3, 7, 9, rng=rng)), batch_dim_mismatch2=(random(3, 5, 7, rng=rng), random(2, 7, 9, rng=rng)), - batch_dim_mismatch3=(random(3, rng=rng), random(5, rng=rng)), bad_dim1=(random(3, 5, 7, rng=rng), random(3, 5, 7, rng=rng)), bad_dim2=(random(3, 5, 7, rng=rng), random(3, 8, 3, rng=rng)), - bad_dim3=(random(3, 5, rng=rng), random(3, 7, rng=rng)), - bad_dim4=(random(3, 5, 7, 11, rng=rng), random(3, 5, rng=rng)), - bad_dim5=(random(3, 5, 7, 11, rng=rng), random(3, 5, 13, rng=rng)), - bad_dim6=(random(3, 5, 7, 11, rng=rng), random(3, 13, 5, 17, rng=rng)), ), ) @@ -2511,7 +2483,8 @@ def test_batched_dot(): rng = np.random.default_rng(unittest_tools.fetch_seed()) first = tensor3("first") second = tensor3("second") - output = batched_dot(first, second) + with pytest.warns(FutureWarning): + output = batched_dot(first, second) first_val = rng.random((10, 10, 20)).astype(config.floatX) second_val = rng.random((10, 20, 5)).astype(config.floatX) result_fn = function([first, second], output) @@ -2522,7 +2495,8 @@ def test_batched_dot(): first_mat = dmatrix("first") second_mat = dmatrix("second") - output = batched_dot(first_mat, second_mat) + with pytest.warns(FutureWarning): + output = batched_dot(first_mat, second_mat) first_mat_val = rng.random((10, 10)).astype(config.floatX) second_mat_val = rng.random((10, 10)).astype(config.floatX) result_fn = function([first_mat, second_mat], output) @@ -2540,7 +2514,7 @@ def np_genarray(*_shape): X = tensor3() W = tensor3() - Z = batched_dot(X, W) + Z = _batched_dot(X, W) f = function([X, W], Z) w = np_genarray(30, 10, 5) @@ -2568,7 +2542,7 @@ def test_batched_dot_blas_flags(): x = tensor("x", shape=(2, 5, 3)) y = tensor("y", shape=(2, 3, 1)) - out = batched_dot(x, y) + out = _batched_dot(x, y) assert isinstance(out.owner.op, BatchedDot) x_test = rng.normal(size=x.type.shape).astype(x.type.dtype) y_test = rng.normal(size=y.type.shape).astype(y.type.dtype) @@ -2590,7 +2564,8 @@ def test_batched_tensordot(): first = tensor4("first") second = tensor4("second") axes = [[1, 2], [3, 1]] - output = batched_tensordot(first, second, axes) + with pytest.warns(FutureWarning): + output = batched_tensordot(first, second, axes) first_val = rng.random((8, 10, 20, 3)).astype(config.floatX) second_val = rng.random((8, 20, 5, 10)).astype(config.floatX) result_fn = function([first, second], output) @@ -2602,7 +2577,8 @@ def test_batched_tensordot(): first_mat = dmatrix("first") second_mat = dmatrix("second") axes = 1 - output = batched_tensordot(first_mat, second_mat, axes) + with pytest.warns(FutureWarning): + output = batched_tensordot(first_mat, second_mat, axes) first_mat_val = rng.random((10, 4)).astype(config.floatX) second_mat_val = rng.random((10, 4)).astype(config.floatX) result_fn = function([first_mat, second_mat], output)