Skip to content

Commit 4f39fa7

Browse files
committed
MAINT: address review comments
1 parent e3fdc3b commit 4f39fa7

File tree

6 files changed

+65
-85
lines changed

6 files changed

+65
-85
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._decorators import deco_ufunc_from_impl
1+
from ._decorators import deco_binary_ufunc_from_impl
22
from ._detail import _ufunc_impl
33

44
#
@@ -10,44 +10,44 @@
1010
#
1111

1212
# the list is autogenerated, cf autogen/gen_ufunc_2.py
13-
add = deco_ufunc_from_impl(_ufunc_impl.add)
14-
arctan2 = deco_ufunc_from_impl(_ufunc_impl.arctan2)
15-
bitwise_and = deco_ufunc_from_impl(_ufunc_impl.bitwise_and)
16-
bitwise_or = deco_ufunc_from_impl(_ufunc_impl.bitwise_or)
17-
bitwise_xor = deco_ufunc_from_impl(_ufunc_impl.bitwise_xor)
18-
copysign = deco_ufunc_from_impl(_ufunc_impl.copysign)
19-
divide = deco_ufunc_from_impl(_ufunc_impl.divide)
20-
equal = deco_ufunc_from_impl(_ufunc_impl.equal)
21-
float_power = deco_ufunc_from_impl(_ufunc_impl.float_power)
22-
floor_divide = deco_ufunc_from_impl(_ufunc_impl.floor_divide)
23-
fmax = deco_ufunc_from_impl(_ufunc_impl.fmax)
24-
fmin = deco_ufunc_from_impl(_ufunc_impl.fmin)
25-
fmod = deco_ufunc_from_impl(_ufunc_impl.fmod)
26-
gcd = deco_ufunc_from_impl(_ufunc_impl.gcd)
27-
greater = deco_ufunc_from_impl(_ufunc_impl.greater)
28-
greater_equal = deco_ufunc_from_impl(_ufunc_impl.greater_equal)
29-
heaviside = deco_ufunc_from_impl(_ufunc_impl.heaviside)
30-
hypot = deco_ufunc_from_impl(_ufunc_impl.hypot)
31-
lcm = deco_ufunc_from_impl(_ufunc_impl.lcm)
32-
ldexp = deco_ufunc_from_impl(_ufunc_impl.ldexp)
33-
left_shift = deco_ufunc_from_impl(_ufunc_impl.left_shift)
34-
less = deco_ufunc_from_impl(_ufunc_impl.less)
35-
less_equal = deco_ufunc_from_impl(_ufunc_impl.less_equal)
36-
logaddexp = deco_ufunc_from_impl(_ufunc_impl.logaddexp)
37-
logaddexp2 = deco_ufunc_from_impl(_ufunc_impl.logaddexp2)
38-
logical_and = deco_ufunc_from_impl(_ufunc_impl.logical_and)
39-
logical_or = deco_ufunc_from_impl(_ufunc_impl.logical_or)
40-
logical_xor = deco_ufunc_from_impl(_ufunc_impl.logical_xor)
41-
matmul = deco_ufunc_from_impl(_ufunc_impl.matmul)
42-
maximum = deco_ufunc_from_impl(_ufunc_impl.maximum)
43-
minimum = deco_ufunc_from_impl(_ufunc_impl.minimum)
44-
remainder = deco_ufunc_from_impl(_ufunc_impl.remainder)
45-
multiply = deco_ufunc_from_impl(_ufunc_impl.multiply)
46-
nextafter = deco_ufunc_from_impl(_ufunc_impl.nextafter)
47-
not_equal = deco_ufunc_from_impl(_ufunc_impl.not_equal)
48-
power = deco_ufunc_from_impl(_ufunc_impl.power)
49-
remainder = deco_ufunc_from_impl(_ufunc_impl.remainder)
50-
right_shift = deco_ufunc_from_impl(_ufunc_impl.right_shift)
51-
subtract = deco_ufunc_from_impl(_ufunc_impl.subtract)
52-
divide = deco_ufunc_from_impl(_ufunc_impl.divide)
13+
add = deco_binary_ufunc_from_impl(_ufunc_impl.add)
14+
arctan2 = deco_binary_ufunc_from_impl(_ufunc_impl.arctan2)
15+
bitwise_and = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_and)
16+
bitwise_or = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_or)
17+
bitwise_xor = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_xor)
18+
copysign = deco_binary_ufunc_from_impl(_ufunc_impl.copysign)
19+
divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide)
20+
equal = deco_binary_ufunc_from_impl(_ufunc_impl.equal)
21+
float_power = deco_binary_ufunc_from_impl(_ufunc_impl.float_power)
22+
floor_divide = deco_binary_ufunc_from_impl(_ufunc_impl.floor_divide)
23+
fmax = deco_binary_ufunc_from_impl(_ufunc_impl.fmax)
24+
fmin = deco_binary_ufunc_from_impl(_ufunc_impl.fmin)
25+
fmod = deco_binary_ufunc_from_impl(_ufunc_impl.fmod)
26+
gcd = deco_binary_ufunc_from_impl(_ufunc_impl.gcd)
27+
greater = deco_binary_ufunc_from_impl(_ufunc_impl.greater)
28+
greater_equal = deco_binary_ufunc_from_impl(_ufunc_impl.greater_equal)
29+
heaviside = deco_binary_ufunc_from_impl(_ufunc_impl.heaviside)
30+
hypot = deco_binary_ufunc_from_impl(_ufunc_impl.hypot)
31+
lcm = deco_binary_ufunc_from_impl(_ufunc_impl.lcm)
32+
ldexp = deco_binary_ufunc_from_impl(_ufunc_impl.ldexp)
33+
left_shift = deco_binary_ufunc_from_impl(_ufunc_impl.left_shift)
34+
less = deco_binary_ufunc_from_impl(_ufunc_impl.less)
35+
less_equal = deco_binary_ufunc_from_impl(_ufunc_impl.less_equal)
36+
logaddexp = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp)
37+
logaddexp2 = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp2)
38+
logical_and = deco_binary_ufunc_from_impl(_ufunc_impl.logical_and)
39+
logical_or = deco_binary_ufunc_from_impl(_ufunc_impl.logical_or)
40+
logical_xor = deco_binary_ufunc_from_impl(_ufunc_impl.logical_xor)
41+
matmul = deco_binary_ufunc_from_impl(_ufunc_impl.matmul)
42+
maximum = deco_binary_ufunc_from_impl(_ufunc_impl.maximum)
43+
minimum = deco_binary_ufunc_from_impl(_ufunc_impl.minimum)
44+
remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder)
45+
multiply = deco_binary_ufunc_from_impl(_ufunc_impl.multiply)
46+
nextafter = deco_binary_ufunc_from_impl(_ufunc_impl.nextafter)
47+
not_equal = deco_binary_ufunc_from_impl(_ufunc_impl.not_equal)
48+
power = deco_binary_ufunc_from_impl(_ufunc_impl.power)
49+
remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder)
50+
right_shift = deco_binary_ufunc_from_impl(_ufunc_impl.right_shift)
51+
subtract = deco_binary_ufunc_from_impl(_ufunc_impl.subtract)
52+
divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide)
5353

