Skip to content

Commit bc4bcb6

Browse files
committed
Distinct cache keys when api_version=None
1 parent 057d4de commit bc4bcb6

File tree

2 files changed

+44
-11
lines changed

2 files changed

+44
-11
lines changed

hypothesis-python/src/hypothesis/extra/array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,7 @@ def complex_dtypes(
10761076

10771077
namespace = StrategiesNamespace(**kwargs)
10781078
try:
1079-
_args_to_xps[(xp, api_version)] = namespace
1079+
_args_to_xps[(xp, None if inferred_version else api_version)] = namespace
10801080
except TypeError:
10811081
pass
10821082

hypothesis-python/tests/array_api/test_strategies_namespace.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,32 @@
2121
)
2222
from hypothesis.strategies import SearchStrategy
2323

24+
pytestmark = pytest.mark.filterwarnings("ignore::hypothesis.errors.HypothesisWarning")
2425

25-
@pytest.mark.filterwarnings("ignore::hypothesis.errors.HypothesisWarning")
26-
def test_caching(xp, monkeypatch):
26+
27+
class HashableArrayModuleFactory:
28+
"""
29+
mock_xp cannot be hashed and thus cannot be used in our cache. So just for
30+
the purposes of testing the cache, we wrap it with an unsafe hash method.
31+
"""
32+
33+
def __getattr__(self, name):
34+
return getattr(mock_xp, name)
35+
36+
def __hash__(self):
37+
return hash(tuple(sorted(mock_xp.__dict__)))
38+
39+
40+
@pytest.mark.parametrize("api_version", ["2021.12", None])
41+
def test_caching(api_version, monkeypatch):
2742
"""Caches namespaces respective to arguments."""
28-
try:
29-
hash(xp)
30-
except TypeError:
31-
pytest.skip("xp not hashable")
32-
assert isinstance(array_api._args_to_xps, WeakValueDictionary)
43+
xp = HashableArrayModuleFactory()
44+
assert isinstance(array_api._args_to_xps, WeakValueDictionary) # sanity check
3345
monkeypatch.setattr(array_api, "_args_to_xps", WeakValueDictionary())
3446
assert len(array_api._args_to_xps) == 0 # sanity check
35-
xps1 = array_api.make_strategies_namespace(xp, api_version="2021.12")
47+
xps1 = array_api.make_strategies_namespace(xp, api_version=api_version)
3648
assert len(array_api._args_to_xps) == 1
37-
xps2 = array_api.make_strategies_namespace(xp, api_version="2021.12")
49+
xps2 = array_api.make_strategies_namespace(xp, api_version=api_version)
3850
assert len(array_api._args_to_xps) == 1
3951
assert isinstance(xps2, SimpleNamespace)
4052
assert xps2 is xps1
@@ -43,7 +55,28 @@ def test_caching(xp, monkeypatch):
4355
assert len(array_api._args_to_xps) == 0
4456

4557

46-
@pytest.mark.filterwarnings("ignore::hypothesis.errors.HypothesisWarning")
58+
@pytest.mark.parametrize(
59+
"api_version1, api_version2", [(None, "2021.12"), ("2021.12", None)]
60+
)
61+
def test_inferred_namespace_is_cached_seperately(
62+
api_version1, api_version2, monkeypatch
63+
):
64+
"""Results from inferred versions do not share the same cache key as results
65+
from specified versions."""
66+
xp = HashableArrayModuleFactory()
67+
xp.__array_api_version__ = "2021.12"
68+
assert isinstance(array_api._args_to_xps, WeakValueDictionary) # sanity check
69+
monkeypatch.setattr(array_api, "_args_to_xps", WeakValueDictionary())
70+
assert len(array_api._args_to_xps) == 0 # sanity check
71+
xps1 = array_api.make_strategies_namespace(xp, api_version=api_version1)
72+
assert xps1.api_version == "2021.12" # sanity check
73+
assert len(array_api._args_to_xps) == 1
74+
xps2 = array_api.make_strategies_namespace(xp, api_version=api_version2)
75+
assert xps2.api_version == "2021.12" # sanity check
76+
assert len(array_api._args_to_xps) == 2
77+
assert xps2 is not xps1
78+
79+
4780
def test_complex_dtypes_raises_on_2021_12():
4881
"""Accessing complex_dtypes() for 2021.12 strategy namespace raises helpful
4982
error, but accessing on future versions returns expected strategy."""

0 commit comments

Comments
 (0)