Skip to content

Commit 92e80b4

Browse files
committed
MAINT: split remaining functions into normalizations and implementations
1 parent 77b52f0 commit 92e80b4

File tree

4 files changed

+178
-93
lines changed

4 files changed

+178
-93
lines changed

torch_np/_detail/_reductions.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,24 @@ def average_weights(a_tensor, axis, w_tensor, keepdims=False):
371371
return result, denominator
372372

373373

374-
def quantile(a_tensor, q_tensor, axis, method, keepdims=False):
374+
def quantile(
375+
a_tensor,
376+
q_tensor,
377+
axis,
378+
overwrite_input,
379+
method,
380+
keepdims=False,
381+
interpolation=None,
382+
):
383+
if overwrite_input:
384+
# raise NotImplementedError("overwrite_input in quantile not implemented.")
385+
# NumPy documents that `overwrite_input` MAY modify inputs:
386+
# https://numpy.org/doc/stable/reference/generated/numpy.percentile.html#numpy-percentile
387+
# Here we choose to work out-of-place because why not.
388+
pass
389+
390+
if interpolation is not None:
391+
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
375392

376393
if (0 > q_tensor).any() or (q_tensor > 1).any():
377394
raise ValueError("Quantiles must be in range [0, 1], got %s" % q_tensor)
@@ -404,3 +421,23 @@ def quantile(a_tensor, q_tensor, axis, method, keepdims=False):
404421
if keepdims:
405422
result = _util.apply_keepdims(result, ax, ndim)
406423
return result
424+
425+
426+
def percentile(
427+
a_tensor,
428+
q_tensor,
429+
axis,
430+
overwrite_input,
431+
method,
432+
keepdims=False,
433+
interpolation=None,
434+
):
435+
return quantile(
436+
a_tensor,
437+
q_tensor / 100.0,
438+
axis=axis,
439+
overwrite_input=overwrite_input,
440+
method=method,
441+
keepdims=keepdims,
442+
interpolation=interpolation,
443+
)

torch_np/_detail/implementations.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ def diff(a_tensor, n=1, axis=-1, prepend_tensor=None, append_tensor=None):
274274
if n < 0:
275275
raise ValueError(f"order must be non-negative but got {n}")
276276

277+
if n == 0:
278+
# match numpy and return the input immediately
279+
return a_tensor
280+
277281
if prepend_tensor is not None:
278282
shape = list(a_tensor.shape)
279283
shape[axis] = prepend_tensor.shape[axis] if prepend_tensor.ndim > 0 else 1
@@ -357,6 +361,14 @@ def vstack(tensors, *, dtype=None, casting="same_kind"):
357361
return result
358362

359363

364+
def tile(tensor, reps):
365+
if isinstance(reps, int):
366+
reps = (reps,)
367+
368+
result = torch.tile(tensor, reps)
369+
return result
370+
371+
360372
# #### cov & corrcoef
361373

362374

@@ -450,6 +462,10 @@ def indices(dimensions, dtype=int, sparse=False):
450462

451463

452464
def bincount(x_tensor, /, weights_tensor=None, minlength=0):
465+
if x_tensor.numel() == 0:
466+
# edge case allowed by numpy
467+
x_tensor = torch.as_tensor([], dtype=int)
468+
453469
int_dtype = _dtypes_impl.default_int_dtype
454470
(x_tensor,) = _util.cast_dont_broadcast((x_tensor,), int_dtype, casting="safe")
455471

@@ -460,6 +476,14 @@ def bincount(x_tensor, /, weights_tensor=None, minlength=0):
460476
# ### linspace, geomspace, logspace and arange ###
461477

462478

479+
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0):
480+
if axis != 0 or retstep or not endpoint:
481+
raise NotImplementedError
482+
# XXX: raises TypeError if start or stop are not scalars
483+
result = torch.linspace(start, stop, num, dtype=dtype)
484+
return result
485+
486+
463487
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
464488
if axis != 0 or not endpoint:
465489
raise NotImplementedError
@@ -474,6 +498,13 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
474498
return result
475499

476500

501+
def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
502+
if axis != 0 or not endpoint:
503+
raise NotImplementedError
504+
result = torch.logspace(start, stop, num, base=base, dtype=dtype)
505+
return result
506+
507+
477508
def arange(start=None, stop=None, step=1, dtype=None):
478509
if step == 0:
479510
raise ZeroDivisionError
@@ -523,36 +554,75 @@ def eye(N, M=None, k=0, dtype=float):
523554
return z
524555

525556

