Skip to content

Commit 828d7e0

Browse files
authored
Merge pull request #72 from Quansight-Labs/unxfail_dot
Add ndarray.dot
2 parents f667f8c + ff75f7a commit 828d7e0

File tree

6 files changed

+80
-82
lines changed

6 files changed

+80
-82
lines changed

torch_np/_detail/_ufunc_impl.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from . import _util
3+
from . import _dtypes_impl, _util
44

55

66
def deco_ufunc(torch_func):
@@ -70,7 +70,6 @@ def wrapped(
7070
logical_and = deco_ufunc(torch.logical_and)
7171
logical_or = deco_ufunc(torch.logical_or)
7272
logical_xor = deco_ufunc(torch.logical_xor)
73-
matmul = deco_ufunc(torch.matmul)
7473
maximum = deco_ufunc(torch.maximum)
7574
minimum = deco_ufunc(torch.minimum)
7675
remainder = deco_ufunc(torch.remainder)
@@ -144,6 +143,16 @@ def _absolute(x):
144143
return torch.absolute(x)
145144

146145

146+
def _matmul(x, y):
147+
# work around RuntimeError: expected scalar type Int but found Double
148+
dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype))
149+
x = _util.cast_if_needed(x, dtype)
150+
y = _util.cast_if_needed(y, dtype)
151+
result = torch.matmul(x, y)
152+
return result
153+
154+
147155
cbrt = deco_ufunc(_cbrt)
148156
positive = deco_ufunc(_positive)
149157
absolute = deco_ufunc(_absolute)
158+
matmul = deco_ufunc(_matmul)

torch_np/_detail/implementations.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,17 +492,31 @@ def arange(start=None, stop=None, step=1, dtype=None):
492492
if start is None:
493493
start = 0
494494

495+
# the dtype of the result
495496
if dtype is None:
496-
dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)]
497497
dtype = _dtypes_impl.default_int_dtype
498-
dt_list.append(dtype)
499-
dtype = _dtypes_impl.result_type_impl(dt_list)
498+
dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)]
499+
dt_list.append(dtype)
500+
dtype = _dtypes_impl.result_type_impl(dt_list)
501+
502+
# work around RuntimeError: "arange_cpu" not implemented for 'ComplexFloat'
503+
if dtype.is_complex:
504+
work_dtype, target_dtype = torch.float64, dtype
505+
else:
506+
work_dtype, target_dtype = dtype, dtype
507+
508+
if (step > 0 and start > stop) or (step < 0 and start < stop):
509+
# empty range
510+
return torch.empty(0, dtype=target_dtype)
500511

501512
try:
502-
return torch.arange(start, stop, step, dtype=dtype)
513+
result = torch.arange(start, stop, step, dtype=work_dtype)
514+
result = _util.cast_if_needed(result, target_dtype)
503515
except RuntimeError:
504516
raise ValueError("Maximum allowed size exceeded")
505517

518+
return result
519+
506520

507521
# ### empty/full et al ###
508522

@@ -783,6 +797,10 @@ def vdot(t_a, t_b, /):
783797

784798

785799
def dot(t_a, t_b):
800+
dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype))
801+
t_a = _util.cast_if_needed(t_a, dtype)
802+
t_b = _util.cast_if_needed(t_b, dtype)
803+
786804
if t_a.ndim == 0 or t_b.ndim == 0:
787805
result = t_a * t_b
788806
elif t_a.ndim == 1 and t_b.ndim == 1:

torch_np/_funcs.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
from . import _decorators, _helpers
4-
from ._detail import _flips, _util
4+
from ._detail import _dtypes_impl, _flips, _util
55
from ._detail import implementations as _impl
66

77

@@ -93,7 +93,17 @@ def fill_diagonal(a, val, wrap=False):
9393
return _helpers.array_from(result)
9494

9595

96-
# ### sorting ###
96+
def vdot(a, b, /):
97+
t_a, t_b = _helpers.to_tensors(a, b)
98+
result = _impl.vdot(t_a, t_b)
99+
return result.item()
100+
101+
102+
def dot(a, b, out=None):
103+
t_a, t_b = _helpers.to_tensors(a, b)
104+
result = _impl.dot(t_a, t_b)
105+
return _helpers.result_or_out(result, out)
106+
97107

98108
# ### sort and partition ###
99109

torch_np/_ndarray.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ def reshape(self, *shape, order="C"):
362362

363363
diagonal = _funcs.diagonal
364364
trace = _funcs.trace
365+
dot = _funcs.dot
365366

366367
### sorting ###
367368

torch_np/_wrapper.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -421,18 +421,6 @@ def where(condition, x=None, y=None, /):
421421
return asarray(result)
422422

423423

424-
def vdot(a, b, /):
425-
t_a, t_b = _helpers.to_tensors(a, b)
426-
result = _impl.vdot(t_a, t_b)
427-
return result.item()
428-
429-
430-
def dot(a, b, out=None):
431-
t_a, t_b = _helpers.to_tensors(a, b)
432-
result = _impl.dot(t_a, t_b)
433-
return _helpers.result_or_out(result, out)
434-
435-
436424
###### module-level queries of object properties
437425

