@@ -313,55 +313,51 @@ def average(a, axis, weights, returned=False, keepdims=False):
313
313
return result , wsum
314
314
315
315
316
- def average_noweights (a_tensor , axis , keepdims = False ):
317
- result = mean (a_tensor , axis = axis , keepdims = keepdims )
318
- scl = torch .as_tensor (a_tensor .numel () / result .numel (), dtype = result .dtype )
316
+ def average_noweights (a , axis , keepdims = False ):
317
+ result = mean (a , axis = axis , keepdims = keepdims )
318
+ scl = torch .as_tensor (a .numel () / result .numel (), dtype = result .dtype )
319
319
return result , scl
320
320
321
321
322
- def average_weights (a_tensor , axis , w_tensor , keepdims = False ):
322
+ def average_weights (a , axis , w , keepdims = False ):
323
323
324
324
# dtype
325
325
# FIXME: 1. use result_type
326
326
# 2. actually implement multiply w/dtype
327
- if not a_tensor .dtype .is_floating_point :
327
+ if not a .dtype .is_floating_point :
328
328
result_dtype = torch .float64
329
- a_tensor = a_tensor .to (result_dtype )
329
+ a = a .to (result_dtype )
330
330
331
- result_dtype = _dtypes_impl .result_type_impl ([a_tensor .dtype , w_tensor .dtype ])
331
+ result_dtype = _dtypes_impl .result_type_impl ([a .dtype , w .dtype ])
332
332
333
- a_tensor = _util .cast_if_needed (a_tensor , result_dtype )
334
- w_tensor = _util .cast_if_needed (w_tensor , result_dtype )
333
+ a = _util .cast_if_needed (a , result_dtype )
334
+ w = _util .cast_if_needed (w , result_dtype )
335
335
336
336
# axis=None ravels, so store the originals to reuse with keepdims=True below
337
- ax , ndim = axis , a_tensor .ndim
337
+ ax , ndim = axis , a .ndim
338
338
339
339
# axis
340
340
if axis is None :
341
- (a_tensor , w_tensor ), axis = _util .axis_none_ravel (
342
- a_tensor , w_tensor , axis = axis
343
- )
341
+ (a , w ), axis = _util .axis_none_ravel (a , w , axis = axis )
344
342
345
343
# axis & weights
346
- if a_tensor .shape != w_tensor .shape :
344
+ if a .shape != w .shape :
347
345
if axis is None :
348
346
raise TypeError (
349
347
"Axis must be specified when shapes of a and weights " "differ."
350
348
)
351
- if w_tensor .ndim != 1 :
349
+ if w .ndim != 1 :
352
350
raise TypeError ("1D weights expected when shapes of a and weights differ." )
353
- if w_tensor .shape [0 ] != a_tensor .shape [axis ]:
351
+ if w .shape [0 ] != a .shape [axis ]:
354
352
raise ValueError ("Length of weights not compatible with specified axis." )
355
353
356
354
# setup weight to broadcast along axis
357
- w_tensor = torch .broadcast_to (
358
- w_tensor , (a_tensor .ndim - 1 ) * (1 ,) + w_tensor .shape
359
- )
360
- w_tensor = w_tensor .swapaxes (- 1 , axis )
355
+ w = torch .broadcast_to (w , (a .ndim - 1 ) * (1 ,) + w .shape )
356
+ w = w .swapaxes (- 1 , axis )
361
357
362
358
# do the work
363
- numerator = torch .mul (a_tensor , w_tensor ).sum (axis )
364
- denominator = w_tensor .sum (axis )
359
+ numerator = torch .mul (a , w ).sum (axis )
360
+ denominator = w .sum (axis )
365
361
result = numerator / denominator
366
362
367
363
# keepdims
@@ -372,8 +368,8 @@ def average_weights(a_tensor, axis, w_tensor, keepdims=False):
372
368
373
369
374
370
def quantile (
375
- a_tensor ,
376
- q_tensor ,
371
+ a ,
372
+ q ,
377
373
axis ,
378
374
overwrite_input ,
379
375
method ,
@@ -390,30 +386,30 @@ def quantile(
390
386
if interpolation is not None :
391
387
raise ValueError ("'interpolation' argument is deprecated; use 'method' instead" )
392
388
393
- if (0 > q_tensor ).any () or (q_tensor > 1 ).any ():
394
- raise ValueError ("Quantiles must be in range [0, 1], got %s" % q_tensor )
389
+ if (0 > q ).any () or (q > 1 ).any ():
390
+ raise ValueError ("Quantiles must be in range [0, 1], got %s" % q )
395
391
396
- if not a_tensor .dtype .is_floating_point :
392
+ if not a .dtype .is_floating_point :
397
393
dtype = _dtypes_impl .default_float_dtype
398
- a_tensor = a_tensor .to (dtype )
394
+ a = a .to (dtype )
399
395
400
396
# edge case: torch.quantile only supports float32 and float64
401
- if a_tensor .dtype == torch .float16 :
402
- a_tensor = a_tensor .to (torch .float32 )
397
+ if a .dtype == torch .float16 :
398
+ a = a .to (torch .float32 )
403
399
404
400
# TODO: consider moving this normalize_axis_tuple dance to normalize axis? Across the board if at all.
405
401
# axis
406
402
if axis is not None :
407
- axis = _util .normalize_axis_tuple (axis , a_tensor .ndim )
403
+ axis = _util .normalize_axis_tuple (axis , a .ndim )
408
404
axis = _util .allow_only_single_axis (axis )
409
405
410
- q_tensor = _util .cast_if_needed (q_tensor , a_tensor .dtype )
406
+ q = _util .cast_if_needed (q , a .dtype )
411
407
412
408
# axis=None ravels, so store the originals to reuse with keepdims=True below
413
- ax , ndim = axis , a_tensor .ndim
414
- (a_tensor , q_tensor ), axis = _util .axis_none_ravel (a_tensor , q_tensor , axis = axis )
409
+ ax , ndim = axis , a .ndim
410
+ (a , q ), axis = _util .axis_none_ravel (a , q , axis = axis )
415
411
416
- result = torch .quantile (a_tensor , q_tensor , axis = axis , interpolation = method )
412
+ result = torch .quantile (a , q , axis = axis , interpolation = method )
417
413
418
414
# NB: not using @emulate_keepdims here because the signature is (a, q, axis, ...)
419
415
# while the decorator expects (a, axis, ...)
@@ -424,17 +420,17 @@ def quantile(
424
420
425
421
426
422
def percentile (
427
- a_tensor ,
428
- q_tensor ,
423
+ a ,
424
+ q ,
429
425
axis ,
430
426
overwrite_input ,
431
427
method ,
432
428
keepdims = False ,
433
429
interpolation = None ,
434
430
):
435
431
return quantile (
436
- a_tensor ,
437
- q_tensor / 100.0 ,
432
+ a ,
433
+ q / 100.0 ,
438
434
axis = axis ,
439
435
overwrite_input = overwrite_input ,
440
436
method = method ,
0 commit comments