Skip to content

Commit d450473

Browse files
committed
Deprecate BLAS batch helper functions
1 parent 65b96c1 commit d450473

File tree

5 files changed

+51
-199
lines changed

5 files changed

+51
-199
lines changed

pytensor/tensor/blas.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,13 @@
7979
import logging
8080
import os
8181
import shlex
82+
import warnings
8283
from pathlib import Path
8384

8485
import numpy as np
86+
from numpy.core.numeric import normalize_axis_tuple
87+
88+
from pytensor.graph import vectorize_graph
8589

8690

8791
try:
@@ -100,9 +104,9 @@
100104
from pytensor.printing import FunctionPrinter, pprint
101105
from pytensor.scalar import bool as bool_t
102106
from pytensor.tensor import basic as ptb
103-
from pytensor.tensor.basic import expand_dims
104107
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
105-
from pytensor.tensor.shape import shape_padright, specify_broadcastable
108+
from pytensor.tensor.math import dot, tensordot
109+
from pytensor.tensor.shape import specify_broadcastable
106110
from pytensor.tensor.type import DenseTensorType, tensor
107111

108112

@@ -1604,8 +1608,8 @@ def grad(self, inp, grads):
16041608
x, y = inp
16051609
(gz,) = grads
16061610

1607-
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
1608-
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
1611+
xgrad = _batched_dot(gz, y.dimshuffle(0, 2, 1))
1612+
ygrad = _batched_dot(x.dimshuffle(0, 2, 1), gz)
16091613

16101614
# If x or y contain broadcastable dimensions but only one of
16111615
# them know that a matching dimensions is broadcastable, the
@@ -1729,31 +1733,22 @@ def batched_dot(a, b):
17291733
dot products in terms of batched matrix-matrix dot products, so
17301734
it may be possible to further optimize for performance.
17311735
"""
1736+
warnings.warn(
1737+
"batched_dot is deprecated. "
1738+
"Use `dot` in conjution with `tensor.vectorize` or `graph.replace.vectorize_graph`",
1739+
FutureWarning,
1740+
)
17321741
a, b = ptb.as_tensor_variable(a), ptb.as_tensor_variable(b)
17331742

17341743
if a.ndim == 0:
17351744
raise TypeError("a must have at least one (batch) axis")
17361745
elif b.ndim == 0:
17371746
raise TypeError("b must have at least one (batch) axis")
1738-
elif a.ndim == 1:
1739-
return shape_padright(a, (b.ndim - 1)) * b
1740-
elif b.ndim == 1:
1741-
return a * shape_padright(b, (a.ndim - 1))
1742-
elif a.ndim > 3 or b.ndim > 3:
1743-
return batched_tensordot(a, b, [[a.ndim - 1], [np.maximum(1, b.ndim - 2)]])
1744-
else:
1745-
# If either a or b is a batched vector, expand dims and later squeeze them
1746-
expanded_axis = []
1747-
if a.ndim == 2:
1748-
a = expand_dims(a, axis=1)
1749-
expanded_axis.append(1)
1750-
if b.ndim == 2:
1751-
b = expand_dims(b, axis=2)
1752-
expanded_axis.append(2)
1753-
out = _batched_dot(a, b)
1754-
if expanded_axis:
1755-
out = out.squeeze(axis=expanded_axis)
1756-
return out
1747+
1748+
core_a = a[0].type()
1749+
core_b = b[0].type()
1750+
core_dot = dot(core_a, core_b)
1751+
return vectorize_graph(core_dot, replace={core_a: a, core_b: b})
17571752

17581753

17591754
def batched_tensordot(x, y, axes=2):
@@ -1791,6 +1786,22 @@ def batched_tensordot(x, y, axes=2):
17911786
reshapes to reduce the tensor dot product to a matrix or vector
17921787
dot product. Finally, it calls batched_dot to compute the result.
17931788
"""
1794-
from pytensor.tensor.math import _tensordot_as_dot
1789+
warnings.warn(
1790+
"batched_tensordot is deprecated. "
1791+
"Use `tensordot` in conjuction with `tensor.vectorize` or `graph.replace.vectorize_graph`",
1792+
FutureWarning,
1793+
)
1794+
1795+
if isinstance(axes, int):
1796+
core_axes = axes
1797+
else:
1798+
# Convert batched axes to core axes
1799+
core_axes_a = [a - 1 for a in normalize_axis_tuple(axes[0], x.type.ndim)]
1800+
core_axes = [a - 1 for a in normalize_axis_tuple(axes[1], y.type.ndim)]
1801+
core_axes = [core_axes_a, core_axes]
1802+
1803+
core_x = x[0].type()
1804+
core_y = y[0].type()
1805+
core_tensordot = tensordot(core_x, core_y, axes=core_axes)
17951806

