Skip to content

Commit 37d6f4f

Browse files
authored
Merge pull request #82 from Quansight-Labs/move_to_impl_take_2
MAINT: split remaining functions into normalizations and implementations
2 parents 5875cf2 + 2432f35 commit 37d6f4f

File tree

6 files changed

+269
-173
lines changed

6 files changed

+269
-173
lines changed

torch_np/_detail/_reductions.py

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -317,55 +317,51 @@ def average(a, axis, weights, returned=False, keepdims=False):
317317
return result, wsum
318318

319319

320-
def average_noweights(a_tensor, axis, keepdims=False):
321-
result = mean(a_tensor, axis=axis, keepdims=keepdims)
322-
scl = torch.as_tensor(a_tensor.numel() / result.numel(), dtype=result.dtype)
320+
def average_noweights(a, axis, keepdims=False):
321+
result = mean(a, axis=axis, keepdims=keepdims)
322+
scl = torch.as_tensor(a.numel() / result.numel(), dtype=result.dtype)
323323
return result, scl
324324

325325

326-
def average_weights(a_tensor, axis, w_tensor, keepdims=False):
326+
def average_weights(a, axis, w, keepdims=False):
327327

328328
# dtype
329329
# FIXME: 1. use result_type
330330
# 2. actually implement multiply w/dtype
331-
if not a_tensor.dtype.is_floating_point:
331+
if not a.dtype.is_floating_point:
332332
result_dtype = torch.float64
333-
a_tensor = a_tensor.to(result_dtype)
333+
a = a.to(result_dtype)
334334

335-
result_dtype = _dtypes_impl.result_type_impl([a_tensor.dtype, w_tensor.dtype])
335+
result_dtype = _dtypes_impl.result_type_impl([a.dtype, w.dtype])
336336

337-
a_tensor = _util.cast_if_needed(a_tensor, result_dtype)
338-
w_tensor = _util.cast_if_needed(w_tensor, result_dtype)
337+
a = _util.cast_if_needed(a, result_dtype)
338+
w = _util.cast_if_needed(w, result_dtype)
339339

340340
# axis=None ravels, so store the originals to reuse with keepdims=True below
341-
ax, ndim = axis, a_tensor.ndim
341+
ax, ndim = axis, a.ndim
342342

343343
# axis
344344
if axis is None:
345-
(a_tensor, w_tensor), axis = _util.axis_none_ravel(
346-
a_tensor, w_tensor, axis=axis
347-
)
345+
(a, w), axis = _util.axis_none_ravel(a, w, axis=axis)
348346

349347
# axis & weights
350-
if a_tensor.shape != w_tensor.shape:
348+
if a.shape != w.shape:
351349
if axis is None:
352350
raise TypeError(
353351
"Axis must be specified when shapes of a and weights " "differ."
354352
)
355-
if w_tensor.ndim != 1:
353+
if w.ndim != 1:
356354
raise TypeError("1D weights expected when shapes of a and weights differ.")
357-
if w_tensor.shape[0] != a_tensor.shape[axis]:
355+
if w.shape[0] != a.shape[axis]:
358356
raise ValueError("Length of weights not compatible with specified axis.")
359357

360358
# setup weight to broadcast along axis
361-
w_tensor = torch.broadcast_to(
362-
w_tensor, (a_tensor.ndim - 1) * (1,) + w_tensor.shape
363-
)
364-
w_tensor = w_tensor.swapaxes(-1, axis)
359+
w = torch.broadcast_to(w, (a.ndim - 1) * (1,) + w.shape)
360+
w = w.swapaxes(-1, axis)
365361

366362
# do the work
367-
numerator = torch.mul(a_tensor, w_tensor).sum(axis)
368-
denominator = w_tensor.sum(axis)
363+
numerator = torch.mul(a, w).sum(axis)
364+
denominator = w.sum(axis)
369365
result = numerator / denominator
370366

371367
# keepdims
@@ -375,36 +371,70 @@ def average_weights(a_tensor, axis, w_tensor, keepdims=False):
375371
return result, denominator
376372

377373

378-
def quantile(a_tensor, q_tensor, axis, method, keepdims=False):
379-
380-
if (0 > q_tensor).any() or (q_tensor > 1).any():
381-
raise ValueError("Quantiles must be in range [0, 1], got %s" % q_tensor)
382-
383-
if not a_tensor.dtype.is_floating_point:
374+
def quantile(
375+
a,
376+
q,
377+
axis,
378+
overwrite_input,
379+
method,
380+
keepdims=False,
381+
interpolation=None,
382+
):
383+
if overwrite_input:
384+
# raise NotImplementedError("overwrite_input in quantile not implemented.")
385+
# NumPy documents that `overwrite_input` MAY modify inputs:
386+
# https://numpy.org/doc/stable/reference/generated/numpy.percentile.html#numpy-percentile
387+
# Here we choose to work out-of-place because why not.
388+
pass
389+
390+
if interpolation is not None:
391+
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
392+
393+
if not a.dtype.is_floating_point:
384394
dtype = _dtypes_impl.default_float_dtype
385-
a_tensor = a_tensor.to(dtype)
395+
a = a.to(dtype)
386396

