Skip to content

ENH: introduce NEP 50 "weak scalars" #140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions torch_np/_dtypes_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,82 @@ def result_type_impl(*tensors):
dtyp = _cd._result_type_dict[dtyp][curr.dtype]

return dtyp


# ### NEP 50 helpers ###

SCALAR_TYPES = (int, bool, float, complex)


def _dtype_for_scalar(py_type):
return {
bool: torch.bool,
int: torch.int64,
float: torch.float64,
complex: torch.complex128,
}[py_type]


categories = [
(torch.bool,),
(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
(torch.float16, torch.float32, torch.float64),
(torch.complex64, torch.complex128),
]


def category(dtyp):
for j, cat in enumerate(categories):
if dtyp in cat:
return j
raise ValueError(f"unknown dtype {dtyp}")


dtype_for_cat = {0: torch.bool, 1: torch.int64, 2: torch.float64, 3: torch.complex128}


def nep50_to_tensors(x1, x2):
"""If either of inputs is a python scalar, type-promote with NEP 50.

NB: NEP 50 mandates RuntimeWarnings on some overflows. We do not emit them:
we either raise an OverflowError or silently do the computation.
"""

x1_type, x2_type = type(x1), type(x2)
x1_is_weak = x1_type in SCALAR_TYPES
x2_is_weak = x2_type in SCALAR_TYPES
if x1_is_weak and x2_is_weak:
# two scalars: promote
x1 = torch.as_tensor(x1, dtype=_dtype_for_scalar(x1_type))
x2 = torch.as_tensor(x2, dtype=_dtype_for_scalar(x2_type))
return x1, x2
elif not (x1_is_weak or x2_is_weak):
# two tensors: nothing to do here
return x1, x2
else:
# scalar <op> scalar: NEP 50
weak, not_weak = (x1, x2) if x1_is_weak else (x2, x1)

# find the dtype for the weak's type
weak_dtype = _dtype_for_scalar(type(weak))

cat_weak = category(weak_dtype)
cat_not_weak = category(not_weak.dtype)

dt = not_weak.dtype if cat_weak <= cat_not_weak else dtype_for_cat[cat_weak]

# special-case complex + float32
if weak_dtype.is_complex and not_weak.dtype == torch.float32:
dt = torch.complex64

# finally, can cast make `weak` into a 0D tensor
weak_ = torch.as_tensor(weak, dtype=dt)

# detect uint overflow: in PyTorch, uint8(-1) wraps around to 255,
# while NEP50 mandates an exception.
if weak_.dtype == torch.uint8 and weak_.item() != weak:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this happen just for uint8 or for any int dtype?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, using .item() is not kosher. Let's do 0 <= weak < 2**8 before doing the as_tensor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit different: checking the weak's value does not detect uint(100) + 200. However, numpy warns not raises, so we shouldn't raise either. As discussed, this PR now does what numpy does, sans RuntimeWarnings.

Copy link
Collaborator

@lezcano lezcano May 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it's a bit annoying that this check is just done for ints. If it were done for all dtypes, we could create the tensors with torch.full, which does check if the number fits in the given type.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is also a valid choice, numpy already gives a customizable warning anyway if you overflow to inf, so this just seemed the easier/OK thing. ints don't overflow graciously though...

raise OverflowError(
f"Python integer {weak} out of bounds for {weak_.dtype}"
)

return (weak_, not_weak) if x1_is_weak else (not_weak, weak_)
12 changes: 11 additions & 1 deletion torch_np/_normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@

import torch

from . import _dtypes, _util
from . import _dtypes, _dtypes_impl, _util

ArrayLike = typing.TypeVar("ArrayLike")
Scalar = typing.Union[int, float, complex, bool]
ArrayLikeOrScalar = typing.Union[ArrayLike, Scalar]

DTypeLike = typing.TypeVar("DTypeLike")
AxisLike = typing.TypeVar("AxisLike")
NDArray = typing.TypeVar("NDarray")
Expand Down Expand Up @@ -43,6 +46,12 @@ def normalize_array_like(x, parm=None):
return asarray(x).tensor


def normalize_array_like_or_scalar(x, parm=None):
if type(x) in _dtypes_impl.SCALAR_TYPES:
return x
return normalize_array_like(x, parm)


def normalize_optional_array_like(x, parm=None):
# This explicit normalizer is needed because otherwise normalize_array_like
# does not run for a parameter annotated as Optional[ArrayLike]
Expand Down Expand Up @@ -109,6 +118,7 @@ def normalize_casting(arg, parm=None):

normalizers = {
"ArrayLike": normalize_array_like,
"ArrayLike | Scalar": normalize_array_like_or_scalar,
"Optional[ArrayLike]": normalize_optional_array_like,
"Sequence[ArrayLike]": normalize_seq_array_like,
"Optional[NDArray]": normalize_ndarray,
Expand Down
45 changes: 22 additions & 23 deletions torch_np/_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,11 @@
DTypeLike,
NotImplementedType,
OutArray,
Scalar,
normalizer,
)


def _ufunc_preprocess(tensors, where, casting, order, dtype, subok, signature, extobj):
if dtype is None:
dtype = _dtypes_impl.result_type_impl(*tensors)

tensors = _util.typecast_tensors(tensors, dtype, casting)

return tensors


def _ufunc_postprocess(result, out, casting):
if out is not None:
result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting)
Expand All @@ -49,8 +41,8 @@ def deco_binary_ufunc(torch_func):

