4
4
Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc
5
5
"""
6
6
7
+ import functools
7
8
import typing
8
9
9
10
import torch
10
11
11
12
from . import _dtypes_impl , _util
12
13
13
- NoValue = _util .NoValue
14
-
15
-
16
- import functools
17
-
18
14
############# XXX
19
15
### From _util.axis_expand_func
20
16
@@ -51,7 +47,7 @@ def wrapped(tensor, axis, *args, **kwds):
51
47
52
48
def emulate_keepdims (func ):
53
49
@functools .wraps (func )
54
- def wrapped (tensor , axis = None , keepdims = NoValue , * args , ** kwds ):
50
+ def wrapped (tensor , axis = None , keepdims = None , * args , ** kwds ):
55
51
result = func (tensor , axis = axis , * args , ** kwds )
56
52
if keepdims :
57
53
result = _util .apply_keepdims (result , axis , tensor .ndim )
@@ -133,10 +129,7 @@ def argmin(tensor, axis=None):
133
129
134
130
@emulate_keepdims
135
131
@deco_axis_expand
136
- def any (tensor , axis = None , * , where = NoValue ):
137
- if where is not NoValue :
138
- raise NotImplementedError
139
-
132
+ def any (tensor , axis = None , * , where = None ):
140
133
axis = _util .allow_only_single_axis (axis )
141
134
142
135
if axis is None :
@@ -148,10 +141,7 @@ def any(tensor, axis=None, *, where=NoValue):
148
141
149
142
@emulate_keepdims
150
143
@deco_axis_expand
151
- def all (tensor , axis = None , * , where = NoValue ):
152
- if where is not NoValue :
153
- raise NotImplementedError
154
-
144
+ def all (tensor , axis = None , * , where = None ):
155
145
axis = _util .allow_only_single_axis (axis )
156
146
157
147
if axis is None :
@@ -163,37 +153,25 @@ def all(tensor, axis=None, *, where=NoValue):
163
153
164
154
@emulate_keepdims
165
155
@deco_axis_expand
166
- def max (tensor , axis = None , initial = NoValue , where = NoValue ):
167
- if initial is not NoValue or where is not NoValue :
168
- raise NotImplementedError
169
-
170
- result = tensor .amax (axis )
171
- return result
156
+ def max (tensor , axis = None , initial = None , where = None ):
157
+ return tensor .amax (axis )
172
158
173
159
174
160
@emulate_keepdims
175
161
@deco_axis_expand
176
- def min (tensor , axis = None , initial = NoValue , where = NoValue ):
177
- if initial is not NoValue or where is not NoValue :
178
- raise NotImplementedError
179
-
180
- result = tensor .amin (axis )
181
- return result
162
+ def min (tensor , axis = None , initial = None , where = None ):
163
+ return tensor .amin (axis )
182
164
183
165
184
166
@emulate_keepdims
185
167
@deco_axis_expand
186
168
def ptp (tensor , axis = None ):
187
- result = tensor .amax (axis ) - tensor .amin (axis )
188
- return result
169
+ return tensor .amax (axis ) - tensor .amin (axis )
189
170
190
171
191
172
@emulate_keepdims
192
173
@deco_axis_expand
193
- def sum (tensor , axis = None , dtype = None , initial = NoValue , where = NoValue ):
194
- if initial is not NoValue or where is not NoValue :
195
- raise NotImplementedError
196
-
174
+ def sum (tensor , axis = None , dtype = None , initial = None , where = None ):
197
175
assert dtype is None or isinstance (dtype , torch .dtype )
198
176
199
177
if dtype == torch .bool :
@@ -209,10 +187,7 @@ def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
209
187
210
188
@emulate_keepdims
211
189
@deco_axis_expand
212
- def prod (tensor , axis = None , dtype = None , initial = NoValue , where = NoValue ):
213
- if initial is not NoValue or where is not NoValue :
214
- raise NotImplementedError
215
-
190
+ def prod (tensor , axis = None , dtype = None , initial = None , where = None ):
216
191
axis = _util .allow_only_single_axis (axis )
217
192
218
193
if dtype == torch .bool :
@@ -228,10 +203,7 @@ def prod(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
228
203
229
204
@emulate_keepdims
230
205
@deco_axis_expand
231
- def mean (tensor , axis = None , dtype = None , * , where = NoValue ):
232
- if where is not NoValue :
233
- raise NotImplementedError
234
-
206
+ def mean (tensor , axis = None , dtype = None , * , where = None ):
235
207
dtype = _atleast_float (dtype , tensor .dtype )
236
208
237
209
is_half = dtype == torch .float16
@@ -252,10 +224,7 @@ def mean(tensor, axis=None, dtype=None, *, where=NoValue):
252
224
253
225
@emulate_keepdims
254
226
@deco_axis_expand
255
- def std (tensor , axis = None , dtype = None , ddof = 0 , * , where = NoValue ):
256
- if where is not NoValue :
257
- raise NotImplementedError
258
-
227
+ def std (tensor , axis = None , dtype = None , ddof = 0 , * , where = None ):
259
228
dtype = _atleast_float (dtype , tensor .dtype )
260
229
tensor = _util .cast_if_needed (tensor , dtype )
261
230
result = tensor .std (dim = axis , correction = ddof )
@@ -265,10 +234,7 @@ def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
265
234
266
235
@emulate_keepdims
267
236
@deco_axis_expand
268
- def var (tensor , axis = None , dtype = None , ddof = 0 , * , where = NoValue ):
269
- if where is not NoValue :
270
- raise NotImplementedError
271
-
237
+ def var (tensor , axis = None , dtype = None , ddof = 0 , * , where = None ):
272
238
dtype = _atleast_float (dtype , tensor .dtype )
273
239
tensor = _util .cast_if_needed (tensor , dtype )
274
240
result = tensor .var (dim = axis , correction = ddof )
@@ -387,9 +353,6 @@ def quantile(
387
353
# Here we choose to work out-of-place because why not.
388
354
pass
389
355
390
- if interpolation is not None :
391
- raise ValueError ("'interpolation' argument is deprecated; use 'method' instead" )
392
-
393
356
if not a .dtype .is_floating_point :
394
357
dtype = _dtypes_impl .default_float_dtype
395
358
a = a .to (dtype )
0 commit comments