Skip to content

Commit 0d117e0

Browse files
ev-brlezcano
andauthored
ENH: introduce NEP 50 "weak scalars" (#140)
Co-authored-by: Mario Lezcano Casado <[email protected]>
1 parent 133c367 commit 0d117e0

File tree

8 files changed

+333
-65
lines changed

8 files changed

+333
-65
lines changed

torch_np/_dtypes_impl.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,90 @@ def result_type_impl(*tensors):
4949
dtyp = _cd._result_type_dict[dtyp][curr.dtype]
5050

5151
return dtyp
52+
53+
54+
# ### NEP 50 helpers ###
55+
56+
SCALAR_TYPES = {int, bool, float, complex}
57+
58+
59+
def _dtype_for_scalar(py_type):
60+
return {
61+
bool: torch.bool,
62+
int: torch.int64,
63+
float: torch.float64,
64+
complex: torch.complex128,
65+
}[py_type]
66+
67+
68+
def _category(dtype):
69+
return {
70+
torch.bool: 0,
71+
# int
72+
torch.uint8: 1,
73+
torch.int8: 1,
74+
torch.int16: 1,
75+
torch.int32: 1,
76+
torch.int64: 1,
77+
# float
78+
torch.float16: 2,
79+
torch.float32: 2,
80+
torch.float64: 2,
81+
# complex
82+
torch.complex64: 3,
83+
torch.complex128: 3,
84+
}[dtype]
85+
86+
87+
def nep50_to_tensors(x1, x2, handle_weaks):
88+
"""If either of inputs is a python scalar, type-promote with NEP 50."""
89+
90+
def to_tensor(scalar, dtype=None):
91+
if dtype is None:
92+
dtype = _dtype_for_scalar(type(scalar))
93+
dtype = get_default_dtype_for(dtype)
94+
return torch.as_tensor(scalar, dtype=dtype)
95+
96+
x1_is_weak = not isinstance(x1, torch.Tensor)
97+
x2_is_weak = not isinstance(x2, torch.Tensor)
98+
if not handle_weaks or (x1_is_weak and x2_is_weak):
99+
x1 = to_tensor(x1) if x1_is_weak else x1
100+
x2 = to_tensor(x2) if x2_is_weak else x2
101+
return x1, x2
102+
103+
# scalar <op> tensor: NEP 50
104+
assert x1_is_weak != x2_is_weak
105+
106+
weak, not_weak = (x1, x2) if x1_is_weak else (x2, x1)
107+
108+
# find the dtype for the weak's type
109+
weak_dtype = _dtype_for_scalar(type(weak))
110+
111+
cat_weak = _category(weak_dtype)
112+
cat_not_weak = _category(not_weak.dtype)
113+
114+
dt = not_weak.dtype if cat_weak <= cat_not_weak else None
115+
116+
# special-case complex + float32
117+
if weak_dtype.is_complex and not_weak.dtype == torch.float32:
118+
dt = torch.complex64
119+
120+
# detect overflows: in PyTorch, uint8(-1) wraps around to 255,
121+
# while NEP50 mandates an exception.
122+
#
123+
# Note that we only check if each element of the binop overflows,
124+
# not the result. Consider, e.g. `uint8(100) + 200`. Operands are OK
125+
# in uint8, but the result overflows and wrap around 255.
126+
# Numpy emits a RuntimeWarning, PyTorch does not, and we do not either.
127+
if cat_weak == 1 and cat_not_weak == 1:
128+
# integers
129+
iinfo = torch.iinfo(not_weak.dtype)
130+
if not (iinfo.min <= weak <= iinfo.max):
131+
raise OverflowError(
132+
f"Python integer {weak} out of bounds for {not_weak.dtype}"
133+
)
134+
135+
# finally, can make `weak` into a 0D tensor
136+
weak = to_tensor(weak, dt)
137+
138+
return (weak, not_weak) if x1_is_weak else (not_weak, weak)

torch_np/_normalizations.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99

1010
import torch
1111

12-
from . import _dtypes, _util
12+
from . import _dtypes, _dtypes_impl, _util
1313

1414
ArrayLike = typing.TypeVar("ArrayLike")
15+
Scalar = typing.Union[int, float, complex, bool]
16+
ArrayLikeOrScalar = typing.Union[ArrayLike, Scalar]
17+
1518
DTypeLike = typing.TypeVar("DTypeLike")
1619
AxisLike = typing.TypeVar("AxisLike")
1720
NDArray = typing.TypeVar("NDarray")
@@ -43,6 +46,12 @@ def normalize_array_like(x, parm=None):
4346
return asarray(x).tensor
4447

4548

49+
def normalize_array_like_or_scalar(x, parm=None):
50+
if type(x) in _dtypes_impl.SCALAR_TYPES:
51+
return x
52+
return normalize_array_like(x, parm)
53+
54+
4655
def normalize_optional_array_like(x, parm=None):
4756
# This explicit normalizer is needed because otherwise normalize_array_like
4857
# does not run for a parameter annotated as Optional[ArrayLike]
@@ -109,6 +118,7 @@ def normalize_casting(arg, parm=None):
109118

110119
normalizers = {
111120
"ArrayLike": normalize_array_like,
121+
"Union[ArrayLike, Scalar]": normalize_array_like_or_scalar,
112122
"Optional[ArrayLike]": normalize_optional_array_like,
113123
"Sequence[ArrayLike]": normalize_seq_array_like,
114124
"Optional[NDArray]": normalize_ndarray,

torch_np/_ufuncs.py

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Optional
3+
from typing import Optional, Union
44

55
import torch
66

@@ -11,19 +11,11 @@
1111
DTypeLike,
1212
NotImplementedType,
1313
OutArray,
14+
Scalar,
1415
normalizer,
1516
)
1617

1718

18-
def _ufunc_preprocess(tensors, where, casting, order, dtype, subok, signature, extobj):
19-
if dtype is None:
20-
dtype = _dtypes_impl.result_type_impl(*tensors)
21-
22-
tensors = _util.typecast_tensors(tensors, dtype, casting)
23-
24-
return tensors
25-
26-
2719
def _ufunc_postprocess(result, out, casting):
2820
if out is not None:
2921
result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting)
@@ -40,6 +32,36 @@ def _ufunc_postprocess(result, out, casting):
4032
]
4133

