diff --git a/autogen/gen_dtypes.py b/autogen/gen_dtypes.py index e5f1fbac..b951fed0 100644 --- a/autogen/gen_dtypes.py +++ b/autogen/gen_dtypes.py @@ -22,11 +22,18 @@ def __init__(self, name): 'int64', 'bool'] + +tmap = {dt: torch.as_tensor(np.ones(1, dtype=dt)).dtype for dt in dt_names} + + + templ = """\ {name} = dtype("{name}") """ + + ############### Output the dtypes ############# src_lines = [templ.format(name=name) for name in dt_names] @@ -51,8 +58,8 @@ def generate_can_cast(casting): for dtyp2 in dt_names: can_cast = np.can_cast(np.dtype(dtyp1), np.dtype(dtyp2), casting=casting) - dct_dtyp1[dtyp2] = can_cast - dct[dtyp1] = dct_dtyp1 + dct_dtyp1[tmap[dtyp2]] = can_cast + dct[tmap[dtyp1]] = dct_dtyp1 return dct @@ -63,8 +70,8 @@ def generate_result_type(): dct_dtyp1 = {} for dtyp2 in dt_names: result_type = np.result_type(np.dtype(dtyp1), np.dtype(dtyp2)) - dct_dtyp1[dtyp2] = result_type.name - dct[dtyp1] = dct_dtyp1 + dct_dtyp1[tmap[dtyp2]] = tmap[result_type.name] + dct[tmap[dtyp1]] = dct_dtyp1 return dct diff --git a/torch_np/__init__.py b/torch_np/__init__.py index 5b2eee7f..a6416a53 100644 --- a/torch_np/__init__.py +++ b/torch_np/__init__.py @@ -1,20 +1,15 @@ from ._dtypes import * -from ._scalar_types import * +from ._detail._scalar_types import * from ._wrapper import * #from . import testing from ._unary_ufuncs import * from ._binary_ufuncs import * from ._ndarray import can_cast, result_type, newaxis -from ._util import AxisError, UFuncTypeError +from ._detail._util import AxisError, UFuncTypeError from ._getlimits import iinfo, finfo from ._getlimits import errstate inf = float('inf') nan = float('nan') - -#### HACK HACK HACK #### -import torch -torch.set_default_dtype(torch.float64) -del torch diff --git a/torch_np/_binary_ufuncs.py b/torch_np/_binary_ufuncs.py index 3f1ae960..e6acbd7c 100644 --- a/torch_np/_binary_ufuncs.py +++ b/torch_np/_binary_ufuncs.py @@ -1,79 +1,53 @@ -import functools -import torch +from ._decorators import deco_binary_ufunc_from_impl +from ._detail import _ufunc_impl -from . import _util - -from ._ndarray import asarray -from . import _dtypes -from . import _helpers - -from . import _ufunc_impl - -# -# Functions in _ufunc_impl receive arrays, implement common tasks with ufunc args -# and delegate heavy lifting to pytorch equivalents. # # Functions in this file implement binary ufuncs: wrap two first arguments in # asarray and delegate to functions from _ufunc_impl. # -# One other user of _ufunc_impl functions in ndarray, where its __add__ method -# calls _ufunc_impl.add and so on. Note that ndarray dunders already know -# that its first arg is an array, so they only convert the second argument. +# Functions in _detail/_ufunc_impl.py receive tensors, implement common tasks +# with ufunc args, and delegate heavy lifting to pytorch equivalents. # -# XXX: While it sounds tempting to merge _binary_ufuncs.py and _ufunc_impl.py -# files, doing it would currently create import cycles. -# - -# TODO: deduplicate with _unary_ufuncs/deco_unary_ufunc_from_impl, -# _ndarray/asarray_replacer, and _wrapper/concatenate et al -def deco_ufunc_from_impl(impl_func): - @functools.wraps(impl_func) - def wrapped(x1, x2, *args, **kwds): - x1_array = asarray(x1) - x2_array = asarray(x2) - return impl_func(x1_array, x2_array, *args, **kwds) - return wrapped - # the list is autogenerated, cf autogen/gen_ufunc_2.py -add = deco_ufunc_from_impl(_ufunc_impl.add) -arctan2 = deco_ufunc_from_impl(_ufunc_impl.arctan2) -bitwise_and = deco_ufunc_from_impl(_ufunc_impl.bitwise_and) -bitwise_or = deco_ufunc_from_impl(_ufunc_impl.bitwise_or) -bitwise_xor = deco_ufunc_from_impl(_ufunc_impl.bitwise_xor) -copysign = deco_ufunc_from_impl(_ufunc_impl.copysign) -divide = deco_ufunc_from_impl(_ufunc_impl.divide) -equal = deco_ufunc_from_impl(_ufunc_impl.equal) -float_power = deco_ufunc_from_impl(_ufunc_impl.float_power) -floor_divide = deco_ufunc_from_impl(_ufunc_impl.floor_divide) -fmax = deco_ufunc_from_impl(_ufunc_impl.fmax) -fmin = deco_ufunc_from_impl(_ufunc_impl.fmin) -fmod = deco_ufunc_from_impl(_ufunc_impl.fmod) -gcd = deco_ufunc_from_impl(_ufunc_impl.gcd) -greater = deco_ufunc_from_impl(_ufunc_impl.greater) -greater_equal = deco_ufunc_from_impl(_ufunc_impl.greater_equal) -heaviside = deco_ufunc_from_impl(_ufunc_impl.heaviside) -hypot = deco_ufunc_from_impl(_ufunc_impl.hypot) -lcm = deco_ufunc_from_impl(_ufunc_impl.lcm) -ldexp = deco_ufunc_from_impl(_ufunc_impl.ldexp) -left_shift = deco_ufunc_from_impl(_ufunc_impl.left_shift) -less = deco_ufunc_from_impl(_ufunc_impl.less) -less_equal = deco_ufunc_from_impl(_ufunc_impl.less_equal) -logaddexp = deco_ufunc_from_impl(_ufunc_impl.logaddexp) -logaddexp2 = deco_ufunc_from_impl(_ufunc_impl.logaddexp2) -logical_and = deco_ufunc_from_impl(_ufunc_impl.logical_and) -logical_or = deco_ufunc_from_impl(_ufunc_impl.logical_or) -logical_xor = deco_ufunc_from_impl(_ufunc_impl.logical_xor) -matmul = deco_ufunc_from_impl(_ufunc_impl.matmul) -maximum = deco_ufunc_from_impl(_ufunc_impl.maximum) -minimum = deco_ufunc_from_impl(_ufunc_impl.minimum) -remainder = deco_ufunc_from_impl(_ufunc_impl.remainder) -multiply = deco_ufunc_from_impl(_ufunc_impl.multiply) -nextafter = deco_ufunc_from_impl(_ufunc_impl.nextafter) -not_equal = deco_ufunc_from_impl(_ufunc_impl.not_equal) -power = deco_ufunc_from_impl(_ufunc_impl.power) -remainder = deco_ufunc_from_impl(_ufunc_impl.remainder) -right_shift = deco_ufunc_from_impl(_ufunc_impl.right_shift) -subtract = deco_ufunc_from_impl(_ufunc_impl.subtract) -divide = deco_ufunc_from_impl(_ufunc_impl.divide) +add = deco_binary_ufunc_from_impl(_ufunc_impl.add) +arctan2 = deco_binary_ufunc_from_impl(_ufunc_impl.arctan2) +bitwise_and = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_and) +bitwise_or = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_or) +bitwise_xor = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_xor) +copysign = deco_binary_ufunc_from_impl(_ufunc_impl.copysign) +divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide) +equal = deco_binary_ufunc_from_impl(_ufunc_impl.equal) +float_power = deco_binary_ufunc_from_impl(_ufunc_impl.float_power) +floor_divide = deco_binary_ufunc_from_impl(_ufunc_impl.floor_divide) +fmax = deco_binary_ufunc_from_impl(_ufunc_impl.fmax) +fmin = deco_binary_ufunc_from_impl(_ufunc_impl.fmin) +fmod = deco_binary_ufunc_from_impl(_ufunc_impl.fmod) +gcd = deco_binary_ufunc_from_impl(_ufunc_impl.gcd) +greater = deco_binary_ufunc_from_impl(_ufunc_impl.greater) +greater_equal = deco_binary_ufunc_from_impl(_ufunc_impl.greater_equal) +heaviside = deco_binary_ufunc_from_impl(_ufunc_impl.heaviside) +hypot = deco_binary_ufunc_from_impl(_ufunc_impl.hypot) +lcm = deco_binary_ufunc_from_impl(_ufunc_impl.lcm) +ldexp = deco_binary_ufunc_from_impl(_ufunc_impl.ldexp) +left_shift = deco_binary_ufunc_from_impl(_ufunc_impl.left_shift) +less = deco_binary_ufunc_from_impl(_ufunc_impl.less) +less_equal = deco_binary_ufunc_from_impl(_ufunc_impl.less_equal) +logaddexp = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp) +logaddexp2 = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp2) +logical_and = deco_binary_ufunc_from_impl(_ufunc_impl.logical_and) +logical_or = deco_binary_ufunc_from_impl(_ufunc_impl.logical_or) +logical_xor = deco_binary_ufunc_from_impl(_ufunc_impl.logical_xor) +matmul = deco_binary_ufunc_from_impl(_ufunc_impl.matmul) +maximum = deco_binary_ufunc_from_impl(_ufunc_impl.maximum) +minimum = deco_binary_ufunc_from_impl(_ufunc_impl.minimum) +remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder) +multiply = deco_binary_ufunc_from_impl(_ufunc_impl.multiply) +nextafter = deco_binary_ufunc_from_impl(_ufunc_impl.nextafter) +not_equal = deco_binary_ufunc_from_impl(_ufunc_impl.not_equal) +power = deco_binary_ufunc_from_impl(_ufunc_impl.power) +remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder) +right_shift = deco_binary_ufunc_from_impl(_ufunc_impl.right_shift) +subtract = deco_binary_ufunc_from_impl(_ufunc_impl.subtract) +divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide) diff --git a/torch_np/_decorators.py b/torch_np/_decorators.py new file mode 100644 index 00000000..10730453 --- /dev/null +++ b/torch_np/_decorators.py @@ -0,0 +1,107 @@ +import functools +import operator + +import torch + +from . import _dtypes +from . import _helpers +from ._detail import _util + +NoValue = None + + +def dtype_to_torch(func): + @functools.wraps(func) + def wrapped(*args, dtype=None, **kwds): + torch_dtype = None + if dtype is not None: + dtype = _dtypes.dtype(dtype) + torch_dtype = dtype._scalar_type.torch_dtype + return func(*args, dtype=torch_dtype, **kwds) + return wrapped + + +def emulate_out_arg(func): + """Simulate the out=... handling: move the result tensor to the out array. + + With this decorator, the inner function just does not see the out array. + """ + @functools.wraps(func) + def wrapped(*args, out=None, **kwds): + result_tensor = func(*args, **kwds) + return _helpers.result_or_out(result_tensor, out) + + return wrapped + + +def out_shape_dtype(func): + """Handle out=... kwarg for ufuncs. + + With ufuncs, `out` array can typcast and broadcast ufunc arguments, hence + extract the shape and dtype of the tensor which backs the `out` array + and pass these through. + """ + @functools.wraps(func) + def wrapped(*args, out=None, **kwds): + if out is not None: + kwds.update({'out_shape_dtype': (out.get().dtype, out.get().shape)}) + result_tensor = func(*args, **kwds) + return _helpers.result_or_out(result_tensor, out) + + return wrapped + + +def deco_unary_ufunc_from_impl(impl_func): + @functools.wraps(impl_func) + @dtype_to_torch + @out_shape_dtype + def wrapped(x1, *args, **kwds): + from ._ndarray import asarray + x1_tensor = asarray(x1).get() + result = impl_func((x1_tensor,), *args, **kwds) + return result + return wrapped + + +# TODO: deduplicate with _ndarray/asarray_replacer, +# and _wrapper/concatenate et al +def deco_binary_ufunc_from_impl(impl_func): + @functools.wraps(impl_func) + @dtype_to_torch + @out_shape_dtype + def wrapped(x1, x2, *args, **kwds): + from ._ndarray import asarray + x1_tensor = asarray(x1).get() + x2_tensor = asarray(x2).get() + return impl_func((x1_tensor, x2_tensor), *args, **kwds) + return wrapped + + + +def axis_keepdims_wrapper(func): + """`func` accepts an array-like as a 1st arg, returns a tensor. + + This decorator implements the generic handling of axis, out and keepdims + arguments for reduction functions. + + Note that we peel off `out=...` and `keepdims=...` args (torch functions never + see them). The `axis` argument we normalize and pass through to pytorch functions. + + """ + # XXX: move this out of _ndarray.py (circular imports) + # + # TODO: 1. get rid of _helpers.result_or_out + # 2. sort out function signatures: how they flow through all decorators etc + @functools.wraps(func) + def wrapped(a, axis=None, keepdims=NoValue, *args, **kwds): + from ._ndarray import ndarray, asarray + tensor = asarray(a).get() + + # standardize the axis argument + if isinstance(axis, ndarray): + axis = operator.index(axis) + + result = _util.axis_keepdims(func, tensor, axis, keepdims, *args, **kwds) + return result + + return wrapped diff --git a/torch_np/_detail/__init__.py b/torch_np/_detail/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/torch_np/_detail/_casting_dicts.py b/torch_np/_detail/_casting_dicts.py new file mode 100644 index 00000000..3e82b9b6 --- /dev/null +++ b/torch_np/_detail/_casting_dicts.py @@ -0,0 +1,10 @@ +import torch + +# These two dicts are autogenerated with autogen/gen_dtypes.py, +# using numpy version 1.23.5. + +_can_cast_dict = {'no': {torch.float16: {torch.float16: True, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.float32: {torch.float16: False, torch.float32: True, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.float64: {torch.float16: False, torch.float32: False, torch.float64: True, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.complex64: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: True, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.complex128: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: True, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.uint8: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: True, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.int8: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: True, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.int16: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: True, torch.int32: False, torch.int64: False, torch.bool: False}, torch.int32: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: True, torch.int64: False, torch.bool: False}, torch.int64: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: True, torch.bool: False}, torch.bool: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: True}}, 'equiv': {torch.float16: {torch.float16: True, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.float32: {torch.float16: False, torch.float32: True, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.float64: {torch.float16: False, torch.float32: False, torch.float64: True, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.complex64: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: True, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.complex128: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: True, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.uint8: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: True, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.int8: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: True, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.int16: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: True, torch.int32: False, torch.int64: False, torch.bool: False}, torch.int32: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: True, torch.int64: False, torch.bool: False}, torch.int64: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: True, torch.bool: False}, torch.bool: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: False, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: True}}, 'safe': {torch.float16: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.float32: {torch.float16: False, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.float64: {torch.float16: False, torch.float32: False, torch.float64: True, torch.complex64: False, torch.complex128: True, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.complex64: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: True, torch.complex128: True, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.complex128: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: False, torch.complex128: True, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.uint8: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: True, torch.int8: False, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: False}, torch.int8: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: False, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: False}, torch.int16: {torch.float16: False, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: False, torch.int8: False, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: False}, torch.int32: {torch.float16: False, torch.float32: False, torch.float64: True, torch.complex64: False, torch.complex128: True, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: True, torch.int64: True, torch.bool: False}, torch.int64: {torch.float16: False, torch.float32: False, torch.float64: True, torch.complex64: False, torch.complex128: True, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: True, torch.bool: False}, torch.bool: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: True, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: True}}, 'same_kind': {torch.float16: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.float32: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.float64: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.complex64: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: True, torch.complex128: True, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.complex128: {torch.float16: False, torch.float32: False, torch.float64: False, torch.complex64: True, torch.complex128: True, torch.uint8: False, torch.int8: False, torch.int16: False, torch.int32: False, torch.int64: False, torch.bool: False}, torch.uint8: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: True, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: False}, torch.int8: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: False, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: False}, torch.int16: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: False, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: False}, torch.int32: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: False, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: False}, torch.int64: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: False, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: False}, torch.bool: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: True, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: True}}, 'unsafe': {torch.float16: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: True, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: True}, torch.float32: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: True, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: True}, torch.float64: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: True, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: True}, torch.complex64: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: True, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: True}, torch.complex128: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: True, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: True}, torch.uint8: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: True, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: True}, torch.int8: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: True, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: True}, torch.int16: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: True, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: True}, torch.int32: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: True, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: True}, torch.int64: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: True, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: True}, torch.bool: {torch.float16: True, torch.float32: True, torch.float64: True, torch.complex64: True, torch.complex128: True, torch.uint8: True, torch.int8: True, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: True}}} + + +_result_type_dict = {torch.float16: {torch.float16: torch.float16, torch.float32: torch.float32, torch.float64: torch.float64, torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.float16, torch.int8: torch.float16, torch.int16: torch.float32, torch.int32: torch.float64, torch.int64: torch.float64, torch.bool: torch.float16}, torch.float32: {torch.float16: torch.float32, torch.float32: torch.float32, torch.float64: torch.float64, torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.float32, torch.int8: torch.float32, torch.int16: torch.float32, torch.int32: torch.float64, torch.int64: torch.float64, torch.bool: torch.float32}, torch.float64: {torch.float16: torch.float64, torch.float32: torch.float64, torch.float64: torch.float64, torch.complex64: torch.complex128, torch.complex128: torch.complex128, torch.uint8: torch.float64, torch.int8: torch.float64, torch.int16: torch.float64, torch.int32: torch.float64, torch.int64: torch.float64, torch.bool: torch.float64}, torch.complex64: {torch.float16: torch.complex64, torch.float32: torch.complex64, torch.float64: torch.complex128, torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.complex64, torch.int8: torch.complex64, torch.int16: torch.complex64, torch.int32: torch.complex128, torch.int64: torch.complex128, torch.bool: torch.complex64}, torch.complex128: {torch.float16: torch.complex128, torch.float32: torch.complex128, torch.float64: torch.complex128, torch.complex64: torch.complex128, torch.complex128: torch.complex128, torch.uint8: torch.complex128, torch.int8: torch.complex128, torch.int16: torch.complex128, torch.int32: torch.complex128, torch.int64: torch.complex128, torch.bool: torch.complex128}, torch.uint8: {torch.float16: torch.float16, torch.float32: torch.float32, torch.float64: torch.float64, torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.uint8, torch.int8: torch.int16, torch.int16: torch.int16, torch.int32: torch.int32, torch.int64: torch.int64, torch.bool: torch.uint8}, torch.int8: {torch.float16: torch.float16, torch.float32: torch.float32, torch.float64: torch.float64, torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.int16, torch.int8: torch.int8, torch.int16: torch.int16, torch.int32: torch.int32, torch.int64: torch.int64, torch.bool: torch.int8}, torch.int16: {torch.float16: torch.float32, torch.float32: torch.float32, torch.float64: torch.float64, torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.int16, torch.int8: torch.int16, torch.int16: torch.int16, torch.int32: torch.int32, torch.int64: torch.int64, torch.bool: torch.int16}, torch.int32: {torch.float16: torch.float64, torch.float32: torch.float64, torch.float64: torch.float64, torch.complex64: torch.complex128, torch.complex128: torch.complex128, torch.uint8: torch.int32, torch.int8: torch.int32, torch.int16: torch.int32, torch.int32: torch.int32, torch.int64: torch.int64, torch.bool: torch.int32}, torch.int64: {torch.float16: torch.float64, torch.float32: torch.float64, torch.float64: torch.float64, torch.complex64: torch.complex128, torch.complex128: torch.complex128, torch.uint8: torch.int64, torch.int8: torch.int64, torch.int16: torch.int64, torch.int32: torch.int64, torch.int64: torch.int64, torch.bool: torch.int64}, torch.bool: {torch.float16: torch.float16, torch.float32: torch.float32, torch.float64: torch.float64, torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.uint8, torch.int8: torch.int8, torch.int16: torch.int16, torch.int32: torch.int32, torch.int64: torch.int64, torch.bool: torch.bool}} + diff --git a/torch_np/_detail/_reductions.py b/torch_np/_detail/_reductions.py new file mode 100644 index 00000000..d95ca4ff --- /dev/null +++ b/torch_np/_detail/_reductions.py @@ -0,0 +1,163 @@ +""" Implementation of reduction operations, to be wrapped into arrays, dtypes etc +in the 'public' layer. + +Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc +""" + +import torch +from . import _util +from . import _scalar_types + +NoValue = None + + +def _atleast_float(dtype, other_dtype): + """Return a dtype that is real or complex floating-point. + + For inputs that are boolean or integer dtypes, this returns the default + float dtype; inputs that are complex get converted to the default complex + dtype; real floating-point dtypes (`float*`) get passed through unchanged + """ + if dtype is None: + dtype = other_dtype + if not (dtype.is_floating_point or dtype.is_complex): + sctype = _scalar_types.default_float_type + dtype = sctype.torch_dtype + return dtype + + +def count_nonzero(a, axis=None): + # XXX: this all should probably be generalized to a sum(a != 0, dtype=bool) + try: + return a.count_nonzero(axis) + except RuntimeError: + raise ValueError + return tensor + + +def argmax(tensor, axis=None): + axis = _util.allow_only_single_axis(axis) + tensor = torch.argmax(tensor, axis) + return tensor + +def argmin(tensor, axis=None): + axis = _util.allow_only_single_axis(axis) + tensor = torch.argmin(tensor, axis) + return tensor + + +def any(tensor, axis=None, *, where=NoValue): + if where is not NoValue: + raise NotImplementedError + + axis = _util.allow_only_single_axis(axis) + + if axis is None: + result = tensor.any() + else: + result = tensor.any(axis) + return result + + +def all(tensor, axis=None, *, where=NoValue): + if where is not NoValue: + raise NotImplementedError + + axis = _util.allow_only_single_axis(axis) + + if axis is None: + result = tensor.all() + else: + result = tensor.all(axis) + return result + + +def max(tensor, axis=None, initial=NoValue, where=NoValue): + if initial is not NoValue or where is not NoValue: + raise NotImplementedError + + result = tensor.amax(axis) + return result + + +def min(tensor, axis=None, initial=NoValue, where=NoValue): + if initial is not NoValue or where is not NoValue: + raise NotImplementedError + + result = tensor.amin(axis) + return result + + +def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue): + if initial is not NoValue or where is not NoValue: + raise NotImplementedError + + assert dtype is None or isinstance(dtype, torch.dtype) + + if dtype == torch.bool: + dtype = _scalar_types.default_int_type.dtype + + if axis is None: + result = tensor.sum(dtype=dtype) + else: + result = tensor.sum(dtype=dtype, dim=axis) + + return result + + +def prod(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue): + if initial is not NoValue or where is not NoValue: + raise NotImplementedError + + axis = _util.allow_only_single_axis(axis) + + if dtype == torch.bool: + dtype = _scalar_types.default_int_type.dtype + + if axis is None: + result = tensor.prod(dtype=dtype) + else: + result = tensor.prod(dtype=dtype, dim=axis) + + return result + + +def mean(tensor, axis=None, dtype=None, *, where=NoValue): + if where is not NoValue: + raise NotImplementedError + + dtype = _atleast_float(dtype, tensor.dtype) + + if axis is None: + result = tensor.mean(dtype=dtype) + else: + result = tensor.mean(dtype=dtype, dim=axis) + + return result + + +def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue): + if where is not NoValue: + raise NotImplementedError + + dtype = _atleast_float(dtype, tensor.dtype) + + if dtype is not None: + tensor = tensor.to(dtype) + result = tensor.std(dim=axis, correction=ddof) + + return result + + +def var(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue): + if where is not NoValue: + raise NotImplementedError + + dtype = _atleast_float(dtype, tensor.dtype) + + if dtype is not None: + tensor = tensor.to(dtype) + result = tensor.var(dim=axis, correction=ddof) + + return result + diff --git a/torch_np/_detail/_scalar_types.py b/torch_np/_detail/_scalar_types.py new file mode 100644 index 00000000..b487dda0 --- /dev/null +++ b/torch_np/_detail/_scalar_types.py @@ -0,0 +1,285 @@ +"""Replicate the NumPy scalar type hierarchy +""" +import builtins +import abc +import torch + +class generic(abc.ABC): + @property + @abc.abstractmethod + def name(self): + return self.__class__.__name__ + + def __new__(self, value): + # + # Yes, a call to np.float32(4) produces a zero-dim array. + # + from .. import _ndarray + + if isinstance(value, str) and value in ['inf', 'nan']: + value = {'inf': torch.inf, 'nan': torch.nan}[value] + + if isinstance(value, _ndarray.ndarray): + tensor = value.get() + else: + try: + tensor = torch.as_tensor(value, dtype=self.torch_dtype) + except RuntimeError as e: + if "Overflow" in str(e): + raise OverflowError(e.args) + raise e + # + # With numpy: + # >>> a = np.ones(3) + # >>> np.float64(a) is a # True + # >>> np.float64(a[0]) is a[0] # False + # + # A reasonable assumption is that the second case is more common, + # and here we follow the second approach and create a new object + # *for all inputs*. + # + return _ndarray.ndarray._from_tensor_and_base(tensor, None) + + +##### these are abstract types + +class number(generic): + pass + + +class integer(number): + pass + + +class inexact(number): + pass + + +class signedinteger(integer): + pass + + +class unsignedinteger(integer): + pass + + +class floating(inexact): + pass + + +class complexfloating(inexact): + pass + + +# ##### concrete types + +# signed integers + +class int8(signedinteger): + name = 'int8' + typecode = 'b' + torch_dtype = torch.int8 + + +class int16(signedinteger): + name = 'int16' + typecode = 'h' + torch_dtype = torch.int16 + + +class int32(signedinteger): + name = 'int32' + typecode = 'i' + torch_dtype = torch.int32 + + +class int64(signedinteger): + name = 'int64' + typecode = 'l' + torch_dtype = torch.int64 + + +# unsigned integers + +class uint8(unsignedinteger): + name = 'uint8' + typecode = 'B' + torch_dtype = torch.uint8 + + +# floating point + +class float16(floating): + name = 'float16' + typecode = 'e' + torch_dtype = torch.float16 + + +class float32(floating): + name = 'float32' + typecode = 'f' + torch_dtype = torch.float32 + +class float64(floating): + name = 'float64' + typecode = 'd' + torch_dtype = torch.float64 + + +class complex64(complexfloating): + name = 'complex64' + typecode = 'F' + torch_dtype = torch.complex64 + + +class complex128(complexfloating): + name = 'complex128' + typecode = 'D' + torch_dtype = torch.complex128 + + +class bool_(generic): + name = 'bool_' + typecode = '?' + torch_dtype = torch.bool + + +# name aliases : FIXME (OS, bitness) +intp = int64 +int_ = int64 +intc = int32 + +byte = int8 +short = int16 +longlong = int64 # XXX: is this correct? + +ubyte = uint8 + +half = float16 +single = float32 +double = float64 +float_ = float64 + +csingle = complex64 +cdouble = complex128 + + +# Replicate this NumPy-defined way of grouping scalar types, +# cf tests/core/test_scalar_methods.py +sctypes = { + 'int': [int8, int16, int32, int64], + 'uint': [uint8,], + 'float': [float16, float32, float64], + 'complex': [complex64, complex128], + 'others': [bool_], +} + + +_names = {st.name: st for cat in sctypes for st in sctypes[cat]} +_typecodes = {st.typecode: st for cat in sctypes for st in sctypes[cat]} +_torch_dtypes = {st.torch_dtype: st for cat in sctypes for st in sctypes[cat]} + +_aliases = { + 'u1' : uint8, + 'i1' : int8, + 'i2' : int16, + 'i4' : int32, + 'i8' : int64, + 'b' : int8, # XXX: srsly? + 'f2' : float16, + 'f4' : float32, + 'f8' : float64, + 'c8' : complex64, + 'c16': complex128, + # numpy-specific trailing underscore + 'bool_': bool_, +} + + +_python_types = { + int: int64, + float: float64, + complex: complex128, + builtins.bool: bool_, + # also allow stringified names of python types + int.__name__ : int64, + float.__name__ : float64, + complex.__name__: complex128, + builtins.bool.__name__ : bool_, +} + + +def sctype_from_string(s): + """Normalize a string value: a type 'name' or a typecode or a width alias. + """ + if s in _names: + return _names[s] + if s in _typecodes: + return _typecodes[s] + if s in _aliases: + return _aliases[s] + if s in _python_types: + return _python_types[s] + raise TypeError(f"data type '{s}' not understood") + + +def sctype_from_torch_dtype(torch_dtype): + return _torch_dtypes[torch_dtype] + + +#### default : mimic NumPy +default_scalar_type = float64 +default_int_type = int64 +default_float_type = float64 +default_complex_type = complex128 +########################## + + +def get_default_type_for(sctype): + """Default scalar type given sctype category.""" + if issubclass(sctype, integer): + result = default_int_type + elif issubclass(sctype, floating): + result = default_float_type + elif issubclass(sctype, complexfloating): + result = default_complex_type + elif issubclass(sctype, bool_): + result = bool_ + else: + raise RuntimeError("cannot be here with sctype= %s" % sctype) + return result + + +# XXX: is it ever used? cf _detail/reductions.py::_atleast_float(...) +def float_or_default(sctype, enforce_float=False): + """bool -> int; int -> float""" + if issubclass(sctype, bool_): + sctype = default_int_type + if enforce_float and issubclass(sctype, integer): + sctype = default_float_type + return sctype + + +from . import _casting_dicts as _cd + +def _can_cast_sctypes(from_sctype, to_sctype, casting): + return _can_cast_impl(from_sctype.torch_dtype, to_sctype.torch_dtype, casting) + + +def _can_cast_impl(from_torch_dtype, to_torch_dtype, casting): + return _cd._can_cast_dict[casting][from_torch_dtype][to_torch_dtype] + + + +__all__ = list(_names.keys()) +__all__ += ['intp', 'int_', 'intc', 'byte', 'short', 'longlong', + 'ubyte', 'half', 'single', 'double', 'csingle', 'cdouble', 'float_'] +__all__ += ['sctypes'] +__all__ += ['generic', 'number', + 'integer', 'signedinteger', 'unsignedinteger', + 'inexact', 'floating', 'complexfloating'] + + + + + diff --git a/torch_np/_detail/_ufunc_impl.py b/torch_np/_detail/_ufunc_impl.py new file mode 100644 index 00000000..ff65b0d3 --- /dev/null +++ b/torch_np/_detail/_ufunc_impl.py @@ -0,0 +1,129 @@ +import torch + +from . import _util + +def deco_ufunc(torch_func): + """Common infra for binary ufuncs: receive tensors, sort out type casting, + broadcasting, and delegate to the pytorch function for actual work. + + + Converting array-likes into arrays, unwrapping them into tensors etc + is the caller responsibility. + """ + def wrapped(tensors, /, out_shape_dtype=None , *, where=True, + casting='same_kind', order='K', dtype=None, subok=False, **kwds): + _util.subok_not_ok(subok=subok) + if order != 'K' or not where: + raise NotImplementedError + + # XXX: dtype=... parameter + if dtype is not None: + raise NotImplementedError + + tensors = _util.cast_and_broadcast(tensors, out_shape_dtype, casting) + + result = torch_func(*tensors) + return result + return wrapped + + +# binary ufuncs: the list is autogenerated, cf autogen/gen_ufunc_2.py +# And edited manually! np.equal <--> torch.eq, not torch.equal +add = deco_ufunc(torch.add) +arctan2 = deco_ufunc(torch.arctan2) +bitwise_and = deco_ufunc(torch.bitwise_and) +bitwise_or = deco_ufunc(torch.bitwise_or) +bitwise_xor = deco_ufunc(torch.bitwise_xor) +copysign = deco_ufunc(torch.copysign) +divide = deco_ufunc(torch.divide) +equal = deco_ufunc(torch.eq) +float_power = deco_ufunc(torch.float_power) +floor_divide = deco_ufunc(torch.floor_divide) +fmax = deco_ufunc(torch.fmax) +fmin = deco_ufunc(torch.fmin) +fmod = deco_ufunc(torch.fmod) +gcd = deco_ufunc(torch.gcd) +greater = deco_ufunc(torch.greater) +greater_equal = deco_ufunc(torch.greater_equal) +heaviside = deco_ufunc(torch.heaviside) +hypot = deco_ufunc(torch.hypot) +lcm = deco_ufunc(torch.lcm) +ldexp = deco_ufunc(torch.ldexp) +left_shift = deco_ufunc(torch.bitwise_left_shift) +less = deco_ufunc(torch.less) +less_equal = deco_ufunc(torch.less_equal) +logaddexp = deco_ufunc(torch.logaddexp) +logaddexp2 = deco_ufunc(torch.logaddexp2) +logical_and = deco_ufunc(torch.logical_and) +logical_or = deco_ufunc(torch.logical_or) +logical_xor = deco_ufunc(torch.logical_xor) +matmul = deco_ufunc(torch.matmul) +maximum = deco_ufunc(torch.maximum) +minimum = deco_ufunc(torch.minimum) +remainder = deco_ufunc(torch.remainder) +multiply = deco_ufunc(torch.multiply) +nextafter = deco_ufunc(torch.nextafter) +not_equal = deco_ufunc(torch.not_equal) +power = deco_ufunc(torch.pow) +remainder = deco_ufunc(torch.remainder) +right_shift = deco_ufunc(torch.bitwise_right_shift) +subtract = deco_ufunc(torch.subtract) +divide = deco_ufunc(torch.divide) + + +# unary ufuncs: the list is autogenerated, cf autogen/gen_ufunc_2.py +absolute = deco_ufunc(torch.absolute) +#absolute = deco_ufunc(torch.absolute) +arccos = deco_ufunc(torch.arccos) +arccosh = deco_ufunc(torch.arccosh) +arcsin = deco_ufunc(torch.arcsin) +arcsinh = deco_ufunc(torch.arcsinh) +arctan = deco_ufunc(torch.arctan) +arctanh = deco_ufunc(torch.arctanh) +ceil = deco_ufunc(torch.ceil) +conjugate = deco_ufunc(torch.conj_physical) +#conjugate = deco_ufunc(torch.conj_physical) +cos = deco_ufunc(torch.cos) +cosh = deco_ufunc(torch.cosh) +deg2rad = deco_ufunc(torch.deg2rad) +degrees = deco_ufunc(torch.rad2deg) +exp = deco_ufunc(torch.exp) +exp2 = deco_ufunc(torch.exp2) +expm1 = deco_ufunc(torch.expm1) +fabs = deco_ufunc(torch.absolute) +floor = deco_ufunc(torch.floor) +isfinite = deco_ufunc(torch.isfinite) +isinf = deco_ufunc(torch.isinf) +isnan = deco_ufunc(torch.isnan) +log = deco_ufunc(torch.log) +log10 = deco_ufunc(torch.log10) +log1p = deco_ufunc(torch.log1p) +log2 = deco_ufunc(torch.log2) +logical_not = deco_ufunc(torch.logical_not) +negative = deco_ufunc(torch.negative) +rad2deg = deco_ufunc(torch.rad2deg) +radians = deco_ufunc(torch.deg2rad) +reciprocal = deco_ufunc(torch.reciprocal) +rint = deco_ufunc(torch.round) +sign = deco_ufunc(torch.sign) +signbit = deco_ufunc(torch.signbit) +sin = deco_ufunc(torch.sin) +sinh = deco_ufunc(torch.sinh) +sqrt = deco_ufunc(torch.sqrt) +square = deco_ufunc(torch.square) +tan = deco_ufunc(torch.tan) +tanh = deco_ufunc(torch.tanh) +trunc = deco_ufunc(torch.trunc) + +invert = deco_ufunc(torch.bitwise_not) + +# special cases: torch does not export these names +def _cbrt(x): + return torch.pow(x, 1/3) + +def _positive(x): + return +x + +cbrt = deco_ufunc(_cbrt) +positive = deco_ufunc(_positive) + diff --git a/torch_np/_detail/_util.py b/torch_np/_detail/_util.py new file mode 100644 index 00000000..7d2c4ac0 --- /dev/null +++ b/torch_np/_detail/_util.py @@ -0,0 +1,286 @@ +"""Assorted utilities, which do not need anything other then torch and stdlib. +""" + +import operator + +import torch + +from . import _scalar_types + +# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504 +def is_sequence(seq): + if isinstance(seq, str): + return False + try: + len(seq) + except Exception: + return False + return True + + +def subok_not_ok(like=None, subok=False): + if like is not None: + raise ValueError("like=... parameter is not supported.") + if subok: + raise ValueError("subok parameter is not supported.") + + +class AxisError(ValueError, IndexError): + pass + + +class UFuncTypeError(TypeError, RuntimeError): + pass + + +# a replica of the version in ./numpy/numpy/core/src/multiarray/common.h +def normalize_axis_index(ax, ndim, argname=None): + if not (-ndim <= ax < ndim): + raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}") + if ax < 0: + ax += ndim + return ax + + +# from https://github.com/numpy/numpy/blob/main/numpy/core/numeric.py#L1378 +def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): + """ + Normalizes an axis argument into a tuple of non-negative integer axes. + + This handles shorthands such as ``1`` and converts them to ``(1,)``, + as well as performing the handling of negative indices covered by + `normalize_axis_index`. + + By default, this forbids axes from being specified multiple times. + Used internally by multi-axis-checking logic. + + Parameters + ---------- + axis : int, iterable of int + The un-normalized index or indices of the axis. + ndim : int + The number of dimensions of the array that `axis` should be normalized + against. + argname : str, optional + A prefix to put before the error message, typically the name of the + argument. + allow_duplicate : bool, optional + If False, the default, disallow an axis from being specified twice. + + Returns + ------- + normalized_axes : tuple of int + The normalized axis index, such that `0 <= normalized_axis < ndim` + """ + # Optimization to speed-up the most common cases. + if type(axis) not in (tuple, list): + try: + axis = [operator.index(axis)] + except TypeError: + pass + # Going via an iterator directly is slower than via list comprehension. + axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis]) + if not allow_duplicate and len(set(axis)) != len(axis): + if argname: + raise ValueError("repeated axis in `{}` argument".format(argname)) + else: + raise ValueError("repeated axis") + return axis + + +def allow_only_single_axis(axis): + if axis is None: + return axis + if len(axis) != 1: + raise NotImplementedError("does not handle tuple axis") + return axis[0] + + +def expand_shape(arr_shape, axis): + # taken from numpy 1.23.x, expand_dims function + if type(axis) not in (list, tuple): + axis = (axis,) + out_ndim = len(axis) + len(arr_shape) + axis = normalize_axis_tuple(axis, out_ndim) + shape_it = iter(arr_shape) + shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)] + return shape + + +def apply_keepdims(tensor, axis, ndim): + if axis is None: + # tensor was a scalar + tensor = torch.full((1,)*ndim, fill_value=tensor, dtype=tensor.dtype) + else: + shape = expand_shape(tensor.shape, axis) + tensor = tensor.reshape(shape) + return tensor + + +def axis_none_ravel(*tensors, axis=None): + """Ravel the arrays if axis is none.""" + # XXX: is only used at `concatenate`. Inline unless reused more widely + if axis is None: + tensors = tuple(ar.ravel() for ar in tensors) + return tensors, 0 + else: + return tensors, axis + + +def cast_dont_broadcast(tensors, target_dtype, casting): + """Dtype-cast tensors to target_dtype. + + Parameters + ---------- + tensors : iterable + tuple or list of torch.Tensors to typecast + target_dtype : torch dtype object, optional + The array dtype to cast all tensors to + casting : str + The casting mode, see `np.can_cast` + + Returns + ------- + a tuple of torch.Tensors with dtype being the PyTorch counterpart + of the `target_dtype` + """ + # check if we can dtype-cast all arguments + cast_tensors = [] + can_cast = _scalar_types._can_cast_impl + + for tensor in tensors: + if not can_cast(tensor.dtype, target_dtype, casting=casting): + raise TypeError(f"Cannot cast array data from {tensor.dtype} to" + f" {target_dtype} according to the rule '{casting}'") + + # cast if needed + if tensor.dtype != target_dtype: + tensor = tensor.to(target_dtype) + cast_tensors.append(tensor) + + return tuple(cast_tensors) + + +def cast_and_broadcast(tensors, out_param, casting): + """ + Parameters + ---------- + tensors : iterable + tuple or list of torch.Tensors to broadcast/typecast + target_dtype : a torch.dtype object + The torch dtype to cast all tensors to + target_shape : tuple + The tensor shape to broadcast all `tensors` to + casting : str + The casting mode, see `np.can_cast` + + Returns + ------- + a tuple of torch.Tensors with dtype being the PyTorch counterpart + of the `target_dtype` and `target_shape` + """ + if out_param is None: + return tensors + + target_dtype, target_shape = out_param + + can_cast = _scalar_types._can_cast_impl + + processed_tensors = [] + for tensor in tensors: + # check dtypes of x and out + if not can_cast(tensor.dtype, target_dtype, casting=casting): + raise TypeError(f"Cannot cast array data from {tensor.dtype} to" + f" {target_dtype} according to the rule '{casting}'") + + # cast arr if needed + if tensor.dtype != target_dtype: + tensor = tensor.to(target_dtype) + + # `out` broadcasts `tensor` + if tensor.shape != target_shape: + tensor = torch.broadcast_to(tensor, target_shape) + + processed_tensors.append(tensor) + + return tuple(processed_tensors) + + +def axis_keepdims(func, tensor, axis, keepdims, *args, **kwds): + """Generically handle axis and keepdims arguments in reductions.""" + if axis is not None: + if not isinstance(axis, (list, tuple)): + axis = (axis,) + axis = normalize_axis_tuple(axis, tensor.ndim) + + if axis == (): + newshape = expand_shape(tensor.shape, axis=0) + tensor = tensor.reshape(newshape) + axis = (0,) + + result = func(tensor, axis=axis, *args, **kwds) + + if keepdims: + result = apply_keepdims(result, axis, tensor.ndim) + + return result + + +def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0): + """The core logic of the array(...) function. + + Parameters + ---------- + obj : tensor_like + The thing to coerce + dtype : torch.dtype object or None + Coerce to this torch dtype + copy : bool + Copy or not + + Returns + ------- + tensor : torch.Tensor + a tensor object with requested dtype, ndim and copy semantics. + + Notes + ----- + This is almost a "tensor_like" coersion function. Does not handle wrapper + ndarrays (those should be handled in the ndarray-aware layer prior to + invoking this function). + """ + if isinstance(obj, torch.Tensor): + tensor = obj + base = None + else: + tensor = torch.as_tensor(obj) + base = None + + # At this point, `tensor.dtype` is the pytorch default. Our default may + # differ, so need to typecast. However, we cannot just do `tensor.to`, + # because if our desired dtype is wider then pytorch's, `tensor` + # may have lost precision: + + # int(torch.as_tensor(1e12)) - 1e12 equals -4096 (try it!) + + # Therefore, we treat `tensor.dtype` as a hint, and convert the + # original object *again*, this time with an explicit dtype. + sctype = _scalar_types.get_default_type_for(_scalar_types.sctype_from_torch_dtype(tensor.dtype)) + torch_dtype = sctype.torch_dtype + + tensor = torch.as_tensor(obj, dtype=torch_dtype) + + # type cast if requested + if dtype is not None: + tensor = tensor.to(dtype) + + # adjust ndim if needed + ndim_extra = ndmin - tensor.ndim + if ndim_extra > 0: + tensor = tensor.view((1,)*ndim_extra + tensor.shape) + + # copy if requested + if copy: + tensor = tensor.clone() + + return tensor diff --git a/torch_np/_dtypes.py b/torch_np/_dtypes.py index 66c07267..29085a91 100644 --- a/torch_np/_dtypes.py +++ b/torch_np/_dtypes.py @@ -8,59 +8,72 @@ import builtins import torch -from . import _scalar_types +from ._detail import _scalar_types -__all__ = ['dtype_from_torch', 'dtype', 'typecodes', 'issubdtype'] +__all__ = ['dtype', 'DType', 'typecodes', 'issubdtype'] # Define analogs of numpy dtypes supported by pytorch. -class dtype: - def __init__(self, name, /): - if isinstance(name, dtype): - _name = name.name - elif hasattr(name, 'dtype'): - _name = name.dtype.name - elif name in python_types_dict: - _name = python_types_dict[name] - elif name in dt_names: - _name = name - elif name in typecode_chars_dict: - _name = typecode_chars_dict[name] - elif name in dt_aliases_dict: - _name = dt_aliases_dict[name] - # the check must come last, so that 'name' is not a string - elif issubclass(name, _scalar_types.generic): - _name = name.name + +def dtype(arg): + if arg is None: + arg = _scalar_types.default_scalar_type + return DType(arg) + + +def torch_dtype_from(dtype_arg): + return dtype(dtype_arg).torch_dtype + + +class DType: + def __init__(self, arg): + # a pytorch object? + if isinstance(arg, torch.dtype): + sctype = _scalar_types._torch_dtypes[arg] + elif isinstance(arg, torch.Tensor): + sctype = _scalar_types._torch_dtypes[arg.dtype] + # a scalar type? + elif issubclass_(arg, _scalar_types.generic): + sctype = arg + # a dtype already? + elif isinstance(arg, DType): + sctype = arg._scalar_type + # a has a right attribute? + elif hasattr(arg, 'dtype'): + sctype = arg.dtype._scalar_type else: - raise TypeError(f"data type '{name}' not understood") - self._name = _name + sctype = _scalar_types.sctype_from_string(arg) + self._scalar_type = sctype @property def name(self): - return self._name + return self._scalar_type.name @property def type(self): - return _scalar_types._typemap[self._name] + return self._scalar_type @property def typecode(self): - return _typecodes_from_dtype_dict[self._name] + return self._scalar_type.typecode def __eq__(self, other): - if isinstance(other, dtype): - return self._name == other.name - else: - try: - other_instance = dtype(other) - except TypeError: - return False - return self._name == other_instance.name + if isinstance(other, DType): + return self._scalar_type == other._scalar_type + try: + other_instance = DType(other) + except TypeError: + return False + return self._scalar_type == other_instance._scalar_type + + @property + def torch_dtype(self): + return self._scalar_type.torch_dtype def __hash__(self): - return hash(self._name) + return hash(self._scalar_type.name) def __repr__(self): return f'dtype("{self.name}")' @@ -73,68 +86,10 @@ def itemsize(self): return elem.get().element_size() def __getstate__(self): - return self._name + return self._scalar_type def __setstate__(self, value): - self._name = value - - - -dt_names = ['float16', 'float32', 'float64', - 'complex64', 'complex128', - 'uint8', - 'int8', - 'int16', - 'int32', - 'int64', - 'bool'] - - -dt_aliases_dict = { - 'u1' : 'uint8', - 'i1' : 'int8', - 'i2' : 'int16', - 'i4' : 'int32', - 'i8' : 'int64', - 'b' : 'int8', # XXX: srsly? - 'f2' : 'float16', - 'f4' : 'float32', - 'f8' : 'float64', - 'c8' : 'complex64', - 'c16': 'complex128', - '?' : 'bool', -} - - -python_types_dict = { - int: 'int64', - float: 'float64', - complex: 'complex128', - builtins.bool: 'bool', - # also allow stringified names of python types - int.__name__ : 'int64', - float.__name__ : 'float64', - complex.__name__: 'complex128', -} - - -typecode_chars_dict = { - 'e': 'float16', - 'f': 'float32', - 'd': 'float64', - 'F': 'complex64', - 'D': 'complex128', - 'B': 'uint8', - 'b': 'int8', - 'h': 'int16', - 'i': 'int32', - 'l': 'int64', - '?': 'bool' -} - -# reverse mapping -_typecodes_from_dtype_dict = {typecode_chars_dict[key]: key - for key in typecode_chars_dict} + self._scalar_type = value typecodes = {'All': 'efdFDBbhil?', @@ -147,65 +102,19 @@ def __setstate__(self, value): } -# Map the torch-suppored subset dtypes to local analogs -# "quantized" types not available in numpy, skip -_dtype_from_torch_dict = { - # floating-point - torch.float16: 'float16', - torch.float32: 'float32', - torch.float64 : 'float64', - # np.complex32 does not exist - torch.complex64: 'complex64', - torch.complex128: 'complex128', - # integer, unsigned (unit8 only, torch.uint32 etc do not exist) - torch.uint8: 'uint8', - # integer - torch.int8: 'int8', - torch.int16: 'int16', - torch.int32: 'int32', - torch.int64: 'int64', - # boolean - torch.bool : 'bool' -} - - -# reverse mapping -_torch_dtype_from_dtype_dict = {_dtype_from_torch_dict[key]: key - for key in _dtype_from_torch_dict} - - -def dtype_from_torch(torch_dtype): - try: - name = _dtype_from_torch_dict[torch_dtype] - return dtype(name) - except KeyError: - # mimic numpy: >>> np.dtype('unknown') --> TypeError - raise TypeError - - -def torch_dtype_from(dtyp): - if dtyp is None: - return None - name = dtype(dtyp).name - try: - return _torch_dtype_from_dtype_dict[name] - except KeyError: - # mimic numpy: >>> np.dtype('unknown') --> TypeError - raise TypeError - # ### Defaults and dtype discovery def default_int_type(): - return dtype('int64') + return dtype(_scalar_types.default_int_type) def default_float_type(): - return dtype('float64') + return dtype(_scalar_types.default_float_type) def default_complex_type(): - return dtype('complex128') + return dtype(_scalar_types.default_complex_type) def is_floating(dtyp): @@ -219,18 +128,8 @@ def is_integer(dtyp): def get_default_dtype_for(dtyp): - typ = dtype(dtyp).type - if issubclass(typ, _scalar_types.integer): - result = default_int_type() - elif issubclass(typ, _scalar_types.floating): - result = default_float_type() - elif issubclass(typ, _scalar_types.complexfloating): - result = default_complex_type() - elif issubclass(typ, _scalar_types.bool_): - result = dtype('bool') - else: - raise TypeError("dtype %s not understood." % dtyp) - return result + sctype = dtype(dtyp).type + return _scalar_types.get_default_type_for(sctype) def issubclass_(arg, klass): @@ -249,36 +148,13 @@ def issubdtype(arg1, arg2): return issubclass(arg1, arg2) -# The casting below is defined *with dtypes only*, so no value-based casting! - -# These two dicts are autogenerated with autogen/gen_dtypes.py, -# using numpy version 1.23.5. - -_can_cast_dict = { -'no': {'float16': {'float16': True, 'float32': False, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'float32': {'float16': False, 'float32': True, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'float64': {'float16': False, 'float32': False, 'float64': True, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'complex64': {'float16': False, 'float32': False, 'float64': False, 'complex64': True, 'complex128': False, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'complex128': {'float16': False, 'float32': False, 'float64': False, 'complex64': False, 'complex128': True, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'uint8': {'float16': False, 'float32': False, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': True, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'int8': {'float16': False, 'float32': False, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': True, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'int16': {'float16': False, 'float32': False, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': False, 'int16': True, 'int32': False, 'int64': False, 'bool': False}, 'int32': {'float16': False, 'float32': False, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': False, 'int16': False, 'int32': True, 'int64': False, 'bool': False}, 'int64': {'float16': False, 'float32': False, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': True, 'bool': False}, 'bool': {'float16': False, 'float32': False, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': True}}, - -'equiv': {'float16': {'float16': True, 'float32': False, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'float32': {'float16': False, 'float32': True, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'float64': {'float16': False, 'float32': False, 'float64': True, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'complex64': {'float16': False, 'float32': False, 'float64': False, 'complex64': True, 'complex128': False, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'complex128': {'float16': False, 'float32': False, 'float64': False, 'complex64': False, 'complex128': True, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'uint8': {'float16': False, 'float32': False, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': True, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'int8': {'float16': False, 'float32': False, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': True, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'int16': {'float16': False, 'float32': False, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': False, 'int16': True, 'int32': False, 'int64': False, 'bool': False}, 'int32': {'float16': False, 'float32': False, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': False, 'int16': False, 'int32': True, 'int64': False, 'bool': False}, 'int64': {'float16': False, 'float32': False, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': True, 'bool': False}, 'bool': {'float16': False, 'float32': False, 'float64': False, 'complex64': False, 'complex128': False, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': True}}, +def can_cast(from_dtype, to_dtype, casting): + from_sctype = dtype(from_dtype).type.torch_dtype + to_sctype = dtype(to_dtype).type.torch_dtype -'safe': {'float16': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'float32': {'float16': False, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'float64': {'float16': False, 'float32': False, 'float64': True, 'complex64': False, 'complex128': True, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'complex64': {'float16': False, 'float32': False, 'float64': False, 'complex64': True, 'complex128': True, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'complex128': {'float16': False, 'float32': False, 'float64': False, 'complex64': False, 'complex128': True, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'uint8': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': True, 'int8': False, 'int16': True, 'int32': True, 'int64': True, 'bool': False}, 'int8': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': False, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': False}, 'int16': {'float16': False, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': False, 'int8': False, 'int16': True, 'int32': True, 'int64': True, 'bool': False}, 'int32': {'float16': False, 'float32': False, 'float64': True, 'complex64': False, 'complex128': True, 'uint8': False, 'int8': False, 'int16': False, 'int32': True, 'int64': True, 'bool': False}, 'int64': {'float16': False, 'float32': False, 'float64': True, 'complex64': False, 'complex128': True, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': True, 'bool': False}, 'bool': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': True, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': True}}, - -'same_kind': {'float16': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'float32': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'float64': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'complex64': {'float16': False, 'float32': False, 'float64': False, 'complex64': True, 'complex128': True, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'complex128': {'float16': False, 'float32': False, 'float64': False, 'complex64': True, 'complex128': True, 'uint8': False, 'int8': False, 'int16': False, 'int32': False, 'int64': False, 'bool': False}, 'uint8': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': True, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': False}, 'int8': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': False, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': False}, 'int16': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': False, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': False}, 'int32': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': False, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': False}, 'int64': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': False, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': False}, 'bool': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': True, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': True}}, - -'unsafe': {'float16': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': True, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': True}, 'float32': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': True, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': True}, 'float64': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': True, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': True}, 'complex64': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': True, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': True}, 'complex128': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': True, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': True}, 'uint8': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': True, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': True}, 'int8': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': True, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': True}, 'int16': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': True, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': True}, 'int32': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': True, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': True}, 'int64': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': True, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': True}, 'bool': {'float16': True, 'float32': True, 'float64': True, 'complex64': True, 'complex128': True, 'uint8': True, 'int8': True, 'int16': True, 'int32': True, 'int64': True, 'bool': True}} -} + return _scalar_types._can_cast_impl(from_sctype, to_sctype, casting) -_result_type_dict = { -'float16': {'float16': 'float16', 'float32': 'float32', 'float64': 'float64', 'complex64': 'complex64', 'complex128': 'complex128', 'uint8': 'float16', 'int8': 'float16', 'int16': 'float32', 'int32': 'float64', 'int64': 'float64', 'bool': 'float16'}, -'float32': {'float16': 'float32', 'float32': 'float32', 'float64': 'float64', 'complex64': 'complex64', 'complex128': 'complex128', 'uint8': 'float32', 'int8': 'float32', 'int16': 'float32', 'int32': 'float64', 'int64': 'float64', 'bool': 'float32'}, -'float64': {'float16': 'float64', 'float32': 'float64', 'float64': 'float64', 'complex64': 'complex128', 'complex128': 'complex128', 'uint8': 'float64', 'int8': 'float64', 'int16': 'float64', 'int32': 'float64', 'int64': 'float64', 'bool': 'float64'}, -'complex64': {'float16': 'complex64', 'float32': 'complex64', 'float64': 'complex128', 'complex64': 'complex64', 'complex128': 'complex128', 'uint8': 'complex64', 'int8': 'complex64', 'int16': 'complex64', 'int32': 'complex128', 'int64': 'complex128', 'bool': 'complex64'}, -'complex128': {'float16': 'complex128', 'float32': 'complex128', 'float64': 'complex128', 'complex64': 'complex128', 'complex128': 'complex128', 'uint8': 'complex128', 'int8': 'complex128', 'int16': 'complex128', 'int32': 'complex128', 'int64': 'complex128', 'bool': 'complex128'}, -'uint8': {'float16': 'float16', 'float32': 'float32', 'float64': 'float64', 'complex64': 'complex64', 'complex128': 'complex128', 'uint8': 'uint8', 'int8': 'int16', 'int16': 'int16', 'int32': 'int32', 'int64': 'int64', 'bool': 'uint8'}, -'int8': {'float16': 'float16', 'float32': 'float32', 'float64': 'float64', 'complex64': 'complex64', 'complex128': 'complex128', 'uint8': 'int16', 'int8': 'int8', 'int16': 'int16', 'int32': 'int32', 'int64': 'int64', 'bool': 'int8'}, -'int16': {'float16': 'float32', 'float32': 'float32', 'float64': 'float64', 'complex64': 'complex64', 'complex128': 'complex128', 'uint8': 'int16', 'int8': 'int16', 'int16': 'int16', 'int32': 'int32', 'int64': 'int64', 'bool': 'int16'}, -'int32': {'float16': 'float64', 'float32': 'float64', 'float64': 'float64', 'complex64': 'complex128', 'complex128': 'complex128', 'uint8': 'int32', 'int8': 'int32', 'int16': 'int32', 'int32': 'int32', 'int64': 'int64', 'bool': 'int32'}, -'int64': {'float16': 'float64', 'float32': 'float64', 'float64': 'float64', 'complex64': 'complex128', 'complex128': 'complex128', 'uint8': 'int64', 'int8': 'int64', 'int16': 'int64', 'int32': 'int64', 'int64': 'int64', 'bool': 'int64'}, -'bool': {'float16': 'float16', 'float32': 'float32', 'float64': 'float64', 'complex64': 'complex64', 'complex128': 'complex128', 'uint8': 'uint8', 'int8': 'int8', 'int16': 'int16', 'int32': 'int32', 'int64': 'int64', 'bool': 'bool'}} - -########################## end autogenerated part +# XXX : used in _ndarray.py/result_type, clean up +from ._detail._casting_dicts import _result_type_dict diff --git a/torch_np/_helpers.py b/torch_np/_helpers.py index 201f8166..58956946 100644 --- a/torch_np/_helpers.py +++ b/torch_np/_helpers.py @@ -2,10 +2,12 @@ import torch from . import _dtypes -from ._ndarray import can_cast, ndarray, asarray -from . import _util +from ._ndarray import ndarray, asarray -def cast_and_broadcast(arrays, out, casting): +from ._detail import _util + + +def cast_and_broadcast(tensors, out, casting): """Cast dtypes of arrays to out.dtype and broadcast if needed. Parameters @@ -28,68 +30,23 @@ def cast_and_broadcast(arrays, out, casting): """ if out is None: - return tuple(arr.get() for arr in arrays) + return tensors else: if not isinstance(out, ndarray): raise TypeError("Return arrays must be of ArrayType") - tensors = [] - for arr in arrays: - # check dtypes of x and out - if not can_cast(arr.dtype, out.dtype, casting=casting): - raise TypeError(f"Cannot cast array data from {arr.dtype} to" - " {out_dtype} according to the rule '{casting}'") - tensor = arr.get() - - # `out` broadcasts `arr` - if arr.shape != out.shape: - tensor = torch.broadcast_to(tensor, out.shape) - - # cast arr if needed - if arr.dtype != out.dtype: - tensor = tensor.to(_dtypes.torch_dtype_from(out.dtype)) - - tensors.append(tensor) - - return tuple(tensors) - - - -def cast_dont_broadcast(arrays, out_dtype, casting): - """Dtype-cast arrays to dtype. - """ - # check if we can dtype-cast all arguments - tensors = [] - for arr in arrays: - if not can_cast(arr.dtype, out_dtype, casting=casting): - raise TypeError(f"Cannot cast array data from {arr.dtype} to" - " {out_dtype} according to the rule '{casting}'") - tensor = arr.get() - - # cast arr if needed - if arr.dtype != out_dtype: - tensor = tensor.to(_dtypes.torch_dtype_from(out_dtype)) - - tensors.append(tensor) + tensors = _util.cast_and_broadcast(tensors, out.dtype.type.torch_dtype, out.shape, casting) return tuple(tensors) - -def axis_none_ravel(*arrays, axis=None): - """Ravel the arrays if axis is none.""" - if axis is None: - arrays = tuple(ar.ravel() for ar in arrays) - return arrays, 0 - else: - return arrays, axis - - def result_or_out(result_tensor, out_array=None): """A helper for returns with out= argument.""" if out_array is not None: + if not isinstance(out_array, ndarray): + raise TypeError("Return arrays must be of ArrayType") if result_tensor.shape != out_array.shape: - raise ValueError + raise ValueError("Bad size of the out array.") out_tensor = out_array.get() out_tensor.copy_(result_tensor) return out_array @@ -97,50 +54,15 @@ def result_or_out(result_tensor, out_array=None): return asarray(result_tensor) -def apply_keepdims(tensor, axis, ndim): - if axis is None: - # tensor was a scalar - tensor = torch.full((1,)*ndim, fill_value=tensor) - else: - shape = _util.expand_shape(tensor.shape, axis) - tensor = tensor.reshape(shape) - return tensor - - -def standardize_axis_arg(axis, ndim): - """Return axis as either None or a tuple of normalized axes.""" - if isinstance(axis, ndarray): - axis = operator.index(axis) - - if axis is not None: - if not isinstance(axis, (list, tuple)): - axis = (axis,) - axis = _util.normalize_axis_tuple(axis, ndim) - return axis - - -def allow_only_single_axis(axis): - if axis is None: - return axis - if len(axis) != 1: - raise NotImplementedError("does not handle tuple axis") - return axis[0] - - -def to_tensors(*inputs): - """Convert all ndarrays from `inputs` to tensors.""" +def ndarrays_to_tensors(*inputs): + """Convert all ndarrays from `inputs` to tensors. (other things are intact) + """ return tuple([value.get() if isinstance(value, ndarray) else value for value in inputs]) -def float_or_default(dtype, self_dtype, enforce_float=False): - """dtype helper for reductions.""" - if dtype is None: - dtype = self_dtype - if dtype == _dtypes.dtype('bool'): - dtype = _dtypes.default_int_type() - if enforce_float: - if _dtypes.is_integer(dtype): - dtype = _dtypes.default_float_type() - torch_dtype = _dtypes.torch_dtype_from(dtype) - return torch_dtype +def to_tensors(*inputs): + """Convert all array_likes from `inputs` to tensors. + """ + return tuple(asarray(value).get() for value in inputs) + diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 0cd387f8..e01fd927 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -2,39 +2,18 @@ import torch -from . import _util +from ._detail import _util +from ._detail import _reductions from . import _helpers from . import _dtypes -from . import _ufunc_impl +from . import _unary_ufuncs +from . import _binary_ufuncs -NoValue = None -newaxis = None - - -def axis_out_keepdims_wrapper(func): - """`func` accepts an array-like as a 1st arg, returns a tensor. - - This decorator implements the generic handling of axis, out and keepdims - arguments for reduction functions. - """ - # XXX: move this out of _ndarray.py (circular imports) - @functools.wraps(func) - def wrapped(a, axis=None, out=None, keepdims=NoValue, *args, **kwds): - arr = asarray(a) - axis = _helpers.standardize_axis_arg(axis, arr.ndim) - - if axis == (): - newshape = _util.expand_shape(arr.shape, axis=0) - arr = arr.reshape(newshape) - axis = (0,) +from ._decorators import emulate_out_arg, axis_keepdims_wrapper, dtype_to_torch - result = func(arr, axis=axis, *args, **kwds) +from ._decorators import NoValue - if keepdims: - result = _helpers.apply_keepdims(result, axis, arr.ndim) - return _helpers.result_or_out(result, out) - - return wrapped +newaxis = None ##################### ndarray class ########################### @@ -68,7 +47,7 @@ def ndim(self): @property def dtype(self): - return _dtypes.dtype_from_torch(self._tensor.dtype) + return _dtypes.dtype(self._tensor.dtype) @property def strides(self): @@ -127,31 +106,24 @@ def __str__(self): ### comparisons ### def __eq__(self, other): try: - return _ufunc_impl.equal(self, asarray(other)) - except RuntimeError: + return _binary_ufuncs.equal(self, other) + except (RuntimeError, TypeError): # Failed to convert other to array: definitely not equal. falsy = torch.full(self.shape, fill_value=False, dtype=bool) return asarray(falsy) def __neq__(self, other): try: - return _ufunc_impl.not_equal(self, asarray(other)) - except RuntimeError: + return _binary_ufuncs.not_equal(self, other) + except (RuntimeError, TypeError): # Failed to convert other to array: definitely not equal. falsy = torch.full(self.shape, fill_value=True, dtype=bool) return asarray(falsy) - def __gt__(self, other): - return _ufunc_impl.greater(self, asarray(other)) - - def __lt__(self, other): - return _ufunc_impl.less(self, asarray(other)) - - def __ge__(self, other): - return _ufunc_impl.greater_equal(self, asarray(other)) - - def __le__(self, other): - return _ufunc_impl.less_equal(self, asarray(other)) + __gt__ = _binary_ufuncs.greater + __lt__ = _binary_ufuncs.less + __ge__ = _binary_ufuncs.greater_equal + __le__ = _binary_ufuncs.less_equal def __bool__(self): try: @@ -167,10 +139,6 @@ def __index__(self): mesg = "only integer scalar arrays can be converted to a scalar index" raise TypeError(mesg) - # HACK : otherwise cannot check array.dtype in _dtypes.dict - def __hash__(self): - return id(self) - def __float__(self): return float(self._tensor) @@ -194,129 +162,81 @@ def __len__(self): ### arithmetic ### # add, self + other - def __add__(self, other): - return _ufunc_impl.add(self, asarray(other)) - - def __radd__(self, other): - return _ufunc_impl.add(self, asarray(other)) + __add__ = __radd__ = _binary_ufuncs.add def __iadd__(self, other): - return _ufunc_impl.add(self, asarray(other), out=self) + return _binary_ufuncs.add(self, other, out=self) # sub, self - other - def __sub__(self, other): - return _ufunc_impl.subtract(self, asarray(other)) - - def __rsub__(self, other): - return _ufunc_impl.subtract(self, asarray(other)) + __sub__ = __rsub__ = _binary_ufuncs.subtract def __isub__(self, other): - return _ufunc_impl.subtract(self, asarray(other), out=self) + return _binary_ufuncs.subtract(self, other, out=self) # mul, self * other - def __mul__(self, other): - return _ufunc_impl.multiply(self, asarray(other)) - - def __rmul__(self, other): - return _ufunc_impl.multiply(self, asarray(other)) + __mul__ = __rmul__ = _binary_ufuncs.multiply def __imul__(self, other): - return _ufunc_impl.multiply(self, asarray(other), out=self) + return _binary_ufuncs.multiply(self, other, out=self) # div, self / other - def __truediv__(self, other): - return _ufunc_impl.divide(self, asarray(other)) - - def __rtruediv__(self, other): - return _ufunc_impl.divide(self, asarray(other)) + __truediv__ = __rtruediv__ = _binary_ufuncs.divide def __itruediv__(self, other): - return _ufunc_impl.divide(self, asarray(other), out=self) + return _binary_ufuncs.divide(self, other, out=self) # floordiv, self // other - def __floordiv__(self, other): - return _ufunc_impl.floor_divide(self, asarray(other)) - - def __rfloordiv__(self, other): - return _ufunc_impl.floor_divide(self, asarray(other)) + __floordiv__ = __rfloordiv__ = _binary_ufuncs.floor_divide def __ifloordiv__(self, other): - return _ufunc_impl.floor_divide(self, asarray(other), out=self) + return _binary_ufuncs.floor_divide(self, other, out=self) # power, self**exponent - def __pow__(self, exponent): - return _ufunc_impl.float_power(self, asarray(exponent)) - - def __rpow__(self, exponent): - return _ufunc_impl.float_power(self, asarray(exponent)) + __pow__ = __rpow__ = _binary_ufuncs.float_power def __ipow__(self, exponent): - return _ufunc_impl.float_power(self, asarray(exponent), out=self) + return _binary_ufuncs.float_power(self, exponent, out=self) # remainder, self % other - def __mod__(self, other): - return _ufunc_impl.remainder(self, asarray(other)) - - def __rmod__(self, other): - return _ufunc_impl.remainder(self, asarray(other)) + __mod__ = __rmod__ = _binary_ufuncs.remainder def __imod__(self, other): - return _ufunc_impl.remainder(self, asarray(other), out=self) + return _binary_ufuncs.remainder(self, other, out=self) # bitwise ops # and, self & other - def __and__(self, other): - return _ufunc_impl.bitwise_and(self, asarray(other)) - - def __rand__(self, other): - return _ufunc_impl.bitwise_and(self, asarray(other)) + __and__ = __rand__ = _binary_ufuncs.bitwise_and def __iand__(self, other): - return _ufunc_impl.bitwise_and(self, asarray(other), out=self) + return _binary_ufuncs.bitwise_and(self, other, out=self) # or, self | other - def __or__(self, other): - return _ufunc_impl.bitwise_or(self, asarray(other)) - - def __ror__(self, other): - return _ufunc_impl.bitwise_or(self, asarray(other)) + __or__ = __ror__ = _binary_ufuncs.bitwise_or def __ior__(self, other): - return _ufunc_impl.bitwise_or(self, asarray(other), out=self) + return _binary_ufuncs.bitwise_or(self, other, out=self) # xor, self ^ other - def __xor__(self, other): - return _ufunc_impl.bitwise_xor(self, asarray(other)) - - def __rxor__(self, other): - return _ufunc_impl.bitwise_xor(self, asarray(other)) + __xor__ = __rxor__ = _binary_ufuncs.bitwise_xor def __ixor__(self, other): - return _ufunc_impl.bitwise_xor(self, asarray(other), out=self) + return _binary_ufuncs.bitwise_xor(self, other, out=self) # unary ops - def __invert__(self): - return _ufunc_impl.invert(self) - - def __abs__(self): - return _ufunc_impl.absolute(self) - - def __pos__(self): - return _ufunc_impl.positive(self) - - def __neg__(self): - return _ufunc_impl.negative(self) - + __invert__ = _unary_ufuncs.invert + __abs__ = _unary_ufuncs.absolute + __pos__ = _unary_ufuncs.positive + __neg__ = _unary_ufuncs.negative ### methods to match namespace functions @@ -329,17 +249,6 @@ def squeeze(self, axis=None): tensor = self._tensor.squeeze(axis) return ndarray._from_tensor_and_base(tensor, self) - @axis_out_keepdims_wrapper - def argmax(self, axis=None, out=None, *, keepdims=NoValue): - axis = _helpers.allow_only_single_axis(axis) - tensor = torch.argmax(self._tensor, axis) - return tensor - - @axis_out_keepdims_wrapper - def argmin(self, axis=None, out=None, *, keepdims=NoValue): - axis = _helpers.allow_only_single_axis(axis) - tensor = torch.argmin(self._tensor, axis) - return tensor def reshape(self, *shape, order='C'): newshape = shape[0] if len(shape) == 1 else shape @@ -369,131 +278,24 @@ def nonzero(self): tensor = self._tensor return tuple(asarray(_) for _ in tensor.nonzero(as_tuple=True)) - @axis_out_keepdims_wrapper - def any(self, axis=None, out=None, keepdims=NoValue, *, where=NoValue): - if where is not None: - raise NotImplementedError - - axis = _helpers.allow_only_single_axis(axis) - - if axis is None: - result = self._tensor.any() - else: - result = self._tensor.any(axis) - return result - - - @axis_out_keepdims_wrapper - def all(self, axis=None, out=None, keepdims=NoValue, *, where=NoValue): - if where is not None: - raise NotImplementedError - - axis = _helpers.allow_only_single_axis(axis) - - if axis is None: - result = self._tensor.all() - else: - result = self._tensor.all(axis) - return result - - - @axis_out_keepdims_wrapper - def max(self, axis=None, out=None, keepdims=NoValue, initial=NoValue, - where=NoValue): - if where is not None: - raise NotImplementedError - if initial is not None: - raise NotImplementedError - - result = self._tensor.amax(axis) - return result - - @axis_out_keepdims_wrapper - def min(self, axis=None, out=None, keepdims=NoValue, initial=NoValue, - where=NoValue): - if where is not None: - raise NotImplementedError - if initial is not None: - raise NotImplementedError - - result = self._tensor.amin(axis) - return result - - @axis_out_keepdims_wrapper - def mean(self, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoValue): - if where is not None: - raise NotImplementedError - - torch_dtype = _helpers.float_or_default(dtype, self.dtype, enforce_float=True) - if axis is None: - result = self._tensor.mean(dtype=torch_dtype) - else: - result = self._tensor.mean(dtype=torch_dtype, dim=axis) - - return result - - - @axis_out_keepdims_wrapper - def sum(self, axis=None, dtype=None, out=None, keepdims=NoValue, - initial=NoValue, where=NoValue): - if initial is not None or where is not None: - raise NotImplementedError - - torch_dtype = _helpers.float_or_default(dtype, self.dtype) - if axis is None: - result = self._tensor.sum(dtype=torch_dtype) - else: - result = self._tensor.sum(dtype=torch_dtype, dim=axis) - - return result - - @axis_out_keepdims_wrapper - def prod(self, axis=None, dtype=None, out=None, keepdims=NoValue, - initial=NoValue, where=NoValue): - if initial is not None or where is not None: - raise NotImplementedError - - axis = _helpers.allow_only_single_axis(axis) - - torch_dtype = _helpers.float_or_default(dtype, self.dtype) - if axis is None: - result = self._tensor.prod(dtype=torch_dtype) - else: - result = self._tensor.prod(dtype=torch_dtype, dim=axis) - - return result - - - @axis_out_keepdims_wrapper - def std(self, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, - where=NoValue): - if where is not None: - raise NotImplementedError - - torch_dtype = _helpers.float_or_default(dtype, self.dtype, enforce_float=True) - tensor = self._tensor.to(torch_dtype) - - result = tensor.std(dim=axis, correction=ddof) - - return result - - @axis_out_keepdims_wrapper - def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, - where=NoValue): - if where is not None: - raise NotImplementedError - - torch_dtype = _helpers.float_or_default(dtype, self.dtype, enforce_float=True) - tensor = self._tensor.to(torch_dtype) + argmin = emulate_out_arg(axis_keepdims_wrapper(_reductions.argmin)) + argmax = emulate_out_arg(axis_keepdims_wrapper(_reductions.argmax)) - result = tensor.var(dim=axis, correction=ddof) + any = emulate_out_arg(axis_keepdims_wrapper(_reductions.any)) + all = emulate_out_arg(axis_keepdims_wrapper(_reductions.all)) + max = emulate_out_arg(axis_keepdims_wrapper(_reductions.max)) + min = emulate_out_arg(axis_keepdims_wrapper(_reductions.min)) - return result + sum = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.sum))) + prod = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.prod))) + mean = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.mean))) + var = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.var))) + std = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.std))) ### indexing ### def __getitem__(self, *args, **kwds): - t_args = _helpers.to_tensors(*args) + t_args = _helpers.ndarrays_to_tensors(*args) return ndarray._from_tensor_and_base(self._tensor.__getitem__(*t_args, **kwds), self) def __setitem__(self, index, value): @@ -504,70 +306,40 @@ def __setitem__(self, index, value): # This is the ideally the only place which talks to ndarray directly. # The rest goes through asarray (preferred) or array. -def array(object, dtype=None, *, copy=True, order='K', subok=False, ndmin=0, +def array(obj, dtype=None, *, copy=True, order='K', subok=False, ndmin=0, like=None): _util.subok_not_ok(like, subok) if order != 'K': raise NotImplementedError # a happy path - if isinstance(object, ndarray): - if copy is False and dtype is None and ndmin <= object.ndim: - return object + if isinstance(obj, ndarray): + if copy is False and dtype is None and ndmin <= obj.ndim: + return obj # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists - if isinstance(object, (list, tuple)): + if isinstance(obj, (list, tuple)): a1 = [] - for elem in object: + for elem in obj: if isinstance(elem, ndarray): a1.append(elem.get().tolist()) else: a1.append(elem) - object = a1 - - # get the tensor from "object" - if isinstance(object, ndarray): - tensor = object._tensor - base = object - elif isinstance(object, torch.Tensor): - tensor = object - base = None - else: - tensor = torch.as_tensor(object) - base = None - - # At this point, `tensor.dtype` is the pytorch default. Our default may - # differ, so need to typecast. However, we cannot just do `tensor.to`, - # because if our desired dtype is wider then pytorch's, `tensor` - # may have lost precision: - - # int(torch.as_tensor(1e12)) - 1e12 equals -4096 (try it!) - - # Therefore, we treat `tensor.dtype` as a hint, and convert the - # original object *again*, this time with an explicit dtype. - dtyp = _dtypes.dtype_from_torch(tensor.dtype) - default = _dtypes.get_default_dtype_for(dtyp) - torch_dtype = _dtypes.torch_dtype_from(default) + obj = a1 - tensor = torch.as_tensor(object, dtype=torch_dtype) + # is obj an ndarray already? + base = None + if isinstance(obj, ndarray): + obj = obj._tensor + base = obj - # type cast if requested + # is a specific dtype requrested? + torch_dtype = None if dtype is not None: torch_dtype = _dtypes.torch_dtype_from(dtype) - tensor = tensor.to(torch_dtype) - base = None - - # adjust ndim if needed - ndim_extra = ndmin - tensor.ndim - if ndim_extra > 0: - tensor = tensor.view((1,)*ndim_extra + tensor.shape) - base = None - - # copy if requested - if copy: - tensor = tensor.clone() base = None + tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin) return ndarray._from_tensor_and_base(tensor, base) @@ -577,7 +349,6 @@ def asarray(a, dtype=None, order=None, *, like=None): return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0) - class asarray_replacer: def __init__(self, dispatch='one'): if dispatch not in ['one', 'two']: @@ -598,13 +369,15 @@ def wrapped(x, *args, **kwds): ###### dtype routines def can_cast(from_, to, casting='safe'): + # XXX: merge with _dtypes.can_cast. The Q is who converts from ndarray, if needed. from_dtype = from_.dtype if isinstance(from_, ndarray) else _dtypes.dtype(from_) to_dtype = to.dtype if isinstance(to, ndarray) else _dtypes.dtype(to) - return _dtypes._can_cast_dict[casting][from_dtype.name][to_dtype.name] + return _dtypes.can_cast(from_dtype, to_dtype, casting) def result_type(*arrays_and_dtypes): + # XXX: clean up dtypes = [] from ._dtypes import issubclass_ @@ -612,7 +385,7 @@ def result_type(*arrays_and_dtypes): for entry in arrays_and_dtypes: if issubclass_(entry, _dtypes._scalar_types.generic): dtypes.append(_dtypes.dtype(entry)) - elif isinstance(entry, _dtypes.dtype): + elif isinstance(entry, _dtypes.DType): dtypes.append(entry) else: dtypes.append(asarray(entry).dtype) @@ -622,7 +395,7 @@ def result_type(*arrays_and_dtypes): return dtyp for curr in dtypes[1:]: - name = _dtypes._result_type_dict[dtyp.name][curr.name] + name = _dtypes._result_type_dict[dtyp.type.torch_dtype][curr.type.torch_dtype] dtyp = _dtypes.dtype(name) return dtyp diff --git a/torch_np/_scalar_types.py b/torch_np/_scalar_types.py deleted file mode 100644 index b97c1921..00000000 --- a/torch_np/_scalar_types.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Replicate the NumPy scalar type hierarchy -""" - -import abc -import torch - -class generic(abc.ABC): - @property - @abc.abstractmethod - def name(self): - return self.__class__.__name__ - - def __new__(self, value): - # - # Yes, a call to np.float32(4) produces a zero-dim array. - # - from . import _dtypes - from . import _ndarray - - torch_dtype = _dtypes.torch_dtype_from(self.name) - - if isinstance(value, str) and value in ['inf', 'nan']: - value = {'inf': torch.inf, 'nan': torch.nan}[value] - - if isinstance(value, _ndarray.ndarray): - tensor = value.get() - else: - try: - tensor = torch.as_tensor(value, dtype=torch_dtype) - except RuntimeError as e: - if "Overflow" in str(e): - raise OverflowError(e.args) - raise e - # - # With numpy: - # >>> a = np.ones(3) - # >>> np.float64(a) is a # True - # >>> np.float64(a[0]) is a[0] # False - # - # A reasonable assumption is that the second case is more common, - # and here we follow the second approach and create a new object - # *for all inputs*. - # - return _ndarray.ndarray._from_tensor_and_base(tensor, None) - - -##### these are abstract types - -class number(generic): - pass - - -class integer(number): - pass - - -class inexact(number): - pass - - -class signedinteger(integer): - pass - - -class unsignedinteger(integer): - pass - - -class floating(inexact): - pass - - -class complexfloating(inexact): - pass - - -# ##### concrete types - -# signed integers - -class int8(signedinteger): - name = 'int8' - - -class int16(signedinteger): - name = 'int16' - - -class int32(signedinteger): - name = 'int32' - - -class int64(signedinteger): - name = 'int64' - - -# unsigned integers - -class uint8(unsignedinteger): - name = 'uint8' - - -# floating point - -class float16(floating): - name = 'float16' - - -class float32(floating): - name = 'float32' - - -class float64(floating): - name = 'float64' - - -class complex64(complexfloating): - name = 'complex64' - - -class complex128(complexfloating): - name = 'complex128' - - -class bool_(generic): - name = 'bool' - - -# name aliases : FIXME (OS, bitness) -intp = int64 -int_ = int64 -intc = int32 - -byte = int8 -short = int16 -longlong = int64 # XXX: is this correct? - -ubyte = uint8 - -half = float16 -single = float32 -double = float64 -float_ = float64 - -csingle = complex64 -cdouble = complex128 - - -_typemap ={ - 'int8' : int8, - 'int16' : int16, - 'int32' : int32, - 'int64' : int64, - 'uint8' : uint8, - 'float16': float16, - 'float32': float32, - 'float64': float64, - 'complex64': complex64, - 'complex128': complex128, - 'bool': bool_ -} - - -# Replicate this -- yet another --- NumPy-defined way of grouping scalar types, -# cf tests/core/test_scalar_methods.py -sctypes = { - 'int': [int8, int16, int32, int64], - 'uint': [uint8,], - 'float': [float16, float32, float64], - 'complex': [complex64, complex128], - 'others': [bool], -} - - -__all__ = list(_typemap.keys()) -__all__.remove('bool') - -__all__ += ['bool_', 'intp', 'int_', 'intc', 'byte', 'short', 'longlong', - 'ubyte', 'half', 'single', 'double', 'csingle', 'cdouble', 'float_'] -__all__ += ['sctypes'] -__all__ += ['generic', 'number', - 'integer', 'signedinteger', 'unsignedinteger', - 'inexact', 'floating', 'complexfloating'] diff --git a/torch_np/_ufunc_impl.py b/torch_np/_ufunc_impl.py deleted file mode 100644 index 894ace34..00000000 --- a/torch_np/_ufunc_impl.py +++ /dev/null @@ -1,156 +0,0 @@ -import torch - -from . import _util -from . import _helpers - -def deco_binary_ufunc(torch_func): - """Common infra for binary ufuncs: receive arrays, sort out type casting, - broadcasting, out array handling etc, and delegate to the - pytorch function for actual work, then wrap the results into an array. - - x1, x2 are arrays! array_like -> array conversion is the caller responsibility. - """ - def wrapped(x1, x2, /, out=None, *, where=True, - casting='same_kind', order='K', dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or not where: - raise NotImplementedError - - # XXX: dtype=... parameter - if dtype is not None: - raise NotImplementedError - - arrays = (x1, x2) - tensors = _helpers.cast_and_broadcast(arrays, out, casting) - - result = torch_func(*tensors) - - return _helpers.result_or_out(result, out) - return wrapped - - -def deco_unary_ufunc(torch_func): - # TODO: deduplicate with `deco_binary_ufunc` above. Need to figure out the - # effect of the `dtype` parameter, does it differ between unary and binary ufuncs. - def wrapped(x1, /, out=None, *, where=True, - casting='same_kind', order='K', dtype=None, subok=False, **kwds): - _util.subok_not_ok(subok=subok) - if order != 'K' or not where: - raise NotImplementedError - - # XXX: dtype=... parameter - if dtype is not None: - raise NotImplementedError - - arrays = (x1, ) - tensors = _helpers.cast_and_broadcast(arrays, out, casting) - - result = torch_func(*tensors) - - return _helpers.result_or_out(result, out) - return wrapped - - - - -# binary ufuncs: the list is autogenerated, cf autogen/gen_ufunc_2.py -# And edited manually! np.equal <--> torch.eq, not torch.equal -add = deco_binary_ufunc(torch.add) -arctan2 = deco_binary_ufunc(torch.arctan2) -bitwise_and = deco_binary_ufunc(torch.bitwise_and) -bitwise_or = deco_binary_ufunc(torch.bitwise_or) -bitwise_xor = deco_binary_ufunc(torch.bitwise_xor) -copysign = deco_binary_ufunc(torch.copysign) -divide = deco_binary_ufunc(torch.divide) -equal = deco_binary_ufunc(torch.eq) -float_power = deco_binary_ufunc(torch.float_power) -floor_divide = deco_binary_ufunc(torch.floor_divide) -fmax = deco_binary_ufunc(torch.fmax) -fmin = deco_binary_ufunc(torch.fmin) -fmod = deco_binary_ufunc(torch.fmod) -gcd = deco_binary_ufunc(torch.gcd) -greater = deco_binary_ufunc(torch.greater) -greater_equal = deco_binary_ufunc(torch.greater_equal) -heaviside = deco_binary_ufunc(torch.heaviside) -hypot = deco_binary_ufunc(torch.hypot) -lcm = deco_binary_ufunc(torch.lcm) -ldexp = deco_binary_ufunc(torch.ldexp) -left_shift = deco_binary_ufunc(torch.bitwise_left_shift) -less = deco_binary_ufunc(torch.less) -less_equal = deco_binary_ufunc(torch.less_equal) -logaddexp = deco_binary_ufunc(torch.logaddexp) -logaddexp2 = deco_binary_ufunc(torch.logaddexp2) -logical_and = deco_binary_ufunc(torch.logical_and) -logical_or = deco_binary_ufunc(torch.logical_or) -logical_xor = deco_binary_ufunc(torch.logical_xor) -matmul = deco_binary_ufunc(torch.matmul) -maximum = deco_binary_ufunc(torch.maximum) -minimum = deco_binary_ufunc(torch.minimum) -remainder = deco_binary_ufunc(torch.remainder) -multiply = deco_binary_ufunc(torch.multiply) -nextafter = deco_binary_ufunc(torch.nextafter) -not_equal = deco_binary_ufunc(torch.not_equal) -power = deco_binary_ufunc(torch.pow) -remainder = deco_binary_ufunc(torch.remainder) -right_shift = deco_binary_ufunc(torch.bitwise_right_shift) -subtract = deco_binary_ufunc(torch.subtract) -divide = deco_binary_ufunc(torch.divide) - - - -# unary ufuncs: the list is autogenerated, cf autogen/gen_ufunc_2.py -absolute = deco_unary_ufunc(torch.absolute) -#absolute = deco_unary_ufunc(torch.absolute) -arccos = deco_unary_ufunc(torch.arccos) -arccosh = deco_unary_ufunc(torch.arccosh) -arcsin = deco_unary_ufunc(torch.arcsin) -arcsinh = deco_unary_ufunc(torch.arcsinh) -arctan = deco_unary_ufunc(torch.arctan) -arctanh = deco_unary_ufunc(torch.arctanh) -ceil = deco_unary_ufunc(torch.ceil) -conjugate = deco_unary_ufunc(torch.conj_physical) -#conjugate = deco_unary_ufunc(torch.conj_physical) -cos = deco_unary_ufunc(torch.cos) -cosh = deco_unary_ufunc(torch.cosh) -deg2rad = deco_unary_ufunc(torch.deg2rad) -degrees = deco_unary_ufunc(torch.rad2deg) -exp = deco_unary_ufunc(torch.exp) -exp2 = deco_unary_ufunc(torch.exp2) -expm1 = deco_unary_ufunc(torch.expm1) -fabs = deco_unary_ufunc(torch.absolute) -floor = deco_unary_ufunc(torch.floor) -isfinite = deco_unary_ufunc(torch.isfinite) -isinf = deco_unary_ufunc(torch.isinf) -isnan = deco_unary_ufunc(torch.isnan) -log = deco_unary_ufunc(torch.log) -log10 = deco_unary_ufunc(torch.log10) -log1p = deco_unary_ufunc(torch.log1p) -log2 = deco_unary_ufunc(torch.log2) -logical_not = deco_unary_ufunc(torch.logical_not) -negative = deco_unary_ufunc(torch.negative) -rad2deg = deco_unary_ufunc(torch.rad2deg) -radians = deco_unary_ufunc(torch.deg2rad) -reciprocal = deco_unary_ufunc(torch.reciprocal) -rint = deco_unary_ufunc(torch.round) -sign = deco_unary_ufunc(torch.sign) -signbit = deco_unary_ufunc(torch.signbit) -sin = deco_unary_ufunc(torch.sin) -sinh = deco_unary_ufunc(torch.sinh) -sqrt = deco_unary_ufunc(torch.sqrt) -square = deco_unary_ufunc(torch.square) -tan = deco_unary_ufunc(torch.tan) -tanh = deco_unary_ufunc(torch.tanh) -trunc = deco_unary_ufunc(torch.trunc) - -invert = deco_unary_ufunc(torch.bitwise_not) - -# special cases: torch does not export these names -def _cbrt(x): - return torch.pow(x, 1/3) - -def _positive(x): - return +x - -cbrt = deco_unary_ufunc(_cbrt) -positive = deco_unary_ufunc(_positive) - diff --git a/torch_np/_unary_ufuncs.py b/torch_np/_unary_ufuncs.py index 657e03ff..d902dea6 100644 --- a/torch_np/_unary_ufuncs.py +++ b/torch_np/_unary_ufuncs.py @@ -1,24 +1,7 @@ -import functools -import torch +from ._decorators import deco_unary_ufunc_from_impl +from ._detail import _ufunc_impl -from . import _util - -from ._ndarray import asarray -from . import _dtypes -from . import _helpers - -from . import _ufunc_impl - -__all__ = ['abs', 'absolute', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctanh', 'asarray', 'cbrt', 'ceil', 'conj', 'conjugate', 'cos', 'cosh', 'deg2rad', 'degrees', 'exp', 'exp2', 'expm1', 'fabs', 'floor', 'isfinite', 'isinf', 'isnan', 'log', 'log10', 'log1p', 'log2', 'logical_not', 'negative', 'positive', 'rad2deg', 'radians', 'reciprocal', 'rint', 'sign', 'signbit', 'sin', 'sinh', 'sqrt', 'square', 'tan', 'tanh', 'trunc', 'invert'] - - - -def deco_unary_ufunc_from_impl(impl_func): - @functools.wraps(impl_func) - def wrapped(x1, *args, **kwds): - x1_array = asarray(x1) - return impl_func(x1_array, *args, **kwds) - return wrapped +__all__ = ['abs', 'absolute', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctanh', 'cbrt', 'ceil', 'conj', 'conjugate', 'cos', 'cosh', 'deg2rad', 'degrees', 'exp', 'exp2', 'expm1', 'fabs', 'floor', 'isfinite', 'isinf', 'isnan', 'log', 'log10', 'log1p', 'log2', 'logical_not', 'negative', 'positive', 'rad2deg', 'radians', 'reciprocal', 'rint', 'sign', 'signbit', 'sin', 'sinh', 'sqrt', 'square', 'tan', 'tanh', 'trunc', 'invert'] absolute = deco_unary_ufunc_from_impl(_ufunc_impl.absolute) diff --git a/torch_np/_util.py b/torch_np/_util.py deleted file mode 100644 index e96e491a..00000000 --- a/torch_np/_util.py +++ /dev/null @@ -1,100 +0,0 @@ -import operator - - -# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504 -def is_sequence(seq): - if isinstance(seq, str): - return False - try: - len(seq) - except Exception: - return False - return True - - -def subok_not_ok(like=None, subok=False): - if like is not None: - raise ValueError("like=... parameter is not supported.") - if subok: - raise ValueError("subok parameter is not supported.") - - -class AxisError(ValueError, IndexError): - pass - - -class UFuncTypeError(TypeError, RuntimeError): - pass - - -# a replica of the version in ./numpy/numpy/core/src/multiarray/common.h -def normalize_axis_index(ax, ndim, argname=None): - if ax < -ndim or ax >= ndim: - raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}") - if ax < 0: - ax += ndim - return ax - - -# from https://github.com/numpy/numpy/blob/main/numpy/core/numeric.py#L1378 -def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): - """ - Normalizes an axis argument into a tuple of non-negative integer axes. - This handles shorthands such as ``1`` and converts them to ``(1,)``, - as well as performing the handling of negative indices covered by - `normalize_axis_index`. - By default, this forbids axes from being specified multiple times. - Used internally by multi-axis-checking logic. - .. versionadded:: 1.13.0 - Parameters - ---------- - axis : int, iterable of int - The un-normalized index or indices of the axis. - ndim : int - The number of dimensions of the array that `axis` should be normalized - against. - argname : str, optional - A prefix to put before the error message, typically the name of the - argument. - allow_duplicate : bool, optional - If False, the default, disallow an axis from being specified twice. - Returns - ------- - normalized_axes : tuple of int - The normalized axis index, such that `0 <= normalized_axis < ndim` - Raises - ------ - AxisError - If any axis provided is out of range - ValueError - If an axis is repeated - See also - -------- - normalize_axis_index : normalizing a single scalar axis - """ - # Optimization to speed-up the most common cases. - if type(axis) not in (tuple, list): - try: - axis = [operator.index(axis)] - except TypeError: - pass - # Going via an iterator directly is slower than via list comprehension. - axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis]) - if not allow_duplicate and len(set(axis)) != len(axis): - if argname: - raise ValueError('repeated axis in `{}` argument'.format(argname)) - else: - raise ValueError('repeated axis') - return axis - - -def expand_shape(arr_shape, axis): - # taken from numpy 1.23.x, expand_dims function - if type(axis) not in (list, tuple): - axis = (axis,) - out_ndim = len(axis) + len(arr_shape) - axis = normalize_axis_tuple(axis, out_ndim) - shape_it = iter(arr_shape) - shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)] - return shape - diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 045a3089..a8910ee9 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -7,13 +7,17 @@ import torch -from . import _util +from ._detail import _util +from ._detail import _reductions + from . import _dtypes from . import _helpers from ._ndarray import ndarray, asarray, array, asarray_replacer, newaxis from ._ndarray import can_cast, result_type + + # Things to decide on (punt for now) # # 1. Q: What are the return types of wrapper functions: plain torch.Tensors or @@ -194,7 +198,7 @@ def arange(start=None, stop=None, step=1, dtype=None, *, like=None): dtype = _dtypes.default_int_type() dtype = result_type(start, stop, step, dtype) torch_dtype = _dtypes.torch_dtype_from(dtype) - start, stop, step = _helpers.to_tensors(start, stop, step) + start, stop, step = _helpers.ndarrays_to_tensors(start, stop, step) try: return asarray(torch.arange(start, stop, step, dtype=torch_dtype)) @@ -217,7 +221,7 @@ def empty_like(prototype, dtype=None, order='K', subok=False, shape=None): _util.subok_not_ok(subok=subok) if order != 'K': raise NotImplementedError - torch_dtype = _dtypes.torch_dtype_from(dtype) + torch_dtype = None if dtype is None else _dtypes.torch_dtype_from(dtype) result = torch.empty_like(prototype, dtype=torch_dtype) if shape is not None: result = result.reshape(shape) @@ -239,7 +243,7 @@ def full_like(a, fill_value, dtype=None, order='K', subok=False, shape=None): _util.subok_not_ok(subok=subok) if order != 'K': raise NotImplementedError - torch_dtype = _dtypes.torch_dtype_from(dtype) + torch_dtype = None if dtype is None else _dtypes.torch_dtype_from(dtype) result = torch.full_like(a, fill_value, dtype=torch_dtype) if shape is not None: result = result.reshape(shape) @@ -259,7 +263,7 @@ def ones_like(a, dtype=None, order='K', subok=False, shape=None): _util.subok_not_ok(subok=subok) if order != 'K': raise NotImplementedError - torch_dtype = _dtypes.torch_dtype_from(dtype) + torch_dtype = None if dtype is None else _dtypes.torch_dtype_from(dtype) result = torch.ones_like(a, dtype=torch_dtype) if shape is not None: result = result.reshape(shape) @@ -280,7 +284,7 @@ def zeros_like(a, dtype=None, order='K', subok=False, shape=None): _util.subok_not_ok(subok=subok) if order != 'K': raise NotImplementedError - torch_dtype = _dtypes.torch_dtype_from(dtype) + torch_dtype = None if dtype is None else _dtypes.torch_dtype_from(dtype) result = torch.zeros_like(a, dtype=torch_dtype) if shape is not None: result = result.reshape(shape) @@ -326,32 +330,33 @@ def corrcoef(x, y=None, rowvar=True, bias=NoValue, ddof=NoValue, *, dtype=None): def concatenate(ar_tuple, axis=0, out=None, dtype=None, casting="same_kind"): - if out is not None: - if dtype is not None: - # mimic numpy - raise TypeError("concatenate() only takes `out` or `dtype` as an " - "argument, but both were provided.") - if not isinstance(out, ndarray): - raise ValueError("'out' must be an array") if ar_tuple == (): # XXX: RuntimeError in torch, ValueError in numpy raise ValueError("need at least one array to concatenate") - # make sure inputs are arrays - arrays = tuple(asarray(ar) for ar in ar_tuple) + tensors = _helpers.to_tensors(*ar_tuple) # np.concatenate ravels if axis=None - arrays, axis = _helpers.axis_none_ravel(*arrays, axis=axis) + tensors, axis = _util.axis_none_ravel(*tensors, axis=axis) + + if out is not None: + if not isinstance(out, ndarray): + raise ValueError("'out' must be an array") + + if dtype is not None: + # mimic numpy + raise TypeError("concatenate() only takes `out` or `dtype` as an " + "argument, but both were provided.") # figure out the type of the inputs and outputs if out is None and dtype is None: out_dtype = None - tensors = tuple(ar.get() for ar in arrays) else: - out_dtype = _dtypes.dtype(dtype) if dtype is not None else out.dtype + out_dtype = out.dtype if dtype is None else _dtypes.dtype(dtype) + out_dtype = out_dtype.type.torch_dtype # cast input arrays if necessary; do not broadcast them agains `out` - tensors = _helpers.cast_dont_broadcast(arrays, out_dtype, casting) + tensors = _util.cast_dont_broadcast(tensors, out_dtype, casting) try: result = torch.cat(tensors, axis) @@ -496,21 +501,9 @@ def argwhere(a): return asarray(torch.argwhere(tensor)) -def abs(a): - # FIXME: should go the other way, together with other ufuncs - arr = asarray(a) - return a.__abs__() - -from ._ndarray import axis_out_keepdims_wrapper - -@axis_out_keepdims_wrapper -def count_nonzero(a, axis=None, *, keepdims=False): - # XXX: this all should probably be generalized to a sum(a != 0, dtype=bool) - try: - tensor = a.get().count_nonzero(axis) - except RuntimeError: - raise ValueError - return tensor +from ._ndarray import axis_keepdims_wrapper +from ._decorators import emulate_out_arg +count_nonzero = emulate_out_arg(axis_keepdims_wrapper(_reductions.count_nonzero)) @asarray_replacer() @@ -577,9 +570,6 @@ def tri(N, M=None, k=0, dtype=float, *, like=None): return asarray(tensor) ###### reductions - -# YYY: pattern : argmax, argmin - def argmax(a, axis=None, out=None, *, keepdims=NoValue): arr = asarray(a) return arr.argmax(axis=axis, out=out, keepdims=keepdims) @@ -618,8 +608,6 @@ def any(a, axis=None, out=None, keepdims=NoValue, *, where=NoValue): return arr.any(axis=axis, out=out, keepdims=keepdims, where=where) -# YYY: pattern: dtype kwarg, None not accepted - def mean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoValue): arr = asarray(a) return arr.mean(axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where) @@ -647,6 +635,7 @@ def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=N arr = asarray(a) return arr.std(axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims, where=where) + def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=NoValue): arr = asarray(a) return arr.var(axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims, where=where) @@ -654,7 +643,7 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=N @asarray_replacer() def nanmean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoValue): - if where is not None: + if where is not NoValue: raise NotImplementedError if dtype is None: dtype = a.dtype @@ -758,8 +747,11 @@ def isscalar(a): def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): - a = asarray(a).get() - b = asarray(a).get() + a, b = _helpers.to_tensors(a, b) + dtype = result_type(a, b) + torch_dtype = dtype.type.torch_dtype + a = a.to(torch_dtype) + b = b.to(torch_dtype) return asarray(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) ###### mapping from numpy API objects to wrappers from this module ###### diff --git a/torch_np/testing/utils.py b/torch_np/testing/utils.py index 5525e8e8..ec837ba8 100644 --- a/torch_np/testing/utils.py +++ b/torch_np/testing/utils.py @@ -186,7 +186,7 @@ def assert_equal(actual, desired, err_msg='', verbose=True): return assert_array_equal(actual, desired, err_msg, verbose) msg = build_err_msg([actual, desired], err_msg, verbose=verbose) - if isinstance(actual, np.dtype) and isinstance(desired, np.dtype): + if isinstance(actual, np.DType) and isinstance(desired, np.DType): return actual == desired # Handle complex numbers: separate into real/imag to handle diff --git a/torch_np/tests/test_ndarray_methods.py b/torch_np/tests/test_ndarray_methods.py index 2f039844..04b26237 100644 --- a/torch_np/tests/test_ndarray_methods.py +++ b/torch_np/tests/test_ndarray_methods.py @@ -279,14 +279,18 @@ def test_output_shape(self, method): # Check some simple shape mismatches out = np.ones(11, dtype=np.int_) - assert_raises(ValueError, arg_method, -1, out) + + with assert_raises(ValueError): + arg_method(-1, out=out) out = np.ones((2, 5), dtype=np.int_) - assert_raises(ValueError, arg_method, -1, out) + with assert_raises(ValueError): + arg_method(-1, out=out) # these could be relaxed possibly (used to allow even the previous) out = np.ones((1, 10), dtype=np.int_) - assert_raises(ValueError, arg_method, -1, out) + with assert_raises(ValueError): + arg_method(-1, out=out) out = np.ones(10, dtype=np.int_) arg_method(-1, out=out) @@ -311,12 +315,6 @@ def test_np_vs_ndarray(self, arr_method, np_method): a = np.arange(6).reshape((2, 3)) arg_method = getattr(a, arr_method) - # check positional args - out1 = np.zeros(2, dtype=int) - out2 = np.zeros(2, dtype=int) - assert_equal(arg_method(1, out1), np_method(a, 1, out2)) - assert_equal(out1, out2) - # check keyword args out1 = np.zeros(3, dtype=int) out2 = np.zeros(3, dtype=int) @@ -324,6 +322,21 @@ def test_np_vs_ndarray(self, arr_method, np_method): np_method(a, out=out2, axis=0)) assert_equal(out1, out2) + @pytest.mark.xfail(reason="out=... as a positional arg") + @pytest.mark.parametrize('arr_method, np_method', + [('argmax', np.argmax), + ('argmin', np.argmin)]) + def test_np_vs_ndarray_positional(self, arr_method, np_method): + a = np.arange(6).reshape((2, 3)) + arg_method = getattr(a, arr_method) + + # check positional args + out1 = np.zeros(2, dtype=int) + out2 = np.zeros(2, dtype=int) + assert_equal(arg_method(1, out1), np_method(a, 1, out2)) + assert_equal(out1, out2) + + class TestArgmax: usg_data = [ diff --git a/torch_np/tests/test_reductions.py b/torch_np/tests/test_reductions.py index b996f0d5..14f870e4 100644 --- a/torch_np/tests/test_reductions.py +++ b/torch_np/tests/test_reductions.py @@ -5,7 +5,7 @@ from torch_np.testing import (assert_equal, assert_array_equal, assert_allclose, assert_almost_equal) -import torch_np._util as _util +import torch_np._detail._util as _util class TestNonzeroAndCountNonzero: