Skip to content

ENH: introduce NEP 50 "weak scalars" #140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions torch_np/_dtypes_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,90 @@ def result_type_impl(*tensors):
dtyp = _cd._result_type_dict[dtyp][curr.dtype]

return dtyp


# ### NEP 50 helpers ###

SCALAR_TYPES = {int, bool, float, complex}


def _dtype_for_scalar(py_type):
return {
bool: torch.bool,
int: torch.int64,
float: torch.float64,
complex: torch.complex128,
}[py_type]


def _category(dtype):
return {
torch.bool: 0,
# int
torch.uint8: 1,
torch.int8: 1,
torch.int16: 1,
torch.int32: 1,
torch.int64: 1,
# float
torch.float16: 2,
torch.float32: 2,
torch.float64: 2,
# complex
torch.complex64: 3,
torch.complex128: 3,
}[dtype]


def nep50_to_tensors(x1, x2, handle_weaks):
"""If either of inputs is a python scalar, type-promote with NEP 50."""

def to_tensor(scalar, dtype=None):
if dtype is None:
dtype = _dtype_for_scalar(type(scalar))
dtype = get_default_dtype_for(dtype)
return torch.as_tensor(scalar, dtype=dtype)

x1_is_weak = not isinstance(x1, torch.Tensor)
x2_is_weak = not isinstance(x2, torch.Tensor)
if not handle_weaks or (x1_is_weak and x2_is_weak):
x1 = to_tensor(x1) if x1_is_weak else x1
x2 = to_tensor(x2) if x2_is_weak else x2
return x1, x2

# scalar <op> tensor: NEP 50
assert x1_is_weak != x2_is_weak

weak, not_weak = (x1, x2) if x1_is_weak else (x2, x1)

# find the dtype for the weak's type
weak_dtype = _dtype_for_scalar(type(weak))

cat_weak = _category(weak_dtype)
cat_not_weak = _category(not_weak.dtype)

dt = not_weak.dtype if cat_weak <= cat_not_weak else None

# special-case complex + float32
if weak_dtype.is_complex and not_weak.dtype == torch.float32:
dt = torch.complex64

# detect overflows: in PyTorch, uint8(-1) wraps around to 255,
# while NEP50 mandates an exception.
#
# Note that we only check if each element of the binop overflows,
# not the result. Consider, e.g. `uint8(100) + 200`. Operands are OK
# in uint8, but the result overflows and wrap around 255.
# Numpy emits a RuntimeWarning, PyTorch does not, and we do not either.
if cat_weak == 1 and cat_not_weak == 1:
# integers
iinfo = torch.iinfo(not_weak.dtype)
if not (iinfo.min <= weak <= iinfo.max):
raise OverflowError(
f"Python integer {weak} out of bounds for {not_weak.dtype}"
)

# finally, can make `weak` into a 0D tensor
weak = to_tensor(weak, dt)

return (weak, not_weak) if x1_is_weak else (not_weak, weak)
12 changes: 11 additions & 1 deletion torch_np/_normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@

import torch

from . import _dtypes, _util
from . import _dtypes, _dtypes_impl, _util

ArrayLike = typing.TypeVar("ArrayLike")
Scalar = typing.Union[int, float, complex, bool]
ArrayLikeOrScalar = typing.Union[ArrayLike, Scalar]

DTypeLike = typing.TypeVar("DTypeLike")
AxisLike = typing.TypeVar("AxisLike")
NDArray = typing.TypeVar("NDarray")
Expand Down Expand Up @@ -43,6 +46,12 @@ def normalize_array_like(x, parm=None):
return asarray(x).tensor


def normalize_array_like_or_scalar(x, parm=None):
if type(x) in _dtypes_impl.SCALAR_TYPES:
return x
return normalize_array_like(x, parm)


def normalize_optional_array_like(x, parm=None):
# This explicit normalizer is needed because otherwise normalize_array_like
# does not run for a parameter annotated as Optional[ArrayLike]
Expand Down Expand Up @@ -109,6 +118,7 @@ def normalize_casting(arg, parm=None):

