Skip to content

Add ndarray.dot #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions torch_np/_detail/_ufunc_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from . import _util
from . import _dtypes_impl, _util


def deco_ufunc(torch_func):
Expand Down Expand Up @@ -70,7 +70,6 @@ def wrapped(
logical_and = deco_ufunc(torch.logical_and)
logical_or = deco_ufunc(torch.logical_or)
logical_xor = deco_ufunc(torch.logical_xor)
matmul = deco_ufunc(torch.matmul)
maximum = deco_ufunc(torch.maximum)
minimum = deco_ufunc(torch.minimum)
remainder = deco_ufunc(torch.remainder)
Expand Down Expand Up @@ -144,6 +143,16 @@ def _absolute(x):
return torch.absolute(x)


def _matmul(x, y):
# work around RuntimeError: expected scalar type Int but found Double
dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype))
x = _util.cast_if_needed(x, dtype)
y = _util.cast_if_needed(y, dtype)
result = torch.matmul(x, y)
return result


cbrt = deco_ufunc(_cbrt)
positive = deco_ufunc(_positive)
absolute = deco_ufunc(_absolute)
matmul = deco_ufunc(_matmul)
26 changes: 22 additions & 4 deletions torch_np/_detail/implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,17 +492,31 @@ def arange(start=None, stop=None, step=1, dtype=None):
if start is None:
start = 0

# the dtype of the result
if dtype is None:
dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)]
dtype = _dtypes_impl.default_int_dtype
dt_list.append(dtype)
dtype = _dtypes_impl.result_type_impl(dt_list)
dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)]
dt_list.append(dtype)
dtype = _dtypes_impl.result_type_impl(dt_list)

# work around RuntimeError: "arange_cpu" not implemented for 'ComplexFloat'
if dtype.is_complex:
work_dtype, target_dtype = torch.float64, dtype
else:
work_dtype, target_dtype = dtype, dtype

if (step > 0 and start > stop) or (step < 0 and start < stop):
# empty range
return torch.empty(0, dtype=target_dtype)

try:
return torch.arange(start, stop, step, dtype=dtype)
result = torch.arange(start, stop, step, dtype=work_dtype)
result = _util.cast_if_needed(result, target_dtype)
except RuntimeError:
raise ValueError("Maximum allowed size exceeded")

return result


# ### empty/full et al ###

Expand Down Expand Up @@ -783,6 +797,10 @@ def vdot(t_a, t_b, /):


def dot(t_a, t_b):
dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype))
t_a = _util.cast_if_needed(t_a, dtype)
t_b = _util.cast_if_needed(t_b, dtype)

if t_a.ndim == 0 or t_b.ndim == 0:
result = t_a * t_b
elif t_a.ndim == 1 and t_b.ndim == 1:
Expand Down
14 changes: 12 additions & 2 deletions torch_np/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from . import _decorators, _helpers
from ._detail import _flips, _util
from ._detail import _dtypes_impl, _flips, _util
from ._detail import implementations as _impl


Expand Down Expand Up @@ -93,7 +93,17 @@ def fill_diagonal(a, val, wrap=False):
return _helpers.array_from(result)


# ### sorting ###
def vdot(a, b, /):
t_a, t_b = _helpers.to_tensors(a, b)
result = _impl.vdot(t_a, t_b)
return result.item()


def dot(a, b, out=None):
t_a, t_b = _helpers.to_tensors(a, b)
result = _impl.dot(t_a, t_b)
return _helpers.result_or_out(result, out)


# ### sort and partition ###

Expand Down
1 change: 1 addition & 0 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def reshape(self, *shape, order="C"):

diagonal = _funcs.diagonal
trace = _funcs.trace
dot = _funcs.dot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing also vdot?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing in numpy, yes :-)

In [23]: hasattr(np.array([1, 2, 3]), 'vdot')
Out[23]: False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol. Let's still add it to be forward-looking.

Copy link
Collaborator Author

@ev-br ev-br Feb 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, can do.
If we're starting to do this however, what's the guiding principle, where do we stop. E.g. do we want feature parity for the main namespace and ndarray methods?

In [24]: len(dir(np))
Out[24]: 595

In [25]: len(dir(np.ndarray))
Out[25]: 165

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can do it "for free" it may be a fine thing to do? WDYT @rgommers?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no guiding principle, it's an accident of history. It's highly unlikely though that we'll add more method to numpy.ndarray, there are already too many. So I wouldn't add anything that's not already present.


