|
46 | 46 | tensor,
|
47 | 47 | uint_dtypes,
|
48 | 48 | )
|
49 |
| -from pytensor.tensor.utils import as_list, normalize_reduce_axis |
| 49 | +from pytensor.tensor.utils import normalize_reduce_axis |
50 | 50 | from pytensor.tensor.variable import (
|
51 | 51 | TensorVariable,
|
52 | 52 | _tensor_py_operators,
|
@@ -1919,133 +1919,6 @@ def dense_dot(a, b):
|
1919 | 1919 | return _dot(a, b)
|
1920 | 1920 |
|
1921 | 1921 |
|
1922 |
| -def _tensordot_as_dot(a, b, axes, dot, batched): |
1923 |
| - """ |
1924 |
| - Reduces a tensor dot product to a matrix or vector dot product. Based |
1925 |
| - on code from Tijmen Tieleman's gnumpy |
1926 |
| - (http://www.cs.toronto.edu/~tijmen/gnumpy.html). |
1927 |
| -
|
1928 |
| - Please see the documentation of tensordot for the meaning of the a, b |
1929 |
| - and axes arguments. |
1930 |
| -
|
1931 |
| - :param dot: a function that accepts two symbolic variables and computes |
1932 |
| - the appropriate dot product (e.g. dot, batched_dot) |
1933 |
| - :type dot: function |
1934 |
| -
|
1935 |
| - :param batched: whether to treat the first axis of a and b as a batch |
1936 |
| - axis. If so, this axis will be preserved in the output, |
1937 |
| - allowing this function to be used also for batched |
1938 |
| - tensor dot products. |
1939 |
| - :type batched: boolean |
1940 |
| -
|
1941 |
| - :returns: a tensor with shape equal to the concatenation of a's shape |
1942 |
| - (less any dimensions that were summed over) and b's shape |
1943 |
| - (less the first dimension and any dimensions that were summed |
1944 |
| - over). |
1945 |
| - :rtype: symbolic tensor |
1946 |
| - """ |
1947 |
| - a, b = as_tensor_variable(a), as_tensor_variable(b) |
1948 |
| - |
1949 |
| - if not np.isscalar(axes) and len(axes) != 2: |
1950 |
| - raise ValueError( |
1951 |
| - "Axes should be an integer or a " |
1952 |
| - f"list/tuple of len 2 ({axes} was provided)" |
1953 |
| - ) |
1954 |
| - |
1955 |
| - # if 'axes' is a number of axes to multiply and sum over (trailing axes |
1956 |
| - # of a, leading axes of b), we can just reshape and use dot. |
1957 |
| - elif np.isscalar(axes): |
1958 |
| - axes = int(axes) |
1959 |
| - |
1960 |
| - for operand_name, operand in (("a", a), ("b", b)): |
1961 |
| - if axes > operand.ndim: |
1962 |
| - raise ValueError( |
1963 |
| - f"axes can not be larger than the dimension of {operand_name} " |
1964 |
| - f"({operand_name}.ndim={operand.ndim}, axes={axes})" |
1965 |
| - ) |
1966 |
| - if batched and axes == operand.ndim: |
1967 |
| - raise ValueError( |
1968 |
| - "axes to sum over must not include the batch axis " |
1969 |
| - f"of {operand_name} ({operand_name}.ndim={operand.ndim}, axes={axes})" |
1970 |
| - ) |
1971 |
| - |
1972 |
| - batch_axes = 1 if batched else 0 |
1973 |
| - a_outaxes = slice(0, a.ndim - axes) |
1974 |
| - b_outaxes = slice(batch_axes + axes, b.ndim) |
1975 |
| - outshape = concatenate([a.shape[a_outaxes], b.shape[b_outaxes]]) |
1976 |
| - outbcast = a.broadcastable[a_outaxes] + b.broadcastable[b_outaxes] |
1977 |
| - outndim = len(outbcast) |
1978 |
| - |
1979 |
| - a_shape = [1] * 2 |
1980 |
| - b_shape = [1] * 2 |
1981 |
| - |
1982 |
| - # compute total size of summed axes |
1983 |
| - for i in range(0, axes): |
1984 |
| - a_shape[1] *= a.shape[-(i + 1)] |
1985 |
| - b_shape[0] *= b.shape[batch_axes + i] |
1986 |
| - # compute total size of other axes |
1987 |
| - for i in range(0, a.ndim - axes - batch_axes): |
1988 |
| - a_shape[0] *= a.shape[batch_axes + i] |
1989 |
| - for i in range(0, b.ndim - axes - batch_axes): |
1990 |
| - b_shape[1] *= b.shape[-(i + 1)] |
1991 |
| - |
1992 |
| - if batched: |
1993 |
| - a_shape.insert(0, a.shape[0]) |
1994 |
| - b_shape.insert(0, b.shape[0]) |
1995 |
| - |
1996 |
| - a_reshaped = a.reshape(a_shape) |
1997 |
| - b_reshaped = b.reshape(b_shape) |
1998 |
| - |
1999 |
| - out_reshaped = dot(a_reshaped, b_reshaped) |
2000 |
| - out = out_reshaped.reshape(outshape, ndim=outndim) |
2001 |
| - # Make sure the broadcastable pattern of the result is correct, |
2002 |
| - # since some shape information can be lost in the reshapes. |
2003 |
| - if out.type.broadcastable != outbcast: |
2004 |
| - out = specify_broadcastable( |
2005 |
| - out, *(ax for (ax, b) in enumerate(outbcast) if b) |
2006 |
| - ) |
2007 |
| - return out |
2008 |
| - |
2009 |
| - # if 'axes' is a list, transpose a and b such that the summed axes of a |
2010 |
| - # are last and the summed axes of b are first. |
2011 |
| - else: |
2012 |
| - axes = [as_list(axes_) for axes_ in axes] |
2013 |
| - |
2014 |
| - if len(axes[0]) != len(axes[1]): |
2015 |
| - raise ValueError("Axes elements must have the same length.") |
2016 |
| - |
2017 |
| - for i, (operand_name, operand) in enumerate((("a", a), ("b", b))): |
2018 |
| - if len(axes[i]) > operand.ndim: |
2019 |
| - raise ValueError( |
2020 |
| - f"axes[{i}] should be array_like with length less than " |
2021 |
| - f"the dimensions of {operand_name} ({operand_name}.ndim={operand.ndim}, len(axes[0])={len(axes[i])})." |
2022 |
| - ) |
2023 |
| - if len(axes[i]) > 0 and np.max(axes[i]) >= operand.ndim: |
2024 |
| - raise ValueError( |
2025 |
| - f"axes[{i}] contains dimensions greater than or equal " |
2026 |
| - f"to {operand_name}.ndim ({operand_name}.ndim={operand.ndim}, max(axes[0])={np.max(np.array(axes[i]))})." |
2027 |
| - ) |
2028 |
| - if batched and 0 in axes[i]: |
2029 |
| - raise ValueError( |
2030 |
| - "axes to sum over must not contain the batch axis " |
2031 |
| - f"(axes[{i}]={axes[i]})" |
2032 |
| - ) |
2033 |
| - |
2034 |
| - batch_axes = [0] if batched else [] |
2035 |
| - other_axes = [ |
2036 |
| - [x for x in range(operand.ndim) if x not in axes[i] and x not in batch_axes] |
2037 |
| - for i, operand in enumerate((a, b)) |
2038 |
| - ] |
2039 |
| - |
2040 |
| - a_shuffled = a.dimshuffle(batch_axes + other_axes[0] + axes[0]) |
2041 |
| - b_shuffled = b.dimshuffle(batch_axes + axes[1] + other_axes[1]) |
2042 |
| - |
2043 |
| - # now that a and b are in the right order, recur with integer axes |
2044 |
| - return _tensordot_as_dot( |
2045 |
| - a_shuffled, b_shuffled, len(axes[0]), dot=dot, batched=batched |
2046 |
| - ) |
2047 |
| - |
2048 |
| - |
2049 | 1922 | def tensordot(
|
2050 | 1923 | a: TensorLike, b: TensorLike, axes: int | Sequence[Sequence[int]] = 2
|
2051 | 1924 | ) -> TensorVariable:
|
|
0 commit comments