torch_np/_decorators.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,11 @@ def wrapped(*args, dtype=None, **kwds):
2222

2323

2424
def emulate_out_arg(func):
25-
"""Simulate the out=... handling *for functions which do not need it*.
25+
"""Simulate the out=... handling: move the result tensor to the out array.
2626
2727
With this decorator, the inner function just does not see the out array.
2828
"""
2929
def wrapped(*args, out=None, **kwds):
30-
from ._ndarray import ndarray
31-
if out is not None:
32-
if not isinstance(out, ndarray):
33-
raise TypeError("Return arrays must be of ArrayType")
3430
result_tensor = func(*args, **kwds)
3531
return _helpers.result_or_out(result_tensor, out)
3632

@@ -45,10 +41,7 @@ def out_shape_dtype(func):
4541
and pass these through.
4642
"""
4743
def wrapped(*args, out=None, **kwds):
48-
from ._ndarray import ndarray
4944
if out is not None:
50-
if not isinstance(out, ndarray):
51-
raise TypeError("Return arrays must be of ArrayType")
5245
kwds.update({'out_shape_dtype': (out.get().dtype, out.get().shape)})
5346
result_tensor = func(*args, **kwds)
5447
return _helpers.result_or_out(result_tensor, out)
@@ -70,7 +63,7 @@ def wrapped(x1, *args, **kwds):
7063

7164
# TODO: deduplicate with _ndarray/asarray_replacer,
7265
# and _wrapper/concatenate et al
73-
def deco_ufunc_from_impl(impl_func):
66+
def deco_binary_ufunc_from_impl(impl_func):
7467
@functools.wraps(impl_func)
7568
@dtype_to_torch
7669
@out_shape_dtype
@@ -98,7 +91,7 @@ def axis_keepdims_wrapper(func):
9891
# TODO: 1. get rid of _helpers.result_or_out
9992
# 2. sort out function signatures: how they flow through all decorators etc
10093
@functools.wraps(func)
101-
def wrapped(a, axis=None, out=None, keepdims=NoValue, *args, **kwds):
94+
def wrapped(a, axis=None, keepdims=NoValue, *args, **kwds):
10295
from ._ndarray import ndarray, asarray
10396
tensor = asarray(a).get()
10497

