diff --git a/torch_np/_binary_ufuncs.py b/torch_np/_binary_ufuncs.py index f29604db..e34c8f44 100644 --- a/torch_np/_binary_ufuncs.py +++ b/torch_np/_binary_ufuncs.py @@ -1,5 +1,7 @@ from typing import Optional +import torch + from . import _helpers from ._detail import _binary_ufuncs from ._normalizations import ( @@ -12,7 +14,9 @@ ) __all__ = [ - name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch" + name + for name in dir(_binary_ufuncs) + if not name.startswith("_") and name not in ["torch", "matmul"] ] @@ -40,12 +44,49 @@ def wrapped( tensors = _helpers.ufunc_preprocess( (x1, x2), out, where, casting, order, dtype, subok, signature, extobj ) + # now broadcast input tensors against the out=... array + if out is not None: + # XXX: need to filter out noop broadcasts if t.shape == out.shape? + shape = out.shape + tensors = tuple(torch.broadcast_to(t, shape) for t in tensors) + result = torch_func(*tensors) return result, out return wrapped +# +# matmul is special in that its `out=...` array does not broadcast x1 and x2. +# E.g. consider x1.shape = (5, 2) and x2.shape = (2, 3). Then `out.shape` is (5, 3). +# +@normalizer +def matmul( + x1: ArrayLike, + x2: ArrayLike, + /, + out: Optional[NDArray] = None, + *, + casting="same_kind", + order="K", + dtype: DTypeLike = None, + subok: SubokLike = False, + signature=None, + extobj=None, + axes=None, + axis=None, +) -> OutArray: + tensors = _helpers.ufunc_preprocess( + (x1, x2), out, True, casting, order, dtype, subok, signature, extobj + ) + if axis is not None or axes is not None: + raise NotImplementedError + + # NB: do not broadcast input tensors against the out=... array + result = _binary_ufuncs.matmul(*tensors) + return result, out + + # # For each torch ufunc implementation, decorate and attach the decorated name # to this module. Its contents is then exported to the public namespace in __init__.py @@ -111,4 +152,4 @@ def modf(x, /, *args, **kwds): return rem, quot -__all__ = __all__ + ["divmod", "modf"] +__all__ = __all__ + ["divmod", "modf", "matmul"] diff --git a/torch_np/_detail/_binary_ufuncs.py b/torch_np/_detail/_binary_ufuncs.py index 122871c4..c6f66421 100644 --- a/torch_np/_detail/_binary_ufuncs.py +++ b/torch_np/_detail/_binary_ufuncs.py @@ -45,9 +45,26 @@ # work around torch limitations w.r.t. numpy def matmul(x, y): - # work around RuntimeError: expected scalar type Int but found Double + # work around: + # - RuntimeError: expected scalar type Int but found Double + # - RuntimeError: "addmm_impl_cpu_" not implemented for 'Bool' + # - RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype)) - x = _util.cast_if_needed(x, dtype) - y = _util.cast_if_needed(y, dtype) + is_bool = dtype == torch.bool + is_half = dtype == torch.float16 + + work_dtype = dtype + if is_bool: + work_dtype = torch.uint8 + if is_half: + work_dtype = torch.float32 + + x = _util.cast_if_needed(x, work_dtype) + y = _util.cast_if_needed(y, work_dtype) + result = torch.matmul(x, y) + + if work_dtype != dtype: + result = result.to(dtype) + return result diff --git a/torch_np/_helpers.py b/torch_np/_helpers.py index 2e405089..e7a0474f 100644 --- a/torch_np/_helpers.py +++ b/torch_np/_helpers.py @@ -27,13 +27,6 @@ def ufunc_preprocess( if out_dtype: tensors = _util.typecast_tensors(tensors, out_dtype, casting) - - # now broadcast input tensors against the out=... array - if out is not None: - # XXX: need to filter out noop broadcasts if t.shape == out.shape? - shape = out.shape - tensors = tuple(torch.broadcast_to(t, shape) for t in tensors) - return tensors diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 8ad6b645..cf20027c 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -319,6 +319,14 @@ def __ilshift__(self, other): def __irshift__(self, other): return _binary_ufuncs.right_shift(self, other, out=self) + __matmul__ = _binary_ufuncs.matmul + + def __rmatmul__(self, other): + return _binary_ufuncs.matmul(other, self) + + def __imatmul__(self, other): + return _binary_ufuncs.matmul(self, other, out=self) + # unary ops __invert__ = _unary_ufuncs.invert __abs__ = _unary_ufuncs.absolute diff --git a/torch_np/_unary_ufuncs.py b/torch_np/_unary_ufuncs.py index 1a322482..23296c7e 100644 --- a/torch_np/_unary_ufuncs.py +++ b/torch_np/_unary_ufuncs.py @@ -4,6 +4,8 @@ from typing import Optional +import torch + from . import _helpers from ._detail import _unary_ufuncs from ._normalizations import ( @@ -43,6 +45,11 @@ def wrapped( tensors = _helpers.ufunc_preprocess( (x,), out, where, casting, order, dtype, subok, signature, extobj ) + # now broadcast the input tensor against the out=... array + if out is not None: + # XXX: need to filter out noop broadcasts if t.shape == out.shape? + shape = out.shape + tensors = tuple(torch.broadcast_to(t, shape) for t in tensors) result = torch_func(*tensors) return result, out diff --git a/torch_np/tests/numpy_tests/core/test_multiarray.py b/torch_np/tests/numpy_tests/core/test_multiarray.py index 49e8f01d..14303224 100644 --- a/torch_np/tests/numpy_tests/core/test_multiarray.py +++ b/torch_np/tests/numpy_tests/core/test_multiarray.py @@ -2605,7 +2605,6 @@ def test_diagonal_memleak(self): if HAS_REFCOUNT: assert_(sys.getrefcount(a) < 50) - @pytest.mark.xfail(reason="TODO: implement np.dot") def test_size_zero_memleak(self): # Regression test for issue 9615 # Exercises a special-case code path for dot products of length @@ -5708,7 +5707,7 @@ class MatmulCommon: """ # Should work with these types. Will want to add # "O" at some point - types = "?bhilqBefdFD" + types = "?bhilBefdFD" def test_exceptions(self): dims = [ @@ -5726,7 +5725,7 @@ def test_exceptions(self): for dt, (dm1, dm2) in itertools.product(self.types, dims): a = np.ones(dm1, dtype=dt) b = np.ones(dm2, dtype=dt) - assert_raises(ValueError, self.matmul, a, b) + assert_raises((RuntimeError, ValueError), self.matmul, a, b) def test_shapes(self): dims = [ @@ -5758,7 +5757,13 @@ def test_result_types(self): res = self.matmul(*arg) assert_(res.dtype == dt) - # vector vector returns scalars + @pytest.mark.xfail(reason='no scalars') + def test_result_types_2(self): + # in numpy, vector vector returns scalars + # we return a 0D array instead + + for dt in self.types: + v = np.ones((1,)).astype(dt) if dt != "O": res = self.matmul(v, v) assert_(type(res) is np.dtype(dt).type) @@ -5919,9 +5924,10 @@ def test_matrix_matrix_values(self): assert_equal(res, tgt12_21) -@pytest.mark.xfail(reason='TODO: matmul (ufunc wrapping goes south?)') class TestMatmul(MatmulCommon): - matmul = np.matmul + + def setup_method(self): + self.matmul = np.matmul def test_out_arg(self): a = np.ones((5, 2), dtype=float) @@ -5941,7 +5947,7 @@ def test_out_arg(self): assert_array_equal(out, tgt, err_msg=msg) # test out with not allowed type cast (safe casting) - msg = "Cannot cast ufunc .* output" + msg = "Cannot cast" out = np.zeros((5, 2), dtype=np.int32) assert_raises_regex(TypeError, msg, self.matmul, a, b, out=out) @@ -5949,9 +5955,9 @@ def test_out_arg(self): out = np.zeros((5, 2), dtype=np.complex128) c = self.matmul(a, b, out=out) assert_(c is out) - with suppress_warnings() as sup: - sup.filter(np.ComplexWarning, '') - c = c.astype(tgt.dtype) + # with suppress_warnings() as sup: + # sup.filter(np.ComplexWarning, '') + c = c.astype(tgt.dtype) assert_array_equal(c, tgt) def test_empty_out(self): @@ -5961,7 +5967,7 @@ def test_empty_out(self): out = np.ones((1, 1, 1)) assert self.matmul(arr, arr).shape == (0, 1, 1) - with pytest.raises(ValueError, match=r"non-broadcastable"): + with pytest.raises(ValueError, match="Bad size of the out array"): # match=r"non-broadcastable"): self.matmul(arr, arr, out=out) def test_out_contiguous(self): @@ -5974,7 +5980,7 @@ def test_out_contiguous(self): # test out non-contiguous out = np.ones((5, 2, 2), dtype=float) c = self.matmul(a, b, out=out[..., 0]) - assert c.base is out + assert c._tensor._base is out._tensor # FIXME: self.tensor (no underscore) assert_array_equal(c, tgt) c = self.matmul(a, v, out=out[:, 0, 0]) assert_array_equal(c, tgt_mv) @@ -6025,6 +6031,7 @@ def test_dot_equivalent(self, args): assert_equal(r1, r3) + @pytest.mark.skip(reason='object arrays') def test_matmul_exception_multiply(self): # test that matmul fails if `__mul__` is missing class add_not_multiply(): @@ -6034,6 +6041,7 @@ def __add__(self, other): with assert_raises(TypeError): b = np.matmul(a, a) + @pytest.mark.skip(reason='object arrays') def test_matmul_exception_add(self): # test that matmul fails if `__add__` is missing class multiply_not_add(): @@ -6043,6 +6051,7 @@ def __mul__(self, other): with assert_raises(TypeError): b = np.matmul(a, a) + @pytest.mark.xfail(reason="TODO: implement .view") def test_matmul_bool(self): # gh-14439 a = np.array([[1, 0],[1, 1]], dtype=bool) @@ -6062,11 +6071,11 @@ def test_matmul_bool(self): assert not np.any(c) -@pytest.mark.xfail(reason='TODO: @') class TestMatmulOperator(MatmulCommon): import operator matmul = operator.matmul + @pytest.mark.skip(reason="no __array_priority__") def test_array_priority_override(self): class A: @@ -6084,11 +6093,10 @@ def __rmatmul__(self, other): assert_equal(self.matmul(b, a), "A") def test_matmul_raises(self): - assert_raises(TypeError, self.matmul, np.int8(5), np.int8(5)) - assert_raises(TypeError, self.matmul, np.void(b'abc'), np.void(b'abc')) - assert_raises(TypeError, self.matmul, np.arange(10), np.void(b'abc')) + assert_raises((RuntimeError, TypeError), self.matmul, np.int8(5), np.int8(5)) + -@pytest.mark.xfail(reason='TODO @') +@pytest.mark.xfail(reason="torch supports inplace matmul, and so do we") def test_matmul_inplace(): # It would be nice to support in-place matmul eventually, but for now # we don't have a working implementation, so better just to error out