Skip to content

BUG: Fix matmul with out= arrays #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions torch_np/_binary_ufuncs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional

import torch

from . import _helpers
from ._detail import _binary_ufuncs
from ._normalizations import (
Expand All @@ -12,7 +14,9 @@
)

__all__ = [
name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch"
name
for name in dir(_binary_ufuncs)
if not name.startswith("_") and name not in ["torch", "matmul"]
]


Expand Down Expand Up @@ -40,12 +44,49 @@ def wrapped(
tensors = _helpers.ufunc_preprocess(
(x1, x2), out, where, casting, order, dtype, subok, signature, extobj
)
# now broadcast input tensors against the out=... array
if out is not None:
# XXX: need to filter out noop broadcasts if t.shape == out.shape?
shape = out.shape
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)

result = torch_func(*tensors)
return result, out

return wrapped


#
# matmul is special in that its `out=...` array does not broadcast x1 and x2.
# E.g. consider x1.shape = (5, 2) and x2.shape = (2, 3). Then `out.shape` is (5, 3).
#
@normalizer
def matmul(
x1: ArrayLike,
x2: ArrayLike,
/,
out: Optional[NDArray] = None,
*,
casting="same_kind",
order="K",
dtype: DTypeLike = None,
subok: SubokLike = False,
signature=None,
extobj=None,
axes=None,
axis=None,
) -> OutArray:
tensors = _helpers.ufunc_preprocess(
(x1, x2), out, True, casting, order, dtype, subok, signature, extobj
)
if axis is not None or axes is not None:
raise NotImplementedError

# NB: do not broadcast input tensors against the out=... array
result = _binary_ufuncs.matmul(*tensors)
return result, out


#
# For each torch ufunc implementation, decorate and attach the decorated name
# to this module. Its contents is then exported to the public namespace in __init__.py
Expand Down Expand Up @@ -111,4 +152,4 @@ def modf(x, /, *args, **kwds):
return rem, quot


__all__ = __all__ + ["divmod", "modf"]
__all__ = __all__ + ["divmod", "modf", "matmul"]
23 changes: 20 additions & 3 deletions torch_np/_detail/_binary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,26 @@

# work around torch limitations w.r.t. numpy
def matmul(x, y):
# work around RuntimeError: expected scalar type Int but found Double
# work around:
# - RuntimeError: expected scalar type Int but found Double
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Bool'
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype))
x = _util.cast_if_needed(x, dtype)
y = _util.cast_if_needed(y, dtype)
is_bool = dtype == torch.bool
is_half = dtype == torch.float16

work_dtype = dtype
if is_bool:
work_dtype = torch.uint8
if is_half:
work_dtype = torch.float32

x = _util.cast_if_needed(x, work_dtype)
y = _util.cast_if_needed(y, work_dtype)

result = torch.matmul(x, y)

if work_dtype != dtype:
result = result.to(dtype)

return result
7 changes: 0 additions & 7 deletions torch_np/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,6 @@ def ufunc_preprocess(

if out_dtype:
tensors = _util.typecast_tensors(tensors, out_dtype, casting)

# now broadcast input tensors against the out=... array
if out is not None:
# XXX: need to filter out noop broadcasts if t.shape == out.shape?
shape = out.shape
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)

return tensors


Expand Down
8 changes: 8 additions & 0 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,14 @@ def __ilshift__(self, other):
def __irshift__(self, other):
return _binary_ufuncs.right_shift(self, other, out=self)

__matmul__ = _binary_ufuncs.matmul

def __rmatmul__(self, other):
return _binary_ufuncs.matmul(other, self)

def __imatmul__(self, other):
return _binary_ufuncs.matmul(self, other, out=self)

# unary ops
__invert__ = _unary_ufuncs.invert
__abs__ = _unary_ufuncs.absolute
Expand Down
7 changes: 7 additions & 0 deletions torch_np/_unary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from typing import Optional

import torch

