Skip to content

Commit dfc1db5

Browse files
authored
MAINT: rework result_type_impl to accept *tensors, not dtypes (#139)
1 parent 789ec48 commit dfc1db5

9 files changed

+61
-60
lines changed

torch_np/_binary_ufuncs_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def matmul(x, y):
5151
# - RuntimeError: expected scalar type Int but found Double
5252
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Bool'
5353
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
54-
dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype))
54+
dtype = _dtypes_impl.result_type_impl(x, y)
5555
is_bool = dtype == torch.bool
5656
is_half = dtype == torch.float16
5757

torch_np/_dtypes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88

9-
from . import _dtypes_impl
9+
from . import _dtypes_impl, _util
1010

1111
# more __all__ manipulations at the bottom
1212
__all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype"]
@@ -34,7 +34,7 @@ def __new__(self, value):
3434
tensor = value.tensor
3535
else:
3636
try:
37-
tensor = torch.as_tensor(value, dtype=self.torch_dtype)
37+
tensor = _util._coerce_to_tensor(value, dtype=self.torch_dtype)
3838
except RuntimeError as e:
3939
if "Overflow" in str(e):
4040
raise OverflowError(e.args)

torch_np/_dtypes_impl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@ def can_cast_impl(from_torch_dtype, to_torch_dtype, casting):
3939
return _cd._can_cast_dict[casting][from_torch_dtype][to_torch_dtype]
4040

4141

42-
def result_type_impl(dtypes):
42+
def result_type_impl(*tensors):
4343
# NB: torch dtypes here
44-
dtyp = dtypes[0]
45-
if len(dtypes) == 1:
44+
dtyp = tensors[0].dtype
45+
if len(tensors) == 1:
4646
return dtyp
4747

48-
for curr in dtypes[1:]:
49-
dtyp = _cd._result_type_dict[dtyp][curr]
48+
for curr in tensors[1:]:
49+
dtyp = _cd._result_type_dict[dtyp][curr.dtype]
5050

5151
return dtyp

torch_np/_funcs_impl.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
9191
# figure out the type of the inputs and outputs
9292
out_dtype = out.dtype.torch_dtype if dtype is None else dtype
9393
else:
94-
out_dtype = _dtypes_impl.result_type_impl([t.dtype for t in tensors])
94+
out_dtype = _dtypes_impl.result_type_impl(*tensors)
9595

9696
# cast input arrays if necessary; do not broadcast them agains `out`
9797
tensors = _util.typecast_tensors(tensors, out_dtype, casting)
@@ -354,9 +354,11 @@ def arange(
354354
# the dtype of the result
355355
if dtype is None:
356356
dtype = _dtypes_impl.default_dtypes.int_dtype
357-
dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)]
358-
dt_list.append(dtype)
359-
target_dtype = _dtypes_impl.result_type_impl(dt_list)
357+
# XXX: default values do not get normalized
358+
start, stop, step = (_util._coerce_to_tensor(x) for x in (start, stop, step))
359+
360+
dummy = torch.empty(1, dtype=dtype)
361+
target_dtype = _dtypes_impl.result_type_impl(start, stop, step, dummy)
360362

361363
# work around RuntimeError: "arange_cpu" not implemented for 'ComplexFloat'
362364
work_dtype = torch.float64 if target_dtype.is_complex else target_dtype
@@ -571,7 +573,7 @@ def cov(
571573

572574

573575
def _conv_corr_impl(a, v, mode):
574-
dt = _dtypes_impl.result_type_impl((a.dtype, v.dtype))
576+
dt = _dtypes_impl.result_type_impl(a, v)
575577
a = _util.cast_if_needed(a, dt)
576578
v = _util.cast_if_needed(v, dt)
577579

@@ -857,15 +859,14 @@ def nanpercentile():
857859

858860

859861
def isclose(a: ArrayLike, b: ArrayLike, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
860-
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
862+
dtype = _dtypes_impl.result_type_impl(a, b)
861863
a = _util.cast_if_needed(a, dtype)
862864
b = _util.cast_if_needed(b, dtype)
863-
result = torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
864-
return result
865+
return torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
865866

866867

867868
def allclose(a: ArrayLike, b: ArrayLike, rtol=1e-05, atol=1e-08, equal_nan=False):
868-
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
869+
dtype = _dtypes_impl.result_type_impl(a, b)
869870
a = _util.cast_if_needed(a, dtype)
870871
b = _util.cast_if_needed(b, dtype)
871872
return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
@@ -1175,7 +1176,7 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
11751176
if t_b.ndim > 1:
11761177
t_b = t_b.flatten()
11771178

1178-
dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype))
1179+
dtype = _dtypes_impl.result_type_impl(t_a, t_b)
11791180
is_half = dtype == torch.float16
11801181
is_bool = dtype == torch.bool
11811182

