Skip to content

Commit 0768fc8

Browse files
authored
Add specification for computing a tensor contraction (linalg: tensordot) (#136)
* Add tensordot specification * Update data type requirements * Update dtype requirements * Fix missing header
1 parent 36fd440 commit 0768fc8

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

spec/API_specification/linear_algebra_functions.md

+33
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,39 @@ TODO
332332

333333
TODO
334334

335+
(function-tensordot)=
336+
### tensordot(x1, x2, /, *, axes=2)
337+
338+
Returns a tensor contraction of `x1` and `x2` over specific axes.
339+
340+
#### Parameters
341+
342+
- **x1**: _<array>_
343+
344+
- first input array. Should have a numeric data type.
345+
346+
- **x2**: _<array>_
347+
348+
- second input array. Must be compatible with `x1` (see {ref}`broadcasting`). Should have a numeric data type.
349+
350+
- **axes**: _Union\[ int, Tuple\[ Sequence\[ int ], Sequence\[ int ] ] ]_
351+
352+
- number of axes (dimensions) to contract or explicit sequences of axes (dimensions) for `x1` and `x2`, respectively.
353+
354+
If `axes` is an `int` equal to `N`, then contraction must be performed over the last `N` axes of `x1` and the first `N` axes of `x2` in order. The size of each corresponding axis (dimension) must match. Must be nonnegative.
355+
356+
- If `N` equals `0`, the result is the tensor (outer) product.
357+
- If `N` equals `1`, the result is the tensor dot product.
358+
- If `N` equals `2`, the result is the tensor double contraction (default).
359+
360+
If `axes` is a tuple of two sequences `(x1_axes, x2_axes)`, the first sequence must apply to `x` and the second sequence to `x2`. Both sequences must have the same length. Each axis (dimension) `x1_axes[i]` for `x1` must have the same size as the respective axis (dimension) `x2_axes[i]` for `x2`. Each sequence must consist of unique (nonnegative) integers that specify valid axes for each respective array.
361+
362+
#### Returns
363+
364+
- **out**: _<array>_
365+
366+
- an array containing the tensor contraction whose shape consists of the non-contracted axes (dimensions) of the first array `x1`, followed by the non-contracted axes (dimensions) of the second array `x2`. The returned array must have a data type determined by {ref}`type-promotion`.
367+
335368
(function-trace)=
336369
### trace(x, /, *, axis1=0, axis2=1, offset=0)
337370

0 commit comments

Comments
 (0)