from . import _helpers
from ._detail import _unary_ufuncs
from ._normalizations import (
Expand Down Expand Up @@ -43,6 +45,11 @@ def wrapped(
tensors = _helpers.ufunc_preprocess(
(x,), out, where, casting, order, dtype, subok, signature, extobj
)
# now broadcast the input tensor against the out=... array
if out is not None:
# XXX: need to filter out noop broadcasts if t.shape == out.shape?
shape = out.shape
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)
result = torch_func(*tensors)
return result, out

Expand Down
42 changes: 25 additions & 17 deletions torch_np/tests/numpy_tests/core/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2605,7 +2605,6 @@ def test_diagonal_memleak(self):
if HAS_REFCOUNT:
assert_(sys.getrefcount(a) < 50)

@pytest.mark.xfail(reason="TODO: implement np.dot")
def test_size_zero_memleak(self):
# Regression test for issue 9615
# Exercises a special-case code path for dot products of length
Expand Down Expand Up @@ -5708,7 +5707,7 @@ class MatmulCommon:
"""
# Should work with these types. Will want to add
# "O" at some point
types = "?bhilqBefdFD"
types = "?bhilBefdFD"

def test_exceptions(self):
dims = [
Expand All @@ -5726,7 +5725,7 @@ def test_exceptions(self):
for dt, (dm1, dm2) in itertools.product(self.types, dims):
a = np.ones(dm1, dtype=dt)
b = np.ones(dm2, dtype=dt)
assert_raises(ValueError, self.matmul, a, b)
assert_raises((RuntimeError, ValueError), self.matmul, a, b)

def test_shapes(self):
dims = [
Expand Down Expand Up @@ -5758,7 +5757,13 @@ def test_result_types(self):
res = self.matmul(*arg)
assert_(res.dtype == dt)

# vector vector returns scalars
@pytest.mark.xfail(reason='no scalars')
def test_result_types_2(self):
# in numpy, vector vector returns scalars
# we return a 0D array instead

for dt in self.types:
v = np.ones((1,)).astype(dt)
if dt != "O":
res = self.matmul(v, v)
assert_(type(res) is np.dtype(dt).type)
Expand Down Expand Up @@ -5919,9 +5924,10 @@ def test_matrix_matrix_values(self):
assert_equal(res, tgt12_21)


@pytest.mark.xfail(reason='TODO: matmul (ufunc wrapping goes south?)')
class TestMatmul(MatmulCommon):
matmul = np.matmul

def setup_method(self):
self.matmul = np.matmul
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is a bit curious. Storing ufuncs as class attributes seems to work ok in scripts/interactive interpreter, but breaks down in combination with pytest:

the argument processing/normalization machinery picks up the extra self-like argument, which leads to

>           tensor = torch.as_tensor(obj)
E           RuntimeError: Could not infer dtype of TestMatmul

so there's some spooky action on a distance between meta-stuff pytest is doing and what we do here.


def test_out_arg(self):
a = np.ones((5, 2), dtype=float)
Expand All @@ -5941,17 +5947,17 @@ def test_out_arg(self):
assert_array_equal(out, tgt, err_msg=msg)

# test out with not allowed type cast (safe casting)
msg = "Cannot cast ufunc .* output"
msg = "Cannot cast"
out = np.zeros((5, 2), dtype=np.int32)
assert_raises_regex(TypeError, msg, self.matmul, a, b, out=out)

# test out with type upcast to complex
out = np.zeros((5, 2), dtype=np.complex128)
c = self.matmul(a, b, out=out)
assert_(c is out)
with suppress_warnings() as sup:
sup.filter(np.ComplexWarning, '')
c = c.astype(tgt.dtype)
# with suppress_warnings() as sup:
# sup.filter(np.ComplexWarning, '')
c = c.astype(tgt.dtype)
assert_array_equal(c, tgt)

def test_empty_out(self):
Expand All @@ -5961,7 +5967,7 @@ def test_empty_out(self):
out = np.ones((1, 1, 1))
assert self.matmul(arr, arr).shape == (0, 1, 1)

with pytest.raises(ValueError, match=r"non-broadcastable"):
with pytest.raises(ValueError, match="Bad size of the out array"): # match=r"non-broadcastable"):
self.matmul(arr, arr, out=out)

def test_out_contiguous(self):
Expand All @@ -5974,7 +5980,7 @@ def test_out_contiguous(self):
# test out non-contiguous
out = np.ones((5, 2, 2), dtype=float)
c = self.matmul(a, b, out=out[..., 0])
assert c.base is out
assert c._tensor._base is out._tensor # FIXME: self.tensor (no underscore)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will get cleaned up before merging, after the base of the stack gets in.

assert_array_equal(c, tgt)
c = self.matmul(a, v, out=out[:, 0, 0])
assert_array_equal(c, tgt_mv)
Expand Down Expand Up @@ -6025,6 +6031,7 @@ def test_dot_equivalent(self, args):
assert_equal(r1, r3)


@pytest.mark.skip(reason='object arrays')
def test_matmul_exception_multiply(self):
# test that matmul fails if `__mul__` is missing
class add_not_multiply():
Expand All @@ -6034,6 +6041,7 @@ def __add__(self, other):
with assert_raises(TypeError):
b = np.matmul(a, a)

@pytest.mark.skip(reason='object arrays')
def test_matmul_exception_add(self):
# test that matmul fails if `__add__` is missing
class multiply_not_add():
Expand All @@ -6043,6 +6051,7 @@ def __mul__(self, other):
with assert_raises(TypeError):
b = np.matmul(a, a)

@pytest.mark.xfail(reason="TODO: implement .view")
def test_matmul_bool(self):
# gh-14439
a = np.array([[1, 0],[1, 1]], dtype=bool)
Expand All @@ -6062,11 +6071,11 @@ def test_matmul_bool(self):
assert not np.any(c)


@pytest.mark.xfail(reason='TODO: @')
class TestMatmulOperator(MatmulCommon):
import operator
matmul = operator.matmul

@pytest.mark.skip(reason="no __array_priority__")
def test_array_priority_override(self):

class A:
Expand All @@ -6084,11 +6093,10 @@ def __rmatmul__(self, other):
assert_equal(self.matmul(b, a), "A")

def test_matmul_raises(self):
assert_raises(TypeError, self.matmul, np.int8(5), np.int8(5))
assert_raises(TypeError, self.matmul, np.void(b'abc'), np.void(b'abc'))
assert_raises(TypeError, self.matmul, np.arange(10), np.void(b'abc'))
assert_raises((RuntimeError, TypeError), self.matmul, np.int8(5), np.int8(5))


@pytest.mark.xfail(reason='TODO @')
@pytest.mark.xfail(reason="torch supports inplace matmul, and so do we")
def test_matmul_inplace():
# It would be nice to support in-place matmul eventually, but for now
# we don't have a working implementation, so better just to error out
Expand Down