Skip to content

Commit 68f6c6f

Browse files
committed
ENH: add np.ptp and ndarray.ptp
1 parent ecf720e commit 68f6c6f

File tree

5 files changed

+11
-4
lines changed

5 files changed

+11
-4
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -774,9 +774,6 @@ def product(*args, **kwargs):
774774
raise NotImplementedError
775775

776776

777-
def ptp(a, axis=None, out=None, keepdims=NoValue):
778-
raise NotImplementedError
779-
780777

781778
def put(a, ind, v, mode="raise"):
782779
raise NotImplementedError

torch_np/_detail/_reductions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ def min(tensor, axis=None, initial=NoValue, where=NoValue):
8989
return result
9090

9191

92+
def ptp(tensor, axis=None):
93+
result = tensor.amax(axis) - tensor.amin(axis)
94+
return result
95+
96+
9297
def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
9398
if initial is not NoValue or where is not NoValue:
9499
raise NotImplementedError

torch_np/_ndarray.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def nonzero(self):
282282
all = emulate_out_arg(axis_keepdims_wrapper(_reductions.all))
283283
max = emulate_out_arg(axis_keepdims_wrapper(_reductions.max))
284284
min = emulate_out_arg(axis_keepdims_wrapper(_reductions.min))
285+
ptp = emulate_out_arg(axis_keepdims_wrapper(_reductions.ptp))
285286

286287
sum = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.sum)))
287288
prod = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.prod)))

torch_np/_wrapper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,11 @@ def amin(a, axis=None, out=None, keepdims=NoValue, initial=NoValue, where=NoValu
604604
min = amin
605605

606606

607+
def ptp(a, axis=None, out=None, keepdims=NoValue):
608+
arr = asarray(a)
609+
return arr.ptp(axis=axis, out=out, keepdims=keepdims)
610+
611+
607612
def all(a, axis=None, out=None, keepdims=NoValue, *, where=NoValue):
608613
arr = asarray(a)
609614
return arr.all(axis=axis, out=out, keepdims=keepdims, where=where)

torch_np/tests/numpy_tests/lib/test_function_base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,6 @@ def test_basic(self):
593593
assert_equal(np.amin(b, axis=1), [3.0, 4.0, 2.0])
594594

595595

596-
@pytest.mark.xfail(reason='TODO: implement')
597596
class TestPtp:
598597

599598
def test_basic(self):

0 commit comments

Comments
 (0)