1
1
import builtins
2
2
import warnings
3
- from typing import TYPE_CHECKING , Optional
3
+ from collections .abc import Sequence
4
+ from typing import TYPE_CHECKING , Optional , Union
4
5
5
6
import numpy as np
7
+ from numpy .core .numeric import normalize_axis_tuple
6
8
7
9
from pytensor import config , printing
8
10
from pytensor import scalar as ps
15
17
from pytensor .link .c .type import Generic
16
18
from pytensor .misc .safe_asarray import _asarray
17
19
from pytensor .printing import pprint
20
+ from pytensor .raise_op import Assert
18
21
from pytensor .scalar .basic import BinaryScalarOp
22
+ from pytensor .tensor import TensorLike
19
23
from pytensor .tensor .basic import (
20
24
alloc ,
21
25
arange ,
47
51
)
48
52
from pytensor .tensor .type_other import NoneConst
49
53
from pytensor .tensor .utils import as_list
50
- from pytensor .tensor .variable import TensorConstant , _tensor_py_operators
54
+ from pytensor .tensor .variable import (
55
+ TensorConstant ,
56
+ TensorVariable ,
57
+ _tensor_py_operators ,
58
+ )
51
59
52
60
53
61
if TYPE_CHECKING :
@@ -2266,57 +2274,47 @@ def _tensordot_as_dot(a, b, axes, dot, batched):
2266
2274
)
2267
2275
2268
2276
2269
- def tensordot (a , b , axes = 2 ):
2277
+ def tensordot (
2278
+ a : TensorLike , b : TensorLike , axes : Union [int , Sequence [Sequence [int ]]] = 2
2279
+ ) -> TensorVariable :
2270
2280
"""
2271
- Compute a generalized dot product over provided axes.
2281
+ Compute tensor dot product along specified axes.
2282
+
2283
+ Implementation is mostly taken from numpy version 1.26.0
2272
2284
2273
- Given two tensors a and b, tensordot computes a generalized dot product over
2274
- the provided axes. PyTensor's implementation reduces all expressions to
2275
- matrix or vector dot products and is based on code from Tijmen Tieleman's
2276
- gnumpy (http://www.cs.toronto.edu/~tijmen/gnumpy.html).
2285
+ Given two tensors, `a` and `b`, and a sequence object containing
2286
+ two sequence objects, ``(a_axes, b_axes)``, sum the products of
2287
+ `a`'s and `b`'s elements (components) over the axes specified by
2288
+ ``a_axes`` and ``b_axes``. The third argument can be a single non-negative
2289
+ integer_like scalar, ``N``; if it is such, then the last ``N`` dimensions
2290
+ of `a` and the first ``N`` dimensions of `b` are summed over.
2277
2291
2278
2292
Parameters
2279
2293
----------
2280
- a: symbolic tensor
2281
- The first tensor variable.
2282
- b: symbolic tensor
2283
- The second tensor variable
2284
- axes: int or array-like of length 2
2285
- If an integer, the number of axes to sum over.
2286
- If an array, it must have two array elements containing the axes
2287
- to sum over in each tensor.
2288
-
2289
- Note that the default value of 2 is not guaranteed to work
2290
- for all values of a and b, and an error will be raised if
2291
- that is the case. The reason for keeping the default is to
2292
- maintain the same signature as numpy's tensordot function
2293
- (and np.tensordot raises analogous errors for non-compatible
2294
- inputs).
2295
-
2296
- If an integer i, it is converted to an array containing
2297
- the last i dimensions of the first tensor and the first
2298
- i dimensions of the second tensor:
2299
- axes = [list(range(a.ndim - i, b.ndim)), list(range(i))]
2300
-
2301
- If an array, its two elements must contain compatible axes
2302
- of the two tensors. For example, [[1, 2], [2, 0]] means sum
2303
- over the 2nd and 3rd axes of a and the 3rd and 1st axes of b.
2304
- (Remember axes are zero-indexed!) The 2nd axis of a and the
2305
- 3rd axis of b must have the same shape; the same is true for
2306
- the 3rd axis of a and the 1st axis of b.
2294
+ a, b : tensor_like
2295
+ Tensors to "dot".
2296
+
2297
+ axes : int or (2,) array_like
2298
+ * integer_like
2299
+ If an int N, sum over the last N axes of `a` and the first N axes
2300
+ of `b` in order. The sizes of the corresponding axes must match.
2301
+ * (2,) array_like
2302
+ Or, a list of axes to be summed over, first sequence applying to `a`,
2303
+ second to `b`. Both elements array_like must be of the same length.
2307
2304
2308
2305
Returns
2309
2306
-------
2310
- symbolic tensor
2311
- A tensor with shape equal to the concatenation of a's shape
2312
- (less any dimensions that were summed over) and b's shape
2313
- (less any dimensions that were summed over).
2307
+ output : TensorVariable
2308
+ The tensor dot product of the input.
2309
+ Its shape will be equal to the concatenation of `a` and `b` shapes
2310
+ (ignoring the dimensions that were summed over given in ``a_axes``
2311
+ and ``b_axes``)
2314
2312
2315
2313
Examples
2316
2314
--------
2317
2315
It may be helpful to consider an example to see what tensordot does.
2318
- PyTensor's implementation is identical to NumPy's. Here a has shape (2, 3, 4)
2319
- and b has shape (5, 6, 4, 3). The axes to sum over are [[1, 2], [3, 2]] --
2316
+ PyTensor's implementation is identical to NumPy's. Here ``a`` has shape (2, 3, 4)
2317
+ and ``b`` has shape (5, 6, 4, 3). The axes to sum over are [[1, 2], [3, 2]] --
2320
2318
note that a.shape[1] == b.shape[3] and a.shape[2] == b.shape[2]; these axes
2321
2319
are compatible. The resulting tensor will have shape (2, 5, 6) -- the
2322
2320
dimensions that are not being summed:
@@ -2347,10 +2345,9 @@ def tensordot(a, b, axes=2):
2347
2345
true
2348
2346
2349
2347
This specific implementation avoids a loop by transposing a and b such that
2350
- the summed axes of a are last and the summed axes of b are first. The
2351
- resulting arrays are reshaped to 2 dimensions (or left as vectors, if
2352
- appropriate) and a matrix or vector dot product is taken. The result is
2353
- reshaped back to the required output dimensions.
2348
+ the summed axes of ``a`` are last and the summed axes of ``b`` are first. The
2349
+ resulting arrays are reshaped to 2 dimensions and a matrix dot product is taken.
2350
+ The result is reshaped back to the required output dimensions.
2354
2351
2355
2352
In an extreme case, no axes may be specified. The resulting tensor
2356
2353
will have shape equal to the concatenation of the shapes of a and b:
@@ -2366,7 +2363,85 @@ def tensordot(a, b, axes=2):
2366
2363
See the documentation of numpy.tensordot for more examples.
2367
2364
2368
2365
"""
2369
- return _tensordot_as_dot (a , b , axes , dot = dot , batched = False )
2366
+ try :
2367
+ iter (axes )
2368
+ except Exception :
2369
+ axes_a = list (range (- axes , 0 ))
2370
+ axes_b = list (range (0 , axes ))
2371
+ else :
2372
+ axes_a , axes_b = axes
2373
+ try :
2374
+ na = len (axes_a )
2375
+ axes_a = list (axes_a )
2376
+ except TypeError :
2377
+ axes_a = [axes_a ]
2378
+ na = 1
2379
+ try :
2380
+ nb = len (axes_b )
2381
+ axes_b = list (axes_b )
2382
+ except TypeError :
2383
+ axes_b = [axes_b ]
2384
+ nb = 1
2385
+
2386
+ a = as_tensor_variable (a )
2387
+ b = as_tensor_variable (b )
2388
+ runtime_shape_a = a .shape
2389
+ bcast_a = a .broadcastable
2390
+ static_shape_a = a .type .shape
2391
+ ndim_a = a .ndim
2392
+ runtime_shape_b = b .shape
2393
+ bcast_b = b .broadcastable
2394
+ static_shape_b = b .type .shape
2395
+ ndim_b = b .ndim
2396
+ if na != nb :
2397
+ raise ValueError (
2398
+ "The number of axes supplied for tensordot must be equal for each tensor. "
2399
+ f"Got { na } and { nb } respectively."
2400
+ )
2401
+ axes_a = list (normalize_axis_tuple (axes_a , ndim_a ))
2402
+ axes_b = list (normalize_axis_tuple (axes_b , ndim_b ))
2403
+ must_assert_runtime = False
2404
+ for k in range (na ):
2405
+ ax_a = axes_a [k ]
2406
+ ax_b = axes_b [k ]
2407
+ if (bcast_a [ax_a ] != bcast_b [ax_b ]) or (
2408
+ static_shape_a [ax_a ] is not None
2409
+ and static_shape_b [ax_b ] is not None
2410
+ and static_shape_a [ax_a ] != static_shape_b [ax_b ]
2411
+ ):
2412
+ raise ValueError (
2413
+ "Input arrays have inconsistent broadcastable pattern or type shape along the axes "
2414
+ "that are to be reduced with tensordot."
2415
+ )
2416
+ elif static_shape_a [ax_a ] is None or static_shape_b [ax_b ] is None :
2417
+ if must_assert_runtime :
2418
+ a = Assert (
2419
+ "Input array shape along reduced axes of tensordot are not equal"
2420
+ )(a , eq (a .shape [ax_a ], b .shape [ax_b ]))
2421
+ must_assert_runtime = True
2422
+
2423
+ # Move the axes to sum over to the end of "a"
2424
+ # and to the front of "b"
2425
+ notin = [k for k in range (ndim_a ) if k not in axes_a ]
2426
+ newaxes_a = notin + axes_a
2427
+ N2 = 1
2428
+ for axis in axes_a :
2429
+ N2 *= runtime_shape_a [axis ]
2430
+ newshape_a = (- 1 , N2 )
2431
+ olda = [runtime_shape_a [axis ] for axis in notin ]
2432
+
2433
+ notin = [k for k in range (ndim_b ) if k not in axes_b ]
2434
+ newaxes_b = axes_b + notin
2435
+ N2 = 1
2436
+ for axis in axes_b :
2437
+ N2 *= runtime_shape_b [axis ]
2438
+ newshape_b = (N2 , - 1 )
2439
+ oldb = [runtime_shape_b [axis ] for axis in notin ]
2440
+
2441
+ at = a .transpose (newaxes_a ).reshape (newshape_a )
2442
+ bt = b .transpose (newaxes_b ).reshape (newshape_b )
2443
+ res = _dot (at , bt )
2444
+ return res .reshape (olda + oldb )
2370
2445
2371
2446
2372
2447
def outer (x , y ):
0 commit comments