Skip to content

Commit a0cb202

Browse files
committed
MAINT: address review comments
1 parent 248c02b commit a0cb202

File tree

2 files changed

+40
-44
lines changed

2 files changed

+40
-44
lines changed

torch_np/_detail/_reductions.py

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -313,55 +313,51 @@ def average(a, axis, weights, returned=False, keepdims=False):
313313
return result, wsum
314314

315315

316-
def average_noweights(a_tensor, axis, keepdims=False):
317-
result = mean(a_tensor, axis=axis, keepdims=keepdims)
318-
scl = torch.as_tensor(a_tensor.numel() / result.numel(), dtype=result.dtype)
316+
def average_noweights(a, axis, keepdims=False):
317+
result = mean(a, axis=axis, keepdims=keepdims)
318+
scl = torch.as_tensor(a.numel() / result.numel(), dtype=result.dtype)
319319
return result, scl
320320

321321

322-
def average_weights(a_tensor, axis, w_tensor, keepdims=False):
322+
def average_weights(a, axis, w, keepdims=False):
323323

324324
# dtype
325325
# FIXME: 1. use result_type
326326
# 2. actually implement multiply w/dtype
327-
if not a_tensor.dtype.is_floating_point:
327+
if not a.dtype.is_floating_point:
328328
result_dtype = torch.float64
329-
a_tensor = a_tensor.to(result_dtype)
329+
a = a.to(result_dtype)
330330

331-
result_dtype = _dtypes_impl.result_type_impl([a_tensor.dtype, w_tensor.dtype])
331+
result_dtype = _dtypes_impl.result_type_impl([a.dtype, w.dtype])
332332

333-
a_tensor = _util.cast_if_needed(a_tensor, result_dtype)
334-
w_tensor = _util.cast_if_needed(w_tensor, result_dtype)
333+
a = _util.cast_if_needed(a, result_dtype)
334+
w = _util.cast_if_needed(w, result_dtype)
335335

336336
# axis=None ravels, so store the originals to reuse with keepdims=True below
337-
ax, ndim = axis, a_tensor.ndim
337+
ax, ndim = axis, a.ndim
338338

339339
# axis
340340
if axis is None:
341-
(a_tensor, w_tensor), axis = _util.axis_none_ravel(
342-
a_tensor, w_tensor, axis=axis
343-
)
341+
(a, w), axis = _util.axis_none_ravel(a, w, axis=axis)
344342

345343
# axis & weights
346-
if a_tensor.shape != w_tensor.shape:
344+
if a.shape != w.shape:
347345
if axis is None:
348346
raise TypeError(
349347
"Axis must be specified when shapes of a and weights " "differ."
350348
)
351-
if w_tensor.ndim != 1:
349+
if w.ndim != 1:
352350
raise TypeError("1D weights expected when shapes of a and weights differ.")
353-
if w_tensor.shape[0] != a_tensor.shape[axis]:
351+
if w.shape[0] != a.shape[axis]:
354352
raise ValueError("Length of weights not compatible with specified axis.")
355353

356354
# setup weight to broadcast along axis
357-
w_tensor = torch.broadcast_to(
358-
w_tensor, (a_tensor.ndim - 1) * (1,) + w_tensor.shape
359-
)
360-
w_tensor = w_tensor.swapaxes(-1, axis)
355+
w = torch.broadcast_to(w, (a.ndim - 1) * (1,) + w.shape)
356+
w = w.swapaxes(-1, axis)
361357

362358
# do the work
363-
numerator = torch.mul(a_tensor, w_tensor).sum(axis)
364-
denominator = w_tensor.sum(axis)
359+
numerator = torch.mul(a, w).sum(axis)
360+
denominator = w.sum(axis)
365361
result = numerator / denominator
366362

367363
# keepdims
@@ -372,8 +368,8 @@ def average_weights(a_tensor, axis, w_tensor, keepdims=False):
372368

373369

