Skip to content

Commit 25d7efb

Browse files
committed
ENH: add ndarray.__matmul__ (also __imatmul__)
Note that NumPy does not support in-place __imatmul__, but PyTorch does. So do we then.
1 parent 53b865a commit 25d7efb

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

torch_np/_ndarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,14 @@ def __ilshift__(self, other):
313313
def __irshift__(self, other):
314314
return _binary_ufuncs.right_shift(self, other, out=self)
315315

316+
__matmul__ = _binary_ufuncs.matmul
317+
318+
def __rmatmul__(self, other):
319+
return _binary_ufuncs.matmul(other, self)
320+
321+
def __imatmul__(self, other):
322+
return _binary_ufuncs.matmul(self, other, out=self)
323+
316324
# unary ops
317325
__invert__ = _unary_ufuncs.invert
318326
__abs__ = _unary_ufuncs.absolute

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2605,7 +2605,6 @@ def test_diagonal_memleak(self):
26052605
if HAS_REFCOUNT:
26062606
assert_(sys.getrefcount(a) < 50)
26072607

2608-
@pytest.mark.xfail(reason="TODO: implement np.dot")
26092608
def test_size_zero_memleak(self):
26102609
# Regression test for issue 9615
26112610
# Exercises a special-case code path for dot products of length
@@ -6071,11 +6070,11 @@ def test_matmul_bool(self):
60716070
assert not np.any(c)
60726071

60736072

6074-
@pytest.mark.xfail(reason='TODO: @')
60756073
class TestMatmulOperator(MatmulCommon):
60766074
import operator
60776075
matmul = operator.matmul
60786076

6077+
@pytest.mark.skip(reason="no __array_priority__")
60796078
def test_array_priority_override(self):
60806079

60816080
class A:
@@ -6093,11 +6092,10 @@ def __rmatmul__(self, other):
60936092
assert_equal(self.matmul(b, a), "A")
60946093

60956094
def test_matmul_raises(self):
6096-
assert_raises(TypeError, self.matmul, np.int8(5), np.int8(5))
6097-
assert_raises(TypeError, self.matmul, np.void(b'abc'), np.void(b'abc'))
6098-
assert_raises(TypeError, self.matmul, np.arange(10), np.void(b'abc'))
6095+
assert_raises((RuntimeError, TypeError), self.matmul, np.int8(5), np.int8(5))
60996096

6100-
@pytest.mark.xfail(reason='TODO @')
6097+
6098+
@pytest.mark.xfail(reason="torch supports inplace matmul, and so do we")
61016099
def test_matmul_inplace():
61026100
# It would be nice to support in-place matmul eventually, but for now
61036101
# we don't have a working implementation, so better just to error out

0 commit comments

Comments
 (0)