### sorting ###

Expand Down
12 changes: 0 additions & 12 deletions torch_np/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,18 +421,6 @@ def where(condition, x=None, y=None, /):
return asarray(result)


def vdot(a, b, /):
t_a, t_b = _helpers.to_tensors(a, b)
result = _impl.vdot(t_a, t_b)
return result.item()


def dot(a, b, out=None):
t_a, t_b = _helpers.to_tensors(a, b)
result = _impl.dot(t_a, t_b)
return _helpers.result_or_out(result, out)


###### module-level queries of object properties


Expand Down
96 changes: 34 additions & 62 deletions torch_np/tests/numpy_tests/core/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2386,7 +2386,6 @@ def test_flatten(self):
assert_equal(x1.flatten('F'), x1.T.flatten())


@pytest.mark.xfail(reason="TODO np.dot")
@pytest.mark.parametrize('func', (np.dot, np.matmul))
def test_arr_mult(self, func):
a = np.array([[1, 0], [0, 1]])
Expand Down Expand Up @@ -2428,7 +2427,27 @@ def test_arr_mult(self, func):
assert_equal(func(ebf.T, ebf), eaf)
assert_equal(func(ebf, ebf.T), eaf)
assert_equal(func(ebf.T, ebf.T), eaf)
# syrk - different shape
for et in [np.float32, np.float64, np.complex64, np.complex128]:
edf = d.astype(et)
eddtf = ddt.astype(et)
edtdf = dtd.astype(et)
assert_equal(func(edf, edf.T), eddtf)
assert_equal(func(edf.T, edf), edtdf)

