Skip to content

Commit f94edc5

Browse files
committed
MAINT: address review comments
1 parent 05cfd66 commit f94edc5

File tree

2 files changed

+40
-44
lines changed

2 files changed

+40
-44
lines changed

torch_np/_detail/_reductions.py

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -317,55 +317,51 @@ def average(a, axis, weights, returned=False, keepdims=False):
317317
return result, wsum
318318

319319

320-
def average_noweights(a_tensor, axis, keepdims=False):
321-
result = mean(a_tensor, axis=axis, keepdims=keepdims)
322-
scl = torch.as_tensor(a_tensor.numel() / result.numel(), dtype=result.dtype)
320+
def average_noweights(a, axis, keepdims=False):
321+
result = mean(a, axis=axis, keepdims=keepdims)
322+
scl = torch.as_tensor(a.numel() / result.numel(), dtype=result.dtype)
323323
return result, scl
324324

325325

326-
def average_weights(a_tensor, axis, w_tensor, keepdims=False):
326+
def average_weights(a, axis, w, keepdims=False):
327327

328328
# dtype
329329
# FIXME: 1. use result_type
330330
# 2. actually implement multiply w/dtype
331-
if not a_tensor.dtype.is_floating_point:
331+
if not a.dtype.is_floating_point:
332332
result_dtype = torch.float64
333-
a_tensor = a_tensor.to(result_dtype)
333+
a = a.to(result_dtype)
334334

335-
result_dtype = _dtypes_impl.result_type_impl([a_tensor.dtype, w_tensor.dtype])
335+
result_dtype = _dtypes_impl.result_type_impl([a.dtype, w.dtype])
336336

337-
a_tensor = _util.cast_if_needed(a_tensor, result_dtype)
338-
w_tensor = _util.cast_if_needed(w_tensor, result_dtype)
337+
a = _util.cast_if_needed(a, result_dtype)
338+
w = _util.cast_if_needed(w, result_dtype)
339339

340340
# axis=None ravels, so store the originals to reuse with keepdims=True below
341-
ax, ndim = axis, a_tensor.ndim
341+
ax, ndim = axis, a.ndim
342342

343343
# axis
344344
if axis is None:
345-
(a_tensor, w_tensor), axis = _util.axis_none_ravel(
346-
a_tensor, w_tensor, axis=axis
347-
)
345+
(a, w), axis = _util.axis_none_ravel(a, w, axis=axis)
348346

349347
# axis & weights
350-
if a_tensor.shape != w_tensor.shape:
348+
if a.shape != w.shape:
351349
if axis is None:
352350
raise TypeError(
353351
"Axis must be specified when shapes of a and weights " "differ."
354352
)
355-
if w_tensor.ndim != 1:
353+
if w.ndim != 1:
356354
raise TypeError("1D weights expected when shapes of a and weights differ.")
357-
if w_tensor.shape[0] != a_tensor.shape[axis]:
355+
if w.shape[0] != a.shape[axis]:
358356
raise ValueError("Length of weights not compatible with specified axis.")
359357

360358
# setup weight to broadcast along axis
361-
w_tensor = torch.broadcast_to(
362-
w_tensor, (a_tensor.ndim - 1) * (1,) + w_tensor.shape
363-
)
364-
w_tensor = w_tensor.swapaxes(-1, axis)
359+
w = torch.broadcast_to(w, (a.ndim - 1) * (1,) + w.shape)
360+
w = w.swapaxes(-1, axis)
365361

366362
# do the work
367-
numerator = torch.mul(a_tensor, w_tensor).sum(axis)
368-
denominator = w_tensor.sum(axis)
363+
numerator = torch.mul(a, w).sum(axis)
364+
denominator = w.sum(axis)
369365
result = numerator / denominator
370366

371367
# keepdims
@@ -376,8 +372,8 @@ def average_weights(a_tensor, axis, w_tensor, keepdims=False):
376372

377373