526-
def zeros_like(a, dtype=None, shape=None):
557+
def zeros(shape, dtype=None, order="C"):
558+
if order != "C":
559+
raise NotImplementedError
560+
if dtype is None:
561+
dtype = _dtypes_impl.default_float_dtype
562+
result = torch.zeros(shape, dtype=dtype)
563+
return result
564+
565+
566+
def zeros_like(a, dtype=None, shape=None, order="K"):
567+
if order != "K":
568+
raise NotImplementedError
527569
result = torch.zeros_like(a, dtype=dtype)
528570
if shape is not None:
529571
result = result.reshape(shape)
530572
return result
531573

532574

533-
def ones_like(a, dtype=None, shape=None):
575+
def ones(shape, dtype=None, order="C"):
576+
if order != "C":
577+
raise NotImplementedError
578+
if dtype is None:
579+
dtype = _dtypes_impl.default_float_dtype
580+
result = torch.ones(shape, dtype=dtype)
581+
return result
582+
583+
584+
def ones_like(a, dtype=None, shape=None, order="K"):
585+
if order != "K":
586+
raise NotImplementedError
534587
result = torch.ones_like(a, dtype=dtype)
535588
if shape is not None:
536589
result = result.reshape(shape)
537590
return result
538591

539592

540-
def full_like(a, fill_value, dtype=None, shape=None):
593+
def full_like(a, fill_value, dtype=None, shape=None, order="K"):
594+
if order != "K":
595+
raise NotImplementedError
541596
# XXX: fill_value broadcasts
542597
result = torch.full_like(a, fill_value, dtype=dtype)
543598
if shape is not None:
544599
result = result.reshape(shape)
545600
return result
546601

547602

548-
def empty_like(prototype, dtype=None, shape=None):
603+
def empty(shape, dtype=None, order="C"):
604+
if order != "C":
605+
raise NotImplementedError
606+
if dtype is None:
607+
dtype = _dtypes_impl.default_float_dtype
608+
result = torch.empty(shape, dtype=dtype)
609+
return result
610+
611+
612+
def empty_like(prototype, dtype=None, shape=None, order="K"):
613+
if order != "K":
614+
raise NotImplementedError
549615
result = torch.empty_like(prototype, dtype=dtype)
550616
if shape is not None:
551617
result = result.reshape(shape)
552618
return result
553619

554620

555-
def full(shape, fill_value, dtype=None):
621+
def full(shape, fill_value, dtype=None, order="C"):
622+
if isinstance(shape, int):
623+
shape = (shape,)
624+
if order != "C":
625+
raise NotImplementedError
556626
if dtype is None:
557627
dtype = fill_value.dtype
558628
if not isinstance(shape, (tuple, list)):

torch_np/_funcs.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis=None):
5050
return _helpers.array_from(result)
5151

5252

53+
@normalizer
54+
def tile(A: ArrayLike, reps):
55+
result = _impl.tile(A, reps)
56+
return _helpers.array_from(result)
57+
58+
5359
# ### diag et al ###
5460

5561

@@ -449,8 +455,45 @@ def quantile(
449455
*,
450456
interpolation=None,
451457
):
452-
if interpolation is not None:
453-
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
458+
result = _impl.quantile(
459+
a,
460+
q,
461+
axis,
462+
overwrite_input=overwrite_input,
463+
method=method,
464+
keepdims=keepdims,
465+
interpolation=interpolation,
466+
)
467+
return _helpers.result_or_out(result, out, promote_scalar=True)
454468

455-
result = _impl.quantile(a, q, axis, method=method, keepdims=keepdims)
469+
470+
@normalizer
471+
def percentile(
472+
a: ArrayLike,
473+
q: ArrayLike,
474+
axis: AxisLike = None,
475+
out: Optional[NDArray] = None,
476+
overwrite_input=False,
477+
method="linear",
478+
keepdims=False,
479+
*,
480+
interpolation=None,
481+
):
482+
result = _impl.percentile(
483+
a,
484+
q,
485+
axis,
486+
overwrite_input=overwrite_input,
487+
method=method,
488+
keepdims=keepdims,
489+
interpolation=interpolation,
490+
)
456491
return _helpers.result_or_out(result, out, promote_scalar=True)
492+
493+
494+
def median(
495+
a, axis=None, out: Optional[NDArray] = None, overwrite_input=False, keepdims=False
496+
):
497+
return quantile(
498+
a, 0.5, axis=axis, overwrite_input=overwrite_input, out=out, keepdims=keepdims
499+
)

0 commit comments

Comments
 (0)