diff --git a/torch_np/_dtypes_impl.py b/torch_np/_dtypes_impl.py index f21b83be..fb53f721 100644 --- a/torch_np/_dtypes_impl.py +++ b/torch_np/_dtypes_impl.py @@ -49,3 +49,90 @@ def result_type_impl(*tensors): dtyp = _cd._result_type_dict[dtyp][curr.dtype] return dtyp + + +# ### NEP 50 helpers ### + +SCALAR_TYPES = {int, bool, float, complex} + + +def _dtype_for_scalar(py_type): + return { + bool: torch.bool, + int: torch.int64, + float: torch.float64, + complex: torch.complex128, + }[py_type] + + +def _category(dtype): + return { + torch.bool: 0, + # int + torch.uint8: 1, + torch.int8: 1, + torch.int16: 1, + torch.int32: 1, + torch.int64: 1, + # float + torch.float16: 2, + torch.float32: 2, + torch.float64: 2, + # complex + torch.complex64: 3, + torch.complex128: 3, + }[dtype] + + +def nep50_to_tensors(x1, x2, handle_weaks): + """If either of inputs is a python scalar, type-promote with NEP 50.""" + + def to_tensor(scalar, dtype=None): + if dtype is None: + dtype = _dtype_for_scalar(type(scalar)) + dtype = get_default_dtype_for(dtype) + return torch.as_tensor(scalar, dtype=dtype) + + x1_is_weak = not isinstance(x1, torch.Tensor) + x2_is_weak = not isinstance(x2, torch.Tensor) + if not handle_weaks or (x1_is_weak and x2_is_weak): + x1 = to_tensor(x1) if x1_is_weak else x1 + x2 = to_tensor(x2) if x2_is_weak else x2 + return x1, x2 + + # scalar tensor: NEP 50 + assert x1_is_weak != x2_is_weak + + weak, not_weak = (x1, x2) if x1_is_weak else (x2, x1) + + # find the dtype for the weak's type + weak_dtype = _dtype_for_scalar(type(weak)) + + cat_weak = _category(weak_dtype) + cat_not_weak = _category(not_weak.dtype) + + dt = not_weak.dtype if cat_weak <= cat_not_weak else None + + # special-case complex + float32 + if weak_dtype.is_complex and not_weak.dtype == torch.float32: + dt = torch.complex64 + + # detect overflows: in PyTorch, uint8(-1) wraps around to 255, + # while NEP50 mandates an exception. + # + # Note that we only check if each element of the binop overflows, + # not the result. Consider, e.g. `uint8(100) + 200`. Operands are OK + # in uint8, but the result overflows and wrap around 255. + # Numpy emits a RuntimeWarning, PyTorch does not, and we do not either. + if cat_weak == 1 and cat_not_weak == 1: + # integers + iinfo = torch.iinfo(not_weak.dtype) + if not (iinfo.min <= weak <= iinfo.max): + raise OverflowError( + f"Python integer {weak} out of bounds for {not_weak.dtype}" + ) + + # finally, can make `weak` into a 0D tensor + weak = to_tensor(weak, dt) + + return (weak, not_weak) if x1_is_weak else (not_weak, weak) diff --git a/torch_np/_normalizations.py b/torch_np/_normalizations.py index 9ab8932b..a52176ad 100644 --- a/torch_np/_normalizations.py +++ b/torch_np/_normalizations.py @@ -9,9 +9,12 @@ import torch -from . import _dtypes, _util +from . import _dtypes, _dtypes_impl, _util ArrayLike = typing.TypeVar("ArrayLike") +Scalar = typing.Union[int, float, complex, bool] +ArrayLikeOrScalar = typing.Union[ArrayLike, Scalar] + DTypeLike = typing.TypeVar("DTypeLike") AxisLike = typing.TypeVar("AxisLike") NDArray = typing.TypeVar("NDarray") @@ -43,6 +46,12 @@ def normalize_array_like(x, parm=None): return asarray(x).tensor +def normalize_array_like_or_scalar(x, parm=None): + if type(x) in _dtypes_impl.SCALAR_TYPES: + return x + return normalize_array_like(x, parm) + + def normalize_optional_array_like(x, parm=None): # This explicit normalizer is needed because otherwise normalize_array_like # does not run for a parameter annotated as Optional[ArrayLike] @@ -109,6 +118,7 @@ def normalize_casting(arg, parm=None): normalizers = { "ArrayLike": normalize_array_like, + "Union[ArrayLike, Scalar]": normalize_array_like_or_scalar, "Optional[ArrayLike]": normalize_optional_array_like, "Sequence[ArrayLike]": normalize_seq_array_like, "Optional[NDArray]": normalize_ndarray, diff --git a/torch_np/_ufuncs.py b/torch_np/_ufuncs.py index 8021d64a..e6c983a7 100644 --- a/torch_np/_ufuncs.py +++ b/torch_np/_ufuncs.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional +from typing import Optional, Union import torch @@ -11,19 +11,11 @@ DTypeLike, NotImplementedType, OutArray, + Scalar, normalizer, ) -def _ufunc_preprocess(tensors, where, casting, order, dtype, subok, signature, extobj): - if dtype is None: - dtype = _dtypes_impl.result_type_impl(*tensors) - - tensors = _util.typecast_tensors(tensors, dtype, casting) - - return tensors - - def _ufunc_postprocess(result, out, casting): if out is not None: result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting) @@ -40,6 +32,36 @@ def _ufunc_postprocess(result, out, casting): ] +NEP50_FUNCS = ( + "add", + "subtract", + "multiply", + "floor_divide", + "true_divide", + "divide", + "remainder", + "bitwise_and", + "bitwise_or", + "bitwise_xor", + "bitwise_left_shift", + "bitwise_right_shift", + "hypot", + "arctan2", + "logaddexp", + "logaddexp2", + "heaviside", + "copysign", + "fmax", + "minimum", + "fmin", + "maximum", + "fmod", + "gcd", + "lcm", + "pow", +) + + def deco_binary_ufunc(torch_func): """Common infra for binary ufuncs. @@ -49,8 +71,8 @@ def deco_binary_ufunc(torch_func): @normalizer def wrapped( - x1: ArrayLike, - x2: ArrayLike, + x1: Union[ArrayLike, Scalar], + x2: Union[ArrayLike, Scalar], /, out: Optional[OutArray] = None, *, @@ -62,13 +84,28 @@ def wrapped( signature=None, extobj=None, ): - tensors = _ufunc_preprocess( - (x1, x2), where, casting, order, dtype, subok, signature, extobj - ) - result = torch_func(*tensors) - result = _ufunc_postprocess(result, out, casting) - return result + if dtype is not None: + + def cast(x, dtype): + if isinstance(x, torch.Tensor): + return _util.typecast_tensors((x,), dtype, casting)[0] + else: + return torch.as_tensor(x, dtype=dtype) + + x1 = cast(x1, dtype) + x2 = cast(x2, dtype) + elif isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor): + dtype = _dtypes_impl.result_type_impl(x1, x2) + x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) + else: + x1, x2 = _dtypes_impl.nep50_to_tensors( + x1, x2, torch_func.__name__ in NEP50_FUNCS + ) + + result = torch_func(x1, x2) + + return _ufunc_postprocess(result, out, casting) wrapped.__qualname__ = torch_func.__name__ wrapped.__name__ = torch_func.__name__ @@ -80,6 +117,7 @@ def wrapped( # matmul's signature is _slightly_ different from other ufuncs: # - no where=... # - additional axis=..., axes=... +# - no NEP50 scalars in or out # @normalizer def matmul( @@ -97,10 +135,12 @@ def matmul( axes: NotImplementedType = None, axis: NotImplementedType = None, ): - tensors = _ufunc_preprocess( - (x1, x2), True, casting, order, dtype, subok, signature, extobj - ) - result = _binary_ufuncs_impl.matmul(*tensors) + + if dtype is None: + dtype = _dtypes_impl.result_type_impl(x1, x2) + x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) + + result = _binary_ufuncs_impl.matmul(x1, x2) result = _ufunc_postprocess(result, out, casting) return result @@ -140,11 +180,11 @@ def divmod( else: out1, out2 = out - tensors = _ufunc_preprocess( - (x1, x2), True, casting, order, dtype, subok, signature, extobj - ) + if dtype is None: + dtype = _dtypes_impl.result_type_impl(x1, x2) + x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) - quot, rem = _binary_ufuncs_impl.divmod(*tensors) + quot, rem = _binary_ufuncs_impl.divmod(x1, x2) quot = _ufunc_postprocess(quot, out1, casting) rem = _ufunc_postprocess(rem, out2, casting) diff --git a/torch_np/_util.py b/torch_np/_util.py index fc7651dd..d3154d55 100644 --- a/torch_np/_util.py +++ b/torch_np/_util.py @@ -182,6 +182,10 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0): Coerce to this torch dtype copy : bool Copy or not + ndmin : int + The results as least this many dimensions + is_weak : bool + Whether obj is a weakly typed python scalar. Returns ------- @@ -198,14 +202,11 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0): tensor = obj else: tensor = torch.as_tensor(obj) - base = None - # At this point, `tensor.dtype` is the pytorch default. Our default may - # differ, so need to typecast. However, we cannot just do `tensor.to`, - # because if our desired dtype is wider then pytorch's, `tensor` - # may have lost precision: - - # int(torch.as_tensor(1e12)) - 1e12 equals -4096 (try it!) + # tensor.dtype is the pytorch default, typically float32. If obj's elements + # are not exactly representable in float32, we've lost precision: + # >>> torch.as_tensor(1e12).item() - 1e12 + # -4096.0 # Therefore, we treat `tensor.dtype` as a hint, and convert the # original object *again*, this time with an explicit dtype. diff --git a/torch_np/tests/numpy_tests/core/test_scalarmath.py b/torch_np/tests/numpy_tests/core/test_scalarmath.py index 13b10405..7c06e1ca 100644 --- a/torch_np/tests/numpy_tests/core/test_scalarmath.py +++ b/torch_np/tests/numpy_tests/core/test_scalarmath.py @@ -481,6 +481,21 @@ def test_numpy_scalar_relational_operators(self): assert_(not np.array(1, dtype=dt1)[()] < np.array(0, dtype=dt2)[()], "type %s and %s failed" % (dt1, dt2)) + #Signed integers and floats + for dt1 in 'bhl' + np.typecodes['Float']: + assert_(1 > np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,)) + assert_(not 1 < np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,)) + assert_(-1 == np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,)) + + for dt2 in 'bhl' + np.typecodes['Float']: + assert_(np.array(1, dtype=dt1)[()] > np.array(-1, dtype=dt2)[()], + "type %s and %s failed" % (dt1, dt2)) + assert_(not np.array(1, dtype=dt1)[()] < np.array(-1, dtype=dt2)[()], + "type %s and %s failed" % (dt1, dt2)) + assert_(np.array(-1, dtype=dt1)[()] == np.array(-1, dtype=dt2)[()], + "type %s and %s failed" % (dt1, dt2)) + + def test_numpy_scalar_relational_operators_2(self): #Unsigned integers for dt1 in 'B': assert_(-1 < np.array(1, dtype=dt1)[()], "type %s failed" % (dt1,)) @@ -496,19 +511,6 @@ def test_numpy_scalar_relational_operators(self): assert_(np.array(1, dtype=dt1)[()] != np.array(-1, dtype=dt2)[()], "type %s and %s failed" % (dt1, dt2)) - #Signed integers and floats - for dt1 in 'bhl' + np.typecodes['Float']: - assert_(1 > np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,)) - assert_(not 1 < np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,)) - assert_(-1 == np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,)) - - for dt2 in 'bhl' + np.typecodes['Float']: - assert_(np.array(1, dtype=dt1)[()] > np.array(-1, dtype=dt2)[()], - "type %s and %s failed" % (dt1, dt2)) - assert_(not np.array(1, dtype=dt1)[()] < np.array(-1, dtype=dt2)[()], - "type %s and %s failed" % (dt1, dt2)) - assert_(np.array(-1, dtype=dt1)[()] == np.array(-1, dtype=dt2)[()], - "type %s and %s failed" % (dt1, dt2)) def test_scalar_comparison_to_none(self): # Scalars should just return False and not give a warnings. diff --git a/torch_np/tests/numpy_tests/fft/test_pocketfft.py b/torch_np/tests/numpy_tests/fft/test_pocketfft.py index c3534eb2..6ca7f3ab 100644 --- a/torch_np/tests/numpy_tests/fft/test_pocketfft.py +++ b/torch_np/tests/numpy_tests/fft/test_pocketfft.py @@ -44,8 +44,8 @@ def test_fft(self): np.random.seed(1234) x = random(30) + 1j*random(30) - assert_allclose(fft1(x), np.fft.fft(x), atol=2e-5) - assert_allclose(fft1(x), np.fft.fft(x, norm="backward"), atol=2e-5) + assert_allclose(fft1(x), np.fft.fft(x), atol=3e-5) + assert_allclose(fft1(x), np.fft.fft(x, norm="backward"), atol=3e-5) assert_allclose(fft1(x) / np.sqrt(30), np.fft.fft(x, norm="ortho"), atol=5e-6) assert_allclose(fft1(x) / 30., diff --git a/torch_np/tests/test_nep50_examples.py b/torch_np/tests/test_nep50_examples.py index 37eb2adc..b87f3d64 100644 --- a/torch_np/tests/test_nep50_examples.py +++ b/torch_np/tests/test_nep50_examples.py @@ -1,6 +1,27 @@ """Test examples for NEP 50.""" -from torch_np import array, float32, float64, inf, int64, uint8 +import itertools + +try: + import numpy as _np + + HAVE_NUMPY = True +except ImportError: + HAVE_NUMPY = False + +import torch_np as tnp +from torch_np import ( + array, + bool_, + complex64, + complex128, + float32, + float64, + inf, + int16, + int64, + uint8, +) from torch_np.testing import assert_allclose uint16 = uint8 # can be anything here, see below @@ -36,24 +57,24 @@ "uint8(1) + 300": (int64(301), Exception), "uint8(100) + 200": (int64(301), uint8(44)), # and RuntimeWarning "float32(1) + 3e100": (float64(3e100), float32(inf)), # and RuntimeWarning [T7] - # "array([0.1], float32) == 0.1": (array([False]), unchanged), # XXX: a typo in NEP50? + "array([0.1], float32) == 0.1": ( + array([False]), + unchanged, + ), # XXX: a typo in NEP50? "array([0.1], float32) == float64(0.1)": (array([True]), array([False])), "array([1.], float32) + 3": (array([4.0], float32), unchanged), "array([1.], float32) + int64(3)": (array([4.0], float32), array([4.0], float64)), + # additional examples from the NEP text + "int16(2) + 2": (int64(4), int16(4)), + "int16(4) + 4j": (complex128(4 + 4j), unchanged), + "float32(5) + 5j": (complex128(5 + 5j), complex64(5 + 5j)), + "bool_(True) + 1": (int64(2), unchanged), + "True + uint8(2)": (uint8(3), unchanged), } fails = [ - "uint8(1) + 2", - "array([1], uint8) + 1", - "array([1], uint8) + 200", - "array([1], uint8) + array(1, int64)", - "array([100], uint8) + 200", - "array([1], uint8) + 300", - "uint8(1) + 300", - "uint8(100) + 200", - "float32(1) + 3e100", - "array([1.], float32) + 3", + "array([0.1], float32) == 0.1", # TODO: fix the example ] @@ -77,3 +98,114 @@ def test_nep50_exceptions(example): assert_allclose(result, new, atol=1e-16) assert result.dtype == new.dtype + + +# ### Directly compare to numpy ### + +weaks = (True, 1, 2.0, 3j) +non_weaks = ( + tnp.asarray(True), + tnp.uint8(1), + tnp.int8(1), + tnp.int32(1), + tnp.int64(1), + tnp.float32(1), + tnp.float64(1), + tnp.complex64(1), + tnp.complex128(1), +) +if HAVE_NUMPY: + dtypes = ( + None, + _np.bool_, + _np.uint8, + _np.int8, + _np.int32, + _np.int64, + _np.float32, + _np.float64, + _np.complex64, + _np.complex128, + ) +else: + dtypes = (None,) + + +@pytest.mark.skipif(not HAVE_NUMPY, reason="NumPy not found") +@pytest.mark.parametrize( + "scalar, array, dtype", itertools.product(weaks, non_weaks, dtypes) +) +def test_direct_compare(scalar, array, dtype): + # compare to NumPy w/ NEP 50. + try: + state = _np._get_promotion_state() + _np._set_promotion_state("weak") + + if dtype is not None: + kwargs = {"dtype": dtype} + try: + result_numpy = _np.add(scalar, array.tensor.numpy(), **kwargs) + except Exception: + return + + kwargs = {} + if dtype is not None: + kwargs = {"dtype": getattr(tnp, dtype.__name__)} + result = tnp.add(scalar, array, **kwargs).tensor.numpy() + assert result.dtype == result_numpy.dtype + assert result == result_numpy + + finally: + _np._set_promotion_state(state) + + +# ufunc name: [array.dtype] +corners = { + "true_divide": ["bool_", "uint8", "int8", "int16", "int32", "int64"], + "divide": ["bool_", "uint8", "int8", "int16", "int32", "int64"], + "arctan2": ["bool_", "uint8", "int8", "int16", "int32", "int64"], + "copysign": ["bool_", "uint8", "int8", "int16", "int32", "int64"], + "heaviside": ["bool_", "uint8", "int8", "int16", "int32", "int64"], + "ldexp": ["bool_", "uint8", "int8", "int16", "int32", "int64"], + "power": ["uint8"], + "nextafter": ["float32"], +} + + +@pytest.mark.skipif(not HAVE_NUMPY, reason="NumPy not found") +@pytest.mark.parametrize("name", tnp._ufuncs._binary) +@pytest.mark.parametrize("scalar, array", itertools.product(weaks, non_weaks)) +def test_compare_ufuncs(name, scalar, array): + + if name in corners and ( + array.dtype.name in corners[name] + or tnp.asarray(scalar).dtype.name in corners[name] + ): + return pytest.skip(f"{name}(..., dtype=array.dtype)") + + try: + state = _np._get_promotion_state() + _np._set_promotion_state("weak") + + if name in ["matmul", "modf", "divmod"]: + return + ufunc = getattr(tnp, name) + ufunc_numpy = getattr(_np, name) + + try: + result = ufunc(scalar, array) + except RuntimeError: + # RuntimeError: "bitwise_xor_cpu" not implemented for 'ComplexDouble' etc + result = None + + try: + result_numpy = ufunc_numpy(scalar, array.tensor.numpy()) + except TypeError: + # TypeError: ufunc 'hypot' not supported for the input types + result_numpy = None + + if result is not None and result_numpy is not None: + assert result.tensor.numpy().dtype == result_numpy.dtype + + finally: + _np._set_promotion_state(state) diff --git a/torch_np/tests/test_ufuncs_basic.py b/torch_np/tests/test_ufuncs_basic.py index 9fcb0ae8..02c8b312 100644 --- a/torch_np/tests/test_ufuncs_basic.py +++ b/torch_np/tests/test_ufuncs_basic.py @@ -380,10 +380,6 @@ def test_binary_ufunc_dtype(self): assert r32.dtype == "float32" assert r32 == 1 - # casting of floating inputs to booleans - with assert_raises(TypeError): - np.add(1.0, 1e-15, dtype=bool) - # now force the cast rb = np.add(1.0, 1e-15, dtype=bool, casting="unsafe") assert rb.dtype == bool