Skip to content

Commit 856fec6

Browse files
committed
ENH: implement quantile
1 parent 4071246 commit 856fec6

File tree

4 files changed

+36
-23
lines changed

4 files changed

+36
-23
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -794,20 +794,6 @@ def putmask(a, mask, values):
794794
raise NotImplementedError
795795

796796

797-
def quantile(
798-
a,
799-
q,
800-
axis=None,
801-
out=None,
802-
overwrite_input=False,
803-
method="linear",
804-
keepdims=False,
805-
*,
806-
interpolation=None,
807-
):
808-
raise NotImplementedError
809-
810-
811797
def ravel(a, order="C"):
812798
raise NotImplementedError
813799

torch_np/_detail/_reductions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,10 @@ def average(a_tensor, axis, w_tensor):
228228

229229

230230

231-
def percentile(a_tensor, q_tensor, axis, method):
231+
def quantile(a_tensor, q_tensor, axis, method):
232232

233-
if (0 > q_tensor).any() or (q_tensor > 100).any():
234-
raise ValueError("Percentiles must be in range [0, 100], got %s" % q_tensor)
233+
if (0 > q_tensor).any() or (q_tensor > 1).any():
234+
raise ValueError("Quantiles must be in range [0, 1], got %s" % q_tensor)
235235

236236
if not a_tensor.dtype.is_floating_point:
237237
dtype = _scalar_types.default_float_type.torch_dtype
@@ -246,7 +246,7 @@ def percentile(a_tensor, q_tensor, axis, method):
246246
axis = _util.normalize_axis_tuple(axis, a_tensor.ndim)
247247
axis = _util.allow_only_single_axis(axis)
248248

249-
q_tensor = (q_tensor / 100.0).to(a_tensor.dtype)
249+
q_tensor = q_tensor.to(a_tensor.dtype)
250250

251251
(a_tensor, q_tensor), axis = _util.axis_none_ravel(a_tensor, q_tensor, axis=axis)
252252

torch_np/_wrapper.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -712,16 +712,39 @@ def percentile(
712712
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
713713

714714
a_tensor, q_tensor = _helpers.to_tensors(a, q)
715-
716-
result = _reductions.percentile(a_tensor, q_tensor, axis, method)
715+
result = _reductions.quantile(a_tensor, q_tensor / 100., axis, method)
717716

718717
# keepdims
719718
if keepdims:
720719
result = _util.apply_keepdims(result, axis, a_tensor.ndim)
720+
return _helpers.result_or_out(result, out, promote_scalar=True)
721+
721722

723+
def quantile(
724+
a,
725+
q,
726+
axis=None,
727+
out=None,
728+
overwrite_input=False,
729+
method="linear",
730+
keepdims=False,
731+
*,
732+
interpolation=None,
733+
):
734+
if interpolation is not None:
735+
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
736+
737+
a_tensor, q_tensor = _helpers.to_tensors(a, q)
738+
result = _reductions.quantile(a_tensor, q_tensor, axis, method)
739+
740+
# keepdims
741+
if keepdims:
742+
result = _util.apply_keepdims(result, axis, a_tensor.ndim)
722743
return _helpers.result_or_out(result, out, promote_scalar=True)
723744

724745

746+
747+
725748
@asarray_replacer()
726749
def nanmean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoValue):
727750
if where is not NoValue:

torch_np/tests/numpy_tests/lib/test_function_base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3300,10 +3300,11 @@ def test_nan_q(self):
33003300
np.percentile([1, 2, 3, 4.0], q)
33013301

33023302

3303-
@pytest.mark.xfail(reason='TODO: implement')
3303+
33043304
class TestQuantile:
33053305
# most of this is already tested by TestPercentile
33063306

3307+
@pytest.mark.skip(reason="do not chase 1ulp")
33073308
def test_max_ulp(self):
33083309
x = [0.0, 0.2, 0.4]
33093310
a = np.quantile(x, 0.45)
@@ -3318,6 +3319,7 @@ def test_basic(self):
33183319
assert_equal(np.quantile(x, 1), 3.5)
33193320
assert_equal(np.quantile(x, 0.5), 1.75)
33203321

3322+
@pytest.mark.xfail(reason="quantile w/integers or bools")
33213323
def test_correct_quantile_value(self):
33223324
a = np.array([True])
33233325
tf_quant = np.quantile(True, False)
@@ -3328,6 +3330,7 @@ def test_correct_quantile_value(self):
33283330
assert_array_equal(quant_res, a)
33293331
assert_equal(quant_res.dtype, a.dtype)
33303332

3333+
@pytest.mark.skip(reason="support arrays of Fractions?")
33313334
def test_fraction(self):
33323335
# fractional input, integral quantile
33333336
x = [Fraction(i, 2) for i in range(8)]
@@ -3355,10 +3358,9 @@ def test_fraction(self):
33553358
x = np.arange(8)
33563359
assert_equal(np.quantile(x, Fraction(1, 2)), Fraction(7, 2))
33573360

3361+
@pytest.mark.skip(reason="does not raise in numpy?")
33583362
def test_complex(self):
33593363
#See gh-22652
3360-
arr_c = np.array([0.5+3.0j, 2.1+0.5j, 1.6+2.3j], dtype='G')
3361-
assert_raises(TypeError, np.quantile, arr_c, 0.5)
33623364
arr_c = np.array([0.5+3.0j, 2.1+0.5j, 1.6+2.3j], dtype='D')
33633365
assert_raises(TypeError, np.quantile, arr_c, 0.5)
33643366
arr_c = np.array([0.5+3.0j, 2.1+0.5j, 1.6+2.3j], dtype='F')
@@ -3376,12 +3378,14 @@ def test_no_p_overwrite(self):
33763378
np.quantile(np.arange(100.), p, method="midpoint")
33773379
assert_array_equal(p, p0)
33783380

3381+
@pytest.mark.xfail(reason="TODO: make quantile preserve integers")
33793382
@pytest.mark.parametrize("dtype", np.typecodes["AllInteger"])
33803383
def test_quantile_preserve_int_type(self, dtype):
33813384
res = np.quantile(np.array([1, 2], dtype=dtype), [0.5],
33823385
method="nearest")
33833386
assert res.dtype == dtype
33843387

3388+
@pytest.mark.xfail(reason="1) np.sort not implemented; 2) methods")
33853389
@pytest.mark.parametrize("method",
33863390
['inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
33873391
'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear',

0 commit comments

Comments
 (0)