Skip to content

Commit e1ad9d8

Browse files
committed
add ndarray.dot
1 parent f667f8c commit e1ad9d8

File tree

4 files changed

+39
-38
lines changed

4 files changed

+39
-38
lines changed

torch_np/_funcs.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,18 @@ def fill_diagonal(a, val, wrap=False):
9393
return _helpers.array_from(result)
9494

9595

96-
# ### sorting ###
96+
def vdot(a, b, /):
97+
t_a, t_b = _helpers.to_tensors(a, b)
98+
result = _impl.vdot(t_a, t_b)
99+
return result.item()
100+
101+
102+
def dot(a, b, out=None):
103+
t_a, t_b = _helpers.to_tensors(a, b)
104+
result = _impl.dot(t_a, t_b)
105+
return _helpers.result_or_out(result, out)
106+
107+
97108

98109
# ### sort and partition ###
99110

torch_np/_ndarray.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ def reshape(self, *shape, order="C"):
362362

363363
diagonal = _funcs.diagonal
364364
trace = _funcs.trace
365+
dot = _funcs.dot
365366

366367
### sorting ###
367368

torch_np/_wrapper.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -421,17 +421,6 @@ def where(condition, x=None, y=None, /):
421421
return asarray(result)
422422

423423

424-
def vdot(a, b, /):
425-
t_a, t_b = _helpers.to_tensors(a, b)
426-
result = _impl.vdot(t_a, t_b)
427-
return result.item()
428-
429-
430-
def dot(a, b, out=None):
431-
t_a, t_b = _helpers.to_tensors(a, b)
432-
result = _impl.dot(t_a, t_b)
433-
return _helpers.result_or_out(result, out)
434-
435424

436425
###### module-level queries of object properties
437426

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2386,7 +2386,6 @@ def test_flatten(self):
23862386
assert_equal(x1.flatten('F'), x1.T.flatten())
23872387

23882388

2389-
@pytest.mark.xfail(reason="TODO np.dot")
23902389
@pytest.mark.parametrize('func', (np.dot, np.matmul))
23912390
def test_arr_mult(self, func):
23922391
a = np.array([[1, 0], [0, 1]])
@@ -2428,7 +2427,27 @@ def test_arr_mult(self, func):
24282427
assert_equal(func(ebf.T, ebf), eaf)
24292428
assert_equal(func(ebf, ebf.T), eaf)
24302429
assert_equal(func(ebf.T, ebf.T), eaf)
2430+
# syrk - different shape
2431+
for et in [np.float32, np.float64, np.complex64, np.complex128]:
2432+
edf = d.astype(et)
2433+
eddtf = ddt.astype(et)
2434+
edtdf = dtd.astype(et)
2435+
assert_equal(func(edf, edf.T), eddtf)
2436+
assert_equal(func(edf.T, edf), edtdf)
24312437

2438+
assert_equal(
2439+
func(edf[:edf.shape[0] // 2, :], edf[::2, :].T),
2440+
func(edf[:edf.shape[0] // 2, :].copy(), edf[::2, :].T.copy())
2441+
)
2442+
assert_equal(
2443+
func(edf[::2, :], edf[:edf.shape[0] // 2, :].T),
2444+
func(edf[::2, :].copy(), edf[:edf.shape[0] // 2, :].T.copy())
2445+
)
2446+
2447+
2448+
@pytest.mark.skip(reason="dot/matmul with negative strides")
2449+
@pytest.mark.parametrize('func', (np.dot, np.matmul))
2450+
def test_arr_mult_2(self, func):
24322451
# syrk - different shape, stride, and view validations
24332452
for et in [np.float32, np.float64, np.complex64, np.complex128]:
24342453
edf = d.astype(et)
@@ -2448,22 +2467,6 @@ def test_arr_mult(self, func):
24482467
func(edf, edf[:, ::-1].T),
24492468
func(edf, edf[:, ::-1].T.copy())
24502469
)
2451-
assert_equal(
2452-
func(edf[:edf.shape[0] // 2, :], edf[::2, :].T),
2453-
func(edf[:edf.shape[0] // 2, :].copy(), edf[::2, :].T.copy())
2454-
)
2455-
assert_equal(
2456-
func(edf[::2, :], edf[:edf.shape[0] // 2, :].T),
2457-
func(edf[::2, :].copy(), edf[:edf.shape[0] // 2, :].T.copy())
2458-
)
2459-
2460-
# syrk - different shape
2461-
for et in [np.float32, np.float64, np.complex64, np.complex128]:
2462-
edf = d.astype(et)
2463-
eddtf = ddt.astype(et)
2464-
edtdf = dtd.astype(et)
2465-
assert_equal(func(edf, edf.T), eddtf)
2466-
assert_equal(func(edf.T, edf), edtdf)
24672470

24682471
@pytest.mark.xfail(reason="TODO np.dot")
24692472
@pytest.mark.parametrize('func', (np.dot, np.matmul))
@@ -2481,6 +2484,11 @@ def test_no_dgemv(self, func, dtype):
24812484
ret2 = func(b.T.copy(), a.T)
24822485
assert_equal(ret1, ret2)
24832486

2487+
2488+
@pytest.mark.skip(reason="__array_interface__")
2489+
@pytest.mark.parametrize('func', (np.dot, np.matmul))
2490+
@pytest.mark.parametrize('dtype', 'ifdFD')
2491+
def test_no_dgemv_2(self, func, dtype):
24842492
# check for unaligned data
24852493
dt = np.dtype(dtype)
24862494
a = np.zeros(8 * dt.itemsize // 2 + 1, dtype='int16')[1:].view(dtype)
@@ -2496,7 +2504,6 @@ def test_no_dgemv(self, func, dtype):
24962504
ret2 = func(b.T.copy(), a.T.copy())
24972505
assert_equal(ret1, ret2)
24982506

2499-
@pytest.mark.xfail(reason="TODO np.dot")
25002507
def test_dot(self):
25012508
a = np.array([[1, 0], [0, 1]])
25022509
b = np.array([[0, 1], [1, 0]])
@@ -2515,15 +2522,8 @@ def test_dot(self):
25152522
a.dot(b=b, out=c)
25162523
assert_equal(c, np.dot(a, b))
25172524

2518-
@pytest.mark.xfail(reason="TODO np.dot")
2519-
def test_dot_type_mismatch(self):
2520-
c = 1.
2521-
A = np.array((1,1), dtype='i,i')
25222525

2523-
assert_raises(TypeError, np.dot, c, A)
2524-
assert_raises(TypeError, np.dot, A, c)
2525-
2526-
@pytest.mark.xfail(reason="TODO np.dot")
2526+
@pytest.mark.xfail(reason="_aligned_zeros")
25272527
def test_dot_out_mem_overlap(self):
25282528
np.random.seed(1)
25292529

0 commit comments

Comments
 (0)