Skip to content

Commit 53b865a

Browse files
committed
BUG: fix matmul with out=... array
matmul(x1, x2, out) does not broadcast x1, x2 against out, like other ufuncs do.
1 parent 4be6de4 commit 53b865a

File tree

5 files changed

+91
-23
lines changed

5 files changed

+91
-23
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from typing import Optional
22

3+
import torch
4+
35
from . import _helpers
46
from ._detail import _binary_ufuncs
57
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
68

79
__all__ = [
8-
name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch"
10+
name
11+
for name in dir(_binary_ufuncs)
12+
if not name.startswith("_") and name not in ["torch", "matmul"]
913
]
1014

1115

@@ -33,12 +37,49 @@ def wrapped(
3337
tensors = _helpers.ufunc_preprocess(
3438
(x1, x2), out, where, casting, order, dtype, subok, signature, extobj
3539
)
40+
# now broadcast input tensors against the out=... array
41+
if out is not None:
42+
# XXX: need to filter out noop broadcasts if t.shape == out.shape?
43+
shape = out.shape
44+
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)
45+
3646
result = torch_func(*tensors)
3747
return _helpers.result_or_out(result, out)
3848

3949
return wrapped
4050

4151

52+
#
53+
# matmul is special in that its `out=...` array does not broadcast x1 and x2.
54+
# E.g. consider x1.shape = (5, 2) and x2.shape = (2, 3). Then `out.shape` is (5, 3).
55+
#
56+
@normalizer
57+
def matmul(
58+
x1: ArrayLike,
59+
x2: ArrayLike,
60+
/,
61+
out: Optional[NDArray] = None,
62+
*,
63+
casting="same_kind",
64+
order="K",
65+
dtype: DTypeLike = None,
66+
subok: SubokLike = False,
67+
signature=None,
68+
extobj=None,
69+
axes=None,
70+
axis=None,
71+
) -> OutArray:
72+
tensors = _helpers.ufunc_preprocess(
73+
(x1, x2), out, True, casting, order, dtype, subok, signature, extobj
74+
)
75+
if axis is not None or axes is not None:
76+
raise NotImplementedError
77+
78+
# NB: do not broadcast input tensors against the out=... array
79+
result = _binary_ufuncs.matmul(*tensors)
80+
return result, out
81+
82+
4283
#
4384
# For each torch ufunc implementation, decorate and attach the decorated name
4485
# to this module. Its contents is then exported to the public namespace in __init__.py
@@ -104,4 +145,4 @@ def modf(x, /, *args, **kwds):
104145
return rem, quot
105146

106147

107-
__all__ = __all__ + ["divmod", "modf"]
148+
__all__ = __all__ + ["divmod", "modf", "matmul"]

torch_np/_detail/_binary_ufuncs.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,26 @@
4747

4848
# work around torch limitations w.r.t. numpy
4949
def matmul(x, y):
50-
# work around RuntimeError: expected scalar type Int but found Double
50+
# work around:
51+
# - RuntimeError: expected scalar type Int but found Double
52+
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Bool'
53+
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
5154
dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype))
52-
x = _util.cast_if_needed(x, dtype)
53-
y = _util.cast_if_needed(y, dtype)
55+
is_bool = dtype == torch.bool
56+
is_half = dtype == torch.float16
57+
58+
work_dtype = dtype
59+
if is_bool:
60+
work_dtype = torch.uint8
61+
if is_half:
62+
work_dtype = torch.float32
63+
64+
x = _util.cast_if_needed(x, work_dtype)
65+
y = _util.cast_if_needed(y, work_dtype)
66+
5467
result = torch.matmul(x, y)
68+
69+
if work_dtype != dtype:
70+
result = result.to(dtype)
71+
5572
return result

torch_np/_helpers.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,6 @@ def ufunc_preprocess(
2727

2828
if out_dtype:
2929
tensors = _util.typecast_tensors(tensors, out_dtype, casting)
30-
31-
# now broadcast input tensors against the out=... array
32-
if out is not None:
33-
# XXX: need to filter out noop broadcasts if t.shape == out.shape?
34-
shape = out.shape
35-
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)
36-
3730
return tensors
3831

3932

torch_np/_unary_ufuncs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from typing import Optional
66

