Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 9b3e5ec

Browse files
committedMar 23, 2023
BUG: fix up matmul
1 parent 25d7efb commit 9b3e5ec

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed
 

‎torch_np/_binary_ufuncs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def matmul(
6868
extobj=None,
6969
axes=None,
7070
axis=None,
71-
) -> OutArray:
71+
):
7272
tensors = _helpers.ufunc_preprocess(
7373
(x1, x2), out, True, casting, order, dtype, subok, signature, extobj
7474
)
@@ -77,7 +77,7 @@ def matmul(
7777

7878
# NB: do not broadcast input tensors against the out=... array
7979
result = _binary_ufuncs.matmul(*tensors)
80-
return result, out
80+
return _helpers.result_or_out(result, out)
8181

8282

8383
#

‎torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5979,7 +5979,7 @@ def test_out_contiguous(self):
59795979
# test out non-contiguous
59805980
out = np.ones((5, 2, 2), dtype=float)
59815981
c = self.matmul(a, b, out=out[..., 0])
5982-
assert c._tensor._base is out._tensor # FIXME: self.tensor (no underscore)
5982+
assert c.tensor._base is out.tensor
59835983
assert_array_equal(c, tgt)
59845984
c = self.matmul(a, v, out=out[:, 0, 0])
59855985
assert_array_equal(c, tgt_mv)

0 commit comments

Comments
 (0)
Please sign in to comment.