1796-
return _tensordot_as_dot(x, y, axes, dot=batched_dot, batched=True)
1807+
return vectorize_graph(core_tensordot, replace={core_x: x, core_y: y})

pytensor/tensor/math.py

Lines changed: 1 addition & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
tensor,
4747
uint_dtypes,
4848
)
49-
from pytensor.tensor.utils import as_list, normalize_reduce_axis
49+
from pytensor.tensor.utils import normalize_reduce_axis
5050
from pytensor.tensor.variable import (
5151
TensorVariable,
5252
_tensor_py_operators,
@@ -1919,133 +1919,6 @@ def dense_dot(a, b):
19191919
return _dot(a, b)
19201920

19211921

1922-
def _tensordot_as_dot(a, b, axes, dot, batched):
1923-
"""
1924-
Reduces a tensor dot product to a matrix or vector dot product. Based
1925-
on code from Tijmen Tieleman's gnumpy
1926-
(http://www.cs.toronto.edu/~tijmen/gnumpy.html).
1927-
1928-
Please see the documentation of tensordot for the meaning of the a, b
1929-
and axes arguments.
1930-
1931-
:param dot: a function that accepts two symbolic variables and computes
1932-
the appropriate dot product (e.g. dot, batched_dot)
1933-
:type dot: function
1934-
1935-
:param batched: whether to treat the first axis of a and b as a batch
1936-
axis. If so, this axis will be preserved in the output,
1937-
allowing this function to be used also for batched
1938-
tensor dot products.
1939-
:type batched: boolean
1940-
1941-
:returns: a tensor with shape equal to the concatenation of a's shape
1942-
(less any dimensions that were summed over) and b's shape
1943-
(less the first dimension and any dimensions that were summed
1944-
over).
1945-
:rtype: symbolic tensor
1946-
"""
1947-
a, b = as_tensor_variable(a), as_tensor_variable(b)
1948-
1949-
if not np.isscalar(axes) and len(axes) != 2:
1950-
raise ValueError(
1951-
"Axes should be an integer or a "
1952-
f"list/tuple of len 2 ({axes} was provided)"
1953-
)
1954-
1955-
# if 'axes' is a number of axes to multiply and sum over (trailing axes
1956-
# of a, leading axes of b), we can just reshape and use dot.
1957-
elif np.isscalar(axes):
1958-
axes = int(axes)
1959-
1960-
for operand_name, operand in (("a", a), ("b", b)):
1961-
if axes > operand.ndim:
1962-
raise ValueError(
1963-
f"axes can not be larger than the dimension of {operand_name} "
1964-
f"({operand_name}.ndim={operand.ndim}, axes={axes})"
1965-
)
1966-
if batched and axes == operand.ndim:
1967-
raise ValueError(
1968-
"axes to sum over must not include the batch axis "
1969-
f"of {operand_name} ({operand_name}.ndim={operand.ndim}, axes={axes})"
1970-
)
1971-
1972-
batch_axes = 1 if batched else 0
1973-
a_outaxes = slice(0, a.ndim - axes)
1974-
b_outaxes = slice(batch_axes + axes, b.ndim)
1975-
outshape = concatenate([a.shape[a_outaxes], b.shape[b_outaxes]])
1976-
outbcast = a.broadcastable[a_outaxes] + b.broadcastable[b_outaxes]
1977-
outndim = len(outbcast)
1978-
1979-
a_shape = [1] * 2
1980-
b_shape = [1] * 2
1981-
1982-
# compute total size of summed axes
1983-
for i in range(0, axes):
1984-
a_shape[1] *= a.shape[-(i + 1)]
1985-
b_shape[0] *= b.shape[batch_axes + i]
1986-
# compute total size of other axes
1987-
for i in range(0, a.ndim - axes - batch_axes):
1988-
a_shape[0] *= a.shape[batch_axes + i]
1989-
for i in range(0, b.ndim - axes - batch_axes):
1990-
b_shape[1] *= b.shape[-(i + 1)]
1991-
1992-
if batched:
1993-
a_shape.insert(0, a.shape[0])
1994-
b_shape.insert(0, b.shape[0])
1995-
1996-
a_reshaped = a.reshape(a_shape)
1997-
b_reshaped = b.reshape(b_shape)
1998-
1999-
out_reshaped = dot(a_reshaped, b_reshaped)
2000-
out = out_reshaped.reshape(outshape, ndim=outndim)
2001-
# Make sure the broadcastable pattern of the result is correct,
2002-
# since some shape information can be lost in the reshapes.
2003-
if out.type.broadcastable != outbcast:
2004-
out = specify_broadcastable(
2005-
out, *(ax for (ax, b) in enumerate(outbcast) if b)
2006-
)
2007-
return out
2008-
2009-
# if 'axes' is a list, transpose a and b such that the summed axes of a
2010-
# are last and the summed axes of b are first.
2011-
else:
2012-
axes = [as_list(axes_) for axes_ in axes]
2013-
2014-
if len(axes[0]) != len(axes[1]):
2015-
raise ValueError("Axes elements must have the same length.")
2016-
2017-
for i, (operand_name, operand) in enumerate((("a", a), ("b", b))):
2018-
if len(axes[i]) > operand.ndim:
2019-
raise ValueError(
2020-
f"axes[{i}] should be array_like with length less than "
2021-
f"the dimensions of {operand_name} ({operand_name}.ndim={operand.ndim}, len(axes[0])={len(axes[i])})."
2022-
)
2023-
if len(axes[i]) > 0 and np.max(axes[i]) >= operand.ndim:
2024-
raise ValueError(
2025-
f"axes[{i}] contains dimensions greater than or equal "
2026-
f"to {operand_name}.ndim ({operand_name}.ndim={operand.ndim}, max(axes[0])={np.max(np.array(axes[i]))})."
2027-
)
2028-
if batched and 0 in axes[i]:
2029-
raise ValueError(
2030-
"axes to sum over must not contain the batch axis "
2031-
f"(axes[{i}]={axes[i]})"
2032-
)
2033-
2034-
batch_axes = [0] if batched else []
2035-
other_axes = [
2036-
[x for x in range(operand.ndim) if x not in axes[i] and x not in batch_axes]
2037-
for i, operand in enumerate((a, b))
2038-
]
2039-
2040-
a_shuffled = a.dimshuffle(batch_axes + other_axes[0] + axes[0])
2041-
b_shuffled = b.dimshuffle(batch_axes + axes[1] + other_axes[1])
2042-
2043-
# now that a and b are in the right order, recur with integer axes
2044-
return _tensordot_as_dot(
2045-
a_shuffled, b_shuffled, len(axes[0]), dot=dot, batched=batched
2046-
)
2047-
2048-
20491922
def tensordot(
20501923
a: TensorLike, b: TensorLike, axes: int | Sequence[Sequence[int]] = 2
20511924
) -> TensorVariable:

pytensor/tensor/rewriting/blas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@
8484
from pytensor.tensor import basic as ptb
8585
from pytensor.tensor.blas import (
8686
Dot22,
87+
_batched_dot,
8788
_dot22,
8889
_dot22scalar,
89-
batched_dot,
9090
gemm_inplace,
9191
gemm_no_inplace,
9292
gemv_inplace,
@@ -926,7 +926,7 @@ def specialize_matmul_to_batched_dot(fgraph, node):
926926
x = x.reshape((-1, x_shape[-2], x_shape[-1]))
927927
y = y.reshape((-1, y_shape[-2], y_shape[-1]))
928928

929-
new_out = batched_dot(x, y)
929+
new_out = _batched_dot(x, y)
930930

931931
if len(x_shape) > 3:
932932
# And then unravel it

pytensor/tensor/utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,6 @@ def shape_of_variables(
107107
return l
108108

109109

110-
def as_list(x):
111-
"""Convert x to a list if it is an iterable; otherwise, wrap it in a list."""
112-
try:
113-
return list(x)
114-
except TypeError:
115-
return [x]
116-
117-
118110
def import_func_from_string(func_string: str): # -> Optional[Callable]:
119111
func = getattr(np, func_string, None)
120112
if func is not None:

0 commit comments

Comments
 (0)