Skip to content

Commit 19e96c2

Browse files
committed
MAINT: stop attaching things to Tensors, move the logic to binary ufuncs
1 parent 3656962 commit 19e96c2

File tree

6 files changed

+93
-140
lines changed

6 files changed

+93
-140
lines changed

torch_np/_dtypes_impl.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,18 @@ def result_type_impl(*tensors):
5353

5454
# ### NEP 50 helpers ###
5555

56+
SCALAR_TYPES = (int, bool, float, complex)
57+
58+
59+
def _dtype_for_scalar(py_type):
60+
return {
61+
bool: torch.bool,
62+
int: torch.int64,
63+
float: torch.float64,
64+
complex: torch.complex128,
65+
}[py_type]
66+
67+
5668
categories = [
5769
(torch.bool,),
5870
(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
@@ -69,3 +81,37 @@ def category(dtyp):
6981

7082

7183
dtype_for_cat = {0: torch.bool, 1: torch.int64, 2: torch.float64, 3: torch.complex128}
84+
85+
86+
def nep50_to_tensors(x1, x2):
87+
x1_type, x2_type = type(x1), type(x2)
88+
x1_is_weak = x1_type in SCALAR_TYPES
89+
x2_is_weak = x2_type in SCALAR_TYPES
90+
if x1_is_weak and x2_is_weak:
91+
# two scalars: promote
92+
x1 = torch.as_tensor(x1, dtype=_dtype_for_scalar(x1_type))
93+
x2 = torch.as_tensor(x2, dtype=_dtype_for_scalar(x2_type))
94+
return x1, x2
95+
elif not (x1_is_weak or x2_is_weak):
96+
# two tensors: nothing to do here
97+
return x1, x2
98+
else:
99+
# scalar <op> scalar: NEP 50
100+
weak, not_weak = (x1, x2) if x1_is_weak else (x2, x1)
101+
102+
# find the dtype for the weak's type
103+
weak_dtype = _dtype_for_scalar(type(weak))
104+
105+
cat_weak = category(weak_dtype)
106+
cat_not_weak = category(not_weak.dtype)
107+
108+
dt = not_weak.dtype if cat_weak <= cat_not_weak else dtype_for_cat[cat_weak]
109+
110+
# special-case complex + float32
111+
if weak_dtype.is_complex and not_weak.dtype == torch.float32:
112+
dt = torch.complex64
113+
114+
# finally, can cast make `weak` into a 0D tensor
115+
weak = torch.as_tensor(weak, dtype=dt)
116+
117+
return (weak, not_weak) if x1_is_weak else (not_weak, weak)

torch_np/_ndarray.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -453,17 +453,7 @@ def __dlpack_device__(self):
453453
# The rest goes through asarray (preferred) or array.
454454

455455

456-
def _array(
457-
obj,
458-
dtype=None,
459-
*,
460-
copy=True,
461-
order="K",
462-
subok=False,
463-
ndmin=0,
464-
like=None,
465-
is_weak=False,
466-
):
456+
def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None):
467457
_util.subok_not_ok(like, subok)
468458
if order != "K":
469459
raise NotImplementedError
@@ -496,35 +486,12 @@ def _array(
496486
if dtype is not None:
497487
torch_dtype = _dtypes.dtype(dtype).torch_dtype
498488

499-
tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin, is_weak)
489+
tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin)
500490
return ndarray(tensor)
501491

502492

503-
def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None):
504-
# The result of the public `np.array(obj)` is not weakly typed.
505-
return _array(
506-
obj,
507-
dtype,
508-
copy=copy,
509-
order=order,
510-
subok=subok,
511-
ndmin=ndmin,
512-
like=like,
513-
is_weak=False,
514-
)
515-
516-
517-
def _asarray(a, dtype=None, order="K", *, like=None, is_weak=False):
518-
return _array(
519-
a, dtype=dtype, order=order, like=like, copy=False, ndmin=0, is_weak=is_weak
520-
)
521-
522-
523493
def asarray(a, dtype=None, order="K", *, like=None):
524-
# The result of the public `np.asarray(obj)` is not weakly typed.
525-
return _array(
526-
a, dtype=dtype, order=order, like=like, copy=False, ndmin=0, is_weak=False
527-
)
494+
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
528495

529496

530497
def from_dlpack(x, /):

torch_np/_normalizations.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99

1010
import torch
1111

12-
from . import _dtypes, _util
12+
from . import _dtypes, _dtypes_impl, _util
1313

1414
ArrayLike = typing.TypeVar("ArrayLike")
15+
Scalar = typing.Union[int, float, complex, bool]
16+
ArrayLikeOrScalar = typing.Union[ArrayLike, Scalar]
17+
1518
DTypeLike = typing.TypeVar("DTypeLike")
1619
AxisLike = typing.TypeVar("AxisLike")
1720
NDArray = typing.TypeVar("NDarray")
@@ -38,15 +41,15 @@
3841