374370
def quantile(
375-
a_tensor,
376-
q_tensor,
371+
a,
372+
q,
377373
axis,
378374
overwrite_input,
379375
method,
@@ -390,30 +386,30 @@ def quantile(
390386
if interpolation is not None:
391387
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
392388

393-
if (0 > q_tensor).any() or (q_tensor > 1).any():
394-
raise ValueError("Quantiles must be in range [0, 1], got %s" % q_tensor)
389+
if (0 > q).any() or (q > 1).any():
390+
raise ValueError("Quantiles must be in range [0, 1], got %s" % q)
395391

396-
if not a_tensor.dtype.is_floating_point:
392+
if not a.dtype.is_floating_point:
397393
dtype = _dtypes_impl.default_float_dtype
398-
a_tensor = a_tensor.to(dtype)
394+
a = a.to(dtype)
399395

400396
# edge case: torch.quantile only supports float32 and float64
401-
if a_tensor.dtype == torch.float16:
402-
a_tensor = a_tensor.to(torch.float32)
397+
if a.dtype == torch.float16:
398+
a = a.to(torch.float32)
403399

404400
# TODO: consider moving this normalize_axis_tuple dance to normalize axis? Across the board if at all.
405401
# axis
406402
if axis is not None:
407-
axis = _util.normalize_axis_tuple(axis, a_tensor.ndim)
403+
axis = _util.normalize_axis_tuple(axis, a.ndim)
408404
axis = _util.allow_only_single_axis(axis)
409405

410-
q_tensor = _util.cast_if_needed(q_tensor, a_tensor.dtype)
406+
q = _util.cast_if_needed(q, a.dtype)
411407

412408
# axis=None ravels, so store the originals to reuse with keepdims=True below
413-
ax, ndim = axis, a_tensor.ndim
414-
(a_tensor, q_tensor), axis = _util.axis_none_ravel(a_tensor, q_tensor, axis=axis)
409+
ax, ndim = axis, a.ndim
410+
(a, q), axis = _util.axis_none_ravel(a, q, axis=axis)
415411

416-
result = torch.quantile(a_tensor, q_tensor, axis=axis, interpolation=method)
412+
result = torch.quantile(a, q, axis=axis, interpolation=method)
417413

418414
# NB: not using @emulate_keepdims here because the signature is (a, q, axis, ...)
419415
# while the decorator expects (a, axis, ...)
@@ -424,17 +420,17 @@ def quantile(
424420

425421

426422
def percentile(
427-
a_tensor,
428-
q_tensor,
423+
a,
424+
q,
429425
axis,
430426
overwrite_input,
431427
method,
432428
keepdims=False,
433429
interpolation=None,
434430
):
435431
return quantile(
436-
a_tensor,
437-
q_tensor / 100.0,
432+
a,
433+
q / 100.0,
438434
axis=axis,
439435
overwrite_input=overwrite_input,
440436
method=method,

torch_np/_detail/implementations.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -461,15 +461,15 @@ def indices(dimensions, dtype=int, sparse=False):
461461
return res
462462

463463

464-
def bincount(x_tensor, /, weights_tensor=None, minlength=0):
465-
if x_tensor.numel() == 0:
464+
def bincount(x, /, weights=None, minlength=0):
465+
if x.numel() == 0:
466466
# edge case allowed by numpy
467-
x_tensor = torch.as_tensor([], dtype=int)
467+
x = x.new_empty(0, dtype=int)
468468

469469
int_dtype = _dtypes_impl.default_int_dtype
470-
(x_tensor,) = _util.cast_dont_broadcast((x_tensor,), int_dtype, casting="safe")
470+
(x,) = _util.cast_dont_broadcast((x,), int_dtype, casting="safe")
471471

472-
result = torch.bincount(x_tensor, weights_tensor, minlength)
472+
result = torch.bincount(x, weights, minlength)
473473
return result
474474

475475

0 commit comments

Comments
 (0)