Skip to content

Commit c81f677

Browse files
committed
MAINT: postprocess out= returns in normalizer
1 parent 74523fa commit c81f677

9 files changed

+71
-142
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 2 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def wrapped(
4444
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)
4545

4646
result = torch_func(*tensors)
47-
return _helpers.result_or_out(result, out)
47+
return result
4848

4949
return wrapped
5050

@@ -77,7 +77,7 @@ def matmul(
7777

7878
# NB: do not broadcast input tensors against the out=... array
7979
result = _binary_ufuncs.matmul(*tensors)
80-
return _helpers.result_or_out(result, out)
80+
return result
8181

8282

8383
#
@@ -93,53 +93,6 @@ def matmul(
9393
vars()[name] = decorated
9494

9595

96-
# a stub implementation of divmod, should be improved after
97-
# https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch
98-
#
99-
# Implementation details: we just call two ufuncs which have been created
100-
# just above, for x1 // x2 and x1 % x2.
101-
# This means we are normalizing x1, x2 in each of the ufuncs --- note that there
102-
# is no @normalizer on divmod.
103-
104-
105-
def divmod(
106-
x1,
107-
x2,
108-
/,
109-
out=None,
110-
*,
111-
where=True,
112-
casting="same_kind",
113-
order="K",
114-
dtype=None,
115-
subok: SubokLike = False,
116-
signature=None,
117-
extobj=None,
118-
):
119-
out1, out2 = None, None
120-
if out is not None:
121-
out1, out2 = out
122-
123-
kwds = dict(
124-
where=where,
125-
casting=casting,
126-
order=order,
127-
dtype=dtype,
128-
subok=subok,
129-
signature=signature,
130-
extobj=extobj,
131-
)
132-
133-
# NB: use local names for
134-
quot = floor_divide(x1, x2, out=out1, **kwds)
135-
rem = remainder(x1, x2, out=out2, **kwds)
136-
137-
quot = _helpers.result_or_out(quot.tensor, out1)
138-
rem = _helpers.result_or_out(rem.tensor, out2)
139-
140-
return quot, rem
141-
142-
14396
def modf(x, /, *args, **kwds):
14497
quot, rem = divmod(x, 1, *args, **kwds)
14598
return rem, quot

torch_np/_decorators.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

torch_np/_detail/_binary_ufuncs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,9 @@ def matmul(x, y):
7070
result = result.to(dtype)
7171

7272
return result
73+
74+
75+
# a stub implementation of divmod, should be improved after
76+
# https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch
77+
def divmod(x, y):
78+
return x // y, x % y

torch_np/_funcs.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def clip(
4545
# np.clip requires both a_min and a_max not None, while ndarray.clip allows
4646
# one of them to be None. Follow the more lax version.
4747
result = _impl.clip(a, min, max)
48-
return _helpers.result_or_out(result, out)
48+
return result
4949

5050

5151
@normalizer
@@ -80,7 +80,7 @@ def trace(
8080
out: Optional[NDArray] = None,
8181
):
8282
result = _impl.trace(a, offset, axis1, axis2, dtype)
83-
return _helpers.result_or_out(result, out)
83+
return result
8484

8585

8686
@normalizer
@@ -135,7 +135,7 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
135135
@normalizer
136136
def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
137137
result = _impl.dot(a, b)
138-
return _helpers.result_or_out(result, out)
138+
return result
139139

140140

141141
# ### sort and partition ###
@@ -234,7 +234,7 @@ def imag(a: ArrayLike):
234234
@normalizer
235235
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
236236
result = _impl.round(a, decimals)
237-
return _helpers.result_or_out(result, out)
237+
return result
238238

239239

240240
around = round_
@@ -257,7 +257,7 @@ def sum(
257257
result = _impl.sum(
258258
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
259259
)
260-
return _helpers.result_or_out(result, out)
260+
return result
261261

262262

263263
@normalizer
@@ -273,7 +273,7 @@ def prod(
273273
result = _impl.prod(
274274
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
275275
)
276-
return _helpers.result_or_out(result, out)
276+
return result
277277

278278

279279
product = prod
@@ -290,7 +290,7 @@ def mean(
290290
where=NoValue,
291291
):
292292
result = _impl.mean(a, axis=axis, dtype=dtype, where=NoValue, keepdims=keepdims)
293-
return _helpers.result_or_out(result, out)
293+
return result
294294

295295

296296
@normalizer
@@ -307,7 +307,7 @@ def var(
307307
result = _impl.var(
308308
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
309309
)
310-
return _helpers.result_or_out(result, out)
310+
return result
311311

312312

313313
@normalizer
@@ -324,7 +324,7 @@ def std(
324324
result = _impl.std(
325325
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
326326
)
327-
return _helpers.result_or_out(result, out)
327+
return result
328328

329329

330330
@normalizer
@@ -336,7 +336,7 @@ def argmin(
336336
keepdims=NoValue,
337337
):
338338
result = _impl.argmin(a, axis=axis, keepdims=keepdims)
339-
return _helpers.result_or_out(result, out)
339+
return result
340340

341341

342342
@normalizer
@@ -348,7 +348,7 @@ def argmax(
348348
keepdims=NoValue,
349349
):
350350
result = _impl.argmax(a, axis=axis, keepdims=keepdims)
351-
return _helpers.result_or_out(result, out)
351+
return result
352352

353353

354354
@normalizer
@@ -361,7 +361,7 @@ def amax(
361361
where=NoValue,
362362
):
363363
result = _impl.max(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
364-
return _helpers.result_or_out(result, out)
364+
return result
365365

366366

367367
max = amax
@@ -377,7 +377,7 @@ def amin(
377377
where=NoValue,
378378
):
379379
result = _impl.min(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
380-
return _helpers.result_or_out(result, out)
380+
return result
381381

382382

383383
min = amin
@@ -388,7 +388,7 @@ def ptp(
388388
a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue
389389
):
390390
result = _impl.ptp(a, axis=axis, keepdims=keepdims)
391-
return _helpers.result_or_out(result, out)
391+
return result
392392

393393

394394
@normalizer
@@ -401,7 +401,7 @@ def all(
401401
where=NoValue,
402402
):
403403
result = _impl.all(a, axis=axis, where=where, keepdims=keepdims)
404-
return _helpers.result_or_out(result, out)
404+
return result
405405

406406

407407
@normalizer
@@ -414,7 +414,7 @@ def any(
414414
where=NoValue,
415415
):
416416
result = _impl.any(a, axis=axis, where=where, keepdims=keepdims)
417-
return _helpers.result_or_out(result, out)
417+
return result
418418

419419

420420
@normalizer
@@ -431,7 +431,7 @@ def cumsum(
431431
out: Optional[NDArray] = None,
432432
):
433433
result = _impl.cumsum(a, axis=axis, dtype=dtype)
434-
return _helpers.result_or_out(result, out)
434+
return result
435435

436436

437437
@normalizer
@@ -442,13 +442,13 @@ def cumprod(
442442
out: Optional[NDArray] = None,
443443
):
444444
result = _impl.cumprod(a, axis=axis, dtype=dtype)
445-
return _helpers.result_or_out(result, out)
445+
return result
446446

447447

448448
cumproduct = cumprod
449449

450450

451-
@normalizer
451+
@normalizer(promote_scalar_result=True)
452452
def quantile(
453453
a: ArrayLike,
454454
q: ArrayLike,
@@ -469,10 +469,10 @@ def quantile(
469469
keepdims=keepdims,
470470
interpolation=interpolation,
471471
)
472-
return _helpers.result_or_out(result, out, promote_scalar=True)
472+
return result
473473

474474

475-
@normalizer
475+
@normalizer(promote_scalar_result=True)
476476
def percentile(
477477
a: ArrayLike,
478478
q: ArrayLike,
@@ -493,7 +493,7 @@ def percentile(
493493
keepdims=keepdims,
494494
interpolation=interpolation,
495495
)
496-
return _helpers.result_or_out(result, out, promote_scalar=True)
496+
return result
497497

498498

499499
def median(

torch_np/_helpers.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,39 +30,6 @@ def ufunc_preprocess(
3030
return tensors
3131

3232

33-
# ### Return helpers: wrap a single tensor, a tuple of tensors, out= etc ###
34-
35-
36-
def result_or_out(result_tensor, out_array=None, promote_scalar=False):
37-
"""A helper for returns with out= argument.
38-
39-
If `promote_scalar is True`, then:
40-
if result_tensor.numel() == 1 and out is zero-dimensional,
41-
result_tensor is placed into the out array.
42-
This weirdness is used e.g. in `np.percentile`
43-
"""
44-
if out_array is not None:
45-
if result_tensor.shape != out_array.shape:
46-
can_fit = result_tensor.numel() == 1 and out_array.ndim == 0
47-
if promote_scalar and can_fit:
48-
result_tensor = result_tensor.squeeze()
49-
else:
50-
raise ValueError(
51-
f"Bad size of the out array: out.shape = {out_array.shape}"
52-
f" while result.shape = {result_tensor.shape}."
53-
)
54-
out_tensor = out_array.tensor
55-
out_tensor.copy_(result_tensor)
56-
return out_array
57-
else:
58-
from ._ndarray import ndarray
59-
60-
return ndarray(result_tensor)
61-
62-
63-
# ### Various ways of converting array-likes to tensors ###
64-
65-
6633
def ndarrays_to_tensors(*inputs):
6734
"""Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
6835
from ._ndarray import asarray, ndarray

0 commit comments

Comments
 (0)