7+
import torch
8+
79
from . import _helpers
810
from ._detail import _unary_ufuncs
911
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
@@ -36,6 +38,11 @@ def wrapped(
3638
tensors = _helpers.ufunc_preprocess(
3739
(x,), out, where, casting, order, dtype, subok, signature, extobj
3840
)
41+
# now broadcast the input tensor against the out=... array
42+
if out is not None:
43+
# XXX: need to filter out noop broadcasts if t.shape == out.shape?
44+
shape = out.shape
45+
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)
3946
result = torch_func(*tensors)
4047
return _helpers.result_or_out(result, out)
4148

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5707,7 +5707,7 @@ class MatmulCommon:
57075707
"""
57085708
# Should work with these types. Will want to add
57095709
# "O" at some point
5710-
types = "?bhilqBefdFD"
5710+
types = "?bhilBefdFD"
57115711

57125712
def test_exceptions(self):
57135713
dims = [
@@ -5725,7 +5725,7 @@ def test_exceptions(self):
57255725
for dt, (dm1, dm2) in itertools.product(self.types, dims):
57265726
a = np.ones(dm1, dtype=dt)
57275727
b = np.ones(dm2, dtype=dt)
5728-
assert_raises(ValueError, self.matmul, a, b)
5728+
assert_raises((RuntimeError, ValueError), self.matmul, a, b)
57295729

57305730
def test_shapes(self):
57315731
dims = [
@@ -5757,7 +5757,13 @@ def test_result_types(self):
57575757
res = self.matmul(*arg)
57585758
assert_(res.dtype == dt)
57595759

5760-
# vector vector returns scalars
5760+
@pytest.mark.xfail(reason='no scalars')
5761+
def test_result_types_2(self):
5762+
# in numpy, vector vector returns scalars
5763+
# we return a 0D array instead
5764+
5765+
for dt in self.types:
5766+
v = np.ones((1,)).astype(dt)
57615767
if dt != "O":
57625768
res = self.matmul(v, v)
57635769
assert_(type(res) is np.dtype(dt).type)
@@ -5918,9 +5924,10 @@ def test_matrix_matrix_values(self):
59185924
assert_equal(res, tgt12_21)
59195925

59205926

5921-
@pytest.mark.xfail(reason='TODO: matmul (ufunc wrapping goes south?)')
59225927
class TestMatmul(MatmulCommon):
5923-
matmul = np.matmul
5928+
5929+
def setup_method(self):
5930+
self.matmul = np.matmul
59245931

59255932
def test_out_arg(self):
59265933
a = np.ones((5, 2), dtype=float)
@@ -5940,17 +5947,17 @@ def test_out_arg(self):
59405947
assert_array_equal(out, tgt, err_msg=msg)
59415948

59425949
# test out with not allowed type cast (safe casting)
5943-
msg = "Cannot cast ufunc .* output"
5950+
msg = "Cannot cast"
59445951
out = np.zeros((5, 2), dtype=np.int32)
59455952
assert_raises_regex(TypeError, msg, self.matmul, a, b, out=out)
59465953

59475954
# test out with type upcast to complex
59485955
out = np.zeros((5, 2), dtype=np.complex128)
59495956
c = self.matmul(a, b, out=out)
59505957
assert_(c is out)
5951-
with suppress_warnings() as sup:
5952-
sup.filter(np.ComplexWarning, '')
5953-
c = c.astype(tgt.dtype)
5958+
# with suppress_warnings() as sup:
5959+
# sup.filter(np.ComplexWarning, '')
5960+
c = c.astype(tgt.dtype)
59545961
assert_array_equal(c, tgt)
59555962

59565963
def test_empty_out(self):
@@ -5960,7 +5967,7 @@ def test_empty_out(self):
59605967
out = np.ones((1, 1, 1))
59615968
assert self.matmul(arr, arr).shape == (0, 1, 1)
59625969

5963-
with pytest.raises(ValueError, match=r"non-broadcastable"):
5970+
with pytest.raises(ValueError, match="Bad size of the out array"): # match=r"non-broadcastable"):
59645971
self.matmul(arr, arr, out=out)
59655972

59665973
def test_out_contiguous(self):
@@ -5973,7 +5980,7 @@ def test_out_contiguous(self):
59735980
# test out non-contiguous
59745981
out = np.ones((5, 2, 2), dtype=float)
59755982
c = self.matmul(a, b, out=out[..., 0])
5976-
assert c.base is out
5983+
assert c._tensor._base is out._tensor # FIXME: self.tensor (no underscore)
59775984
assert_array_equal(c, tgt)
59785985
c = self.matmul(a, v, out=out[:, 0, 0])
59795986
assert_array_equal(c, tgt_mv)
@@ -6024,6 +6031,7 @@ def test_dot_equivalent(self, args):
60246031
assert_equal(r1, r3)
60256032

60266033

6034+
@pytest.mark.skip(reason='object arrays')
60276035
def test_matmul_exception_multiply(self):
60286036
# test that matmul fails if `__mul__` is missing
60296037
class add_not_multiply():
@@ -6033,6 +6041,7 @@ def __add__(self, other):
60336041
with assert_raises(TypeError):
60346042
b = np.matmul(a, a)
60356043

6044+
@pytest.mark.skip(reason='object arrays')
60366045
def test_matmul_exception_add(self):
60376046
# test that matmul fails if `__add__` is missing
60386047
class multiply_not_add():
@@ -6042,6 +6051,7 @@ def __mul__(self, other):
60426051
with assert_raises(TypeError):
60436052
b = np.matmul(a, a)
60446053

6054+
@pytest.mark.xfail(reason="TODO: implement .view")
60456055
def test_matmul_bool(self):
60466056
# gh-14439
60476057
a = np.array([[1, 0],[1, 1]], dtype=bool)

0 commit comments

Comments
 (0)