Skip to content

Commit 5fea60f

Browse files
committed
MAINT: avoid synchronization/graph break in percentile
any(...) is bad for torch dynamo, so remove the check. The same check is done internally in torch.quantile, so we're only paying by a RuntimeError instead of NumPy's ValueError.
1 parent f94edc5 commit 5fea60f

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

torch_np/_detail/_reductions.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,9 +390,6 @@ def quantile(
390390
if interpolation is not None:
391391
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
392392

393-
if (0 > q).any() or (q > 1).any():
394-
raise ValueError("Quantiles must be in range [0, 1], got %s" % q)
395-
396393
if not a.dtype.is_floating_point:
397394
dtype = _dtypes_impl.default_float_dtype
398395
a = a.to(dtype)

torch_np/tests/numpy_tests/lib/test_function_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2900,10 +2900,10 @@ def test_scalar_q_2(self):
29002900
def test_exception(self):
29012901
assert_raises((RuntimeError, ValueError), np.percentile, [1, 2], 56,
29022902
method='foobar')
2903-
assert_raises(ValueError, np.percentile, [1], 101)
2904-
assert_raises(ValueError, np.percentile, [1], -1)
2905-
assert_raises(ValueError, np.percentile, [1], list(range(50)) + [101])
2906-
assert_raises(ValueError, np.percentile, [1], list(range(50)) + [-0.1])
2903+
assert_raises((RuntimeError, ValueError), np.percentile, [1], 101)
2904+
assert_raises((RuntimeError, ValueError), np.percentile, [1], -1)
2905+
assert_raises((RuntimeError, ValueError), np.percentile, [1], list(range(50)) + [101])
2906+
assert_raises((RuntimeError, ValueError), np.percentile, [1], list(range(50)) + [-0.1])
29072907

29082908
def test_percentile_list(self):
29092909
assert_equal(np.percentile([1, 2, 3], 0), 1)

0 commit comments

Comments
 (0)