3942

4043
def normalize_array_like(x, parm=None):
41-
from ._ndarray import _asarray
44+
from ._ndarray import asarray
45+
46+
return asarray(x).tensor
4247

43-
# special case: python scalars are weakly typed
44-
is_py_scalar = type(x) in (int, bool, float, complex)
45-
if is_py_scalar:
46-
dtype = _util._dtype_for_scalar(type(x))
47-
return _asarray(x, dtype=dtype, is_weak=True).tensor
4848

49-
return _asarray(x).tensor
49+
def normalize_array_like_or_scalar(x, parm=None):
50+
if type(x) in _dtypes_impl.SCALAR_TYPES:
51+
return x
52+
return normalize_array_like(x, parm)
5053

5154

5255
def normalize_optional_array_like(x, parm=None):
@@ -115,6 +118,7 @@ def normalize_casting(arg, parm=None):
115118

116119
normalizers = {
117120
"ArrayLike": normalize_array_like,
121+
"ArrayLike | Scalar": normalize_array_like_or_scalar,
118122
"Optional[ArrayLike]": normalize_optional_array_like,
119123
"Sequence[ArrayLike]": normalize_seq_array_like,
120124
"Optional[NDArray]": normalize_ndarray,

torch_np/_ufuncs.py

Lines changed: 22 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -11,48 +11,11 @@
1111
DTypeLike,
1212
NotImplementedType,
1313
OutArray,
14+
Scalar,
1415
normalizer,
1516
)
1617

1718

18-
def _ufunc_preprocess(
19-
tensors, where, casting, order, dtype, subok, signature, extobj, scalars=False
20-
):
21-
22-
if scalars:
23-
# if one of the original inputs is a weak scalar, activate the NEP 50 dance
24-
# XXX: this is only needed for binops
25-
x1, x2 = tensors
26-
x1_is_weak = getattr(x1, "is_weakly_typed", False)
27-
x2_is_weak = getattr(x2, "is_weakly_typed", False)
28-
if x1_is_weak != x2_is_weak:
29-
# scalar <op> array: NEP50; nothing to do otherwise
30-
weak, non_weak = (x1, x2) if x1_is_weak else (x2, x1)
31-
32-
cat_weak = _dtypes_impl.category(weak.dtype)
33-
cat_non_weak = _dtypes_impl.category(non_weak.dtype)
34-
35-
dt_weak = (
36-
non_weak.dtype
37-
if cat_weak <= cat_non_weak
38-
else _dtypes_impl.dtype_for_cat[cat_weak]
39-
)
40-
41-
# special-case complex + float32
42-
if weak.dtype.is_complex and non_weak.dtype == torch.float32:
43-
dt_weak = torch.complex64
44-
45-
weak = _util.cast_if_needed(weak, dt_weak)
46-
tensors = (weak, non_weak) if x1_is_weak else (non_weak, weak)
47-
48-
if dtype is None:
49-
dtype = _dtypes_impl.result_type_impl(*tensors)
50-
51-
tensors = _util.typecast_tensors(tensors, dtype, casting)
52-
53-
return tensors
54-
55-
5619
def _ufunc_postprocess(result, out, casting):
5720
if out is not None:
5821
result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting)
@@ -78,8 +41,8 @@ def deco_binary_ufunc(torch_func):
7841

