Skip to content

Commit ecf720e

Browse files
committed
ENH: implement median
1 parent 856fec6 commit ecf720e

File tree

3 files changed

+34
-16
lines changed

3 files changed

+34
-16
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -602,10 +602,6 @@ def may_share_memory(a, b, /, max_work=None):
602602
raise NotImplementedError
603603

604604

605-
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
606-
raise NotImplementedError
607-
608-
609605
def meshgrid(*xi, copy=True, sparse=False, indexing="xy"):
610606
raise NotImplementedError
611607

torch_np/_wrapper.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -708,16 +708,8 @@ def percentile(
708708
*,
709709
interpolation=None,
710710
):
711-
if interpolation is not None:
712-
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
713-
714-
a_tensor, q_tensor = _helpers.to_tensors(a, q)
715-
result = _reductions.quantile(a_tensor, q_tensor / 100., axis, method)
716-
717-
# keepdims
718-
if keepdims:
719-
result = _util.apply_keepdims(result, axis, a_tensor.ndim)
720-
return _helpers.result_or_out(result, out, promote_scalar=True)
711+
return quantile(a, asarray(q) / 100., axis, out, overwrite_input, method,
712+
keepdims=keepdims)
721713

722714

723715
def quantile(
@@ -743,6 +735,9 @@ def quantile(
743735
return _helpers.result_or_out(result, out, promote_scalar=True)
744736

745737

738+
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
739+
return quantile(a, 0.5, axis=axis, overwrite_input=overwrite_input,
740+
out=out, keepdims=keepdims)
746741

747742

748743
@asarray_replacer()

torch_np/tests/numpy_tests/lib/test_function_base.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3474,7 +3474,7 @@ def test_linear_interpolation_formula_0d_inputs(self):
34743474
assert nfb._lerp(a, b, t) == 2.6
34753475

34763476

3477-
@pytest.mark.xfail(reason='TODO: implement')
3477+
#@pytest.mark.xfail(reason='TODO: implement')
34783478
class TestMedian:
34793479

34803480
def test_basic(self):
@@ -3494,7 +3494,11 @@ def test_basic(self):
34943494
assert_equal(a[0], np.median(a))
34953495
a = np.array([0.0444502, 0.141249, 0.0463301])
34963496
assert_equal(a[-1], np.median(a))
3497+
3498+
@pytest.mark.xfail(reason="median: scalar output vs 0-dim")
3499+
def test_basic_2(self):
34973500
# check array scalar result
3501+
a = np.array([0.0444502, 0.141249, 0.0463301])
34983502
assert_equal(np.median(a).ndim, 0)
34993503
a[1] = np.nan
35003504
assert_equal(np.median(a).ndim, 0)
@@ -3590,7 +3594,7 @@ def test_nan_behavior(self):
35903594

35913595
# no axis
35923596
assert_equal(np.median(a), np.nan)
3593-
assert_equal(np.median(a).ndim, 0)
3597+
# assert_equal(np.median(a).ndim, 0)
35943598

35953599
# axis0
35963600
b = np.median(np.arange(24, dtype=float).reshape(2, 3, 4), 0)
@@ -3604,12 +3608,29 @@ def test_nan_behavior(self):
36043608
b[1, 2] = np.nan
36053609
assert_equal(np.median(a, 1), b)
36063610

3611+
3612+
@pytest.mark.xfail(reason="median: does not support tuple axes")
3613+
def test_nan_behavior_2(self):
3614+
a = np.arange(24, dtype=float).reshape(2, 3, 4)
3615+
a[1, 2, 3] = np.nan
3616+
a[1, 1, 2] = np.nan
3617+
36073618
# axis02
36083619
b = np.median(np.arange(24, dtype=float).reshape(2, 3, 4), (0, 2))
36093620
b[1] = np.nan
36103621
b[2] = np.nan
36113622
assert_equal(np.median(a, (0, 2)), b)
36123623

3624+
@pytest.mark.xfail(reason="median: scalar vs 0-dim")
3625+
def test_nan_behavior_3(self):
3626+
a = np.arange(24, dtype=float).reshape(2, 3, 4)
3627+
a[1, 2, 3] = np.nan
3628+
a[1, 1, 2] = np.nan
3629+
3630+
# no axis
3631+
assert_equal(np.median(a).ndim, 0)
3632+
3633+
@pytest.mark.xfail(reason="median: torch.quantile does not handle empty tensors")
36133634
@pytest.mark.skipif(IS_WASM, reason="fp errors don't work correctly")
36143635
def test_empty(self):
36153636
# mean(empty array) emits two warnings: empty slice and divide by 0
@@ -3640,6 +3661,7 @@ def test_empty(self):
36403661
assert_equal(np.median(a, axis=2), b)
36413662
assert_(w[0].category is RuntimeWarning)
36423663

3664+
@pytest.mark.xfail(reason="median: tuple axes not implemented")
36433665
def test_extended_axis(self):
36443666
o = np.random.normal(size=(71, 23))
36453667
x = np.dstack([o] * 10)
@@ -3682,6 +3704,10 @@ def test_keepdims(self):
36823704
d = np.ones((3, 5, 7, 11))
36833705
assert_equal(np.median(d, axis=None, keepdims=True).shape,
36843706
(1, 1, 1, 1))
3707+
3708+
@pytest.mark.xfail(reason="median: tuple axis")
3709+
def test_keepdims_2(self):
3710+
d = np.ones((3, 5, 7, 11))
36853711
assert_equal(np.median(d, axis=(0, 1), keepdims=True).shape,
36863712
(1, 1, 7, 11))
36873713
assert_equal(np.median(d, axis=(0, 3), keepdims=True).shape,
@@ -3693,6 +3719,7 @@ def test_keepdims(self):
36933719
assert_equal(np.median(d, axis=(0, 1, 3), keepdims=True).shape,
36943720
(1, 1, 7, 1))
36953721

3722+
@pytest.mark.xfail(reason="median: tuple axis")
36963723
@pytest.mark.parametrize(
36973724
argnames='axis',
36983725
argvalues=[

0 commit comments

Comments
 (0)