4234

35+
NEP50_FUNCS = (
36+
"add",
37+
"subtract",
38+
"multiply",
39+
"floor_divide",
40+
"true_divide",
41+
"divide",
42+
"remainder",
43+
"bitwise_and",
44+
"bitwise_or",
45+
"bitwise_xor",
46+
"bitwise_left_shift",
47+
"bitwise_right_shift",
48+
"hypot",
49+
"arctan2",
50+
"logaddexp",
51+
"logaddexp2",
52+
"heaviside",
53+
"copysign",
54+
"fmax",
55+
"minimum",
56+
"fmin",
57+
"maximum",
58+
"fmod",
59+
"gcd",
60+
"lcm",
61+
"pow",
62+
)
63+
64+
4365
def deco_binary_ufunc(torch_func):
4466
"""Common infra for binary ufuncs.
4567
@@ -49,8 +71,8 @@ def deco_binary_ufunc(torch_func):
4971

5072
@normalizer
5173
def wrapped(
52-
x1: ArrayLike,
53-
x2: ArrayLike,
74+
x1: Union[ArrayLike, Scalar],
75+
x2: Union[ArrayLike, Scalar],
5476
/,
5577
out: Optional[OutArray] = None,
5678
*,
@@ -62,13 +84,28 @@ def wrapped(
6284
signature=None,
6385
extobj=None,
6486
):
65-
tensors = _ufunc_preprocess(
66-
(x1, x2), where, casting, order, dtype, subok, signature, extobj
67-
)
68-
result = torch_func(*tensors)
6987

70-
result = _ufunc_postprocess(result, out, casting)
71-
return result
88+
if dtype is not None:
89+
90+
def cast(x, dtype):
91+
if isinstance(x, torch.Tensor):
92+
return _util.typecast_tensors((x,), dtype, casting)[0]
93+
else:
94+
return torch.as_tensor(x, dtype=dtype)
95+
96+
x1 = cast(x1, dtype)
97+
x2 = cast(x2, dtype)
98+
elif isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
99+
dtype = _dtypes_impl.result_type_impl(x1, x2)
100+
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
101+
else:
102+
x1, x2 = _dtypes_impl.nep50_to_tensors(
103+
x1, x2, torch_func.__name__ in NEP50_FUNCS
104+
)
105+
106+
result = torch_func(x1, x2)
107+
108+
return _ufunc_postprocess(result, out, casting)
72109

73110
wrapped.__qualname__ = torch_func.__name__
74111
wrapped.__name__ = torch_func.__name__
@@ -80,6 +117,7 @@ def wrapped(
80117
# matmul's signature is _slightly_ different from other ufuncs:
81118
# - no where=...
82119
# - additional axis=..., axes=...
120+
# - no NEP50 scalars in or out
83121
#
84122
@normalizer
85123
def matmul(
@@ -97,10 +135,12 @@ def matmul(
97135
axes: NotImplementedType = None,
98136
axis: NotImplementedType = None,
99137
):
100-
tensors = _ufunc_preprocess(
101-
(x1, x2), True, casting, order, dtype, subok, signature, extobj
102-
)
103-
result = _binary_ufuncs_impl.matmul(*tensors)
138+
139+
if dtype is None:
140+
dtype = _dtypes_impl.result_type_impl(x1, x2)
141+
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
142+
143+
result = _binary_ufuncs_impl.matmul(x1, x2)
104144

105145
result = _ufunc_postprocess(result, out, casting)
106146
return result
@@ -140,11 +180,11 @@ def divmod(
140180
else:
141181
out1, out2 = out
142182

143-
tensors = _ufunc_preprocess(
144-
(x1, x2), True, casting, order, dtype, subok, signature, extobj
145-
)
183+
if dtype is None:
184+
dtype = _dtypes_impl.result_type_impl(x1, x2)
185+
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
146186

147-
quot, rem = _binary_ufuncs_impl.divmod(*tensors)
187+
quot, rem = _binary_ufuncs_impl.divmod(x1, x2)
148188

149189
quot = _ufunc_postprocess(quot, out1, casting)
150190
rem = _ufunc_postprocess(rem, out2, casting)

torch_np/_util.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,10 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
175175
Coerce to this torch dtype
176176
copy : bool
177177
Copy or not
178+
ndmin : int
179+
The results as least this many dimensions
180+
is_weak : bool
181+
Whether obj is a weakly typed python scalar.
178182
179183
Returns
180184
-------
@@ -191,14 +195,11 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
191195
tensor = obj
192196
else:
193197
tensor = torch.as_tensor(obj)
194-
base = None
195198

196-
# At this point, `tensor.dtype` is the pytorch default. Our default may
197-
# differ, so need to typecast. However, we cannot just do `tensor.to`,
198-
# because if our desired dtype is wider then pytorch's, `tensor`
199-
# may have lost precision:
200-
201-
# int(torch.as_tensor(1e12)) - 1e12 equals -4096 (try it!)
199+
# tensor.dtype is the pytorch default, typically float32. If obj's elements
200+
# are not exactly representable in float32, we've lost precision:
201+
# >>> torch.as_tensor(1e12).item() - 1e12
202+
# -4096.0
202203

203204
# Therefore, we treat `tensor.dtype` as a hint, and convert the
204205
# original object *again*, this time with an explicit dtype.

torch_np/tests/numpy_tests/core/test_scalarmath.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,21 @@ def test_numpy_scalar_relational_operators(self):
481481
assert_(not np.array(1, dtype=dt1)[()] < np.array(0, dtype=dt2)[()],
482482
"type %s and %s failed" % (dt1, dt2))
483483

484+
#Signed integers and floats
485+
for dt1 in 'bhl' + np.typecodes['Float']:
486+
assert_(1 > np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))
487+
assert_(not 1 < np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))
488+
assert_(-1 == np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))
489+
490+
for dt2 in 'bhl' + np.typecodes['Float']:
491+
assert_(np.array(1, dtype=dt1)[()] > np.array(-1, dtype=dt2)[()],
492+
"type %s and %s failed" % (dt1, dt2))
493+
assert_(not np.array(1, dtype=dt1)[()] < np.array(-1, dtype=dt2)[()],
494+
"type %s and %s failed" % (dt1, dt2))
495+
assert_(np.array(-1, dtype=dt1)[()] == np.array(-1, dtype=dt2)[()],
496+
"type %s and %s failed" % (dt1, dt2))
497+
498+
def test_numpy_scalar_relational_operators_2(self):
484499
#Unsigned integers
485500
for dt1 in 'B':
486501
assert_(-1 < np.array(1, dtype=dt1)[()], "type %s failed" % (dt1,))
@@ -496,19 +511,6 @@ def test_numpy_scalar_relational_operators(self):
496511
assert_(np.array(1, dtype=dt1)[()] != np.array(-1, dtype=dt2)[()],
497512
"type %s and %s failed" % (dt1, dt2))
498513

499-
#Signed integers and floats
500-
for dt1 in 'bhl' + np.typecodes['Float']:
501-
assert_(1 > np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))
502-
assert_(not 1 < np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))
503-
assert_(-1 == np.array(-1, dtype=dt1)[()], "type %s failed" % (dt1,))
504-
505-
for dt2 in 'bhl' + np.typecodes['Float']:
506-
assert_(np.array(1, dtype=dt1)[()] > np.array(-1, dtype=dt2)[()],
507-
"type %s and %s failed" % (dt1, dt2))
508-
assert_(not np.array(1, dtype=dt1)[()] < np.array(-1, dtype=dt2)[()],
509-
"type %s and %s failed" % (dt1, dt2))
510-
assert_(np.array(-1, dtype=dt1)[()] == np.array(-1, dtype=dt2)[()],
511-
"type %s and %s failed" % (dt1, dt2))
512514

513515
def test_scalar_comparison_to_none(self):
514516
# Scalars should just return False and not give a warnings.

torch_np/tests/numpy_tests/fft/test_pocketfft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def test_fft(self):
4444

4545
np.random.seed(1234)
4646
x = random(30) + 1j*random(30)
47-
assert_allclose(fft1(x), np.fft.fft(x), atol=2e-5)
48-
assert_allclose(fft1(x), np.fft.fft(x, norm="backward"), atol=2e-5)
47+
assert_allclose(fft1(x), np.fft.fft(x), atol=3e-5)
48+
assert_allclose(fft1(x), np.fft.fft(x, norm="backward"), atol=3e-5)
4949
assert_allclose(fft1(x) / np.sqrt(30),
5050
np.fft.fft(x, norm="ortho"), atol=5e-6)
5151
assert_allclose(fft1(x) / 30.,

0 commit comments

Comments
 (0)