@@ -1202,15 +1203,15 @@ def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
12021203
if isinstance(axes, (list, tuple)):
12031204
axes = [[ax] if isinstance(ax, int) else ax for ax in axes]
12041205

1205-
target_dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
1206+
target_dtype = _dtypes_impl.result_type_impl(a, b)
12061207
a = _util.cast_if_needed(a, target_dtype)
12071208
b = _util.cast_if_needed(b, target_dtype)
12081209

12091210
return torch.tensordot(a, b, dims=axes)
12101211

12111212

12121213
def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
1213-
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
1214+
dtype = _dtypes_impl.result_type_impl(a, b)
12141215
a = _util.cast_if_needed(a, dtype)
12151216
b = _util.cast_if_needed(b, dtype)
12161217

@@ -1222,7 +1223,7 @@ def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
12221223

12231224

12241225
def inner(a: ArrayLike, b: ArrayLike, /):
1225-
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
1226+
dtype = _dtypes_impl.result_type_impl(a, b)
12261227
is_half = dtype == torch.float16
12271228
is_bool = dtype == torch.bool
12281229

@@ -1284,11 +1285,7 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=
12841285
subscripts, array_operands = operands[0], operands[1:]
12851286

12861287
tensors = [normalize_array_like(op) for op in array_operands]
1287-
target_dtype = (
1288-
_dtypes_impl.result_type_impl([op.dtype for op in tensors])
1289-
if dtype is None
1290-
else dtype
1291-
)
1288+
target_dtype = _dtypes_impl.result_type_impl(*tensors) if dtype is None else dtype
12921289

12931290
# work around 'bmm' not implemented for 'Half' etc
12941291
is_half = target_dtype == torch.float16

torch_np/_ndarray.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -518,11 +518,14 @@ def can_cast(from_, to, casting="safe"):
518518

519519

520520
def result_type(*arrays_and_dtypes):
521-
dtypes = []
522-
521+
tensors = []
523522
for entry in arrays_and_dtypes:
524-
dty = _extract_dtype(entry)
525-
dtypes.append(dty.torch_dtype)
523+
try:
524+
t = asarray(entry).tensor
525+
except ((RuntimeError, ValueError, TypeError)):
526+
dty = _dtypes.dtype(entry)
527+
t = torch.empty(1, dtype=dty.torch_dtype)
528+
tensors.append(t)
526529

527-
torch_dtype = _dtypes_impl.result_type_impl(dtypes)
530+
torch_dtype = _dtypes_impl.result_type_impl(*tensors)
528531
return _dtypes.dtype(torch_dtype)

