Skip to content

Commit e3fb498

Browse files
lucianopazricardoV94
authored andcommitted
Fix tensordot implementation
1 parent f799219 commit e3fb498

File tree

2 files changed

+214
-47
lines changed

2 files changed

+214
-47
lines changed

pytensor/tensor/math.py

Lines changed: 121 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import builtins
22
import warnings
3-
from typing import TYPE_CHECKING, Optional
3+
from collections.abc import Sequence
4+
from typing import TYPE_CHECKING, Optional, Union
45

56
import numpy as np
7+
from numpy.core.numeric import normalize_axis_tuple
68

79
from pytensor import config, printing
810
from pytensor import scalar as ps
@@ -15,7 +17,9 @@
1517
from pytensor.link.c.type import Generic
1618
from pytensor.misc.safe_asarray import _asarray
1719
from pytensor.printing import pprint
20+
from pytensor.raise_op import Assert
1821
from pytensor.scalar.basic import BinaryScalarOp
22+
from pytensor.tensor import TensorLike
1923
from pytensor.tensor.basic import (
2024
alloc,
2125
arange,
@@ -47,7 +51,11 @@
4751
)
4852
from pytensor.tensor.type_other import NoneConst
4953
from pytensor.tensor.utils import as_list
50-
from pytensor.tensor.variable import TensorConstant, _tensor_py_operators
54+
from pytensor.tensor.variable import (
55+
TensorConstant,
56+
TensorVariable,
57+
_tensor_py_operators,
58+
)
5159

5260

5361
if TYPE_CHECKING:
@@ -2266,57 +2274,47 @@ def _tensordot_as_dot(a, b, axes, dot, batched):
22662274
)
22672275

22682276

