Skip to content

Deprecate batched blas helpers #1215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 53 additions & 42 deletions pytensor/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,14 @@
import logging
import os
import shlex
import warnings
from pathlib import Path

import numpy as np

from pytensor.graph import vectorize_graph
from pytensor.npy_2_compat import normalize_axis_tuple


try:
import numpy.__config__
Expand All @@ -99,10 +103,10 @@
from pytensor.link.c.params_type import ParamsType
from pytensor.printing import FunctionPrinter, pprint
from pytensor.scalar import bool as bool_t
from pytensor.tensor import basic as ptb
from pytensor.tensor.basic import expand_dims
from pytensor.tensor.basic import as_tensor_variable, cast
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
from pytensor.tensor.shape import shape_padright, specify_broadcastable
from pytensor.tensor.math import dot, tensordot
from pytensor.tensor.shape import specify_broadcastable
from pytensor.tensor.type import DenseTensorType, tensor


Expand Down Expand Up @@ -153,11 +157,11 @@ def __str__(self):
return f"{self.__class__.__name__}{{no_inplace}}"

def make_node(self, y, alpha, A, x, beta):
y = ptb.as_tensor_variable(y)
x = ptb.as_tensor_variable(x)
A = ptb.as_tensor_variable(A)
alpha = ptb.as_tensor_variable(alpha)
beta = ptb.as_tensor_variable(beta)
y = as_tensor_variable(y)
x = as_tensor_variable(x)
A = as_tensor_variable(A)
alpha = as_tensor_variable(alpha)
beta = as_tensor_variable(beta)
if y.dtype != A.dtype or y.dtype != x.dtype:
raise TypeError(
"Gemv requires matching dtypes", (y.dtype, A.dtype, x.dtype)
Expand Down Expand Up @@ -253,10 +257,10 @@ def __str__(self):
return f"{self.__class__.__name__}{{non-destructive}}"

def make_node(self, A, alpha, x, y):
A = ptb.as_tensor_variable(A)
y = ptb.as_tensor_variable(y)
x = ptb.as_tensor_variable(x)
alpha = ptb.as_tensor_variable(alpha)
A = as_tensor_variable(A)
y = as_tensor_variable(y)
x = as_tensor_variable(x)
alpha = as_tensor_variable(alpha)
if not (A.dtype == x.dtype == y.dtype == alpha.dtype):
raise TypeError(
"ger requires matching dtypes", (A.dtype, alpha.dtype, x.dtype, y.dtype)
Expand Down Expand Up @@ -855,7 +859,7 @@ def __getstate__(self):
return rval

def make_node(self, *inputs):
inputs = list(map(ptb.as_tensor_variable, inputs))
inputs = list(map(as_tensor_variable, inputs))

if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")
Expand Down Expand Up @@ -1125,8 +1129,8 @@ class Dot22(GemmRelated):
check_input = False

def make_node(self, x, y):
x = ptb.as_tensor_variable(x)
y = ptb.as_tensor_variable(y)
x = as_tensor_variable(x)
y = as_tensor_variable(y)

if any(not isinstance(i.type, DenseTensorType) for i in (x, y)):
raise NotImplementedError("Only dense tensor types are supported")
Expand Down Expand Up @@ -1318,8 +1322,8 @@ class BatchedDot(COp):
gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)"

def make_node(self, x, y):
x = ptb.as_tensor_variable(x)
y = ptb.as_tensor_variable(y)
x = as_tensor_variable(x)
y = as_tensor_variable(y)

if not (
isinstance(x.type, DenseTensorType) and isinstance(y.type, DenseTensorType)
Expand Down Expand Up @@ -1353,7 +1357,7 @@ def extract_static_dim(dim_x, dim_y):

# Change dtype if needed
dtype = pytensor.scalar.upcast(x.type.dtype, y.type.dtype)
x, y = ptb.cast(x, dtype), ptb.cast(y, dtype)
x, y = cast(x, dtype), cast(y, dtype)
out = tensor(dtype=dtype, shape=out_shape)
return Apply(self, [x, y], [out])

Expand Down Expand Up @@ -1604,8 +1608,8 @@ def grad(self, inp, grads):
x, y = inp
(gz,) = grads

xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
xgrad = _batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = _batched_dot(x.dimshuffle(0, 2, 1), gz)

# If x or y contain broadcastable dimensions but only one of
# them know that a matching dimensions is broadcastable, the
Expand Down Expand Up @@ -1729,31 +1733,22 @@ def batched_dot(a, b):
dot products in terms of batched matrix-matrix dot products, so
it may be possible to further optimize for performance.
"""
a, b = ptb.as_tensor_variable(a), ptb.as_tensor_variable(b)
warnings.warn(
"batched_dot is deprecated. "
"Use `dot` in conjution with `tensor.vectorize` or `graph.replace.vectorize_graph`",
FutureWarning,
)
a, b = as_tensor_variable(a), as_tensor_variable(b)

if a.ndim == 0:
raise TypeError("a must have at least one (batch) axis")
elif b.ndim == 0:
raise TypeError("b must have at least one (batch) axis")
elif a.ndim == 1:
return shape_padright(a, (b.ndim - 1)) * b
elif b.ndim == 1:
return a * shape_padright(b, (a.ndim - 1))
elif a.ndim > 3 or b.ndim > 3:
return batched_tensordot(a, b, [[a.ndim - 1], [np.maximum(1, b.ndim - 2)]])
else:
# If either a or b is a batched vector, expand dims and later squeeze them
expanded_axis = []
if a.ndim == 2:
a = expand_dims(a, axis=1)
expanded_axis.append(1)
if b.ndim == 2:
b = expand_dims(b, axis=2)
expanded_axis.append(2)
out = _batched_dot(a, b)
if expanded_axis:
out = out.squeeze(axis=expanded_axis)
return out

core_a = a[0].type()
core_b = b[0].type()
core_dot = dot(core_a, core_b)
return vectorize_graph(core_dot, replace={core_a: a, core_b: b})


def batched_tensordot(x, y, axes=2):
Expand Down Expand Up @@ -1791,6 +1786,22 @@ def batched_tensordot(x, y, axes=2):
reshapes to reduce the tensor dot product to a matrix or vector
dot product. Finally, it calls batched_dot to compute the result.
"""
from pytensor.tensor.math import _tensordot_as_dot
warnings.warn(
"batched_tensordot is deprecated. "
"Use `tensordot` in conjuction with `tensor.vectorize` or `graph.replace.vectorize_graph`",
FutureWarning,
)

if isinstance(axes, int):
core_axes = axes
else:
# Convert batched axes to core axes
core_axes_a = [a - 1 for a in normalize_axis_tuple(axes[0], x.type.ndim)]
core_axes = [a - 1 for a in normalize_axis_tuple(axes[1], y.type.ndim)]
core_axes = [core_axes_a, core_axes]

core_x = x[0].type()
core_y = y[0].type()
core_tensordot = tensordot(core_x, core_y, axes=core_axes)

return _tensordot_as_dot(x, y, axes, dot=batched_dot, batched=True)
return vectorize_graph(core_tensordot, replace={core_x: x, core_y: y})
129 changes: 1 addition & 128 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
tensor,
uint_dtypes,
)
from pytensor.tensor.utils import as_list, normalize_reduce_axis
from pytensor.tensor.utils import normalize_reduce_axis
from pytensor.tensor.variable import (
TensorVariable,
_tensor_py_operators,
Expand Down Expand Up @@ -1927,133 +1927,6 @@ def dense_dot(a, b):
return _dot(a, b)


def _tensordot_as_dot(a, b, axes, dot, batched):
"""
Reduces a tensor dot product to a matrix or vector dot product. Based
on code from Tijmen Tieleman's gnumpy
(http://www.cs.toronto.edu/~tijmen/gnumpy.html).

Please see the documentation of tensordot for the meaning of the a, b
and axes arguments.

:param dot: a function that accepts two symbolic variables and computes
the appropriate dot product (e.g. dot, batched_dot)
:type dot: function

:param batched: whether to treat the first axis of a and b as a batch
axis. If so, this axis will be preserved in the output,
allowing this function to be used also for batched
tensor dot products.
:type batched: boolean

:returns: a tensor with shape equal to the concatenation of a's shape
(less any dimensions that were summed over) and b's shape
(less the first dimension and any dimensions that were summed
over).
:rtype: symbolic tensor
"""
a, b = as_tensor_variable(a), as_tensor_variable(b)

if not np.isscalar(axes) and len(axes) != 2:
raise ValueError(
"Axes should be an integer or a "
f"list/tuple of len 2 ({axes} was provided)"
)

# if 'axes' is a number of axes to multiply and sum over (trailing axes
# of a, leading axes of b), we can just reshape and use dot.
elif np.isscalar(axes):
axes = int(axes)

for operand_name, operand in (("a", a), ("b", b)):
if axes > operand.ndim:
raise ValueError(
f"axes can not be larger than the dimension of {operand_name} "
f"({operand_name}.ndim={operand.ndim}, axes={axes})"
)
if batched and axes == operand.ndim:
raise ValueError(
"axes to sum over must not include the batch axis "
f"of {operand_name} ({operand_name}.ndim={operand.ndim}, axes={axes})"
)

batch_axes = 1 if batched else 0
a_outaxes = slice(0, a.ndim - axes)
b_outaxes = slice(batch_axes + axes, b.ndim)
outshape = concatenate([a.shape[a_outaxes], b.shape[b_outaxes]])
outbcast = a.broadcastable[a_outaxes] + b.broadcastable[b_outaxes]
outndim = len(outbcast)

a_shape = [1] * 2
b_shape = [1] * 2

# compute total size of summed axes
for i in range(0, axes):
a_shape[1] *= a.shape[-(i + 1)]
b_shape[0] *= b.shape[batch_axes + i]
# compute total size of other axes
for i in range(0, a.ndim - axes - batch_axes):
a_shape[0] *= a.shape[batch_axes + i]
for i in range(0, b.ndim - axes - batch_axes):
b_shape[1] *= b.shape[-(i + 1)]

if batched:
a_shape.insert(0, a.shape[0])
b_shape.insert(0, b.shape[0])

a_reshaped = a.reshape(a_shape)
b_reshaped = b.reshape(b_shape)

out_reshaped = dot(a_reshaped, b_reshaped)
out = out_reshaped.reshape(outshape, ndim=outndim)
# Make sure the broadcastable pattern of the result is correct,
# since some shape information can be lost in the reshapes.
if out.type.broadcastable != outbcast:
out = specify_broadcastable(
out, *(ax for (ax, b) in enumerate(outbcast) if b)
)
return out

# if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first.
else:
axes = [as_list(axes_) for axes_ in axes]

if len(axes[0]) != len(axes[1]):
raise ValueError("Axes elements must have the same length.")

for i, (operand_name, operand) in enumerate((("a", a), ("b", b))):
if len(axes[i]) > operand.ndim:
raise ValueError(
f"axes[{i}] should be array_like with length less than "
f"the dimensions of {operand_name} ({operand_name}.ndim={operand.ndim}, len(axes[0])={len(axes[i])})."
)
if len(axes[i]) > 0 and np.max(axes[i]) >= operand.ndim:
raise ValueError(
f"axes[{i}] contains dimensions greater than or equal "
f"to {operand_name}.ndim ({operand_name}.ndim={operand.ndim}, max(axes[0])={np.max(np.array(axes[i]))})."
)
if batched and 0 in axes[i]:
raise ValueError(
"axes to sum over must not contain the batch axis "
f"(axes[{i}]={axes[i]})"
)

batch_axes = [0] if batched else []
other_axes = [
[x for x in range(operand.ndim) if x not in axes[i] and x not in batch_axes]
for i, operand in enumerate((a, b))
]

a_shuffled = a.dimshuffle(batch_axes + other_axes[0] + axes[0])
b_shuffled = b.dimshuffle(batch_axes + axes[1] + other_axes[1])

# now that a and b are in the right order, recur with integer axes
return _tensordot_as_dot(
a_shuffled, b_shuffled, len(axes[0]), dot=dot, batched=batched
)


def tensordot(
a: TensorLike, b: TensorLike, axes: int | Sequence[Sequence[int]] = 2
) -> TensorVariable:
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/rewriting/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@
from pytensor.tensor import basic as ptb
from pytensor.tensor.blas import (
Dot22,
_batched_dot,
_dot22,
_dot22scalar,
batched_dot,
gemm_inplace,
gemm_no_inplace,
gemv_inplace,
Expand Down Expand Up @@ -928,7 +928,7 @@ def specialize_matmul_to_batched_dot(fgraph, node):
x = x.reshape((-1, x_shape[-2], x_shape[-1]))
y = y.reshape((-1, y_shape[-2], y_shape[-1]))

new_out = batched_dot(x, y)
new_out = _batched_dot(x, y)

if len(x_shape) > 3:
# And then unravel it
Expand Down
8 changes: 0 additions & 8 deletions pytensor/tensor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,6 @@ def shape_of_variables(
return l


def as_list(x):
"""Convert x to a list if it is an iterable; otherwise, wrap it in a list."""
try:
return list(x)
except TypeError:
return [x]


def import_func_from_string(func_string: str): # -> Optional[Callable]:
func = getattr(np, func_string, None)
if func is not None:
Expand Down
Loading