diff --git a/.gitignore b/.gitignore index cfd46377..56438c6d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,5 @@ -__pycache__/* -autogen/__pycache__ -torch_np/__pycache__/* -torch_np/tests/__pycache__/* -torch_np/tests/numpy_tests/core/__pycache__/* -torch_np/testing/__pycache__/* +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] .coverage diff --git a/autogen/gen_ufuncs_2.py b/autogen/gen_ufuncs_2.py new file mode 100644 index 00000000..8047a697 --- /dev/null +++ b/autogen/gen_ufuncs_2.py @@ -0,0 +1,213 @@ +from dump_namespace import grab_namespace, get_signature + +import numpy as np + +namespace = np + +dct = grab_namespace(namespace) + + +# SKIP these (need special handling) +skip = {np.frexp, np.modf, # non-standard unary ufunc signatures + np.isnat, + np.invert, # bitwise NOT operator + np.spacing, # niche, does not have a direct equivalent +} + +# np functions where torch names differ +torch_names = {np.radians : "deg2rad", + np.degrees : "rad2deg", + np.conjugate : "conj_physical", + np.fabs : "absolute", # FIXME: np.fabs raises form complex + np.rint : "round", + np.left_shift: "bitwise_left_shift", + np.right_shift: "bitwise_right_shift", +} + + +# np functions which do not have a torch equivalent +default_stanza = "torch.{torch_name}(x, out=out)" + +stanzas = {np.cbrt : "torch.pow(x, 1/3, out=out)", + + # XXX what on earth is np.positive + np.positive: "+x", + + # these three do not have an out arg + np.isinf: "torch.isinf(x)", + np.isnan: "torch.isnan(x)", + np.isfinite: "torch.isfinite(x)", +} + + +# for these np functions, pytorch analog does not have the out= arg +needs_out = {np.isinf, np.isnan, np.isfinite, np.positive} +add_out_stanza = """ + if out is not None: + out[...] = result +""" + + +header = """\ +# this file is autogenerated via gen_ufuncs.py +# do not edit manually! + +import torch + +import _util +from _ndarray import asarray_replacer + +""" + +test_header = header + """\ +import numpy as np +import torch + +from _unary_ufuncs import * +from testing import assert_allclose +""" + + +template = """ """ + +test_template = """ + +def test_{np_name}(): + assert_allclose(np.{np_name}(0.5), + {np_name}(0.5), atol=1e-14, check_dtype=False) + +""" + + +###### UNARY UFUNCS ################################### + +_all_list = [] +main_text = header +test_text = test_header + +_impl_list = [] +_ufunc_list = [] + +for ufunc in dct['ufunc']: + if ufunc in skip: + continue + + if ufunc.nin == 1: +# print(get_signature(ufunc)) + + torch_name = torch_names.get(ufunc) + if torch_name is None: + torch_name = ufunc.__name__ + +# print(ufunc.__name__, ' -- ', torch_name) + + _impl_stanza = "{np_name} = deco_unary_ufunc(torch.{torch_name})" + _impl_stanza = _impl_stanza.format(np_name=ufunc.__name__, + torch_name=torch_name,) + _impl_list.append(_impl_stanza) + + continue + + torch_stanza = stanzas.get(ufunc) + if torch_stanza is None: + torch_stanza = default_stanza.format(torch_name=torch_name) + + out_stanza= add_out_stanza if ufunc in needs_out else "" + + main_text += template.format(np_name=ufunc.__name__, + torch_stanza=torch_stanza, + out_stanza=out_stanza) + test_text += test_template.format(np_name=ufunc.__name__) + + _all_list.append(ufunc.__name__) + + +print("\n".join(_impl_list)) +print("\n\n-----\n\n") + +''' +main_text += "\n\n__all__ = %s" % _all_list + + +with open("_unary_ufuncs.py", "w") as f: + f.write(main_text) + +with open("test_unary_ufuncs.py", "w") as f: + f.write(test_text) +''' + +###### BINARY UFUNCS ################################### + + + +test_header = header + """\ +import numpy as np +import torch + +from _binary_ufuncs import * +from testing import assert_allclose +""" + + +template = """ + + +""" + +test_template = """ + +def test_{np_name}(): + assert_allclose(np.{np_name}(0.5, 0.6), + {np_name}(0.5, 0.6), atol=1e-7, check_dtype=False) + +""" + + + +skip = {np.divmod, # two outputs +} + + +torch_names = {np.power: "pow", + np.equal: "eq", +} + + +_all_list = [] +main_text = header +test_text = test_header + +_impl_list = [] +_ufunc_list = [] + + +for ufunc in dct['ufunc']: + + if ufunc in skip: + continue + + if ufunc.nin == 2: + ## print(get_signature(ufunc)) + + torch_name = torch_names.get(ufunc) + if torch_name is None: + torch_name = ufunc.__name__ + + _impl_stanza = "{np_name} = deco_binary_ufunc(torch.{torch_name})" + _impl_stanza = _impl_stanza.format(np_name=ufunc.__name__, + torch_name=torch_name,) + _impl_list.append(_impl_stanza) + + _ufunc_stanza = "{np_name} = deco_ufunc_from_impl(_ufunc_impl.{np_name})" + _ufunc_stanza = _ufunc_stanza.format(np_name=ufunc.__name__) + _ufunc_list.append(_ufunc_stanza) + + +print("\n".join(_impl_list)) + +print("\n\n") +print("\n".join(_ufunc_list)) + + + + diff --git a/torch_np/__init__.py b/torch_np/__init__.py index 5060b762..5b2eee7f 100644 --- a/torch_np/__init__.py +++ b/torch_np/__init__.py @@ -6,10 +6,15 @@ from ._unary_ufuncs import * from ._binary_ufuncs import * from ._ndarray import can_cast, result_type, newaxis -from ._util import AxisError +from ._util import AxisError, UFuncTypeError from ._getlimits import iinfo, finfo from ._getlimits import errstate inf = float('inf') nan = float('nan') + +#### HACK HACK HACK #### +import torch +torch.set_default_dtype(torch.float64) +del torch diff --git a/torch_np/_binary_ufuncs.py b/torch_np/_binary_ufuncs.py index b46b1059..3f1ae960 100644 --- a/torch_np/_binary_ufuncs.py +++ b/torch_np/_binary_ufuncs.py @@ -1,678 +1,79 @@ -# this file is autogenerated via gen_ufuncs.py -# do not edit manually! - +import functools import torch from . import _util -from ._ndarray import asarray_replacer - -from ._ndarray import asarray, ndarray, can_cast +from ._ndarray import asarray from . import _dtypes from . import _helpers - -def add(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or not where: - raise NotImplementedError - - # XXX: dtype=... parameter is silently ignored - - x1_array = asarray(x1) - x2_array = asarray(x2) - - arrays = (x1_array, x2_array) - x1_tensor, x2_tensor = _helpers.cast_and_broadcast(arrays, out, casting) - - result = torch.add(x1_tensor, x2_tensor) - - return _helpers.result_or_out(result, out) - - - - - - -##################################### - -''' -@asarray_replacer("two") -def add(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.add(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result -''' - - -@asarray_replacer("two") -def arctan2(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.arctan2(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def bitwise_and(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.bitwise_and(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def bitwise_or(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.bitwise_or(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def bitwise_xor(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.bitwise_xor(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def copysign(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.copysign(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def divide(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.divide(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def equal(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.eq(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def float_power(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.float_power(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def floor_divide(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.floor_divide(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def fmax(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.fmax(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def fmin(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.fmin(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def fmod(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.fmod(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def gcd(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.gcd(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def greater(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.greater(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def greater_equal(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.greater_equal(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def heaviside(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.heaviside(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def hypot(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.hypot(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def lcm(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.lcm(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def ldexp(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.ldexp(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def left_shift(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.left_shift(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def less(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.less(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def less_equal(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.less_equal(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def logaddexp(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.logaddexp(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def logaddexp2(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.logaddexp2(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def logical_and(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.logical_and(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def logical_or(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.logical_or(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def logical_xor(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.logical_xor(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def matmul(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.matmul(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def maximum(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.maximum(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def minimum(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.minimum(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def remainder(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.remainder(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def multiply(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.multiply(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def nextafter(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.nextafter(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def not_equal(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.not_equal(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def power(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.pow(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def remainder(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.remainder(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def right_shift(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.right_shift(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def subtract(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.subtract(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result - - - -@asarray_replacer("two") -def divide(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX: dtypes, casting - out = out.to(dtype) - result = torch.divide(x1, x2, out=out) - if dtype is not None: - result = result.to(dtype) - return result +from . import _ufunc_impl + +# +# Functions in _ufunc_impl receive arrays, implement common tasks with ufunc args +# and delegate heavy lifting to pytorch equivalents. +# +# Functions in this file implement binary ufuncs: wrap two first arguments in +# asarray and delegate to functions from _ufunc_impl. +# +# One other user of _ufunc_impl functions in ndarray, where its __add__ method +# calls _ufunc_impl.add and so on. Note that ndarray dunders already know +# that its first arg is an array, so they only convert the second argument. +# +# XXX: While it sounds tempting to merge _binary_ufuncs.py and _ufunc_impl.py +# files, doing it would currently create import cycles. +# + +# TODO: deduplicate with _unary_ufuncs/deco_unary_ufunc_from_impl, +# _ndarray/asarray_replacer, and _wrapper/concatenate et al +def deco_ufunc_from_impl(impl_func): + @functools.wraps(impl_func) + def wrapped(x1, x2, *args, **kwds): + x1_array = asarray(x1) + x2_array = asarray(x2) + return impl_func(x1_array, x2_array, *args, **kwds) + return wrapped + + +# the list is autogenerated, cf autogen/gen_ufunc_2.py +add = deco_ufunc_from_impl(_ufunc_impl.add) +arctan2 = deco_ufunc_from_impl(_ufunc_impl.arctan2) +bitwise_and = deco_ufunc_from_impl(_ufunc_impl.bitwise_and) +bitwise_or = deco_ufunc_from_impl(_ufunc_impl.bitwise_or) +bitwise_xor = deco_ufunc_from_impl(_ufunc_impl.bitwise_xor) +copysign = deco_ufunc_from_impl(_ufunc_impl.copysign) +divide = deco_ufunc_from_impl(_ufunc_impl.divide) +equal = deco_ufunc_from_impl(_ufunc_impl.equal) +float_power = deco_ufunc_from_impl(_ufunc_impl.float_power) +floor_divide = deco_ufunc_from_impl(_ufunc_impl.floor_divide) +fmax = deco_ufunc_from_impl(_ufunc_impl.fmax) +fmin = deco_ufunc_from_impl(_ufunc_impl.fmin) +fmod = deco_ufunc_from_impl(_ufunc_impl.fmod) +gcd = deco_ufunc_from_impl(_ufunc_impl.gcd) +greater = deco_ufunc_from_impl(_ufunc_impl.greater) +greater_equal = deco_ufunc_from_impl(_ufunc_impl.greater_equal) +heaviside = deco_ufunc_from_impl(_ufunc_impl.heaviside) +hypot = deco_ufunc_from_impl(_ufunc_impl.hypot) +lcm = deco_ufunc_from_impl(_ufunc_impl.lcm) +ldexp = deco_ufunc_from_impl(_ufunc_impl.ldexp) +left_shift = deco_ufunc_from_impl(_ufunc_impl.left_shift) +less = deco_ufunc_from_impl(_ufunc_impl.less) +less_equal = deco_ufunc_from_impl(_ufunc_impl.less_equal) +logaddexp = deco_ufunc_from_impl(_ufunc_impl.logaddexp) +logaddexp2 = deco_ufunc_from_impl(_ufunc_impl.logaddexp2) +logical_and = deco_ufunc_from_impl(_ufunc_impl.logical_and) +logical_or = deco_ufunc_from_impl(_ufunc_impl.logical_or) +logical_xor = deco_ufunc_from_impl(_ufunc_impl.logical_xor) +matmul = deco_ufunc_from_impl(_ufunc_impl.matmul) +maximum = deco_ufunc_from_impl(_ufunc_impl.maximum) +minimum = deco_ufunc_from_impl(_ufunc_impl.minimum) +remainder = deco_ufunc_from_impl(_ufunc_impl.remainder) +multiply = deco_ufunc_from_impl(_ufunc_impl.multiply) +nextafter = deco_ufunc_from_impl(_ufunc_impl.nextafter) +not_equal = deco_ufunc_from_impl(_ufunc_impl.not_equal) +power = deco_ufunc_from_impl(_ufunc_impl.power) +remainder = deco_ufunc_from_impl(_ufunc_impl.remainder) +right_shift = deco_ufunc_from_impl(_ufunc_impl.right_shift) +subtract = deco_ufunc_from_impl(_ufunc_impl.subtract) +divide = deco_ufunc_from_impl(_ufunc_impl.divide) diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 904d8229..16a0691a 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -5,6 +5,7 @@ from . import _util from . import _helpers from . import _dtypes +from . import _ufunc_impl NoValue = None newaxis = None @@ -106,6 +107,9 @@ def copy(self, order='C'): tensor = self._tensor.clone() # XXX: clone or detach? return ndarray._from_tensor_and_base(tensor, None) + def tolist(self): + return self._tensor.tolist() + ### niceties ### def __str__(self): return str(self._tensor).replace("tensor", "array_w").replace("dtype=torch.", "dtype=") @@ -115,28 +119,31 @@ def __str__(self): ### comparisons ### def __eq__(self, other): try: - t_other = asarray(other).get + return _ufunc_impl.equal(self, asarray(other)) except RuntimeError: # Failed to convert other to array: definitely not equal. - # TODO: generalize, delegate to ufuncs falsy = torch.full(self.shape, fill_value=False, dtype=bool) return asarray(falsy) - return asarray(self._tensor == asarray(other).get()) def __neq__(self, other): - return asarray(self._tensor != asarray(other).get()) + try: + return _ufunc_impl.not_equal(self, asarray(other)) + except RuntimeError: + # Failed to convert other to array: definitely not equal. + falsy = torch.full(self.shape, fill_value=True, dtype=bool) + return asarray(falsy) def __gt__(self, other): - return asarray(self._tensor > asarray(other).get()) + return _ufunc_impl.greater(self, asarray(other)) def __lt__(self, other): - return asarray(self._tensor < asarray(other).get()) + return _ufunc_impl.less(self, asarray(other)) def __ge__(self, other): - return asarray(self._tensor >= asarray(other).get()) + return _ufunc_impl.greater_equal(self, asarray(other)) def __le__(self, other): - return asarray(self._tensor <= asarray(other).get()) + return _ufunc_impl.less_equal(self, asarray(other)) def __bool__(self): try: @@ -178,76 +185,130 @@ def __len__(self): ### arithmetic ### + # add, self + other def __add__(self, other): - other_tensor = asarray(other).get() - return asarray(self._tensor.__add__(other_tensor)) + return _ufunc_impl.add(self, asarray(other)) + + def __radd__(self, other): + return _ufunc_impl.add(self, asarray(other)) def __iadd__(self, other): - other_tensor = asarray(other).get() - return asarray(self._tensor.__iadd__(other_tensor)) + return _ufunc_impl.add(self, asarray(other), out=self) + + # sub, self - other def __sub__(self, other): - other_tensor = asarray(other).get() - try: - return asarray(self._tensor.__sub__(other_tensor)) - except RuntimeError as e: - raise TypeError(e.args) + return _ufunc_impl.subtract(self, asarray(other)) + + def __rsub__(self, other): + return _ufunc_impl.subtract(self, asarray(other)) + + def __isub__(self, other): + return _ufunc_impl.subtract(self, asarray(other), out=self) + + # mul, self * other def __mul__(self, other): - other_tensor = asarray(other).get() - return asarray(self._tensor.__mul__(other_tensor)) + return _ufunc_impl.multiply(self, asarray(other)) def __rmul__(self, other): - other_tensor = asarray(other).get() - return asarray(self._tensor.__rmul__(other_tensor)) + return _ufunc_impl.multiply(self, asarray(other)) - def __floordiv__(self, other): - other_tensor = asarray(other).get() - return asarray(self._tensor.__floordiv__(other_tensor)) + def __imul__(self, other): + return _ufunc_impl.multiply(self, asarray(other), out=self) - def __ifloordiv__(self, other): - other_tensor = asarray(other).get() - return asarray(self._tensor.__ifloordiv__(other_tensor)) + # div, self / other def __truediv__(self, other): - other_tensor = asarray(other).get() - return asarray(self._tensor.__truediv__(other_tensor)) + return _ufunc_impl.divide(self, asarray(other)) + + def __rtruediv__(self, other): + return _ufunc_impl.divide(self, asarray(other)) def __itruediv__(self, other): - other_tensor = asarray(other).get() - return asarray(self._tensor.__itruediv__(other_tensor)) + return _ufunc_impl.divide(self, asarray(other), out=self) + + + # floordiv, self // other + def __floordiv__(self, other): + return _ufunc_impl.floor_divide(self, asarray(other)) + + def __rfloordiv__(self, other): + return _ufunc_impl.floor_divide(self, asarray(other)) + + def __ifloordiv__(self, other): + return _ufunc_impl.floor_divide(self, asarray(other), out=self) + + + # power, self**exponent + def __pow__(self, exponent): + return _ufunc_impl.float_power(self, asarray(exponent)) + + def __rpow__(self, exponent): + return _ufunc_impl.float_power(self, asarray(exponent)) + + def __ipow__(self, exponent): + return _ufunc_impl.float_power(self, asarray(exponent), out=self) + + # remainder, self % other def __mod__(self, other): - other_tensor = asarray(other).get() - return asarray(self._tensor.__mod__(other_tensor)) + return _ufunc_impl.remainder(self, asarray(other)) + + def __rmod__(self, other): + return _ufunc_impl.remainder(self, asarray(other)) def __imod__(self, other): - other_tensor = asarray(other).get() - return asarray(self._tensor.__imod__(other_tensor)) + return _ufunc_impl.remainder(self, asarray(other), out=self) + + + # bitwise ops + # and, self & other + def __and__(self, other): + return _ufunc_impl.bitwise_and(self, asarray(other)) + + def __rand__(self, other): + return _ufunc_impl.bitwise_and(self, asarray(other)) + + def __iand__(self, other): + return _ufunc_impl.bitwise_and(self, asarray(other), out=self) + + # or, self | other def __or__(self, other): - other_tensor = asarray(other).get() - return asarray(self._tensor.__or__(other_tensor)) + return _ufunc_impl.bitwise_or(self, asarray(other)) + + def __ror__(self, other): + return _ufunc_impl.bitwise_or(self, asarray(other)) def __ior__(self, other): - other_tensor = asarray(other).get() - return asarray(self._tensor.__ior__(other_tensor)) + return _ufunc_impl.bitwise_or(self, asarray(other), out=self) + + + # xor, self ^ other + def __xor__(self, other): + return _ufunc_impl.bitwise_xor(self, asarray(other)) + + def __rxor__(self, other): + return _ufunc_impl.bitwise_xor(self, asarray(other)) + + def __ixor__(self, other): + return _ufunc_impl.bitwise_xor(self, asarray(other), out=self) + + # unary ops def __invert__(self): - return asarray(self._tensor.__invert__()) + return _ufunc_impl.invert(self) def __abs__(self): - return asarray(self._tensor.__abs__()) + return _ufunc_impl.absolute(self) + + def __pos__(self): + return _ufunc_impl.positive(self) def __neg__(self): - try: - return asarray(self._tensor.__neg__()) - except RuntimeError as e: - raise TypeError(e.args) + return _ufunc_impl.negative(self) - def __pow__(self, exponent): - exponent_tensor = asarray(exponent).get() - return asarray(self._tensor.__pow__(exponent_tensor)) ### methods to match namespace functions @@ -483,22 +544,12 @@ def __init__(self, dispatch='one'): self._dispatch = dispatch def __call__(self, func): - if self._dispatch == 'one': @functools.wraps(func) def wrapped(x, *args, **kwds): x_tensor = asarray(x).get() return asarray(func(x_tensor, *args, **kwds)) return wrapped - - elif self._dispatch == 'two': - @functools.wraps(func) - def wrapped(x, y, *args, **kwds): - x_tensor = asarray(x).get() - y_tensor = asarray(y).get() - return asarray(func(x_tensor, y_tensor, *args, **kwds)) - return wrapped - else: raise ValueError @@ -513,8 +564,17 @@ def can_cast(from_, to, casting='safe'): def result_type(*arrays_and_dtypes): - dtypes = [elem if isinstance(elem, _dtypes.dtype) else asarray(elem).dtype - for elem in arrays_and_dtypes] + dtypes = [] + + from ._dtypes import issubclass_ + + for entry in arrays_and_dtypes: + if issubclass_(entry, _dtypes._scalar_types.generic): + dtypes.append(_dtypes.dtype(entry)) + elif isinstance(entry, _dtypes.dtype): + dtypes.append(entry) + else: + dtypes.append(asarray(entry).dtype) dtyp = dtypes[0] if len(dtypes) == 1: diff --git a/torch_np/_ufunc_impl.py b/torch_np/_ufunc_impl.py new file mode 100644 index 00000000..894ace34 --- /dev/null +++ b/torch_np/_ufunc_impl.py @@ -0,0 +1,156 @@ +import torch + +from . import _util +from . import _helpers + +def deco_binary_ufunc(torch_func): + """Common infra for binary ufuncs: receive arrays, sort out type casting, + broadcasting, out array handling etc, and delegate to the + pytorch function for actual work, then wrap the results into an array. + + x1, x2 are arrays! array_like -> array conversion is the caller responsibility. + """ + def wrapped(x1, x2, /, out=None, *, where=True, + casting='same_kind', order='K', dtype=None, subok=False, **kwds): + _util.subok_not_ok(subok=subok) + if order != 'K' or not where: + raise NotImplementedError + + # XXX: dtype=... parameter + if dtype is not None: + raise NotImplementedError + + arrays = (x1, x2) + tensors = _helpers.cast_and_broadcast(arrays, out, casting) + + result = torch_func(*tensors) + + return _helpers.result_or_out(result, out) + return wrapped + + +def deco_unary_ufunc(torch_func): + # TODO: deduplicate with `deco_binary_ufunc` above. Need to figure out the + # effect of the `dtype` parameter, does it differ between unary and binary ufuncs. + def wrapped(x1, /, out=None, *, where=True, + casting='same_kind', order='K', dtype=None, subok=False, **kwds): + _util.subok_not_ok(subok=subok) + if order != 'K' or not where: + raise NotImplementedError + + # XXX: dtype=... parameter + if dtype is not None: + raise NotImplementedError + + arrays = (x1, ) + tensors = _helpers.cast_and_broadcast(arrays, out, casting) + + result = torch_func(*tensors) + + return _helpers.result_or_out(result, out) + return wrapped + + + + +# binary ufuncs: the list is autogenerated, cf autogen/gen_ufunc_2.py +# And edited manually! np.equal <--> torch.eq, not torch.equal +add = deco_binary_ufunc(torch.add) +arctan2 = deco_binary_ufunc(torch.arctan2) +bitwise_and = deco_binary_ufunc(torch.bitwise_and) +bitwise_or = deco_binary_ufunc(torch.bitwise_or) +bitwise_xor = deco_binary_ufunc(torch.bitwise_xor) +copysign = deco_binary_ufunc(torch.copysign) +divide = deco_binary_ufunc(torch.divide) +equal = deco_binary_ufunc(torch.eq) +float_power = deco_binary_ufunc(torch.float_power) +floor_divide = deco_binary_ufunc(torch.floor_divide) +fmax = deco_binary_ufunc(torch.fmax) +fmin = deco_binary_ufunc(torch.fmin) +fmod = deco_binary_ufunc(torch.fmod) +gcd = deco_binary_ufunc(torch.gcd) +greater = deco_binary_ufunc(torch.greater) +greater_equal = deco_binary_ufunc(torch.greater_equal) +heaviside = deco_binary_ufunc(torch.heaviside) +hypot = deco_binary_ufunc(torch.hypot) +lcm = deco_binary_ufunc(torch.lcm) +ldexp = deco_binary_ufunc(torch.ldexp) +left_shift = deco_binary_ufunc(torch.bitwise_left_shift) +less = deco_binary_ufunc(torch.less) +less_equal = deco_binary_ufunc(torch.less_equal) +logaddexp = deco_binary_ufunc(torch.logaddexp) +logaddexp2 = deco_binary_ufunc(torch.logaddexp2) +logical_and = deco_binary_ufunc(torch.logical_and) +logical_or = deco_binary_ufunc(torch.logical_or) +logical_xor = deco_binary_ufunc(torch.logical_xor) +matmul = deco_binary_ufunc(torch.matmul) +maximum = deco_binary_ufunc(torch.maximum) +minimum = deco_binary_ufunc(torch.minimum) +remainder = deco_binary_ufunc(torch.remainder) +multiply = deco_binary_ufunc(torch.multiply) +nextafter = deco_binary_ufunc(torch.nextafter) +not_equal = deco_binary_ufunc(torch.not_equal) +power = deco_binary_ufunc(torch.pow) +remainder = deco_binary_ufunc(torch.remainder) +right_shift = deco_binary_ufunc(torch.bitwise_right_shift) +subtract = deco_binary_ufunc(torch.subtract) +divide = deco_binary_ufunc(torch.divide) + + + +# unary ufuncs: the list is autogenerated, cf autogen/gen_ufunc_2.py +absolute = deco_unary_ufunc(torch.absolute) +#absolute = deco_unary_ufunc(torch.absolute) +arccos = deco_unary_ufunc(torch.arccos) +arccosh = deco_unary_ufunc(torch.arccosh) +arcsin = deco_unary_ufunc(torch.arcsin) +arcsinh = deco_unary_ufunc(torch.arcsinh) +arctan = deco_unary_ufunc(torch.arctan) +arctanh = deco_unary_ufunc(torch.arctanh) +ceil = deco_unary_ufunc(torch.ceil) +conjugate = deco_unary_ufunc(torch.conj_physical) +#conjugate = deco_unary_ufunc(torch.conj_physical) +cos = deco_unary_ufunc(torch.cos) +cosh = deco_unary_ufunc(torch.cosh) +deg2rad = deco_unary_ufunc(torch.deg2rad) +degrees = deco_unary_ufunc(torch.rad2deg) +exp = deco_unary_ufunc(torch.exp) +exp2 = deco_unary_ufunc(torch.exp2) +expm1 = deco_unary_ufunc(torch.expm1) +fabs = deco_unary_ufunc(torch.absolute) +floor = deco_unary_ufunc(torch.floor) +isfinite = deco_unary_ufunc(torch.isfinite) +isinf = deco_unary_ufunc(torch.isinf) +isnan = deco_unary_ufunc(torch.isnan) +log = deco_unary_ufunc(torch.log) +log10 = deco_unary_ufunc(torch.log10) +log1p = deco_unary_ufunc(torch.log1p) +log2 = deco_unary_ufunc(torch.log2) +logical_not = deco_unary_ufunc(torch.logical_not) +negative = deco_unary_ufunc(torch.negative) +rad2deg = deco_unary_ufunc(torch.rad2deg) +radians = deco_unary_ufunc(torch.deg2rad) +reciprocal = deco_unary_ufunc(torch.reciprocal) +rint = deco_unary_ufunc(torch.round) +sign = deco_unary_ufunc(torch.sign) +signbit = deco_unary_ufunc(torch.signbit) +sin = deco_unary_ufunc(torch.sin) +sinh = deco_unary_ufunc(torch.sinh) +sqrt = deco_unary_ufunc(torch.sqrt) +square = deco_unary_ufunc(torch.square) +tan = deco_unary_ufunc(torch.tan) +tanh = deco_unary_ufunc(torch.tanh) +trunc = deco_unary_ufunc(torch.trunc) + +invert = deco_unary_ufunc(torch.bitwise_not) + +# special cases: torch does not export these names +def _cbrt(x): + return torch.pow(x, 1/3) + +def _positive(x): + return +x + +cbrt = deco_unary_ufunc(_cbrt) +positive = deco_unary_ufunc(_positive) + diff --git a/torch_np/_unary_ufuncs.py b/torch_np/_unary_ufuncs.py index 8d1865ab..657e03ff 100644 --- a/torch_np/_unary_ufuncs.py +++ b/torch_np/_unary_ufuncs.py @@ -1,793 +1,75 @@ -# this file is autogenerated via gen_ufuncs.py -# do not edit manually! - +import functools import torch from . import _util -from ._ndarray import asarray_replacer -from ._ndarray import asarray, ndarray, can_cast +from ._ndarray import asarray from . import _dtypes from . import _helpers +from . import _ufunc_impl + +__all__ = ['abs', 'absolute', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctanh', 'asarray', 'cbrt', 'ceil', 'conj', 'conjugate', 'cos', 'cosh', 'deg2rad', 'degrees', 'exp', 'exp2', 'expm1', 'fabs', 'floor', 'isfinite', 'isinf', 'isnan', 'log', 'log10', 'log1p', 'log2', 'logical_not', 'negative', 'positive', 'rad2deg', 'radians', 'reciprocal', 'rint', 'sign', 'signbit', 'sin', 'sinh', 'sqrt', 'square', 'tan', 'tanh', 'trunc', 'invert'] + + + +def deco_unary_ufunc_from_impl(impl_func): + @functools.wraps(impl_func) + def wrapped(x1, *args, **kwds): + x1_array = asarray(x1) + return impl_func(x1_array, *args, **kwds) + return wrapped + + +absolute = deco_unary_ufunc_from_impl(_ufunc_impl.absolute) +arccos = deco_unary_ufunc_from_impl(_ufunc_impl.arccos) +arccosh = deco_unary_ufunc_from_impl(_ufunc_impl.arccosh) +arcsin = deco_unary_ufunc_from_impl(_ufunc_impl.arcsin) +arcsinh = deco_unary_ufunc_from_impl(_ufunc_impl.arcsinh) +arctan = deco_unary_ufunc_from_impl(_ufunc_impl.arctan) +arctanh = deco_unary_ufunc_from_impl(_ufunc_impl.arctanh) +ceil = deco_unary_ufunc_from_impl(_ufunc_impl.ceil) +conjugate = deco_unary_ufunc_from_impl(_ufunc_impl.conjugate) +cos = deco_unary_ufunc_from_impl(_ufunc_impl.cos) +cosh = deco_unary_ufunc_from_impl(_ufunc_impl.cosh) +deg2rad = deco_unary_ufunc_from_impl(_ufunc_impl.deg2rad) +degrees = deco_unary_ufunc_from_impl(_ufunc_impl.rad2deg) +exp = deco_unary_ufunc_from_impl(_ufunc_impl.exp) +exp2 = deco_unary_ufunc_from_impl(_ufunc_impl.exp2) +expm1 = deco_unary_ufunc_from_impl(_ufunc_impl.expm1) +fabs = deco_unary_ufunc_from_impl(_ufunc_impl.absolute) +floor = deco_unary_ufunc_from_impl(_ufunc_impl.floor) +isfinite = deco_unary_ufunc_from_impl(_ufunc_impl.isfinite) +isinf = deco_unary_ufunc_from_impl(_ufunc_impl.isinf) +isnan = deco_unary_ufunc_from_impl(_ufunc_impl.isnan) +log = deco_unary_ufunc_from_impl(_ufunc_impl.log) +log10 = deco_unary_ufunc_from_impl(_ufunc_impl.log10) +log1p = deco_unary_ufunc_from_impl(_ufunc_impl.log1p) +log2 = deco_unary_ufunc_from_impl(_ufunc_impl.log2) +logical_not = deco_unary_ufunc_from_impl(_ufunc_impl.logical_not) +negative = deco_unary_ufunc_from_impl(_ufunc_impl.negative) +rad2deg = deco_unary_ufunc_from_impl(_ufunc_impl.rad2deg) +radians = deco_unary_ufunc_from_impl(_ufunc_impl.deg2rad) +reciprocal = deco_unary_ufunc_from_impl(_ufunc_impl.reciprocal) +rint = deco_unary_ufunc_from_impl(_ufunc_impl.rint) +sign = deco_unary_ufunc_from_impl(_ufunc_impl.sign) +signbit = deco_unary_ufunc_from_impl(_ufunc_impl.signbit) +sin = deco_unary_ufunc_from_impl(_ufunc_impl.sin) +sinh = deco_unary_ufunc_from_impl(_ufunc_impl.sinh) +sqrt = deco_unary_ufunc_from_impl(_ufunc_impl.sqrt) +square = deco_unary_ufunc_from_impl(_ufunc_impl.square) +tan = deco_unary_ufunc_from_impl(_ufunc_impl.tan) +tanh = deco_unary_ufunc_from_impl(_ufunc_impl.tanh) +trunc = deco_unary_ufunc_from_impl(_ufunc_impl.trunc) + +invert = deco_unary_ufunc_from_impl(_ufunc_impl.invert) + + +cbrt = deco_unary_ufunc_from_impl(_ufunc_impl.cbrt) +positive = deco_unary_ufunc_from_impl(_ufunc_impl.positive) + +# numpy has these aliases while torch does not +abs = absolute +conj = conjugate +bitwise_not = invert -def sin(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or not where: - raise NotImplementedError - - # XXX: dtype=... parameter is silently ignored - - x_array = asarray(x) - - arrays = (x_array,) - x_tensor, = _helpers.cast_and_broadcast(arrays, out, casting) - - result = torch.sin(x_tensor) - - return _helpers.result_or_out(result, out) - -''' - # XXX: or this, which one is better for TorchInductor? - # result = {torch_stanza} - if out is not None: - torch.sin(x_tensor, out=out_tensor) - return out - else: - result = torch.sin(x_tensor) - return asarray(result) -''' - - -################################# - - -@asarray_replacer() -def absolute(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.absolute(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def absolute(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.absolute(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def arccos(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.arccos(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def arccosh(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.arccosh(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def arcsin(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.arcsin(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def arcsinh(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.arcsinh(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def arctan(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.arctan(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def arctanh(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.arctanh(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def cbrt(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.pow(x, 1/3, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def ceil(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.ceil(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def conjugate(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.conj_physical(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def conjugate(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.conj_physical(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def cos(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.cos(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def cosh(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.cosh(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def deg2rad(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.deg2rad(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def degrees(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.rad2deg(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def exp(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.exp(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def exp2(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.exp2(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def expm1(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.expm1(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def fabs(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.absolute(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def floor(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.floor(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def isfinite(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.isfinite(x) - if dtype is not None: - result = result.to(dtype) - - if out is not None: - out[...] = result - - return result - - - -@asarray_replacer() -def isinf(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.isinf(x) - if dtype is not None: - result = result.to(dtype) - - if out is not None: - out[...] = result - - return result - - - -@asarray_replacer() -def isnan(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.isnan(x) - if dtype is not None: - result = result.to(dtype) - - if out is not None: - out[...] = result - - return result - - - -@asarray_replacer() -def log(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.log(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def log10(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.log10(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def log1p(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.log1p(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def log2(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.log2(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def logical_not(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.logical_not(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def negative(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.negative(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def positive(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = +x - if dtype is not None: - result = result.to(dtype) - - if out is not None: - out[...] = result - - return result - - - -@asarray_replacer() -def rad2deg(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.rad2deg(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def radians(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.deg2rad(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def reciprocal(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.reciprocal(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def rint(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.round(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def sign(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.sign(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def signbit(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.signbit(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - - - - - -@asarray_replacer() -def sinh(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.sinh(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def sqrt(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.sqrt(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def square(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.square(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def tan(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.tan(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def tanh(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.tanh(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -@asarray_replacer() -def trunc(x, /, out=None, *, where=True, casting='same_kind', order='K', - dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or casting != 'same_kind' or not where: - raise NotImplementedError - if out is not None: - # XXX dtypes, casting - out = out.to(dtype) - result = torch.trunc(x, out=out) - if dtype is not None: - result = result.to(dtype) - - return result - - - -__all__ = ['absolute', 'absolute', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctanh', 'cbrt', 'ceil', 'conjugate', 'conjugate', 'cos', 'cosh', 'deg2rad', 'degrees', 'exp', 'exp2', 'expm1', 'fabs', 'floor', 'isfinite', 'isinf', 'isnan', 'log', 'log10', 'log1p', 'log2', 'logical_not', 'negative', 'positive', 'rad2deg', 'radians', 'reciprocal', 'rint', 'sign', 'signbit', 'sin', 'sinh', 'sqrt', 'square', 'tan', 'tanh', 'trunc'] diff --git a/torch_np/_util.py b/torch_np/_util.py index d1e22ec1..e96e491a 100644 --- a/torch_np/_util.py +++ b/torch_np/_util.py @@ -23,6 +23,10 @@ class AxisError(ValueError, IndexError): pass +class UFuncTypeError(TypeError, RuntimeError): + pass + + # a replica of the version in ./numpy/numpy/core/src/multiarray/common.h def normalize_axis_index(ax, ndim, argname=None): if ax < -ndim or ax >= ndim: diff --git a/torch_np/testing/testing.py b/torch_np/testing/testing.py index ebaf97b9..fbe0d34b 100644 --- a/torch_np/testing/testing.py +++ b/torch_np/testing/testing.py @@ -1,19 +1,21 @@ import torch -from .._ndarray import asarray_replacer, asarray +from .._ndarray import asarray import torch_np as np -@asarray_replacer("two") def assert_allclose(actual, desired, rtol=1e-07, atol=0, equal_nan=True, err_msg='', verbose=True, check_dtype=True): + actual = asarray(actual).get() + desired = asarray(desired).get() result = torch.testing.assert_close(actual, desired, atol=atol, rtol=rtol, check_dtype=check_dtype) - return True + return result -@asarray_replacer("two") def assert_equal(actual, desired): """Check `actual == desired`, broadcast if needed """ + actual = np.asarray(actual) + desired = np.asarray(desired) eq = np.all(actual == desired) if not eq: raise AssertionError('not equal') diff --git a/torch_np/tests/numpy_tests/core/test_scalarmath.py b/torch_np/tests/numpy_tests/core/test_scalarmath.py index f02b48ef..4a9a327a 100644 --- a/torch_np/tests/numpy_tests/core/test_scalarmath.py +++ b/torch_np/tests/numpy_tests/core/test_scalarmath.py @@ -619,7 +619,8 @@ def __array__(self): class TestNegative: def test_exceptions(self): a = np.ones((), dtype=np.bool_)[()] - assert_raises(TypeError, operator.neg, a) + # XXX: TypeError from numpy, RuntimeError from torch + assert_raises((TypeError, RuntimeError), operator.neg, a) def test_result(self): types = np.typecodes['AllInteger'] + np.typecodes['AllFloat'] @@ -637,8 +638,8 @@ def test_result(self): class TestSubtract: def test_exceptions(self): a = np.ones((), dtype=np.bool_)[()] - with assert_raises(TypeError): - operator.sub(a, a) + with assert_raises((TypeError, RuntimeError)): # XXX: TypeError from numpy + operator.sub(a, a) # RuntimeError from torch def test_result(self): types = np.typecodes['AllInteger'] + np.typecodes['AllFloat'] diff --git a/torch_np/tests/test_basic.py b/torch_np/tests/test_basic.py index a894a2d5..f9e49164 100644 --- a/torch_np/tests/test_basic.py +++ b/torch_np/tests/test_basic.py @@ -27,6 +27,8 @@ one_arg_funcs += [getattr(w, name) for name in _unary_ufuncs.__all__] +one_arg_funcs = one_arg_funcs[:-1] # FIXME: remove np.invert + @pytest.mark.parametrize('func', one_arg_funcs) diff --git a/torch_np/tests/test_binary_ufuncs.py b/torch_np/tests/test_binary_ufuncs.py index a1a62980..e29a864e 100644 --- a/torch_np/tests/test_binary_ufuncs.py +++ b/torch_np/tests/test_binary_ufuncs.py @@ -20,20 +20,20 @@ def test_arctan2(): def test_bitwise_and(): - assert_allclose(np.bitwise_and(0.5, 0.6), - bitwise_and(0.5, 0.6), atol=1e-7, check_dtype=False) + assert_allclose(np.bitwise_and(5, 6), + bitwise_and(5, 6), atol=1e-7, check_dtype=False) def test_bitwise_or(): - assert_allclose(np.bitwise_or(0.5, 0.6), - bitwise_or(0.5, 0.6), atol=1e-7, check_dtype=False) + assert_allclose(np.bitwise_or(5, 6), + bitwise_or(5, 6), atol=1e-7, check_dtype=False) def test_bitwise_xor(): - assert_allclose(np.bitwise_xor(0.5, 0.6), - bitwise_xor(0.5, 0.6), atol=1e-7, check_dtype=False) + assert_allclose(np.bitwise_xor(5, 6), + bitwise_xor(5, 6), atol=1e-7, check_dtype=False) @@ -86,8 +86,8 @@ def test_fmod(): def test_gcd(): - assert_allclose(np.gcd(0.5, 0.6), - gcd(0.5, 0.6), atol=1e-7, check_dtype=False) + assert_allclose(np.gcd(5, 6), + gcd(5, 6), atol=1e-7, check_dtype=False) @@ -116,20 +116,20 @@ def test_hypot(): def test_lcm(): - assert_allclose(np.lcm(0.5, 0.6), - lcm(0.5, 0.6), atol=1e-7, check_dtype=False) + assert_allclose(np.lcm(5, 6), + lcm(5, 6), atol=1e-7, check_dtype=False) def test_ldexp(): - assert_allclose(np.ldexp(0.5, 0.6), - ldexp(0.5, 0.6), atol=1e-7, check_dtype=False) + assert_allclose(np.ldexp(0.5, 6), + ldexp(0.5, 6), atol=1e-7, check_dtype=False) def test_left_shift(): - assert_allclose(np.left_shift(0.5, 0.6), - left_shift(0.5, 0.6), atol=1e-7, check_dtype=False) + assert_allclose(np.left_shift(5, 6), + left_shift(5, 6), atol=1e-7, check_dtype=False) @@ -176,8 +176,8 @@ def test_logical_xor(): def test_matmul(): - assert_allclose(np.matmul(0.5, 0.6), - matmul(0.5, 0.6), atol=1e-7, check_dtype=False) + assert_allclose(np.matmul([0.5], [0.6]), + matmul([0.5], [0.6]), atol=1e-7, check_dtype=False) @@ -230,8 +230,8 @@ def test_remainder(): def test_right_shift(): - assert_allclose(np.right_shift(0.5, 0.6), - right_shift(0.5, 0.6), atol=1e-7, check_dtype=False) + assert_allclose(np.right_shift(5, 6), + right_shift(5, 6), atol=1e-7, check_dtype=False) diff --git a/torch_np/tests/test_ndarray_methods.py b/torch_np/tests/test_ndarray_methods.py index 7a880d8d..2f039844 100644 --- a/torch_np/tests/test_ndarray_methods.py +++ b/torch_np/tests/test_ndarray_methods.py @@ -549,4 +549,3 @@ def test_basic(self): arr = np.asarray(a) assert_equal(np.amin(arr), arr.min()) - diff --git a/torch_np/tests/test_ufuncs_basic.py b/torch_np/tests/test_ufuncs_basic.py index e4bd5f6f..4ccb3b11 100644 --- a/torch_np/tests/test_ufuncs_basic.py +++ b/torch_np/tests/test_ufuncs_basic.py @@ -7,12 +7,28 @@ by >>> import torch_np as np """ +import operator + import pytest from pytest import raises as assert_raises import torch_np as np from torch_np.testing import assert_equal +#import numpy as np +#from numpy.testing import assert_equal + +try: + import numpy as _np + HAVE_NUMPY = True + + def _numpy_result(op, a, b): + """what would numpy do""" + return op(a._tensor.numpy(), b._tensor.numpy()) + +except ImportError: + HAVE_NUMPY = False + parametrize_unary_ufuncs = pytest.mark.parametrize('ufunc', [np.sin]) parametrize_casting = pytest.mark.parametrize("casting", @@ -84,7 +100,56 @@ def test_x_and_out_broadcast(self, ufunc): -parametrize_binary_ufuncs = pytest.mark.parametrize('ufunc', [np.add]) #, np.logaddexp, np.hypot]) +ufunc_op_iop_numeric = [ + (np.add, operator.__add__, operator.__iadd__), + (np.subtract, operator.__sub__, operator.__isub__), + (np.multiply, operator.__mul__, operator.__imul__), + (np.divide, operator.__truediv__, operator.__itruediv__), + (np.floor_divide, operator.__floordiv__, operator.__ifloordiv__), + (np.float_power, operator.__pow__, operator.__ipow__), + ## (np.remainder, operator.__mod__, operator.__imod__), # does not handle complex + + +# remainder vs fmod? +# pow vs power vs float_power +] + +ufuncs_with_dunders = [ufunc for ufunc, _, _ in ufunc_op_iop_numeric] +numeric_binary_ufuncs = [np.float_power, np.power,] + +# these are not implemented for complex inputs +no_complex = [np.floor_divide, np.hypot, np.arctan2, np.copysign, np.fmax, + np.fmin, np.fmod, np.heaviside, np.logaddexp, np.logaddexp2, + np.maximum, np.minimum, +] + +parametrize_binary_ufuncs = pytest.mark.parametrize( + 'ufunc', ufuncs_with_dunders + numeric_binary_ufuncs + no_complex) + + + +# TODO: these snowflakes need special handling +""" + 'bitwise_and', + 'bitwise_or', + 'bitwise_xor', + 'equal', + 'lcm', + 'ldexp', + 'left_shift', + 'less', + 'less_equal', + 'gcd', + 'greater', + 'greater_equal', + 'logical_and', + 'logical_or', + 'logical_xor', + 'matmul', + 'not_equal', +""" + + class TestBinaryUfuncs: @@ -98,6 +163,11 @@ def test_scalar(self, ufunc): x, y = xy[0][0], xy[1][0] float(ufunc(x, y)) + @parametrize_binary_ufuncs + def test_vector_vs_scalar(self, ufunc): + x, y = self.get_xy(ufunc) + assert_equal(ufunc(x, y), [ufunc(a, b) for a, b in zip(x, y)]) + @parametrize_casting @parametrize_binary_ufuncs @pytest.mark.parametrize('out_dtype', ['float64', 'complex128', 'float32']) @@ -105,6 +175,9 @@ def test_xy_and_out_casting(self, ufunc, casting, out_dtype): x, y = self.get_xy(ufunc) out = np.empty_like(x, dtype=out_dtype) + if ufunc in no_complex and np.issubdtype(out_dtype, np.complexfloating): + pytest.skip(f'{ufunc} does not accept complex.') + can_cast_x = np.can_cast(x, out_dtype, casting=casting) can_cast_y = np.can_cast(y, out_dtype, casting=casting) @@ -131,3 +204,164 @@ def test_xy_and_out_broadcast(self, ufunc): assert_equal(res_out, res_bcast) assert res_out is out + +dtypes_numeric = [np.int32, np.float32, np.float64, np.complex128] + + +class TestNdarrayDunderVsUfunc: + """Test ndarray dunders which delegate to ufuncs, vs ufuncs.""" + + @pytest.mark.parametrize("ufunc, op, iop", ufunc_op_iop_numeric) + def test_basic(self, ufunc, op, iop): + """basic op/rop/iop, no dtypes, no broadcasting""" + + # __add__ + a = np.array([1, 2, 3]) + assert_equal(op(a, 1), ufunc(a, 1)) + assert_equal(op(a, a.tolist()), ufunc(a, a.tolist())) + assert_equal(op(a, a), ufunc(a, a)) + + # __radd__ + a = np.array([1, 2, 3]) + assert_equal(op(1, a), ufunc(a, 1)) + assert_equal(op(a.tolist(), a), ufunc(a, a.tolist())) + + # __iadd__ + a0 = np.array([2, 4, 6]) + a = a0.copy() + + iop(a, 2) # modifies a in-place + assert_equal(a, op(a0, 2)) + + a0 = np.array([2, 4, 6]) + a = a0.copy() + iop(a, a) + assert_equal(a, op(a0, a0)) + + @pytest.mark.parametrize("ufunc, op, iop", ufunc_op_iop_numeric) + @pytest.mark.parametrize("other_dtype", dtypes_numeric) + def test_other_scalar(self, ufunc, op, iop, other_dtype): + """Test op/iop/rop when the other argument is a scalar of a different dtype.""" + a = np.array([1, 2, 3]) + b = other_dtype(3) + + if ufunc in no_complex and issubclass(other_dtype, np.complexfloating): + pytest.skip(f'{ufunc} does not accept complex.') + + # __op__ + result = op(a, b) + assert_equal(result, ufunc(a, b)) + + if result.dtype != np.result_type(a, b): + pytest.xfail(reason="prob need weak type promotion (scalars)") + assert result.dtype == np.result_type(a, b) + + # __rop__ + result = op(b, a) + assert_equal(result, ufunc(b, a)) + if result.dtype != np.result_type(a, b): + pytest.xfail(reason="prob need weak type promotion (scalars)") + assert result.dtype == np.result_type(a, b) + + # __iop__ : casts the result to self.dtype, raises if cannot + can_cast = np.can_cast(np.result_type(a.dtype, other_dtype), + a.dtype, + casting="same_kind") + if can_cast: + a0 = a.copy() + result = iop(a, b) + assert_equal(result, ufunc(a0, b)) + if result.dtype != np.result_type(a, b): + pytest.xfail(reason="prob need weak type promotion (scalars)") + assert result.dtype == np.result_type(a0, b) + + else: + with assert_raises((TypeError, RuntimeError)): # XXX np.UFuncTypeError + iop(a, b) + + + @pytest.mark.parametrize("ufunc, op, iop", ufunc_op_iop_numeric) + @pytest.mark.parametrize("other_dtype", dtypes_numeric) + def test_other_array(self, ufunc, op, iop, other_dtype): + """Test op/iop/rop when the other argument is an array of a different dtype.""" + a = np.array([1, 2, 3]) + b = np.array([5, 6, 7], dtype=other_dtype) + + if ufunc in no_complex and issubclass(other_dtype, np.complexfloating): + pytest.skip(f'{ufunc} does not accept complex.') + + # __op__ + result = op(a, b) + assert_equal(result, ufunc(a, b)) + if result.dtype != np.result_type(a, b): + pytest.xfail(reason="prob need weak type promotion (scalars)") + assert result.dtype == np.result_type(a, b) + + # __rop__(other array) + result = op(b, a) + assert_equal(result, ufunc(b, a)) + if result.dtype != np.result_type(a, b): + pytest.xfail(reason="prob need weak type promotion (scalars)") + assert result.dtype == np.result_type(a, b) + + # __iop__ + can_cast = np.can_cast(np.result_type(a.dtype, other_dtype), + a.dtype, + casting="same_kind") + if can_cast: + a0 = a.copy() + result = iop(a, b) + assert_equal(result, ufunc(a0, b)) + if result.dtype != np.result_type(a, b): + pytest.xfail(reason="prob need weak type promotion (scalars)") + assert result.dtype == np.result_type(a0, b) + else: + with assert_raises((TypeError, RuntimeError)): # XXX np.UFuncTypeError + iop(a, b) + + + @pytest.mark.parametrize("ufunc, op, iop", ufunc_op_iop_numeric) + def test_other_array_bcast(self, ufunc, op, iop): + """Test op/rop/iop with broadcasting """ + # __op__ + a = np.array([1, 2, 3]) + result_op = op(a, a[:, None]) + result_ufunc = ufunc(a, a[:, None]) + assert result_op.shape == result_ufunc.shape + assert_equal(result_op, result_ufunc) + + if result_op.dtype != result_ufunc.dtype: + pytest.xfail(reason="prob need weak type promotion (scalars)") + assert result_op.dtype == result_ufunc.dtype + + # __rop__ + a = np.array([1, 2, 3]) + result_op = op(a[:, None], a) + result_ufunc = ufunc(a[:, None], a) + assert result_op.shape == result_ufunc.shape + assert_equal(result_op, result_ufunc) + + if result_op.dtype != result_ufunc.dtype: + pytest.xfail(reason="prob need weak type promotion (scalars)") + assert result_op.dtype == result_ufunc.dtype + + + # __iop__ : in-place ops (`self += other` etc) do not broadcast self + b = a[:, None].copy() + with assert_raises((ValueError, RuntimeError)): # XXX ValueError in numpy + iop(a, b) + + # however, `self += other` broadcasts other + aa = np.broadcast_to(a, (3, 3)).copy() + aa0 = aa.copy() + + result = iop(aa, a) + result_ufunc = ufunc(aa0, a) + + assert result.shape == result_ufunc.shape + assert_equal(result, result_ufunc) + + if result_op.dtype != result_ufunc.dtype: + pytest.xfail(reason="prob need weak type promotion (scalars)") + assert result_op.dtype == result_ufunc.dtype + diff --git a/torch_np/tests/test_unary_ufuncs.py b/torch_np/tests/test_unary_ufuncs.py index b9033885..df5a5117 100644 --- a/torch_np/tests/test_unary_ufuncs.py +++ b/torch_np/tests/test_unary_ufuncs.py @@ -27,8 +27,8 @@ def test_arccos(): def test_arccosh(): - assert_allclose(np.arccosh(0.5), - arccosh(0.5), atol=1e-14, check_dtype=False) + assert_allclose(np.arccosh(1.5), + arccosh(1.5), atol=1e-14, check_dtype=False)