2269-
def tensordot(a, b, axes=2):
2277+
def tensordot(
2278+
a: TensorLike, b: TensorLike, axes: Union[int, Sequence[Sequence[int]]] = 2
2279+
) -> TensorVariable:
22702280
"""
2271-
Compute a generalized dot product over provided axes.
2281+
Compute tensor dot product along specified axes.
2282+
2283+
Implementation is mostly taken from numpy version 1.26.0
22722284
2273-
Given two tensors a and b, tensordot computes a generalized dot product over
2274-
the provided axes. PyTensor's implementation reduces all expressions to
2275-
matrix or vector dot products and is based on code from Tijmen Tieleman's
2276-
gnumpy (http://www.cs.toronto.edu/~tijmen/gnumpy.html).
2285+
Given two tensors, `a` and `b`, and a sequence object containing
2286+
two sequence objects, ``(a_axes, b_axes)``, sum the products of
2287+
`a`'s and `b`'s elements (components) over the axes specified by
2288+
``a_axes`` and ``b_axes``. The third argument can be a single non-negative
2289+
integer_like scalar, ``N``; if it is such, then the last ``N`` dimensions
2290+
of `a` and the first ``N`` dimensions of `b` are summed over.
22772291
22782292
Parameters
22792293
----------
2280-
a: symbolic tensor
2281-
The first tensor variable.
2282-
b: symbolic tensor
2283-
The second tensor variable
2284-
axes: int or array-like of length 2
2285-
If an integer, the number of axes to sum over.
2286-
If an array, it must have two array elements containing the axes
2287-
to sum over in each tensor.
2288-
2289-
Note that the default value of 2 is not guaranteed to work
2290-
for all values of a and b, and an error will be raised if
2291-
that is the case. The reason for keeping the default is to
2292-
maintain the same signature as numpy's tensordot function
2293-
(and np.tensordot raises analogous errors for non-compatible
2294-
inputs).
2295-
2296-
If an integer i, it is converted to an array containing
2297-
the last i dimensions of the first tensor and the first
2298-
i dimensions of the second tensor:
2299-
axes = [list(range(a.ndim - i, b.ndim)), list(range(i))]
2300-
2301-
If an array, its two elements must contain compatible axes
2302-
of the two tensors. For example, [[1, 2], [2, 0]] means sum
2303-
over the 2nd and 3rd axes of a and the 3rd and 1st axes of b.
2304-
(Remember axes are zero-indexed!) The 2nd axis of a and the
2305-
3rd axis of b must have the same shape; the same is true for
2306-
the 3rd axis of a and the 1st axis of b.
2294+
a, b : tensor_like
2295+
Tensors to "dot".
2296+
2297+
axes : int or (2,) array_like
2298+
* integer_like
2299+
If an int N, sum over the last N axes of `a` and the first N axes
2300+
of `b` in order. The sizes of the corresponding axes must match.
2301+
* (2,) array_like
2302+
Or, a list of axes to be summed over, first sequence applying to `a`,
2303+
second to `b`. Both elements array_like must be of the same length.
23072304
23082305
Returns
23092306
-------
2310-
symbolic tensor
2311-
A tensor with shape equal to the concatenation of a's shape
2312-
(less any dimensions that were summed over) and b's shape
2313-
(less any dimensions that were summed over).
2307+
output : TensorVariable
2308+
The tensor dot product of the input.
2309+
Its shape will be equal to the concatenation of `a` and `b` shapes
2310+
(ignoring the dimensions that were summed over given in ``a_axes``
2311+
and ``b_axes``)
23142312
23152313
Examples
23162314
--------
23172315
It may be helpful to consider an example to see what tensordot does.
2318-
PyTensor's implementation is identical to NumPy's. Here a has shape (2, 3, 4)
2319-
and b has shape (5, 6, 4, 3). The axes to sum over are [[1, 2], [3, 2]] --
2316+
PyTensor's implementation is identical to NumPy's. Here ``a`` has shape (2, 3, 4)
2317+
and ``b`` has shape (5, 6, 4, 3). The axes to sum over are [[1, 2], [3, 2]] --
23202318
note that a.shape[1] == b.shape[3] and a.shape[2] == b.shape[2]; these axes
23212319
are compatible. The resulting tensor will have shape (2, 5, 6) -- the
23222320
dimensions that are not being summed:
@@ -2347,10 +2345,9 @@ def tensordot(a, b, axes=2):
23472345
true
23482346
23492347
This specific implementation avoids a loop by transposing a and b such that
2350-
the summed axes of a are last and the summed axes of b are first. The
2351-
resulting arrays are reshaped to 2 dimensions (or left as vectors, if
2352-
appropriate) and a matrix or vector dot product is taken. The result is
2353-
reshaped back to the required output dimensions.
2348+
the summed axes of ``a`` are last and the summed axes of ``b`` are first. The
2349+
resulting arrays are reshaped to 2 dimensions and a matrix dot product is taken.
2350+
The result is reshaped back to the required output dimensions.
23542351
23552352
In an extreme case, no axes may be specified. The resulting tensor
23562353
will have shape equal to the concatenation of the shapes of a and b:
@@ -2366,7 +2363,85 @@ def tensordot(a, b, axes=2):
23662363
See the documentation of numpy.tensordot for more examples.
23672364
23682365
"""
2369-
return _tensordot_as_dot(a, b, axes, dot=dot, batched=False)
2366+
try:
2367+
iter(axes)
2368+
except Exception:
2369+
axes_a = list(range(-axes, 0))
2370+
axes_b = list(range(0, axes))
2371+
else:
2372+
axes_a, axes_b = axes
2373+
try:
2374+
na = len(axes_a)
2375+
axes_a = list(axes_a)
2376+
except TypeError:
2377+
axes_a = [axes_a]
2378+
na = 1
2379+
try:
2380+
nb = len(axes_b)
2381+
axes_b = list(axes_b)
2382+
except TypeError:
2383+
axes_b = [axes_b]
2384+
nb = 1
2385+
2386+
a = as_tensor_variable(a)
2387+
b = as_tensor_variable(b)
2388+
runtime_shape_a = a.shape
2389+
bcast_a = a.broadcastable
2390+
static_shape_a = a.type.shape
2391+
ndim_a = a.ndim
2392+
runtime_shape_b = b.shape
2393+
bcast_b = b.broadcastable
2394+
static_shape_b = b.type.shape
2395+
ndim_b = b.ndim
2396+
if na != nb:
2397+
raise ValueError(
2398+
"The number of axes supplied for tensordot must be equal for each tensor. "
2399+
f"Got {na} and {nb} respectively."
2400+
)
2401+
axes_a = list(normalize_axis_tuple(axes_a, ndim_a))
2402+
axes_b = list(normalize_axis_tuple(axes_b, ndim_b))
2403+
must_assert_runtime = False
2404+
for k in range(na):
2405+
ax_a = axes_a[k]
2406+
ax_b = axes_b[k]
2407+
if (bcast_a[ax_a] != bcast_b[ax_b]) or (
2408+
static_shape_a[ax_a] is not None
2409+
and static_shape_b[ax_b] is not None
2410+
and static_shape_a[ax_a] != static_shape_b[ax_b]
2411+
):
2412+
raise ValueError(
2413+
"Input arrays have inconsistent broadcastable pattern or type shape along the axes "
2414+
"that are to be reduced with tensordot."
2415+
)
2416+
elif static_shape_a[ax_a] is None or static_shape_b[ax_b] is None:
2417+
if must_assert_runtime:
2418+
a = Assert(
2419+
"Input array shape along reduced axes of tensordot are not equal"
2420+
)(a, eq(a.shape[ax_a], b.shape[ax_b]))
2421+
must_assert_runtime = True
2422+
2423+
# Move the axes to sum over to the end of "a"
2424+
# and to the front of "b"
2425+
notin = [k for k in range(ndim_a) if k not in axes_a]
2426+
newaxes_a = notin + axes_a
2427+
N2 = 1
2428+
for axis in axes_a:
2429+
N2 *= runtime_shape_a[axis]
2430+
newshape_a = (-1, N2)
2431+
olda = [runtime_shape_a[axis] for axis in notin]
2432+
2433+
notin = [k for k in range(ndim_b) if k not in axes_b]
2434+
newaxes_b = axes_b + notin
2435+
N2 = 1
2436+
for axis in axes_b:
2437+
N2 *= runtime_shape_b[axis]
2438+
newshape_b = (N2, -1)
2439+
oldb = [runtime_shape_b[axis] for axis in notin]
2440+
2441+
at = a.transpose(newaxes_a).reshape(newshape_a)
2442+
bt = b.transpose(newaxes_b).reshape(newshape_b)
2443+
res = _dot(at, bt)
2444+
return res.reshape(olda + oldb)
23702445