387397
# edge case: torch.quantile only supports float32 and float64
388-
if a_tensor.dtype == torch.float16:
389-
a_tensor = a_tensor.to(torch.float32)
398+
if a.dtype == torch.float16:
399+
a = a.to(torch.float32)
390400

391401
# TODO: consider moving this normalize_axis_tuple dance to normalize axis? Across the board if at all.
392402
# axis
393403
if axis is not None:
394-
axis = _util.normalize_axis_tuple(axis, a_tensor.ndim)
404+
axis = _util.normalize_axis_tuple(axis, a.ndim)
395405
axis = _util.allow_only_single_axis(axis)
396406

397-
q_tensor = _util.cast_if_needed(q_tensor, a_tensor.dtype)
407+
q = _util.cast_if_needed(q, a.dtype)
398408

399409
# axis=None ravels, so store the originals to reuse with keepdims=True below
400-
ax, ndim = axis, a_tensor.ndim
401-
(a_tensor, q_tensor), axis = _util.axis_none_ravel(a_tensor, q_tensor, axis=axis)
410+
ax, ndim = axis, a.ndim
411+
(a, q), axis = _util.axis_none_ravel(a, q, axis=axis)
402412

403-
result = torch.quantile(a_tensor, q_tensor, axis=axis, interpolation=method)
413+
result = torch.quantile(a, q, axis=axis, interpolation=method)
404414

405415
# NB: not using @emulate_keepdims here because the signature is (a, q, axis, ...)
406416
# while the decorator expects (a, axis, ...)
407417
# this can be fixed, of course, but the cure seems worse then the desease
408418
if keepdims:
409419
result = _util.apply_keepdims(result, ax, ndim)
410420
return result
421+
422+
423+
def percentile(
424+
a,
425+
q,
426+
axis,
427+
overwrite_input,
428+
method,
429+
keepdims=False,
430+
interpolation=None,
431+
):
432+
return quantile(
433+
a,
434+
q / 100.0,
435+
axis=axis,
436+
overwrite_input=overwrite_input,
437+
method=method,
438+
keepdims=keepdims,
439+
interpolation=interpolation,
440+
)

torch_np/_detail/implementations.py

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ def diff(a_tensor, n=1, axis=-1, prepend_tensor=None, append_tensor=None):
274274
if n < 0:
275275
raise ValueError(f"order must be non-negative but got {n}")
276276

277+
if n == 0:
278+
# match numpy and return the input immediately
279+
return a_tensor
280+
277281
if prepend_tensor is not None:
278282
shape = list(a_tensor.shape)
279283
shape[axis] = prepend_tensor.shape[axis] if prepend_tensor.ndim > 0 else 1
@@ -357,6 +361,14 @@ def vstack(tensors, *, dtype=None, casting="same_kind"):
357361
return result
358362

359363

364+
def tile(tensor, reps):
365+
if isinstance(reps, int):
366+
reps = (reps,)
367+
368+
result = torch.tile(tensor, reps)
369+
return result
370+
371+
360372
# #### cov & corrcoef
361373

362374

@@ -449,17 +461,29 @@ def indices(dimensions, dtype=int, sparse=False):
449461
return res
450462

451463

452-
def bincount(x_tensor, /, weights_tensor=None, minlength=0):
464+
def bincount(x, /, weights=None, minlength=0):
465+
if x.numel() == 0:
466+
# edge case allowed by numpy
467+
x = x.new_empty(0, dtype=int)
468+
453469
int_dtype = _dtypes_impl.default_int_dtype
454-
(x_tensor,) = _util.cast_dont_broadcast((x_tensor,), int_dtype, casting="safe")
470+
(x,) = _util.cast_dont_broadcast((x,), int_dtype, casting="safe")
455471

456-
result = torch.bincount(x_tensor, weights_tensor, minlength)
472+
result = torch.bincount(x, weights, minlength)
457473
return result
458474

459475

460476
# ### linspace, geomspace, logspace and arange ###
461477

462478

479+
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0):
480+
if axis != 0 or retstep or not endpoint:
481+
raise NotImplementedError
482+
# XXX: raises TypeError if start or stop are not scalars
483+
result = torch.linspace(start, stop, num, dtype=dtype)
484+
return result
485+
486+
463487
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
464488
if axis != 0 or not endpoint:
465489
raise NotImplementedError
@@ -474,6 +498,13 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
474498
return result
475499

476500

501+
def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
502+
if axis != 0 or not endpoint:
503+
raise NotImplementedError
504+
result = torch.logspace(start, stop, num, base=base, dtype=dtype)
505+
return result
506+
507+
477508
def arange(start=None, stop=None, step=1, dtype=None):
478509
if step == 0:
479510
raise ZeroDivisionError
@@ -523,36 +554,75 @@ def eye(N, M=None, k=0, dtype=float):
523554
return z
524555

525556