@normalizer
def wrapped(
x1: ArrayLike,
x2: ArrayLike,
x1: ArrayLike | Scalar,
x2: ArrayLike | Scalar,
/,
out: Optional[OutArray] = None,
*,
Expand All @@ -62,10 +54,14 @@ def wrapped(
signature=None,
extobj=None,
):
tensors = _ufunc_preprocess(
(x1, x2), where, casting, order, dtype, subok, signature, extobj
)
result = torch_func(*tensors)

x1, x2 = _dtypes_impl.nep50_to_tensors(x1, x2)

if dtype is None:
dtype = _dtypes_impl.result_type_impl(x1, x2)
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)

result = torch_func(x1, x2)

result = _ufunc_postprocess(result, out, casting)
return result
Expand All @@ -80,6 +76,7 @@ def wrapped(
# matmul's signature is _slightly_ different from other ufuncs:
# - no where=...
# - additional axis=..., axes=...
# - no NEP50 scalars in or out
#
@normalizer
def matmul(
Expand All @@ -97,10 +94,12 @@ def matmul(
axes: NotImplementedType = None,
axis: NotImplementedType = None,
):
tensors = _ufunc_preprocess(
(x1, x2), True, casting, order, dtype, subok, signature, extobj
)
result = _binary_ufuncs_impl.matmul(*tensors)

if dtype is None:
dtype = _dtypes_impl.result_type_impl(x1, x2)
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)

result = _binary_ufuncs_impl.matmul(x1, x2)

result = _ufunc_postprocess(result, out, casting)
return result
Expand Down Expand Up @@ -140,11 +139,11 @@ def divmod(
else:
out1, out2 = out

tensors = _ufunc_preprocess(
(x1, x2), True, casting, order, dtype, subok, signature, extobj
)
if dtype is None:
dtype = _dtypes_impl.result_type_impl(x1, x2)
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)

quot, rem = _binary_ufuncs_impl.divmod(*tensors)
quot, rem = _binary_ufuncs_impl.divmod(x1, x2)

quot = _ufunc_postprocess(quot, out1, casting)
rem = _ufunc_postprocess(rem, out2, casting)
Expand Down
15 changes: 8 additions & 7 deletions torch_np/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
Coerce to this torch dtype
copy : bool
Copy or not
ndmin : int
The results as least this many dimensions
is_weak : bool
Whether obj is a weakly typed python scalar.

Returns
-------
Expand All @@ -198,14 +202,11 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
tensor = obj
else:
tensor = torch.as_tensor(obj)
base = None

# At this point, `tensor.dtype` is the pytorch default. Our default may
# differ, so need to typecast. However, we cannot just do `tensor.to`,
# because if our desired dtype is wider then pytorch's, `tensor`
# may have lost precision:

# int(torch.as_tensor(1e12)) - 1e12 equals -4096 (try it!)
# tensor.dtype is the pytorch default, typically float32. If obj's elements
# are not exactly representable in float32, we've lost precision:
# >>> torch.as_tensor(1e12).item() - 1e12
# -4096.0

# Therefore, we treat `tensor.dtype` as a hint, and convert the
# original object *again*, this time with an explicit dtype.
Expand Down
29 changes: 16 additions & 13 deletions torch_np/tests/numpy_tests/core/test_scalarmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,22 @@ def test_numpy_scalar_relational_operators(self):
assert_(not np.array(1, dtype=dt1)[()] < np.array(0, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))

#Signed integers and floats
for dt1 in 'bhl' + np.typecodes['Float']:
assert_(1 > np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))
assert_(not 1 < np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))
assert_(-1 == np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))

