Skip to content

Commit 773d772

Browse files
authored
Merge pull request #107 from Quansight-Labs/dtypes_in_ufuncs
BUG: fix dtype handling in ufuncs
2 parents 4aa60de + 99bdf8c commit 773d772

File tree

4 files changed

+48
-63
lines changed

4 files changed

+48
-63
lines changed

torch_np/_helpers.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,6 @@
33
from ._detail import _dtypes_impl, _util
44

55

6-
def ufunc_preprocess(
7-
tensors, out, where, casting, order, dtype, subok, signature, extobj
8-
):
9-
"""
10-
Notes
11-
-----
12-
The `out` array broadcasts `tensors`, but not vice versa.
13-
"""
14-
# internal preprocessing or args in ufuncs (cf _unary_ufuncs, _binary_ufuncs)
15-
if order != "K" or not where or signature or extobj:
16-
raise NotImplementedError
17-
18-
# dtype of the result: depends on both dtype=... and out=... arguments
19-
if dtype is None:
20-
out_dtype = None if out is None else out.dtype.torch_dtype
21-
else:
22-
out_dtype = (
23-
dtype
24-
if out is None
25-
else _dtypes_impl.result_type_impl([dtype, out.dtype.torch_dtype])
26-
)
27-
28-
if out_dtype:
29-
tensors = _util.typecast_tensors(tensors, out_dtype, casting)
30-
return tensors
31-
32-
336
def ndarrays_to_tensors(*inputs):
347
"""Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
358
from ._ndarray import asarray, ndarray

torch_np/_ufuncs.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,29 @@
33
import torch
44

55
from . import _binary_ufuncs_impl, _helpers, _unary_ufuncs_impl
6+
from ._detail import _dtypes_impl, _util
67
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
78

9+
10+
def _ufunc_preprocess(tensors, where, casting, order, dtype, subok, signature, extobj):
11+
if order != "K" or not where or signature or extobj:
12+
raise NotImplementedError
13+
14+
if dtype is None:
15+
dtype = _dtypes_impl.result_type_impl([t.dtype for t in tensors])
16+
17+
tensors = _util.typecast_tensors(tensors, dtype, casting)
18+
19+
return tensors
20+
21+
22+
def _ufunc_postprocess(result, out, casting):
23+
if out is not None:
24+
(result,) = _util.typecast_tensors((result,), out.dtype.torch_dtype, casting)
25+
result = torch.broadcast_to(result, out.shape)
26+
return result
27+
28+
829
# ############# Binary ufuncs ######################
930

1031
_binary = [
@@ -35,16 +56,12 @@ def wrapped(
3556
signature=None,
3657
extobj=None,
3758
):
38-
tensors = _helpers.ufunc_preprocess(
39-
(x1, x2), out, where, casting, order, dtype, subok, signature, extobj
59+
tensors = _ufunc_preprocess(
60+
(x1, x2), where, casting, order, dtype, subok, signature, extobj
4061
)
41-
# now broadcast input tensors against the out=... array
42-
if out is not None:
43-
# XXX: need to filter out noop broadcasts if t.shape == out.shape?
44-
shape = out.shape
45-
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)
46-
4762
result = torch_func(*tensors)
63+
64+
result = _ufunc_postprocess(result, out, casting)
4865
return result
4966

5067
wrapped.__qualname__ = torch_func.__name__
@@ -54,8 +71,9 @@ def wrapped(
5471

5572

5673
#
57-
# matmul is special in that its `out=...` array does not broadcast x1 and x2.
58-
# E.g. consider x1.shape = (5, 2) and x2.shape = (2, 3). Then `out.shape` is (5, 3).
74+
# matmul's signature is _slightly_ different from other ufuncs:
75+
# - no where=...
76+
# - additional axis=..., axes=...
5977
#
6078
@normalizer
6179
def matmul(
@@ -73,17 +91,21 @@ def matmul(
7391
axes=None,
7492
axis=None,
7593
):
76-
tensors = _helpers.ufunc_preprocess(
77-
(x1, x2), out, True, casting, order, dtype, subok, signature, extobj
94+
tensors = _ufunc_preprocess(
95+
(x1, x2), True, casting, order, dtype, subok, signature, extobj
7896
)
7997
if axis is not None or axes is not None:
8098
raise NotImplementedError
8199

82-
# NB: do not broadcast input tensors against the out=... array
83100
result = _binary_ufuncs_impl.matmul(*tensors)
101+
102+
result = _ufunc_postprocess(result, out, casting)
84103
return result
85104

86105

106+
#
107+
# nin=2, nout=2
108+
#
87109
def divmod(
88110
x1: ArrayLike,
89111
x2: ArrayLike,
@@ -110,12 +132,14 @@ def divmod(
110132
if out1.shape != out2.shape or out1.dtype != out2.dtype:
111133
raise ValueError("out1, out2 must be compatible")
112134

113-
tensors = _helpers.ufunc_preprocess(
114-
(x1, x2), out, True, casting, order, dtype, subok, signature, extobj
135+
tensors = _ufunc_preprocess(
136+
(x1, x2), True, casting, order, dtype, subok, signature, extobj
115137
)
116138

117-
result = _binary_ufuncs_impl.divmod(*tensors)
139+
quot, rem = _binary_ufuncs_impl.divmod(*tensors)
118140

141+
quot = _ufunc_postprocess(quot, out1, casting)
142+
rem = _ufunc_postprocess(rem, out2, casting)
119143
return quot, rem
120144

121145

@@ -167,15 +191,11 @@ def wrapped(
167191
signature=None,
168192
extobj=None,
169193
):
170-
tensors = _helpers.ufunc_preprocess(
171-
(x,), out, where, casting, order, dtype, subok, signature, extobj
194+
tensors = _ufunc_preprocess(
195+
(x,), where, casting, order, dtype, subok, signature, extobj
172196
)
173-
# now broadcast the input tensor against the out=... array
174-
if out is not None:
175-
# XXX: need to filter out noop broadcasts if t.shape == out.shape?
176-
shape = out.shape
177-
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)
178197
result = torch_func(*tensors)
198+
result = _ufunc_postprocess(result, out, casting)
179199
return result
180200

181201
wrapped.__qualname__ = torch_func.__name__

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2907,7 +2907,7 @@ def test_inplace(self):
29072907
b = np.array([3])
29082908
c = (a * a) / b
29092909

2910-
assert_almost_equal(c, 25 / 3)
2910+
assert_almost_equal(c, 25 / 3, decimal=5)
29112911
assert_equal(a, 5)
29122912
assert_equal(b, 3)
29132913

@@ -5577,7 +5577,7 @@ def test_empty_out(self):
55775577
out = np.ones((1, 1, 1))
55785578
assert self.matmul(arr, arr).shape == (0, 1, 1)
55795579

5580-
with pytest.raises(ValueError, match="Bad size of the out array"): # match=r"non-broadcastable"):
5580+
with pytest.raises((RuntimeError, ValueError)):
55815581
self.matmul(arr, arr, out=out)
55825582

55835583
def test_out_contiguous(self):

torch_np/tests/test_ufuncs_basic.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,6 @@ def test_x_and_out_broadcast(self, ufunc):
105105
(np.add, operator.__add__, operator.__iadd__),
106106
(np.subtract, operator.__sub__, operator.__isub__),
107107
(np.multiply, operator.__mul__, operator.__imul__),
108-
(np.divide, operator.__truediv__, operator.__itruediv__),
109-
(np.floor_divide, operator.__floordiv__, operator.__ifloordiv__),
110-
(np.float_power, operator.__pow__, operator.__ipow__),
111-
## (np.remainder, operator.__mod__, operator.__imod__), # does not handle complex
112-
# remainder vs fmod?
113-
# pow vs power vs float_power
114108
]
115109

116110
ufuncs_with_dunders = [ufunc for ufunc, _, _ in ufunc_op_iop_numeric]
@@ -409,13 +403,11 @@ def test_binary_ufunc_dtype_and_out(self):
409403
assert (r32 == [1, 2]).all()
410404
assert r32.dtype == np.float32
411405

412-
# NB: this test differs from numpy: in numpy, r.dtype is float64
413-
# but the precision is lost, r == [1, 2].
414-
# I *guess* numpy casts inputs to the dtype=... value, performs calculations,
415-
# and then casts the result back to out.dtype.
406+
# dtype is float32, so computation is in float32: precision loss
407+
# the result is then cast to float64
416408
out64 = np.empty(2, dtype=np.float64)
417409
r = np.add([1.0, 2.0], 1.0e-15, dtype=np.float32, out=out64)
418-
assert (r != [1, 2]).all()
410+
assert (r == [1, 2]).all()
419411
assert r.dtype == np.float64
420412

421413
# Internal computations are in float64, but the final cast to out.dtype

0 commit comments

Comments
 (0)