Skip to content

Commit ec2a7be

Browse files
committed
MAINT: divmod + framework for out=(array, array)
1 parent eae5033 commit ec2a7be

File tree

7 files changed

+37
-89
lines changed

7 files changed

+37
-89
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 8 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44

55
from . import _helpers
66
from ._detail import _binary_ufuncs
7-
from ._normalizations import ArrayLike, DTypeLike, NDArray, OutArray, SubokLike, normalizer
7+
from ._normalizations import (
8+
ArrayLike,
9+
DTypeLike,
10+
NDArray,
11+
OutArray,
12+
SubokLike,
13+
normalizer,
14+
)
815

916
__all__ = [
1017
name
@@ -93,53 +100,6 @@ def matmul(
93100
vars()[name] = decorated
94101

95102

96-
# a stub implementation of divmod, should be improved after
97-
# https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch
98-
#
99-
# Implementation details: we just call two ufuncs which have been created
100-
# just above, for x1 // x2 and x1 % x2.
101-
# This means we are normalizing x1, x2 in each of the ufuncs --- note that there
102-
# is no @normalizer on divmod.
103-
104-
105-
def divmod(
106-
x1,
107-
x2,
108-
/,
109-
out=None,
110-
*,
111-
where=True,
112-
casting="same_kind",
113-
order="K",
114-
dtype=None,
115-
subok: SubokLike = False,
116-
signature=None,
117-
extobj=None,
118-
):
119-
out1, out2 = None, None
120-
if out is not None:
121-
out1, out2 = out
122-
123-
kwds = dict(
124-
where=where,
125-
casting=casting,
126-
order=order,
127-
dtype=dtype,
128-
subok=subok,
129-
signature=signature,
130-
extobj=extobj,
131-
)
132-
133-
# NB: use local names for
134-
quot = floor_divide(x1, x2, out=out1, **kwds)
135-
rem = remainder(x1, x2, out=out2, **kwds)
136-
137-
quot = _helpers.result_or_out(quot.tensor, out1)
138-
rem = _helpers.result_or_out(rem.tensor, out2)
139-
140-
return quot, rem
141-
142-
143103
def modf(x, /, *args, **kwds):
144104
quot, rem = divmod(x, 1, *args, **kwds)
145105
return rem, quot

torch_np/_detail/_binary_ufuncs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,9 @@ def matmul(x, y):
7070
result = result.to(dtype)
7171

7272
return result
73+
74+
75+
# a stub implementation of divmod, should be improved after
76+
# https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch
77+
def divmod(x, y):
78+
return x // y, x % y

torch_np/_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def cumprod(
449449
cumproduct = cumprod
450450

451451

452-
@normalizer
452+
@normalizer(promote_scalar_result=True)
453453
def quantile(
454454
a: ArrayLike,
455455
q: ArrayLike,

torch_np/_helpers.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,39 +30,6 @@ def ufunc_preprocess(
3030
return tensors
3131

3232

33-
# ### Return helpers: wrap a single tensor, a tuple of tensors, out= etc ###
34-
35-
36-
def result_or_out(result_tensor, out_array=None, promote_scalar=False):
37-
"""A helper for returns with out= argument.
38-
39-
If `promote_scalar is True`, then:
40-
if result_tensor.numel() == 1 and out is zero-dimensional,
41-
result_tensor is placed into the out array.
42-
This weirdness is used e.g. in `np.percentile`
43-
"""
44-
if out_array is not None:
45-
if result_tensor.shape != out_array.shape:
46-
can_fit = result_tensor.numel() == 1 and out_array.ndim == 0
47-
if promote_scalar and can_fit:
48-
result_tensor = result_tensor.squeeze()
49-
else:
50-
raise ValueError(
51-
f"Bad size of the out array: out.shape = {out_array.shape}"
52-
f" while result.shape = {result_tensor.shape}."
53-
)
54-
out_tensor = out_array.tensor
55-
out_tensor.copy_(result_tensor)
56-
return out_array
57-
else:
58-
from ._ndarray import ndarray
59-
60-
return ndarray(result_tensor)
61-
62-
63-
# ### Various ways of converting array-likes to tensors ###
64-
65-
6633
def ndarrays_to_tensors(*inputs):
6734
"""Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
6835
from ._ndarray import asarray, ndarray

torch_np/_normalizations.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def maybe_normalize(arg, parm, return_on_failure=_sentinel):
9898
raise exc from None
9999

100100

101-
102101
# ### Return value helpers ###
103102

104103

@@ -123,7 +122,6 @@ def copy_to_out(result_tensor, out_array, promote_scalar_result=False):
123122
return out_array
124123

125124

126-
127125
def wrap_tensors(result, out, promote_scalar_result=False):
128126
from ._ndarray import ndarray
129127

@@ -133,9 +131,11 @@ def wrap_tensors(result, out, promote_scalar_result=False):
133131
else:
134132
return copy_to_out(result, out, promote_scalar_result)
135133
elif isinstance(result, (tuple, list)):
136-
return type(result)(
137-
ndarray(x) if isinstance(x, torch.Tensor) else x for x in result
138-
)
134+
seq = type(result)
135+
if out is None:
136+
return seq(wrap_tensors(x, out) for x in result)
137+
else:
138+
return seq(wrap_tensors(x, out=o) for x, o in zip(result, out))
139139

140140
return result
141141

@@ -151,6 +151,7 @@ def array_or_scalar(values, py_type=float, return_scalar=False):
151151

152152
# ### The main decorator to normalize arguments / postprocess the output ###
153153

154+
154155
def normalizer(_func=None, *, return_on_failure=_sentinel, promote_scalar_result=False):
155156
def normalizer_inner(func):
156157
@functools.wraps(func)

torch_np/_unary_ufuncs.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44

55
from . import _helpers
66
from ._detail import _unary_ufuncs
7-
from ._normalizations import ArrayLike, DTypeLike, NDArray, OutArray, SubokLike, normalizer
7+
from ._normalizations import (
8+
ArrayLike,
9+
DTypeLike,
10+
NDArray,
11+
OutArray,
12+
SubokLike,
13+
normalizer,
14+
)
815

916
__all__ = [
1017
name for name in dir(_unary_ufuncs) if not name.startswith("_") and name != "torch"

torch_np/_wrapper.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,14 @@
1212
from . import _dtypes, _funcs, _helpers
1313
from ._detail import _dtypes_impl, _util
1414
from ._ndarray import asarray
15-
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, OutArray, normalizer
15+
from ._normalizations import (
16+
ArrayLike,
17+
DTypeLike,
18+
NDArray,
19+
OutArray,
20+
SubokLike,
21+
normalizer,
22+
)
1623

1724
NoValue = _util.NoValue
1825

0 commit comments

Comments
 (0)