Skip to content

Commit 0bc1374

Browse files
committed
ENH: dtype=... ufunc argument
The implementation is a bit simpler than numpy: we do not have the notion of ufunc loop types (np.add.types etc), so we just cast input tensors to the `result_type(dtype, out.dtype)`, and ask pytorch to do computations in that dtype.
1 parent e532a07 commit 0bc1374

File tree

4 files changed

+78
-90
lines changed

4 files changed

+78
-90
lines changed

torch_np/_detail/_util.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def axis_none_ravel(*tensors, axis=None):
137137
return tensors, axis
138138

139139

140-
def cast_dont_broadcast(tensors, target_dtype, casting):
140+
def typecast_tensors(tensors, target_dtype, casting):
141141
"""Dtype-cast tensors to target_dtype.
142142
143143
Parameters
@@ -170,52 +170,6 @@ def cast_dont_broadcast(tensors, target_dtype, casting):
170170
return tuple(cast_tensors)
171171

172172

173-
def cast_and_broadcast(tensors, out_param, casting):
174-
"""
175-
Parameters
176-
----------
177-
tensors : iterable
178-
tuple or list of torch.Tensors to broadcast/typecast
179-
target_dtype : a torch.dtype object
180-
The torch dtype to cast all tensors to
181-
target_shape : tuple
182-
The tensor shape to broadcast all `tensors` to
183-
casting : str
184-
The casting mode, see `np.can_cast`
185-
186-
Returns
187-
-------
188-
a tuple of torch.Tensors with dtype being the PyTorch counterpart
189-
of the `target_dtype` and `target_shape`
190-
"""
191-
if out_param is None:
192-
return tensors
193-
194-
target_dtype, target_shape = out_param
195-
196-
can_cast = _dtypes_impl.can_cast_impl
197-
198-
processed_tensors = []
199-
for tensor in tensors:
200-
# check dtypes of x and out
201-
if not can_cast(tensor.dtype, target_dtype, casting=casting):
202-
raise TypeError(
203-
f"Cannot cast array data from {tensor.dtype} to"
204-
f" {target_dtype} according to the rule '{casting}'"
205-
)
206-
207-
# cast arr if needed
208-
tensor = cast_if_needed(tensor, target_dtype)
209-
210-
# `out` broadcasts `tensor`
211-
if tensor.shape != target_shape:
212-
tensor = torch.broadcast_to(tensor, target_shape)
213-
214-
processed_tensors.append(tensor)
215-
216-
return tuple(processed_tensors)
217-
218-
219173
def axis_expand_func(func, tensor, axis, *args, **kwds):
220174
"""Generically handle axis arguments in reductions."""
221175
if axis is not None:

torch_np/_detail/implementations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
308308
out_dtype = _dtypes_impl.result_type_impl([t.dtype for t in tensors])
309309

310310
# cast input arrays if necessary; do not broadcast them agains `out`
311-
tensors = _util.cast_dont_broadcast(tensors, out_dtype, casting)
311+
tensors = _util.typecast_tensors(tensors, out_dtype, casting)
312312

313313
return tensors
314314

@@ -497,7 +497,7 @@ def bincount(x, /, weights=None, minlength=0):
497497
x = x.new_empty(0, dtype=int)
498498

499499
int_dtype = _dtypes_impl.default_int_dtype
500-
(x,) = _util.cast_dont_broadcast((x,), int_dtype, casting="safe")
500+
(x,) = _util.typecast_tensors((x,), int_dtype, casting="safe")
501501

502502
result = torch.bincount(x, weights, minlength)
503503
return result

torch_np/_helpers.py

Lines changed: 22 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,38 @@
11
import torch
22

3-
from . import _dtypes
4-
from ._detail import _util
5-
6-
7-
def cast_and_broadcast(tensors, out, casting):
8-
"""Cast dtypes of arrays to out.dtype and broadcast if needed.
9-
10-
Parameters
11-
----------
12-
arrays : sequence of arrays
13-
Each element is broadcast against `out` and typecast to out.dtype
14-
out : the "output" array
15-
Not modified.
16-
casting : str
17-
One of numpy casting modes
18-
19-
Returns
20-
-------
21-
tensors : tuple of Tensors
22-
Each tensor is dtype-cast and broadcast agains `out`, as needed
23-
24-
Notes
25-
-----
26-
The `out` arrays broadcasts and dtype-casts `arrays`, but not vice versa.
27-
28-
"""
29-
if out is None:
30-
return tensors
31-
else:
32-
tensors = _util.cast_and_broadcast(
33-
tensors, out.dtype.type.torch_dtype, out.shape, casting
34-
)
35-
36-
return tuple(tensors)
3+
from ._detail import _dtypes_impl, _util
374

385

396
def ufunc_preprocess(
407
tensors, out, where, casting, order, dtype, subok, signature, extobj
418
):
9+
"""
10+
Notes
11+
-----
12+
The `out` array broadcasts `tensors`, but not vice versa.
13+
"""
4214
# internal preprocessing or args in ufuncs (cf _unary_ufuncs, _binary_ufuncs)
4315
if order != "K" or not where or signature or extobj:
4416
raise NotImplementedError
4517

46-
# XXX: dtype=... parameter
47-
if dtype is not None:
48-
raise NotImplementedError
18+
# dtype of the result: depends on both dtype=... and out=... arguments
19+
if dtype is None:
20+
out_dtype = None if out is None else out.dtype.torch_dtype
21+
else:
22+
out_dtype = (
23+
dtype
24+
if out is None
25+
else _dtypes_impl.result_type_impl([dtype, out.dtype.torch_dtype])
26+
)
4927

50-
out_shape_dtype = None
51-
if out is not None:
52-
out_shape_dtype = (out.get().dtype, out.get().shape)
28+
if out_dtype:
29+
tensors = _util.typecast_tensors(tensors, out_dtype, casting)
5330

54-
tensors = _util.cast_and_broadcast(tensors, out_shape_dtype, casting)
31+
# now broadcast input tensors against the out=... array
32+
if out is not None:
33+
# XXX: need to filter out noop broadcasts if t.shape == out.shape?
34+
shape = out.shape
35+
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)
5536

5637
return tensors
5738

torch_np/tests/test_ufuncs_basic.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,3 +371,56 @@ def test_other_array_bcast(self, ufunc, op, iop):
371371
if result_op.dtype != result_ufunc.dtype:
372372
pytest.xfail(reason="prob need weak type promotion (scalars)")
373373
assert result_op.dtype == result_ufunc.dtype
374+
375+
376+
class TestUfuncDtypeKwd:
377+
def test_binary_ufunc_dtype(self):
378+
379+
# default computation uses float64:
380+
r64 = np.add(1, 1e-15)
381+
assert r64.dtype == "float64"
382+
assert r64 - 1 > 0
383+
384+
# force the float32 dtype: loss of precision
385+
r32 = np.add(1, 1e-15, dtype="float32")
386+
assert r32.dtype == "float32"
387+
assert r32 == 1
388+
389+
# casting of floating inputs to booleans
390+
with assert_raises(TypeError):
391+
np.add(1.0, 1e-15, dtype=bool)
392+
393+
# now force the cast
394+
rb = np.add(1.0, 1e-15, dtype=bool, casting="unsafe")
395+
assert rb.dtype == bool
396+
397+
def test_binary_ufunc_dtype_and_out(self):
398+
399+
# all in float64: no precision loss
400+
out64 = np.empty(2, dtype=np.float64)
401+
r64 = np.add([1.0, 2.0], 1.0e-15, out=out64)
402+
403+
assert (r64 != [1.0, 2.0]).all()
404+
assert r64.dtype == np.float64
405+
406+
# all in float32: loss of precision, result is float32
407+
out32 = np.empty(2, dtype=np.float32)
408+
r32 = np.add([1.0, 2.0], 1.0e-15, dtype=np.float32, out=out32)
409+
assert (r32 == [1, 2]).all()
410+
assert r32.dtype == np.float32
411+
412+
# NB: this test differs from numpy: in numpy, r.dtype is float64
413+
# but the precision is lost, r == [1, 2].
414+
# I *guess* numpy casts inputs to the dtype=... value, performs calculations,
415+
# and then casts the result back to out.dtype.
416+
out64 = np.empty(2, dtype=np.float64)
417+
r = np.add([1.0, 2.0], 1.0e-15, dtype=np.float32, out=out64)
418+
assert (r != [1, 2]).all()
419+
assert r.dtype == np.float64
420+
421+
# Internal computations are in float64, but the final cast to out.dtype
422+
# truncates the precision => precision loss.
423+
out32 = np.empty(2, dtype=np.float32)
424+
r = np.add([1.0, 2.0], 1.0e-15, dtype=np.float64, out=out32)
425+
assert (r == [1, 2]).all()
426+
assert r.dtype == np.float32

0 commit comments

Comments
 (0)