Skip to content

Commit 6666ed0

Browse files
committed
BUG: fix matmul with out=... array
matmul(x1, x2, out) does not broadcast x1, x2 against out, like other ufuncs do.
1 parent 0bc1374 commit 6666ed0

File tree

5 files changed

+91
-23
lines changed

5 files changed

+91
-23
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Optional
22

3+
import torch
4+
35
from . import _helpers
46
from ._detail import _binary_ufuncs
57
from ._normalizations import (
@@ -12,7 +14,9 @@
1214
)
1315

1416
__all__ = [
15-
name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch"
17+
name
18+
for name in dir(_binary_ufuncs)
19+
if not name.startswith("_") and name not in ["torch", "matmul"]
1620
]
1721

1822

@@ -40,12 +44,49 @@ def wrapped(
4044
tensors = _helpers.ufunc_preprocess(
4145
(x1, x2), out, where, casting, order, dtype, subok, signature, extobj
4246
)
47+
# now broadcast input tensors against the out=... array
48+
if out is not None:
49+
# XXX: need to filter out noop broadcasts if t.shape == out.shape?
50+
shape = out.shape
51+
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)
52+
4353
result = torch_func(*tensors)
4454
return result, out
4555

4656
return wrapped
4757

4858

59+
#
60+
# matmul is special in that its `out=...` array does not broadcast x1 and x2.
61+
# E.g. consider x1.shape = (5, 2) and x2.shape = (2, 3). Then `out.shape` is (5, 3).
62+
#
63+
@normalizer
64+
def matmul(
65+
x1: ArrayLike,
66+
x2: ArrayLike,
67+
/,
68+
out: Optional[NDArray] = None,
69+
*,
70+
casting="same_kind",
71+
order="K",
72+
dtype: DTypeLike = None,
73+
subok: SubokLike = False,
74+
signature=None,
75+
extobj=None,
76+
axes=None,
77+
axis=None,
78+
) -> OutArray:
79+
tensors = _helpers.ufunc_preprocess(
80+
(x1, x2), out, True, casting, order, dtype, subok, signature, extobj
81+
)
82+
if axis is not None or axes is not None:
83+
raise NotImplementedError
84+
85+
# NB: do not broadcast input tensors against the out=... array
86+
result = _binary_ufuncs.matmul(*tensors)
87+
return result, out
88+
89+
4990
#
5091
# For each torch ufunc implementation, decorate and attach the decorated name
5192
# to this module. Its contents is then exported to the public namespace in __init__.py
@@ -111,4 +152,4 @@ def modf(x, /, *args, **kwds):
111152
return rem, quot
112153

113154

114-
__all__ = __all__ + ["divmod", "modf"]
155+
__all__ = __all__ + ["divmod", "modf", "matmul"]

torch_np/_detail/_binary_ufuncs.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,26 @@
4545

4646
# work around torch limitations w.r.t. numpy
4747
def matmul(x, y):
48-
# work around RuntimeError: expected scalar type Int but found Double
48+
# work around:
49+
# - RuntimeError: expected scalar type Int but found Double
50+
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Bool'
51+
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
4952
dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype))
50-
x = _util.cast_if_needed(x, dtype)
51-
y = _util.cast_if_needed(y, dtype)
53+
is_bool = dtype == torch.bool
54+
is_half = dtype == torch.float16
55+
56+
work_dtype = dtype
57+
if is_bool:
58+
work_dtype = torch.uint8
59+
if is_half:
60+
work_dtype = torch.float32
61+
62+
x = _util.cast_if_needed(x, work_dtype)
63+
y = _util.cast_if_needed(y, work_dtype)
64+
5265
result = torch.matmul(x, y)
66+
67+
if work_dtype != dtype:
68+
result = result.to(dtype)
69+
5370
return result

torch_np/_helpers.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,6 @@ def ufunc_preprocess(
2727

2828
if out_dtype:
2929
tensors = _util.typecast_tensors(tensors, out_dtype, casting)
30-
31-
# now broadcast input tensors against the out=... array
32-
if out is not None:
33-
# XXX: need to filter out noop broadcasts if t.shape == out.shape?
34-
shape = out.shape
35-
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)
36-
3730
return tensors
3831

3932

torch_np/_unary_ufuncs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from typing import Optional
66