for dt2 in 'bhl' + np.typecodes['Float']:
assert_(np.array(1, dtype=dt1)[()] > np.array(-1, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))
assert_(not np.array(1, dtype=dt1)[()] < np.array(-1, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))
assert_(np.array(-1, dtype=dt1)[()] == np.array(-1, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))

@pytest.mark.xfail(reason="NEP50")
def test_numpy_scalar_relational_operators_2(self):
#Unsigned integers
for dt1 in 'B':
assert_(-1 < np.array(1, dtype=dt1)[()], "type %s failed" % (dt1,))
Expand All @@ -496,19 +512,6 @@ def test_numpy_scalar_relational_operators(self):
assert_(np.array(1, dtype=dt1)[()] != np.array(-1, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))

#Signed integers and floats
for dt1 in 'bhl' + np.typecodes['Float']:
assert_(1 > np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))
assert_(not 1 < np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))
assert_(-1 == np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))

for dt2 in 'bhl' + np.typecodes['Float']:
assert_(np.array(1, dtype=dt1)[()] > np.array(-1, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))
assert_(not np.array(1, dtype=dt1)[()] < np.array(-1, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))
assert_(np.array(-1, dtype=dt1)[()] == np.array(-1, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))

def test_scalar_comparison_to_none(self):
# Scalars should just return False and not give a warnings.
Expand Down
36 changes: 24 additions & 12 deletions torch_np/tests/test_nep50_examples.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
"""Test examples for NEP 50."""

from torch_np import array, float32, float64, inf, int64, uint8
import torch_np as tnp
from torch_np import (
array,
bool_,
complex64,
complex128,
float32,
float64,
inf,
int16,
int64,
uint8,
)
from torch_np.testing import assert_allclose

uint16 = uint8 # can be anything here, see below
Expand Down Expand Up @@ -36,24 +48,24 @@
"uint8(1) + 300": (int64(301), Exception),
"uint8(100) + 200": (int64(301), uint8(44)), # and RuntimeWarning
"float32(1) + 3e100": (float64(3e100), float32(inf)), # and RuntimeWarning [T7]
# "array([0.1], float32) == 0.1": (array([False]), unchanged), # XXX: a typo in NEP50?
"array([0.1], float32) == 0.1": (
array([False]),
unchanged,
), # XXX: a typo in NEP50?
"array([0.1], float32) == float64(0.1)": (array([True]), array([False])),
"array([1.], float32) + 3": (array([4.0], float32), unchanged),
"array([1.], float32) + int64(3)": (array([4.0], float32), array([4.0], float64)),
# additional examples from the NEP text
"int16(2) + 2": (int64(4), int16(4)),
"int16(4) + 4j": (complex128(4 + 4j), unchanged),
"float32(5) + 5j": (complex128(5 + 5j), complex64(5 + 5j)),
"bool_(True) + 1": (int64(2), unchanged),
"True + uint8(2)": (uint8(3), unchanged),
}


fails = [
"uint8(1) + 2",
"array([1], uint8) + 1",
"array([1], uint8) + 200",
"array([1], uint8) + array(1, int64)",
"array([100], uint8) + 200",
"array([1], uint8) + 300",
"uint8(1) + 300",
"uint8(100) + 200",
"float32(1) + 3e100",
"array([1.], float32) + 3",
"array([0.1], float32) == 0.1", # TODO: fix the example
]


Expand Down