normalizers = {
"ArrayLike": normalize_array_like,
"Union[ArrayLike, Scalar]": normalize_array_like_or_scalar,
"Optional[ArrayLike]": normalize_optional_array_like,
"Sequence[ArrayLike]": normalize_seq_array_like,
"Optional[NDArray]": normalize_ndarray,
Expand Down
92 changes: 66 additions & 26 deletions torch_np/_ufuncs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Optional
from typing import Optional, Union

import torch

Expand All @@ -11,19 +11,11 @@
DTypeLike,
NotImplementedType,
OutArray,
Scalar,
normalizer,
)


def _ufunc_preprocess(tensors, where, casting, order, dtype, subok, signature, extobj):
if dtype is None:
dtype = _dtypes_impl.result_type_impl(*tensors)

tensors = _util.typecast_tensors(tensors, dtype, casting)

return tensors


def _ufunc_postprocess(result, out, casting):
if out is not None:
result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting)
Expand All @@ -40,6 +32,36 @@ def _ufunc_postprocess(result, out, casting):
]


NEP50_FUNCS = (
"add",
"subtract",
"multiply",
"floor_divide",
"true_divide",
"divide",
"remainder",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"bitwise_left_shift",
"bitwise_right_shift",
"hypot",
"arctan2",
"logaddexp",
"logaddexp2",
"heaviside",
"copysign",
"fmax",
"minimum",
"fmin",
"maximum",
"fmod",
"gcd",
"lcm",
"pow",
)