438426

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 34 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2386,7 +2386,6 @@ def test_flatten(self):
23862386
assert_equal(x1.flatten('F'), x1.T.flatten())
23872387

23882388

2389-
@pytest.mark.xfail(reason="TODO np.dot")
23902389
@pytest.mark.parametrize('func', (np.dot, np.matmul))
23912390
def test_arr_mult(self, func):
23922391
a = np.array([[1, 0], [0, 1]])
@@ -2428,7 +2427,27 @@ def test_arr_mult(self, func):
24282427
assert_equal(func(ebf.T, ebf), eaf)
24292428
assert_equal(func(ebf, ebf.T), eaf)
24302429
assert_equal(func(ebf.T, ebf.T), eaf)
2430+
# syrk - different shape
2431+
for et in [np.float32, np.float64, np.complex64, np.complex128]:
2432+
edf = d.astype(et)
2433+
eddtf = ddt.astype(et)
2434+
edtdf = dtd.astype(et)
2435+
assert_equal(func(edf, edf.T), eddtf)
2436+
assert_equal(func(edf.T, edf), edtdf)
2437+
2438+
assert_equal(
2439+
func(edf[:edf.shape[0] // 2, :], edf[::2, :].T),
2440+
func(edf[:edf.shape[0] // 2, :].copy(), edf[::2, :].T.copy())
2441+
)
2442+
assert_equal(
2443+
func(edf[::2, :], edf[:edf.shape[0] // 2, :].T),
2444+
func(edf[::2, :].copy(), edf[:edf.shape[0] // 2, :].T.copy())
2445+
)
24312446

2447+
2448+
@pytest.mark.skip(reason="dot/matmul with negative strides")
2449+
@pytest.mark.parametrize('func', (np.dot, np.matmul))
2450+
def test_arr_mult_2(self, func):
24322451
# syrk - different shape, stride, and view validations
24332452
for et in [np.float32, np.float64, np.complex64, np.complex128]:
24342453
edf = d.astype(et)
@@ -2448,24 +2467,7 @@ def test_arr_mult(self, func):
24482467
func(edf, edf[:, ::-1].T),
24492468
func(edf, edf[:, ::-1].T.copy())
24502469
)
2451-
assert_equal(
2452-
func(edf[:edf.shape[0] // 2, :], edf[::2, :].T),
2453-
func(edf[:edf.shape[0] // 2, :].copy(), edf[::2, :].T.copy())
2454-
)
2455-
assert_equal(
2456-
func(edf[::2, :], edf[:edf.shape[0] // 2, :].T),
2457-
func(edf[::2, :].copy(), edf[:edf.shape[0] // 2, :].T.copy())
2458-
)
24592470

2460-
# syrk - different shape
2461-
for et in [np.float32, np.float64, np.complex64, np.complex128]:
2462-
edf = d.astype(et)
2463-
eddtf = ddt.astype(et)
2464-
edtdf = dtd.astype(et)
2465-
assert_equal(func(edf, edf.T), eddtf)
2466-
assert_equal(func(edf.T, edf), edtdf)
2467-
2468-
@pytest.mark.xfail(reason="TODO np.dot")
24692471
@pytest.mark.parametrize('func', (np.dot, np.matmul))
24702472
@pytest.mark.parametrize('dtype', 'ifdFD')
24712473
def test_no_dgemv(self, func, dtype):
@@ -2481,6 +2483,11 @@ def test_no_dgemv(self, func, dtype):
24812483
ret2 = func(b.T.copy(), a.T)
24822484
assert_equal(ret1, ret2)
24832485

2486+
2487+
@pytest.mark.skip(reason="__array_interface__")
2488+
@pytest.mark.parametrize('func', (np.dot, np.matmul))
2489+
@pytest.mark.parametrize('dtype', 'ifdFD')
2490+
def test_no_dgemv_2(self, func, dtype):
24842491
# check for unaligned data
24852492
dt = np.dtype(dtype)
24862493
a = np.zeros(8 * dt.itemsize // 2 + 1, dtype='int16')[1:].view(dtype)
@@ -2496,7 +2503,6 @@ def test_no_dgemv(self, func, dtype):
24962503
ret2 = func(b.T.copy(), a.T.copy())
24972504
assert_equal(ret1, ret2)
24982505

2499-
@pytest.mark.xfail(reason="TODO np.dot")
25002506
def test_dot(self):
25012507
a = np.array([[1, 0], [0, 1]])
25022508
b = np.array([[0, 1], [1, 0]])
@@ -2515,15 +2521,8 @@ def test_dot(self):
25152521
a.dot(b=b, out=c)
25162522
assert_equal(c, np.dot(a, b))
25172523

2518-
@pytest.mark.xfail(reason="TODO np.dot")
2519-
def test_dot_type_mismatch(self):
2520-
c = 1.
2521-
A = np.array((1,1), dtype='i,i')
25222524

2523-
assert_raises(TypeError, np.dot, c, A)
2524-
assert_raises(TypeError, np.dot, A, c)
2525-
2526-
@pytest.mark.xfail(reason="TODO np.dot")
2525+
@pytest.mark.xfail(reason="_aligned_zeros")
25272526
def test_dot_out_mem_overlap(self):
25282527
np.random.seed(1)
25292528

@@ -5627,7 +5626,7 @@ def test_dot_array_order(self):
56275626
assert_equal(np.dot(b, a), res)
56285627
assert_equal(np.dot(b, b), res)
56295628

5630-
@pytest.mark.skip(reason='TODO: nbytes, view')
5629+
@pytest.mark.skip(reason='TODO: nbytes, view, __array_interface__')
56315630
def test_accelerate_framework_sgemv_fix(self):
56325631

56335632
def aligned_array(shape, align, dtype, order='C'):
@@ -7878,7 +7877,6 @@ def test_view_discard_refcount(self):
78787877
assert_equal(arr, orig)
78797878

78807879

7881-
@pytest.mark.xfail(reason='TODO')
78827880
class TestArange:
78837881
def test_infinite(self):
78847882
assert_raises_regex(
@@ -7887,8 +7885,8 @@ def test_infinite(self):
78877885
)
78887886

78897887
def test_nan_step(self):
7890-
assert_raises_regex(
7891-
ValueError, "cannot compute length",
7888+
assert_raises(
7889+
ValueError, # "cannot compute length",
78927890
np.arange, 0, 1, np.nan
78937891
)
78947892

@@ -7904,6 +7902,9 @@ def test_require_range(self):
79047902
assert_raises(TypeError, np.arange)
79057903
assert_raises(TypeError, np.arange, step=3)
79067904
assert_raises(TypeError, np.arange, dtype='int64')
7905+
7906+
@pytest.mark.xfail(reason="weird arange signature (optionals before required args)")
7907+
def test_require_range_2(self):
79077908
assert_raises(TypeError, np.arange, start=4)
79087909

79097910
def test_start_stop_kwarg(self):
@@ -7916,6 +7917,7 @@ def test_start_stop_kwarg(self):
79167917
assert len(keyword_start_stop) == 6
79177918
assert_array_equal(keyword_stop, keyword_zerotostop)
79187919

7920+
@pytest.mark.skip(reason="arange for booleans: numpy maybe deprecates?")
79197921
def test_arange_booleans(self):
79207922
# Arange makes some sense for booleans and works up to length 2.
79217923
# But it is weird since `arange(2, 4, dtype=bool)` works.
@@ -7936,28 +7938,6 @@ def test_arange_booleans(self):
79367938
with pytest.raises(TypeError):
79377939
np.arange(3, dtype="bool")
79387940

7939-
@pytest.mark.parametrize("dtype", ["S3", "U", "5i"])
7940-
def test_rejects_bad_dtypes(self, dtype):
7941-
dtype = np.dtype(dtype)
7942-
DType_name = re.escape(str(type(dtype)))
7943-
with pytest.raises(TypeError,
7944-
match=rf"arange\(\) not supported for inputs .* {DType_name}"):
7945-
np.arange(2, dtype=dtype)
7946-
7947-
def test_rejects_strings(self):
7948-
# Explicitly test error for strings which may call "b" - "a":
7949-
DType_name = re.escape(str(type(np.array("a").dtype)))
7950-
with pytest.raises(TypeError,
7951-
match=rf"arange\(\) not supported for inputs .* {DType_name}"):
7952-
np.arange("a", "b")
7953-
7954-
def test_byteswapped(self):
7955-
res_be = np.arange(1, 1000, dtype=">i4")
7956-
res_le = np.arange(1, 1000, dtype="<i4")
7957-
assert res_be.dtype == ">i4"
7958-
assert res_le.dtype == "<i4"
7959-
assert_array_equal(res_le, res_be)
7960-
79617941
@pytest.mark.parametrize("which", [0, 1, 2])
79627942
def test_error_paths_and_promotion(self, which):
79637943
args = [0, 1, 2] # start, stop, and step
@@ -7967,20 +7947,12 @@ def test_error_paths_and_promotion(self, which):
79677947

79687948
# Cover stranger error path, test only to achieve code coverage!
79697949
args[which] = [None, []]
7970-
with pytest.raises(ValueError):
7950+
with pytest.raises((ValueError, RuntimeError)):
79717951
# Fails discovering start dtype
79727952
np.arange(*args)
79737953

79747954

79757955

7976-
7977-
7978-
7979-
7980-
7981-
7982-
7983-
79847956
@pytest.mark.xfail(reason='comparison: builtin.bools or...?')
79857957
def test_richcompare_scalar_boolean_singleton_return():
79867958
# These are currently guaranteed to be the boolean singletons, but maybe

0 commit comments

Comments
 (0)