23712446

23722447
def outer(x, y):

tests/tensor/test_math.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,20 @@
1818
from pytensor.compile.sharedvalue import shared
1919
from pytensor.configdefaults import config
2020
from pytensor.gradient import NullTypeGradError, grad, numeric_grad
21-
from pytensor.graph.basic import Variable, applys_between
21+
from pytensor.graph.basic import Variable, ancestors, applys_between
2222
from pytensor.graph.fg import FunctionGraph
2323
from pytensor.graph.replace import vectorize_node
2424
from pytensor.link.c.basic import DualLinker
2525
from pytensor.misc.safe_asarray import _asarray
2626
from pytensor.printing import pprint
27+
from pytensor.raise_op import Assert
2728
from pytensor.tensor import blas, blas_c
2829
from pytensor.tensor.basic import (
2930
as_tensor_variable,
3031
constant,
3132
eye,
3233
get_underlying_scalar_constant_value,
34+
ones,
3335
switch,
3436
)
3537
from pytensor.tensor.blas import Dot22
@@ -2208,6 +2210,96 @@ def test_broadcastable2(self):
22082210
zv = f(xv, yv)
22092211
assert np.allclose(np.tensordot(xv, yv, axes=axes), zv)
22102212

2213+
def test_type_shape(self):
2214+
x = ones(shape=(7, 3, 2))
2215+
y = ones(
2216+
shape=(
2217+
10,
2218+
2,
2219+
)
2220+
)
2221+
xv = x.eval()
2222+
yv = y.eval()
2223+
sy = tensor("sy", shape=(None, 2))
2224+
axes = [[-1], [-1]]
2225+
z = tensordot(x, y, axes=axes)
2226+
sz = tensordot(x, sy, axes=axes)
2227+
2228+
assert (
2229+
len(
2230+
{
2231+
node
2232+
for node in ancestors([z])
2233+
if node.owner and isinstance(node.owner.op, Assert)
2234+
}
2235+
)
2236+
== 0
2237+
)
2238+
assert z.type.shape == (7, 3, 10)
2239+
assert z.broadcastable == (False, False, False)
2240+
assert np.allclose(np.tensordot(xv, yv, axes=axes), z.eval())
2241+
2242+
assert (
2243+
len(
2244+
{
2245+
node
2246+
for node in ancestors([sz])
2247+
if node.owner and isinstance(node.owner.op, Assert)
2248+
}
2249+
)
2250+
== 0
2251+
)
2252+
assert sz.type.shape == (7, 3, None)
2253+
assert z.broadcastable == (False, False, False)
2254+
assert np.allclose(np.tensordot(xv, yv, axes=axes), sz.eval({sy: yv}))
2255+
2256+
with pytest.raises(
2257+
ValueError,
2258+
match="Input arrays have inconsistent broadcastable pattern or type shape",
2259+
):
2260+
tensordot(ones(shape=(7, 4)), ones(shape=(7, 4)), axes=1)
2261+
2262+
@pytest.mark.parametrize(
2263+
["axes", "has_assert", "values", "expected_fail"],
2264+
[
2265+
([[1], [2]], False, (np.ones((7, 3, 2)), np.ones((7, 2, 3))), False),
2266+
([[0, 2], [0, 1]], True, (np.ones((7, 3, 2)), np.ones((7, 2, 3))), False),
2267+
([[0], [0]], False, (np.ones((7, 3, 1)), np.ones((100, 1, 3))), True),
2268+
([[1, 2], [1, 2]], True, (np.ones((7, 3, 2)), np.ones((7, 2, 3))), True),
2269+
],
2270+
)
2271+
def test_shape_assert(self, axes, has_assert, values, expected_fail):
2272+
x = tensor(shape=(7, 3, None))
2273+
y = tensor(shape=(None, None, 3))
2274+
2275+
xv, yv = values
2276+
xv = xv.astype(x.dtype)
2277+
yv = yv.astype(x.dtype)
2278+
2279+
z = tensordot(x, y, axes=axes)
2280+
2281+
found_asserts = {
2282+
node
2283+
for node in ancestors([z])
2284+
if node.owner and isinstance(node.owner.op, Assert)
2285+
}
2286+
if has_assert:
2287+
assert found_asserts
2288+
else:
2289+
assert not found_asserts
2290+
if expected_fail:
2291+
if has_assert:
2292+
with pytest.raises(
2293+
AssertionError,
2294+
match="Input array shape along reduced axes of tensordot are not equal",
2295+
):
2296+
z.eval({x: xv, y: yv})
2297+
else:
2298+
with pytest.raises(ValueError):
2299+
z.eval({x: xv, y: yv})
2300+
else:
2301+
assert np.allclose(np.tensordot(xv, yv, axes=axes), z.eval({x: xv, y: yv}))
2302+
22112303

22122304
def test_smallest():
22132305
x = dvector()

0 commit comments

Comments
 (0)