7+
import torch
8+
79
from . import _helpers
810
from ._detail import _unary_ufuncs
911
from ._normalizations import (
@@ -43,6 +45,11 @@ def wrapped(
4345
tensors = _helpers.ufunc_preprocess(
4446
(x,), out, where, casting, order, dtype, subok, signature, extobj
4547
)
48+
# now broadcast the input tensor against the out=... array
49+
if out is not None:
50+
# XXX: need to filter out noop broadcasts if t.shape == out.shape?
51+
shape = out.shape
52+
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)
4653
result = torch_func(*tensors)
4754
return result, out
4855

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5708,7 +5708,7 @@ class MatmulCommon:
57085708
"""
57095709
# Should work with these types. Will want to add
57105710
# "O" at some point
5711-
types = "?bhilqBefdFD"
5711+
types = "?bhilBefdFD"
57125712

57135713
def test_exceptions(self):
57145714
dims = [
@@ -5726,7 +5726,7 @@ def test_exceptions(self):
57265726
for dt, (dm1, dm2) in itertools.product(self.types, dims):
57275727
a = np.ones(dm1, dtype=dt)
57285728
b = np.ones(dm2, dtype=dt)
5729-
assert_raises(ValueError, self.matmul, a, b)
5729+
assert_raises((RuntimeError, ValueError), self.matmul, a, b)
57305730

57315731
def test_shapes(self):
57325732
dims = [
@@ -5758,7 +5758,13 @@ def test_result_types(self):
57585758
res = self.matmul(*arg)
57595759
assert_(res.dtype == dt)
57605760

5761-
# vector vector returns scalars
5761+
@pytest.mark.xfail(reason='no scalars')
5762+
def test_result_types_2(self):
5763+
# in numpy, vector vector returns scalars
5764+
# we return a 0D array instead
5765+
5766+
for dt in self.types:
5767+
v = np.ones((1,)).astype(dt)
57625768
if dt != "O":
57635769
res = self.matmul(v, v)
57645770
assert_(type(res) is np.dtype(dt).type)
@@ -5919,9 +5925,10 @@ def test_matrix_matrix_values(self):
59195925
assert_equal(res, tgt12_21)
59205926

59215927

5922-
@pytest.mark.xfail(reason='TODO: matmul (ufunc wrapping goes south?)')
59235928
class TestMatmul(MatmulCommon):
5924-
matmul = np.matmul
5929+
5930+
def setup_method(self):
5931+
self.matmul = np.matmul
59255932

59265933
def test_out_arg(self):
59275934
a = np.ones((5, 2), dtype=float)
@@ -5941,17 +5948,17 @@ def test_out_arg(self):
59415948
assert_array_equal(out, tgt, err_msg=msg)
59425949

59435950
# test out with not allowed type cast (safe casting)
5944-
msg = "Cannot cast ufunc .* output"
5951+
msg = "Cannot cast"
59455952
out = np.zeros((5, 2), dtype=np.int32)
59465953
assert_raises_regex(TypeError, msg, self.matmul, a, b, out=out)
59475954

59485955
# test out with type upcast to complex
59495956
out = np.zeros((5, 2), dtype=np.complex128)
59505957
c = self.matmul(a, b, out=out)
59515958
assert_(c is out)
5952-
with suppress_warnings() as sup:
5953-
sup.filter(np.ComplexWarning, '')
5954-
c = c.astype(tgt.dtype)
5959+
# with suppress_warnings() as sup:
5960+
# sup.filter(np.ComplexWarning, '')
5961+
c = c.astype(tgt.dtype)
59555962
assert_array_equal(c, tgt)
59565963

59575964
def test_empty_out(self):
@@ -5961,7 +5968,7 @@ def test_empty_out(self):
59615968
out = np.ones((1, 1, 1))
59625969
assert self.matmul(arr, arr).shape == (0, 1, 1)
59635970

5964-
with pytest.raises(ValueError, match=r"non-broadcastable"):
5971+
with pytest.raises(ValueError, match="Bad size of the out array"): # match=r"non-broadcastable"):
59655972
self.matmul(arr, arr, out=out)
59665973

59675974
def test_out_contiguous(self):
@@ -5974,7 +5981,7 @@ def test_out_contiguous(self):
59745981
# test out non-contiguous
59755982
out = np.ones((5, 2, 2), dtype=float)
59765983
c = self.matmul(a, b, out=out[..., 0])
5977-
assert c.base is out
5984+
assert c._tensor._base is out._tensor # FIXME: self.tensor (no underscore)
59785985
assert_array_equal(c, tgt)
59795986
c = self.matmul(a, v, out=out[:, 0, 0])
59805987
assert_array_equal(c, tgt_mv)
@@ -6025,6 +6032,7 @@ def test_dot_equivalent(self, args):
60256032
assert_equal(r1, r3)
60266033

60276034

6035+
@pytest.mark.skip(reason='object arrays')
60286036
def test_matmul_exception_multiply(self):
60296037
# test that matmul fails if `__mul__` is missing
60306038
class add_not_multiply():
@@ -6034,6 +6042,7 @@ def __add__(self, other):
60346042
with assert_raises(TypeError):
60356043
b = np.matmul(a, a)
60366044

6045+
@pytest.mark.skip(reason='object arrays')
60376046
def test_matmul_exception_add(self):
60386047
# test that matmul fails if `__add__` is missing
60396048
class multiply_not_add():
@@ -6043,6 +6052,7 @@ def __mul__(self, other):
60436052
with assert_raises(TypeError):
60446053
b = np.matmul(a, a)
60456054

6055+
@pytest.mark.xfail(reason="TODO: implement .view")
60466056
def test_matmul_bool(self):
60476057
# gh-14439
60486058
a = np.array([[1, 0],[1, 1]], dtype=bool)

0 commit comments

Comments
 (0)