assert_equal(
func(edf[:edf.shape[0] // 2, :], edf[::2, :].T),
func(edf[:edf.shape[0] // 2, :].copy(), edf[::2, :].T.copy())
)
assert_equal(
func(edf[::2, :], edf[:edf.shape[0] // 2, :].T),
func(edf[::2, :].copy(), edf[:edf.shape[0] // 2, :].T.copy())
)


@pytest.mark.skip(reason="dot/matmul with negative strides")
@pytest.mark.parametrize('func', (np.dot, np.matmul))
def test_arr_mult_2(self, func):
# syrk - different shape, stride, and view validations
for et in [np.float32, np.float64, np.complex64, np.complex128]:
edf = d.astype(et)
Expand All @@ -2448,24 +2467,7 @@ def test_arr_mult(self, func):
func(edf, edf[:, ::-1].T),
func(edf, edf[:, ::-1].T.copy())
)
assert_equal(
func(edf[:edf.shape[0] // 2, :], edf[::2, :].T),
func(edf[:edf.shape[0] // 2, :].copy(), edf[::2, :].T.copy())
)
assert_equal(
func(edf[::2, :], edf[:edf.shape[0] // 2, :].T),
func(edf[::2, :].copy(), edf[:edf.shape[0] // 2, :].T.copy())
)

# syrk - different shape
for et in [np.float32, np.float64, np.complex64, np.complex128]:
edf = d.astype(et)
eddtf = ddt.astype(et)
edtdf = dtd.astype(et)
assert_equal(func(edf, edf.T), eddtf)
assert_equal(func(edf.T, edf), edtdf)

@pytest.mark.xfail(reason="TODO np.dot")
@pytest.mark.parametrize('func', (np.dot, np.matmul))
@pytest.mark.parametrize('dtype', 'ifdFD')
def test_no_dgemv(self, func, dtype):
Expand All @@ -2481,6 +2483,11 @@ def test_no_dgemv(self, func, dtype):
ret2 = func(b.T.copy(), a.T)
assert_equal(ret1, ret2)


@pytest.mark.skip(reason="__array_interface__")
@pytest.mark.parametrize('func', (np.dot, np.matmul))
@pytest.mark.parametrize('dtype', 'ifdFD')
def test_no_dgemv_2(self, func, dtype):
# check for unaligned data
dt = np.dtype(dtype)
a = np.zeros(8 * dt.itemsize // 2 + 1, dtype='int16')[1:].view(dtype)
Expand All @@ -2496,7 +2503,6 @@ def test_no_dgemv(self, func, dtype):
ret2 = func(b.T.copy(), a.T.copy())
assert_equal(ret1, ret2)

@pytest.mark.xfail(reason="TODO np.dot")
def test_dot(self):
a = np.array([[1, 0], [0, 1]])
b = np.array([[0, 1], [1, 0]])
Expand All @@ -2515,15 +2521,8 @@ def test_dot(self):
a.dot(b=b, out=c)
assert_equal(c, np.dot(a, b))

@pytest.mark.xfail(reason="TODO np.dot")
def test_dot_type_mismatch(self):
c = 1.
A = np.array((1,1), dtype='i,i')

assert_raises(TypeError, np.dot, c, A)
assert_raises(TypeError, np.dot, A, c)

@pytest.mark.xfail(reason="TODO np.dot")
@pytest.mark.xfail(reason="_aligned_zeros")
def test_dot_out_mem_overlap(self):
np.random.seed(1)

Expand Down Expand Up @@ -5627,7 +5626,7 @@ def test_dot_array_order(self):
assert_equal(np.dot(b, a), res)
assert_equal(np.dot(b, b), res)

@pytest.mark.skip(reason='TODO: nbytes, view')
@pytest.mark.skip(reason='TODO: nbytes, view, __array_interface__')
def test_accelerate_framework_sgemv_fix(self):

def aligned_array(shape, align, dtype, order='C'):
Expand Down Expand Up @@ -7878,7 +7877,6 @@ def test_view_discard_refcount(self):
assert_equal(arr, orig)


@pytest.mark.xfail(reason='TODO')
class TestArange:
def test_infinite(self):
assert_raises_regex(
Expand All @@ -7887,8 +7885,8 @@ def test_infinite(self):
)

def test_nan_step(self):
assert_raises_regex(
ValueError, "cannot compute length",
assert_raises(
ValueError, # "cannot compute length",
np.arange, 0, 1, np.nan
)

Expand All @@ -7904,6 +7902,9 @@ def test_require_range(self):
assert_raises(TypeError, np.arange)
assert_raises(TypeError, np.arange, step=3)
assert_raises(TypeError, np.arange, dtype='int64')

@pytest.mark.xfail(reason="weird arange signature (optionals before required args)")
def test_require_range_2(self):
assert_raises(TypeError, np.arange, start=4)

def test_start_stop_kwarg(self):
Expand All @@ -7916,6 +7917,7 @@ def test_start_stop_kwarg(self):
assert len(keyword_start_stop) == 6
assert_array_equal(keyword_stop, keyword_zerotostop)

@pytest.mark.skip(reason="arange for booleans: numpy maybe deprecates?")
def test_arange_booleans(self):
# Arange makes some sense for booleans and works up to length 2.
# But it is weird since `arange(2, 4, dtype=bool)` works.
Expand All @@ -7936,28 +7938,6 @@ def test_arange_booleans(self):
with pytest.raises(TypeError):
np.arange(3, dtype="bool")

@pytest.mark.parametrize("dtype", ["S3", "U", "5i"])
def test_rejects_bad_dtypes(self, dtype):
dtype = np.dtype(dtype)
DType_name = re.escape(str(type(dtype)))
with pytest.raises(TypeError,
match=rf"arange\(\) not supported for inputs .* {DType_name}"):
np.arange(2, dtype=dtype)

def test_rejects_strings(self):
# Explicitly test error for strings which may call "b" - "a":
DType_name = re.escape(str(type(np.array("a").dtype)))
with pytest.raises(TypeError,
match=rf"arange\(\) not supported for inputs .* {DType_name}"):
np.arange("a", "b")

def test_byteswapped(self):
res_be = np.arange(1, 1000, dtype=">i4")
res_le = np.arange(1, 1000, dtype="<i4")
assert res_be.dtype == ">i4"
assert res_le.dtype == "<i4"
assert_array_equal(res_le, res_be)

@pytest.mark.parametrize("which", [0, 1, 2])
def test_error_paths_and_promotion(self, which):
args = [0, 1, 2] # start, stop, and step
Expand All @@ -7967,20 +7947,12 @@ def test_error_paths_and_promotion(self, which):

# Cover stranger error path, test only to achieve code coverage!
args[which] = [None, []]
with pytest.raises(ValueError):
with pytest.raises((ValueError, RuntimeError)):
# Fails discovering start dtype
np.arange(*args)











@pytest.mark.xfail(reason='comparison: builtin.bools or...?')
def test_richcompare_scalar_boolean_singleton_return():
# These are currently guaranteed to be the boolean singletons, but maybe
Expand Down