@@ -317,55 +317,51 @@ def average(a, axis, weights, returned=False, keepdims=False):
317
317
return result , wsum
318
318
319
319
320
- def average_noweights (a_tensor , axis , keepdims = False ):
321
- result = mean (a_tensor , axis = axis , keepdims = keepdims )
322
- scl = torch .as_tensor (a_tensor .numel () / result .numel (), dtype = result .dtype )
320
+ def average_noweights (a , axis , keepdims = False ):
321
+ result = mean (a , axis = axis , keepdims = keepdims )
322
+ scl = torch .as_tensor (a .numel () / result .numel (), dtype = result .dtype )
323
323
return result , scl
324
324
325
325
326
- def average_weights (a_tensor , axis , w_tensor , keepdims = False ):
326
+ def average_weights (a , axis , w , keepdims = False ):
327
327
328
328
# dtype
329
329
# FIXME: 1. use result_type
330
330
# 2. actually implement multiply w/dtype
331
- if not a_tensor .dtype .is_floating_point :
331
+ if not a .dtype .is_floating_point :
332
332
result_dtype = torch .float64
333
- a_tensor = a_tensor .to (result_dtype )
333
+ a = a .to (result_dtype )
334
334
335
- result_dtype = _dtypes_impl .result_type_impl ([a_tensor .dtype , w_tensor .dtype ])
335
+ result_dtype = _dtypes_impl .result_type_impl ([a .dtype , w .dtype ])
336
336
337
- a_tensor = _util .cast_if_needed (a_tensor , result_dtype )
338
- w_tensor = _util .cast_if_needed (w_tensor , result_dtype )
337
+ a = _util .cast_if_needed (a , result_dtype )
338
+ w = _util .cast_if_needed (w , result_dtype )
339
339
340
340
# axis=None ravels, so store the originals to reuse with keepdims=True below
341
- ax , ndim = axis , a_tensor .ndim
341
+ ax , ndim = axis , a .ndim
342
342
343
343
# axis
344
344
if axis is None :
345
- (a_tensor , w_tensor ), axis = _util .axis_none_ravel (
346
- a_tensor , w_tensor , axis = axis
347
- )
345
+ (a , w ), axis = _util .axis_none_ravel (a , w , axis = axis )
348
346
349
347
# axis & weights
350
- if a_tensor .shape != w_tensor .shape :
348
+ if a .shape != w .shape :
351
349
if axis is None :
352
350
raise TypeError (
353
351
"Axis must be specified when shapes of a and weights " "differ."
354
352
)
355
- if w_tensor .ndim != 1 :
353
+ if w .ndim != 1 :
356
354
raise TypeError ("1D weights expected when shapes of a and weights differ." )
357
- if w_tensor .shape [0 ] != a_tensor .shape [axis ]:
355
+ if w .shape [0 ] != a .shape [axis ]:
358
356
raise ValueError ("Length of weights not compatible with specified axis." )
359
357
360
358
# setup weight to broadcast along axis
361
- w_tensor = torch .broadcast_to (
362
- w_tensor , (a_tensor .ndim - 1 ) * (1 ,) + w_tensor .shape
363
- )
364
- w_tensor = w_tensor .swapaxes (- 1 , axis )
359
+ w = torch .broadcast_to (w , (a .ndim - 1 ) * (1 ,) + w .shape )
360
+ w = w .swapaxes (- 1 , axis )
365
361
366
362
# do the work
367
- numerator = torch .mul (a_tensor , w_tensor ).sum (axis )
368
- denominator = w_tensor .sum (axis )
363
+ numerator = torch .mul (a , w ).sum (axis )
364
+ denominator = w .sum (axis )
369
365
result = numerator / denominator
370
366
371
367
# keepdims
@@ -376,8 +372,8 @@ def average_weights(a_tensor, axis, w_tensor, keepdims=False):
376
372
377
373
378
374
def quantile (
379
- a_tensor ,
380
- q_tensor ,
375
+ a ,
376
+ q ,
381
377
axis ,
382
378
overwrite_input ,
383
379
method ,
@@ -394,30 +390,30 @@ def quantile(
394
390
if interpolation is not None :
395
391
raise ValueError ("'interpolation' argument is deprecated; use 'method' instead" )
396
392
397
- if (0 > q_tensor ).any () or (q_tensor > 1 ).any ():
398
- raise ValueError ("Quantiles must be in range [0, 1], got %s" % q_tensor )
393
+ if (0 > q ).any () or (q > 1 ).any ():
394
+ raise ValueError ("Quantiles must be in range [0, 1], got %s" % q )
399
395
400
- if not a_tensor .dtype .is_floating_point :
396
+ if not a .dtype .is_floating_point :
401
397
dtype = _dtypes_impl .default_float_dtype
402
- a_tensor = a_tensor .to (dtype )
398
+ a = a .to (dtype )
403
399
404
400
# edge case: torch.quantile only supports float32 and float64
405
- if a_tensor .dtype == torch .float16 :
406
- a_tensor = a_tensor .to (torch .float32 )
401
+ if a .dtype == torch .float16 :
402
+ a = a .to (torch .float32 )
407
403
408
404
# TODO: consider moving this normalize_axis_tuple dance to normalize axis? Across the board if at all.
409
405
# axis
410
406
if axis is not None :
411
- axis = _util .normalize_axis_tuple (axis , a_tensor .ndim )
407
+ axis = _util .normalize_axis_tuple (axis , a .ndim )
412
408
axis = _util .allow_only_single_axis (axis )
413
409
414
- q_tensor = _util .cast_if_needed (q_tensor , a_tensor .dtype )
410
+ q = _util .cast_if_needed (q , a .dtype )
415
411
416
412
# axis=None ravels, so store the originals to reuse with keepdims=True below
417
- ax , ndim = axis , a_tensor .ndim
418
- (a_tensor , q_tensor ), axis = _util .axis_none_ravel (a_tensor , q_tensor , axis = axis )
413
+ ax , ndim = axis , a .ndim
414
+ (a , q ), axis = _util .axis_none_ravel (a , q , axis = axis )
419
415
420
- result = torch .quantile (a_tensor , q_tensor , axis = axis , interpolation = method )
416
+ result = torch .quantile (a , q , axis = axis , interpolation = method )
421
417
422
418
# NB: not using @emulate_keepdims here because the signature is (a, q, axis, ...)
423
419
# while the decorator expects (a, axis, ...)
@@ -428,17 +424,17 @@ def quantile(
428
424
429
425
430
426
def percentile (
431
- a_tensor ,
432
- q_tensor ,
427
+ a ,
428
+ q ,
433
429
axis ,
434
430
overwrite_input ,
435
431
method ,
436
432
keepdims = False ,
437
433
interpolation = None ,
438
434
):
439
435
return quantile (
440
- a_tensor ,
441
- q_tensor / 100.0 ,
436
+ a ,
437
+ q / 100.0 ,
442
438
axis = axis ,
443
439
overwrite_input = overwrite_input ,
444
440
method = method ,
0 commit comments