Skip to content

Commit 32e9844

Browse files
committed
BUG: fix up matmult/dot
1 parent e1ad9d8 commit 32e9844

File tree

4 files changed

+29
-9
lines changed

4 files changed

+29
-9
lines changed

torch_np/_detail/_ufunc_impl.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22

33
from . import _util
4+
from . import _dtypes_impl
45

56

67
def deco_ufunc(torch_func):
@@ -70,7 +71,6 @@ def wrapped(
7071
logical_and = deco_ufunc(torch.logical_and)
7172
logical_or = deco_ufunc(torch.logical_or)
7273
logical_xor = deco_ufunc(torch.logical_xor)
73-
matmul = deco_ufunc(torch.matmul)
7474
maximum = deco_ufunc(torch.maximum)
7575
minimum = deco_ufunc(torch.minimum)
7676
remainder = deco_ufunc(torch.remainder)
@@ -143,7 +143,15 @@ def _absolute(x):
143143
return x
144144
return torch.absolute(x)
145145

146+
def _matmul(x, y):
147+
# work around RuntimeError: expected scalar type Int but found Double
148+
dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype))
149+
x = x.to(dtype)
150+
y = y.to(dtype)
151+
result = torch.matmul(x, y)
152+
return result
146153

147154
cbrt = deco_ufunc(_cbrt)
148155
positive = deco_ufunc(_positive)
149156
absolute = deco_ufunc(_absolute)
157+
matmul = deco_ufunc(_matmul)

torch_np/_detail/implementations.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def trace(tensor, offset=0, axis1=0, axis2=1, dtype=None, out=None):
178178
def diagonal(tensor, offset=0, axis1=0, axis2=1):
179179
axis1 = _util.normalize_axis_index(axis1, tensor.ndim)
180180
axis2 = _util.normalize_axis_index(axis2, tensor.ndim)
181+
if axis1 == axis2:
182+
raise ValueError("axis1 and axis2 cannot be the same")
181183
result = torch.diagonal(tensor, offset, axis1, axis2)
182184
return result
183185

@@ -492,17 +494,25 @@ def arange(start=None, stop=None, step=1, dtype=None):
492494
if start is None:
493495
start = 0
494496

495-
if dtype is None:
496-
dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)]
497-
dtype = _dtypes_impl.default_int_dtype
498-
dt_list.append(dtype)
499-
dtype = _dtypes_impl.result_type_impl(dt_list)
497+
# if dtype is None:
498+
dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)]
499+
dtype = _dtypes_impl.default_int_dtype
500+
dt_list.append(dtype)
501+
dtype = _dtypes_impl.result_type_impl(dt_list)
500502

503+
# work around RuntimeError: "arange_cpu" not implemented for 'ComplexFloat'
504+
orig_dtype = dtype
505+
is_complex = dtype is not None and dtype.is_complex
501506
try:
502-
return torch.arange(start, stop, step, dtype=dtype)
507+
if is_complex:
508+
dtype = torch.float64
509+
result = torch.arange(start, stop, step, dtype=orig_dtype)
510+
if is_complex:
511+
result = result.to(dttype)
503512
except RuntimeError:
504513
raise ValueError("Maximum allowed size exceeded")
505514

515+
return result
506516

507517
# ### empty/full et al ###
508518

torch_np/_funcs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
from . import _decorators, _helpers
4-
from ._detail import _flips, _util
4+
from ._detail import _flips, _util, _dtypes_impl
55
from ._detail import implementations as _impl
66

77

@@ -101,6 +101,9 @@ def vdot(a, b, /):
101101

102102
def dot(a, b, out=None):
103103
t_a, t_b = _helpers.to_tensors(a, b)
104+
dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype))
105+
t_a = t_a.to(dtype)
106+
t_b = t_b.to(dtype)
104107
result = _impl.dot(t_a, t_b)
105108
return _helpers.result_or_out(result, out)
106109

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2468,7 +2468,6 @@ def test_arr_mult_2(self, func):
24682468
func(edf, edf[:, ::-1].T.copy())
24692469
)
24702470

2471-
@pytest.mark.xfail(reason="TODO np.dot")
24722471
@pytest.mark.parametrize('func', (np.dot, np.matmul))
24732472
@pytest.mark.parametrize('dtype', 'ifdFD')
24742473
def test_no_dgemv(self, func, dtype):

0 commit comments

Comments
 (0)