torch_np/_detail/_reductions.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,9 @@ def _atleast_float(dtype, other_dtype):
2020
"""
2121
if dtype is None:
2222
dtype = other_dtype
23-
if not dtype.is_floating_point:
24-
if dtype.is_complex:
25-
pass # pass through complex as is
26-
else:
27-
sctype = _scalar_types.default_float_type
28-
dtype = sctype.torch_dtype
23+
if not (dtype.is_floating_point or dtype.is_complex):
24+
sctype = _scalar_types.default_float_type
25+
dtype = sctype.torch_dtype
2926
return dtype
3027

3128

@@ -38,19 +35,19 @@ def count_nonzero(a, axis=None):
3835
return tensor
3936

4037

41-
def argmax(tensor, axis=None, out=None, *, keepdims=NoValue):
38+
def argmax(tensor, axis=None):
4239
axis = _util.allow_only_single_axis(axis)
4340
tensor = torch.argmax(tensor, axis)
4441
return tensor
4542

46-
def argmin(tensor, axis=None, out=None, *, keepdims=NoValue):
43+
def argmin(tensor, axis=None):
4744
axis = _util.allow_only_single_axis(axis)
4845
tensor = torch.argmin(tensor, axis)
4946
return tensor
5047

5148

5249
def any(tensor, axis=None, *, where=NoValue):
53-
if where is not None:
50+
if where is not NoValue:
5451
raise NotImplementedError
5552

5653
axis = _util.allow_only_single_axis(axis)
@@ -63,7 +60,7 @@ def any(tensor, axis=None, *, where=NoValue):
6360

6461

6562
def all(tensor, axis=None, *, where=NoValue):
66-
if where is not None:
63+
if where is not NoValue:
6764
raise NotImplementedError
6865

6966
axis = _util.allow_only_single_axis(axis)
@@ -76,23 +73,23 @@ def all(tensor, axis=None, *, where=NoValue):
7673

7774

7875
def max(tensor, axis=None, initial=NoValue, where=NoValue):
79-
if initial is not None or where is not None:
76+
if initial is not NoValue or where is not NoValue:
8077
raise NotImplementedError
8178

8279
result = tensor.amax(axis)
8380
return result
8481

8582

8683
def min(tensor, axis=None, initial=NoValue, where=NoValue):
87-
if initial is not None or where is not None:
84+
if initial is not NoValue or where is not NoValue:
8885
raise NotImplementedError
8986

9087
result = tensor.amin(axis)
9188
return result
9289

9390

9491
def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
95-
if initial is not None or where is not None:
92+
if initial is not NoValue or where is not NoValue:
9693
raise NotImplementedError
9794

9895
assert dtype is None or isinstance(dtype, torch.dtype)
@@ -109,7 +106,7 @@ def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
109106

110107

111108
def prod(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
112-
if initial is not None or where is not None:
109+
if initial is not NoValue or where is not NoValue:
113110
raise NotImplementedError
114111

115112
axis = _util.allow_only_single_axis(axis)
@@ -126,7 +123,7 @@ def prod(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
126123

127124

128125
def mean(tensor, axis=None, dtype=None, *, where=NoValue):
129-
if where is not None:
126+
if where is not NoValue:
130127
raise NotImplementedError
131128

132129
dtype = _atleast_float(dtype, tensor.dtype)
@@ -140,24 +137,26 @@ def mean(tensor, axis=None, dtype=None, *, where=NoValue):
140137

141138

142139
def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
143-
if where is not None:
140+
if where is not NoValue:
144141
raise NotImplementedError
145142

146143
dtype = _atleast_float(dtype, tensor.dtype)
147144

148-
tensor = tensor.to(dtype)
145+
if dtype is not None:
146+
tensor = tensor.to(dtype)
149147
result = tensor.std(dim=axis, correction=ddof)
150148

151149
return result
152150

153151

154152
def var(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
155-
if where is not None:
153+
if where is not NoValue:
156154
raise NotImplementedError
157155

158156
dtype = _atleast_float(dtype, tensor.dtype)
159157

160-
tensor = tensor.to(dtype)
158+
if dtype is not None:
159+
tensor = tensor.to(dtype)
161160
result = tensor.var(dim=axis, correction=ddof)
162161

163162
return result

torch_np/_detail/_util.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
5454
By default, this forbids axes from being specified multiple times.
5555
Used internally by multi-axis-checking logic.
5656
57-
.. versionadded:: 1.13.0
58-
5957
Parameters
6058
----------
6159
axis : int, iterable of int
@@ -73,17 +71,6 @@ def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
7371
-------
7472
normalized_axes : tuple of int
7573
The normalized axis index, such that `0 <= normalized_axis < ndim`
76-
77-
Raises
78-
------
79-
AxisError
80-
If any axis provided is out of range
81-
ValueError
82-
If an axis is repeated
83-
84-
See also
85-
--------
86-
normalize_axis_index : normalizing a single scalar axis
8774
"""
8875
# Optimization to speed-up the most common cases.
8976
if type(axis) not in (tuple, list):

torch_np/_helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@ def cast_and_broadcast(tensors, out, casting):
4040
return tuple(tensors)
4141

4242

43-
4443
def result_or_out(result_tensor, out_array=None):
4544
"""A helper for returns with out= argument."""
4645
if out_array is not None:
46+
if not isinstance(out_array, ndarray):
47+
raise TypeError("Return arrays must be of ArrayType")
4748
if result_tensor.shape != out_array.shape:
4849
raise ValueError("Bad size of the out array.")
4950
out_tensor = out_array.get()

torch_np/_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=N
643643

644644
@asarray_replacer()
645645
def nanmean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoValue):
646-
if where is not None:
646+
if where is not NoValue:
647647
raise NotImplementedError
648648
if dtype is None:
649649
dtype = a.dtype

0 commit comments

Comments
 (0)