torch_np/_reductions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def average(
342342
weights = weights.swapaxes(-1, axis)
343343

344344
# do the work
345-
result_dtype = _dtypes_impl.result_type_impl([a.dtype, weights.dtype])
345+
result_dtype = _dtypes_impl.result_type_impl(a, weights)
346346
numerator = sum(a * weights, axis, dtype=result_dtype)
347347
wsum = sum(weights, axis, dtype=result_dtype)
348348
result = numerator / wsum

torch_np/_ufuncs.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
def _ufunc_preprocess(tensors, where, casting, order, dtype, subok, signature, extobj):
1919
if dtype is None:
20-
dtype = _dtypes_impl.result_type_impl([t.dtype for t in tensors])
20+
dtype = _dtypes_impl.result_type_impl(*tensors)
2121

2222
tensors = _util.typecast_tensors(tensors, dtype, casting)
2323

@@ -26,7 +26,7 @@ def _ufunc_preprocess(tensors, where, casting, order, dtype, subok, signature, e
2626

2727
def _ufunc_postprocess(result, out, casting):
2828
if out is not None:
29-
(result,) = _util.typecast_tensors((result,), out.dtype.torch_dtype, casting)
29+
result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting)
3030
result = torch.broadcast_to(result, out.shape)
3131
return result
3232

@@ -198,10 +198,9 @@ def wrapped(
198198
signature=None,
199199
extobj=None,
200200
):
201-
tensors = _ufunc_preprocess(
202-
(x,), where, casting, order, dtype, subok, signature, extobj
203-
)
204-
result = torch_func(*tensors)
201+
if dtype is not None:
202+
x = _util.typecast_tensor(x, dtype, casting)
203+
result = torch_func(x)
205204
result = _ufunc_postprocess(result, out, casting)
206205
return result
207206

torch_np/_util.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -135,37 +135,40 @@ def axis_none_flatten(*tensors, axis=None):
135135
return tensors, axis
136136

137137

138-
def typecast_tensors(tensors, target_dtype, casting):
139-
"""Dtype-cast tensors to target_dtype.
138+
def typecast_tensor(t, target_dtype, casting):
139+
"""Dtype-cast tensor to target_dtype.
140140
141141
Parameters
142142
----------
143-
tensors : iterable
144-
tuple or list of torch.Tensors to typecast
145-
target_dtype : torch dtype object, optional
143+
t : torch.Tensor
144+
The tensor to cast
145+
target_dtype : torch dtype object
146146
The array dtype to cast all tensors to
147147
casting : str
148148
The casting mode, see `np.can_cast`
149149
150-
Returns
151-
-------
152-
a tuple of torch.Tensors with dtype being the PyTorch counterpart
153-
of the `target_dtype`
150+
Returns
151+
-------
152+
`torch.Tensor` of the `target_dtype` dtype
153+
154+
Raises
155+
------
156+
ValueError
157+
if the argument cannot be cast according to the `casting` rule
158+
154159
"""
155-
# check if we can dtype-cast all arguments
156-
cast_tensors = []
157160
can_cast = _dtypes_impl.can_cast_impl
158161

159-
for tensor in tensors:
160-
if not can_cast(tensor.dtype, target_dtype, casting=casting):
161-
raise TypeError(
162-
f"Cannot cast array data from {tensor.dtype} to"
163-
f" {target_dtype} according to the rule '{casting}'"
164-
)
165-
tensor = cast_if_needed(tensor, target_dtype)
166-
cast_tensors.append(tensor)
162+
if not can_cast(t.dtype, target_dtype, casting=casting):
163+
raise TypeError(
164+
f"Cannot cast array data from {t.dtype} to"
165+
f" {target_dtype} according to the rule '{casting}'"
166+
)
167+
return cast_if_needed(t, target_dtype)
168+
167169

168-
return tuple(cast_tensors)
170+
def typecast_tensors(tensors, target_dtype, casting):
171+
return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors)
169172

170173

171174
def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
@@ -193,7 +196,6 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
193196
"""
194197
if isinstance(obj, torch.Tensor):
195198
tensor = obj
196-
base = None
197199
else:
198200
tensor = torch.as_tensor(obj)
199201
base = None

torch_np/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _atleast_float_1(a):
2121

2222

2323
def _atleast_float_2(a, b):
24-
dtyp = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
24+
dtyp = _dtypes_impl.result_type_impl(a, b)
2525
if not (dtyp.is_floating_point or dtyp.is_complex):
2626
dtyp = _dtypes_impl.default_dtypes.float_dtype
2727

0 commit comments

Comments
 (0)