From 6252a44ae0a811e4eaf070897634d0de2522b8e2 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Mon, 12 Dec 2022 01:35:31 -0800 Subject: [PATCH] Add complex number support to `tensordot` --- .../array_api/linear_algebra_functions.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/spec/API_specification/array_api/linear_algebra_functions.py b/spec/API_specification/array_api/linear_algebra_functions.py index f8f15e5b0..35367b3e7 100644 --- a/spec/API_specification/array_api/linear_algebra_functions.py +++ b/spec/API_specification/array_api/linear_algebra_functions.py @@ -56,12 +56,15 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], """ Returns a tensor contraction of ``x1`` and ``x2`` over specific axes. + .. note:: + The ``tensordot`` function corresponds to the generalized matrix product. + Parameters ---------- x1: array - first input array. Should have a real-valued data type. + first input array. Should have a numeric data type. x2: array - second input array. Should have a real-valued data type. Corresponding contracted axes of ``x1`` and ``x2`` must be equal. + second input array. Should have a numeric data type. Corresponding contracted axes of ``x1`` and ``x2`` must be equal. .. note:: Contracted axes (dimensions) must not be broadcasted. @@ -77,6 +80,10 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], If ``axes`` is a tuple of two sequences ``(x1_axes, x2_axes)``, the first sequence must apply to ``x`` and the second sequence to ``x2``. Both sequences must have the same length. Each axis (dimension) ``x1_axes[i]`` for ``x1`` must have the same size as the respective axis (dimension) ``x2_axes[i]`` for ``x2``. Each sequence must consist of unique (nonnegative) integers that specify valid axes for each respective array. + + .. note:: + If either ``x1`` or ``x2`` has a complex floating-point data type, neither argument must be complex-conjugated or transposed. If conjugation and/or transposition is desired, these operations should be explicitly performed prior to computing the generalized matrix product. + Returns ------- out: array