526-
def zeros_like(a, dtype=None, shape=None):
557+
def zeros(shape, dtype=None, order="C"):
558+
if order != "C":
559+
raise NotImplementedError
560+
if dtype is None:
561+
dtype = _dtypes_impl.default_float_dtype
562+
result = torch.zeros(shape, dtype=dtype)
563+
return result
564+
565+
566+
def zeros_like(a, dtype=None, shape=None, order="K"):
567+
if order != "K":
568+
raise NotImplementedError
527569
result = torch.zeros_like(a, dtype=dtype)
528570
if shape is not None:
529571
result = result.reshape(shape)
530572
return result
531573

532574

533-
def ones_like(a, dtype=None, shape=None):
575+
def ones(shape, dtype=None, order="C"):
576+
if order != "C":
577+
raise NotImplementedError
578+
if dtype is None:
579+
dtype = _dtypes_impl.default_float_dtype
580+
result = torch.ones(shape, dtype=dtype)
581+
return result
582+
583+
584+
def ones_like(a, dtype=None, shape=None, order="K"):
585+
if order != "K":
586+
raise NotImplementedError
534587
result = torch.ones_like(a, dtype=dtype)
535588
if shape is not None:
536589
result = result.reshape(shape)
537590
return result
538591

539592

540-
def full_like(a, fill_value, dtype=None, shape=None):
593+
def full_like(a, fill_value, dtype=None, shape=None, order="K"):
594+
if order != "K":
595+
raise NotImplementedError
541596
# XXX: fill_value broadcasts
542597
result = torch.full_like(a, fill_value, dtype=dtype)
543598
if shape is not None:
544599
result = result.reshape(shape)
545600
return result
546601

547602

548-
def empty_like(prototype, dtype=None, shape=None):
603+
def empty(shape, dtype=None, order="C"):
604+
if order != "C":
605+
raise NotImplementedError
606+
if dtype is None:
607+
dtype = _dtypes_impl.default_float_dtype
608+
result = torch.empty(shape, dtype=dtype)
609+
return result
610+
611+
612+
def empty_like(prototype, dtype=None, shape=None, order="K"):
613+
if order != "K":
614+
raise NotImplementedError
549615
result = torch.empty_like(prototype, dtype=dtype)
550616
if shape is not None:
551617
result = result.reshape(shape)
552618
return result
553619

554620

555-
def full(shape, fill_value, dtype=None):
621+
def full(shape, fill_value, dtype=None, order="C"):
622+
if isinstance(shape, int):
623+
shape = (shape,)
624+
if order != "C":
625+
raise NotImplementedError
556626
if dtype is None:
557627
dtype = fill_value.dtype
558628
if not isinstance(shape, (tuple, list)):

torch_np/_funcs.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,15 @@ def nonzero(a: ArrayLike):
2323
return _helpers.tuple_arrays_from(result)
2424

2525

26-
def argwhere(a):
27-
(tensor,) = _helpers.to_tensors(a)
28-
result = torch.argwhere(tensor)
26+
@normalizer
27+
def argwhere(a: ArrayLike):
28+
result = torch.argwhere(a)
29+
return _helpers.array_from(result)
30+
31+
32+
@normalizer
33+
def flatnonzero(a: ArrayLike):
34+
result = a.ravel().nonzero(as_tuple=True)[0]
2935
return _helpers.array_from(result)
3036

3137

@@ -49,6 +55,12 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis=None):
4955
return _helpers.array_from(result)
5056

5157

58+
@normalizer
59+
def tile(A: ArrayLike, reps):
60+
result = _impl.tile(A, reps)
61+
return _helpers.array_from(result)
62+
63+
5264
# ### diag et al ###
5365

5466

@@ -448,8 +460,45 @@ def quantile(
448460
*,
449461
interpolation=None,
450462
):
451-
if interpolation is not None:
452-
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
463+
result = _impl.quantile(
464+
a,
465+
q,
466+
axis,
467+
overwrite_input=overwrite_input,
468+
method=method,
469+
keepdims=keepdims,
470+
interpolation=interpolation,
471+
)
472+
return _helpers.result_or_out(result, out, promote_scalar=True)
453473

454-
result = _impl.quantile(a, q, axis, method=method, keepdims=keepdims)
474+
475+
@normalizer
476+
def percentile(
477+
a: ArrayLike,
478+
q: ArrayLike,
479+
axis: AxisLike = None,
480+
out: Optional[NDArray] = None,
481+
overwrite_input=False,
482+
method="linear",
483+
keepdims=False,
484+
*,
485+
interpolation=None,
486+
):
487+
result = _impl.percentile(
488+
a,
489+
q,
490+
axis,
491+
overwrite_input=overwrite_input,
492+
method=method,
493+
keepdims=keepdims,
494+
interpolation=interpolation,
495+
)
455496
return _helpers.result_or_out(result, out, promote_scalar=True)
497+
498+
499+
def median(
500+
a, axis=None, out: Optional[NDArray] = None, overwrite_input=False, keepdims=False
501+
):
502+
return quantile(
503+
a, 0.5, axis=axis, overwrite_input=overwrite_input, out=out, keepdims=keepdims
504+
)

0 commit comments

Comments
 (0)