7942
@normalizer
8043
def wrapped(
81-
x1: ArrayLike,
82-
x2: ArrayLike,
44+
x1: ArrayLike | Scalar,
45+
x2: ArrayLike | Scalar,
8346
/,
8447
out: Optional[OutArray] = None,
8548
*,
@@ -91,21 +54,16 @@ def wrapped(
9154
signature=None,
9255
extobj=None,
9356
):
94-
tensors = _ufunc_preprocess(
95-
(x1, x2),
96-
where,
97-
casting,
98-
order,
99-
dtype,
100-
subok,
101-
signature,
102-
extobj,
103-
scalars=True,
104-
)
105-
result = torch_func(*tensors)
10657

107-
result = _ufunc_postprocess(result, out, casting)
58+
x1, x2 = _dtypes_impl.nep50_to_tensors(x1, x2)
59+
60+
if dtype is None:
61+
dtype = _dtypes_impl.result_type_impl(x1, x2)
62+
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
10863

64+
result = torch_func(x1, x2)
65+
66+
result = _ufunc_postprocess(result, out, casting)
10967
return result
11068

11169
wrapped.__qualname__ = torch_func.__name__
@@ -118,6 +76,7 @@ def wrapped(
11876
# matmul's signature is _slightly_ different from other ufuncs:
11977
# - no where=...
12078
# - additional axis=..., axes=...
79+
# - no NEP50 scalars in or out
12180
#
12281
@normalizer
12382
def matmul(
@@ -135,10 +94,12 @@ def matmul(
13594
axes: NotImplementedType = None,
13695
axis: NotImplementedType = None,
13796
):
138-
tensors = _ufunc_preprocess(
139-
(x1, x2), True, casting, order, dtype, subok, signature, extobj
140-
)
141-
result = _binary_ufuncs_impl.matmul(*tensors)
97+
98+
if dtype is None:
99+
dtype = _dtypes_impl.result_type_impl(x1, x2)
100+
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
101+
102+
result = _binary_ufuncs_impl.matmul(x1, x2)
142103

143104
result = _ufunc_postprocess(result, out, casting)
144105
return result
@@ -178,11 +139,11 @@ def divmod(
178139
else:
179140
out1, out2 = out
180141

181-
tensors = _ufunc_preprocess(
182-
(x1, x2), True, casting, order, dtype, subok, signature, extobj
183-
)
142+
if dtype is None:
143+
dtype = _dtypes_impl.result_type_impl(x1, x2)
144+
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
184145

185-
quot, rem = _binary_ufuncs_impl.divmod(*tensors)
146+
quot, rem = _binary_ufuncs_impl.divmod(x1, x2)
186147

187148
quot = _ufunc_postprocess(quot, out1, casting)
188149
rem = _ufunc_postprocess(rem, out2, casting)

torch_np/_util.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,7 @@ def typecast_tensors(tensors, target_dtype, casting):
171171
return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors)
172172

173173

174-
def _dtype_for_scalar(py_type):
175-
return {bool: torch.bool, int: torch.int64, float: torch.float64, complex: torch.complex128}[py_type]
176-
177-
178-
def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0, is_weak=False):
174+
def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
179175
"""The core logic of the array(...) function.
180176
181177
Parameters
@@ -205,22 +201,17 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0, is_weak=False):
205201
if isinstance(obj, torch.Tensor):
206202
tensor = obj
207203
else:
208-
if is_weak:
209-
# obj is a python scalar
210-
dtype = dtype or _dtype_for_scalar(obj_type)
211-
tensor = torch.as_tensor(obj, dtype=dtype)
212-
else:
213-
tensor = torch.as_tensor(obj)
204+
tensor = torch.as_tensor(obj)
214205

215-
# tensor.dtype is the pytorch default, typically float32. If obj's elements
216-
# are not exactly representable in float32, we've lost precision:
217-
# >>> torch.as_tensor(1e12).item() - 1e12
218-
# -4096.0
206+
# tensor.dtype is the pytorch default, typically float32. If obj's elements
207+
# are not exactly representable in float32, we've lost precision:
208+
# >>> torch.as_tensor(1e12).item() - 1e12
209+
# -4096.0
219210

220-
# Therefore, we treat `tensor.dtype` as a hint, and convert the
221-
# original object *again*, this time with an explicit dtype.
222-
torch_dtype = _dtypes_impl.get_default_dtype_for(tensor.dtype)
223-
tensor = torch.as_tensor(obj, dtype=torch_dtype)
211+
# Therefore, we treat `tensor.dtype` as a hint, and convert the
212+
# original object *again*, this time with an explicit dtype.
213+
torch_dtype = _dtypes_impl.get_default_dtype_for(tensor.dtype)
214+
tensor = torch.as_tensor(obj, dtype=torch_dtype)
224215

225216
# type cast if requested
226217
tensor = cast_if_needed(tensor, dtype)
@@ -234,6 +225,4 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0, is_weak=False):
234225
if copy:
235226
tensor = tensor.clone()
236227

237-
# Attach the flag *to the tensor* (will be used after normalizations)
238-
tensor.is_weakly_typed = is_weak
239228
return tensor

torch_np/tests/test_nep50_examples.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,3 @@ def test_nep50_exceptions(example):
9191

9292
assert_allclose(result, new, atol=1e-16)
9393
assert result.dtype == new.dtype
94-
95-
96-
class TestScalarsWeakTyping:
97-
def test_asarray_scalars(self):
98-
assert tnp.asarray(3).tensor.is_weakly_typed is False
99-
100-
def test_asarray_asarray_scalars(self):
101-
a = tnp.asarray(3)
102-
assert tnp.asarray(a).tensor.is_weakly_typed is False
103-
104-
def test_scalar_scalar(self):
105-
a = tnp.uint8(3)
106-
is_weakly_typed = getattr(a.tensor, "is_weakly_typed", False)
107-
assert is_weakly_typed is False

0 commit comments

Comments
 (0)