378374
def quantile(
379-
a_tensor,
380-
q_tensor,
375+
a,
376+
q,
381377
axis,
382378
overwrite_input,
383379
method,
@@ -394,30 +390,30 @@ def quantile(
394390
if interpolation is not None:
395391
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
396392

397-
if (0 > q_tensor).any() or (q_tensor > 1).any():
398-
raise ValueError("Quantiles must be in range [0, 1], got %s" % q_tensor)
393+
if (0 > q).any() or (q > 1).any():
394+
raise ValueError("Quantiles must be in range [0, 1], got %s" % q)
399395

400-
if not a_tensor.dtype.is_floating_point:
396+
if not a.dtype.is_floating_point:
401397
dtype = _dtypes_impl.default_float_dtype
402-
a_tensor = a_tensor.to(dtype)
398+
a = a.to(dtype)
403399

404400
# edge case: torch.quantile only supports float32 and float64
405-
if a_tensor.dtype == torch.float16:
406-
a_tensor = a_tensor.to(torch.float32)
401+
if a.dtype == torch.float16:
402+
a = a.to(torch.float32)
407403

408404
# TODO: consider moving this normalize_axis_tuple dance to normalize axis? Across the board if at all.
409405
# axis
410406
if axis is not None:
411-
axis = _util.normalize_axis_tuple(axis, a_tensor.ndim)
407+
axis = _util.normalize_axis_tuple(axis, a.ndim)
412408
axis = _util.allow_only_single_axis(axis)
413409

414-
q_tensor = _util.cast_if_needed(q_tensor, a_tensor.dtype)
410+
q = _util.cast_if_needed(q, a.dtype)
415411

416412
# axis=None ravels, so store the originals to reuse with keepdims=True below
417-
ax, ndim = axis, a_tensor.ndim
418-
(a_tensor, q_tensor), axis = _util.axis_none_ravel(a_tensor, q_tensor, axis=axis)
413+
ax, ndim = axis, a.ndim
414+
(a, q), axis = _util.axis_none_ravel(a, q, axis=axis)
419415

420-
result = torch.quantile(a_tensor, q_tensor, axis=axis, interpolation=method)
416+
result = torch.quantile(a, q, axis=axis, interpolation=method)
421417

422418
# NB: not using @emulate_keepdims here because the signature is (a, q, axis, ...)
423419
# while the decorator expects (a, axis, ...)
@@ -428,17 +424,17 @@ def quantile(
428424

429425

430426
def percentile(
431-
a_tensor,
432-
q_tensor,
427+
a,
428+
q,
433429
axis,
434430
overwrite_input,
435431
method,
436432
keepdims=False,
437433
interpolation=None,
438434
):
439435
return quantile(
440-
a_tensor,
441-
q_tensor / 100.0,
436+
a,
437+
q / 100.0,
442438
axis=axis,
443439
overwrite_input=overwrite_input,
444440
method=method,

torch_np/_detail/implementations.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -461,15 +461,15 @@ def indices(dimensions, dtype=int, sparse=False):
461461
return res
462462

463463

464-
def bincount(x_tensor, /, weights_tensor=None, minlength=0):
465-
if x_tensor.numel() == 0:
464+
def bincount(x, /, weights=None, minlength=0):
465+
if x.numel() == 0:
466466
# edge case allowed by numpy
467-
x_tensor = torch.as_tensor([], dtype=int)
467+
x = x.new_empty(0, dtype=int)
468468

469469
int_dtype = _dtypes_impl.default_int_dtype
470-
(x_tensor,) = _util.cast_dont_broadcast((x_tensor,), int_dtype, casting="safe")
470+
(x,) = _util.cast_dont_broadcast((x,), int_dtype, casting="safe")
471471

472-
result = torch.bincount(x_tensor, weights_tensor, minlength)
472+
result = torch.bincount(x, weights, minlength)
473473
return result
474474

475475

0 commit comments

Comments
 (0)