Skip to content

Commit b663377

Browse files
committed
Speed up test for hard cases
Fixes #3010.
1 parent 2f44209 commit b663377

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

hypothesis-python/tests/numpy/test_gen_data.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,15 @@
1515
import numpy as np
1616
import pytest
1717

18-
from hypothesis import HealthCheck, assume, given, note, settings, strategies as st
18+
from hypothesis import (
19+
HealthCheck,
20+
assume,
21+
given,
22+
note,
23+
settings,
24+
strategies as st,
25+
target,
26+
)
1927
from hypothesis.errors import InvalidArgument, UnsatisfiedAssumption
2028
from hypothesis.extra import numpy as nps
2129

@@ -1050,7 +1058,7 @@ def test_advanced_integer_index_minimizes_as_documented(
10501058
np.testing.assert_array_equal(s, d)
10511059

10521060

1053-
@settings(deadline=None, max_examples=10)
1061+
@settings(deadline=None, max_examples=25)
10541062
@given(
10551063
shape=nps.array_shapes(min_dims=1, max_dims=2, min_side=1, max_side=3),
10561064
data=st.data(),
@@ -1059,19 +1067,24 @@ def test_advanced_integer_index_can_generate_any_pattern(shape, data):
10591067
# ensures that generated index-arrays can be used to yield any pattern of elements from an array
10601068
x = np.arange(np.product(shape)).reshape(shape)
10611069

1062-
target = data.draw(
1070+
target_array = data.draw(
10631071
nps.arrays(
10641072
shape=nps.array_shapes(min_dims=1, max_dims=2, min_side=1, max_side=2),
10651073
elements=st.sampled_from(x.flatten()),
10661074
dtype=x.dtype,
10671075
),
10681076
label="target",
10691077
)
1078+
1079+
def index_selects_values_in_order(index):
1080+
selected = x[index]
1081+
target(len(set(selected.flatten())), label="unique indices")
1082+
target(float(np.sum(target_array == selected)), label="elements correct")
1083+
return np.all(target_array == selected)
1084+
10701085
find_any(
1071-
nps.integer_array_indices(
1072-
shape, result_shape=st.just(target.shape), dtype=np.dtype("int8")
1073-
),
1074-
lambda index: np.all(target == x[index]),
1086+
nps.integer_array_indices(shape, result_shape=st.just(target_array.shape)),
1087+
index_selects_values_in_order,
10751088
settings(max_examples=10**6),
10761089
)
10771090

0 commit comments

Comments
 (0)