def deco_binary_ufunc(torch_func):
"""Common infra for binary ufuncs.

Expand All @@ -49,8 +71,8 @@ def deco_binary_ufunc(torch_func):

@normalizer
def wrapped(
x1: ArrayLike,
x2: ArrayLike,
x1: Union[ArrayLike, Scalar],
x2: Union[ArrayLike, Scalar],
/,
out: Optional[OutArray] = None,
*,
Expand All @@ -62,13 +84,28 @@ def wrapped(
signature=None,
extobj=None,
):
tensors = _ufunc_preprocess(
(x1, x2), where, casting, order, dtype, subok, signature, extobj
)
result = torch_func(*tensors)

result = _ufunc_postprocess(result, out, casting)
return result
if dtype is not None:

def cast(x, dtype):
if isinstance(x, torch.Tensor):
return _util.typecast_tensors((x,), dtype, casting)[0]
else:
return torch.as_tensor(x, dtype=dtype)

x1 = cast(x1, dtype)
x2 = cast(x2, dtype)
elif isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
dtype = _dtypes_impl.result_type_impl(x1, x2)
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
else:
x1, x2 = _dtypes_impl.nep50_to_tensors(
x1, x2, torch_func.__name__ in NEP50_FUNCS
)

result = torch_func(x1, x2)

return _ufunc_postprocess(result, out, casting)

wrapped.__qualname__ = torch_func.__name__
wrapped.__name__ = torch_func.__name__
Expand All @@ -80,6 +117,7 @@ def wrapped(
# matmul's signature is _slightly_ different from other ufuncs:
# - no where=...
# - additional axis=..., axes=...
# - no NEP50 scalars in or out
#
@normalizer
def matmul(
Expand All @@ -97,10 +135,12 @@ def matmul(
axes: NotImplementedType = None,
axis: NotImplementedType = None,
):
tensors = _ufunc_preprocess(
(x1, x2), True, casting, order, dtype, subok, signature, extobj
)
result = _binary_ufuncs_impl.matmul(*tensors)

if dtype is None:
dtype = _dtypes_impl.result_type_impl(x1, x2)
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)

result = _binary_ufuncs_impl.matmul(x1, x2)

result = _ufunc_postprocess(result, out, casting)
return result
Expand Down Expand Up @@ -140,11 +180,11 @@ def divmod(
else:
out1, out2 = out

tensors = _ufunc_preprocess(
(x1, x2), True, casting, order, dtype, subok, signature, extobj
)
if dtype is None:
dtype = _dtypes_impl.result_type_impl(x1, x2)
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)

quot, rem = _binary_ufuncs_impl.divmod(*tensors)
quot, rem = _binary_ufuncs_impl.divmod(x1, x2)

quot = _ufunc_postprocess(quot, out1, casting)
rem = _ufunc_postprocess(rem, out2, casting)
Expand Down
15 changes: 8 additions & 7 deletions torch_np/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
Coerce to this torch dtype
copy : bool
Copy or not
ndmin : int
The results as least this many dimensions
is_weak : bool
Whether obj is a weakly typed python scalar.

Returns
-------
Expand All @@ -198,14 +202,11 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
tensor = obj
else:
tensor = torch.as_tensor(obj)
base = None

# At this point, `tensor.dtype` is the pytorch default. Our default may
# differ, so need to typecast. However, we cannot just do `tensor.to`,
# because if our desired dtype is wider then pytorch's, `tensor`
# may have lost precision:

# int(torch.as_tensor(1e12)) - 1e12 equals -4096 (try it!)
# tensor.dtype is the pytorch default, typically float32. If obj's elements
# are not exactly representable in float32, we've lost precision:
# >>> torch.as_tensor(1e12).item() - 1e12
# -4096.0

# Therefore, we treat `tensor.dtype` as a hint, and convert the
# original object *again*, this time with an explicit dtype.
Expand Down
28 changes: 15 additions & 13 deletions torch_np/tests/numpy_tests/core/test_scalarmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,21 @@ def test_numpy_scalar_relational_operators(self):
assert_(not np.array(1, dtype=dt1)[()] < np.array(0, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))

#Signed integers and floats
for dt1 in 'bhl' + np.typecodes['Float']:
assert_(1 > np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))
assert_(not 1 < np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))
assert_(-1 == np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))

for dt2 in 'bhl' + np.typecodes['Float']:
assert_(np.array(1, dtype=dt1)[()] > np.array(-1, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))
assert_(not np.array(1, dtype=dt1)[()] < np.array(-1, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))
assert_(np.array(-1, dtype=dt1)[()] == np.array(-1, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))

def test_numpy_scalar_relational_operators_2(self):
#Unsigned integers
for dt1 in 'B':
assert_(-1 < np.array(1, dtype=dt1)[()], "type %s failed" % (dt1,))
Expand All @@ -496,19 +511,6 @@ def test_numpy_scalar_relational_operators(self):
assert_(np.array(1, dtype=dt1)[()] != np.array(-1, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))

#Signed integers and floats
for dt1 in 'bhl' + np.typecodes['Float']:
assert_(1 > np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))
assert_(not 1 < np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))
assert_(-1 == np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))

for dt2 in 'bhl' + np.typecodes['Float']:
assert_(np.array(1, dtype=dt1)[()] > np.array(-1, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))
assert_(not np.array(1, dtype=dt1)[()] < np.array(-1, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))
assert_(np.array(-1, dtype=dt1)[()] == np.array(-1, dtype=dt2)[()],
"type %s and %s failed" % (dt1, dt2))

def test_scalar_comparison_to_none(self):
# Scalars should just return False and not give a warnings.
Expand Down
4 changes: 2 additions & 2 deletions torch_np/tests/numpy_tests/fft/test_pocketfft.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def test_fft(self):

np.random.seed(1234)
x = random(30) + 1j*random(30)
assert_allclose(fft1(x), np.fft.fft(x), atol=2e-5)
assert_allclose(fft1(x), np.fft.fft(x, norm="backward"), atol=2e-5)
assert_allclose(fft1(x), np.fft.fft(x), atol=3e-5)
assert_allclose(fft1(x), np.fft.fft(x, norm="backward"), atol=3e-5)
assert_allclose(fft1(x) / np.sqrt(30),
np.fft.fft(x, norm="ortho"), atol=5e-6)
assert_allclose(fft1(x) / 30.,
Expand Down
Loading