21
21
)
22
22
from hypothesis .strategies import SearchStrategy
23
23
24
+ pytestmark = pytest .mark .filterwarnings ("ignore::hypothesis.errors.HypothesisWarning" )
24
25
25
- @pytest .mark .filterwarnings ("ignore::hypothesis.errors.HypothesisWarning" )
26
- def test_caching (xp , monkeypatch ):
26
+
27
+ class HashableArrayModuleFactory :
28
+ """
29
+ mock_xp cannot be hashed and thus cannot be used in our cache. So just for
30
+ the purposes of testing the cache, we wrap it with an unsafe hash method.
31
+ """
32
+
33
+ def __getattr__ (self , name ):
34
+ return getattr (mock_xp , name )
35
+
36
+ def __hash__ (self ):
37
+ return hash (tuple (sorted (mock_xp .__dict__ )))
38
+
39
+
40
+ @pytest .mark .parametrize ("api_version" , ["2021.12" , None ])
41
+ def test_caching (api_version , monkeypatch ):
27
42
"""Caches namespaces respective to arguments."""
28
- try :
29
- hash (xp )
30
- except TypeError :
31
- pytest .skip ("xp not hashable" )
32
- assert isinstance (array_api ._args_to_xps , WeakValueDictionary )
43
+ xp = HashableArrayModuleFactory ()
44
+ assert isinstance (array_api ._args_to_xps , WeakValueDictionary ) # sanity check
33
45
monkeypatch .setattr (array_api , "_args_to_xps" , WeakValueDictionary ())
34
46
assert len (array_api ._args_to_xps ) == 0 # sanity check
35
- xps1 = array_api .make_strategies_namespace (xp , api_version = "2021.12" )
47
+ xps1 = array_api .make_strategies_namespace (xp , api_version = api_version )
36
48
assert len (array_api ._args_to_xps ) == 1
37
- xps2 = array_api .make_strategies_namespace (xp , api_version = "2021.12" )
49
+ xps2 = array_api .make_strategies_namespace (xp , api_version = api_version )
38
50
assert len (array_api ._args_to_xps ) == 1
39
51
assert isinstance (xps2 , SimpleNamespace )
40
52
assert xps2 is xps1
@@ -43,7 +55,28 @@ def test_caching(xp, monkeypatch):
43
55
assert len (array_api ._args_to_xps ) == 0
44
56
45
57
46
- @pytest .mark .filterwarnings ("ignore::hypothesis.errors.HypothesisWarning" )
58
+ @pytest .mark .parametrize (
59
+ "api_version1, api_version2" , [(None , "2021.12" ), ("2021.12" , None )]
60
+ )
61
+ def test_inferred_namespace_is_cached_seperately (
62
+ api_version1 , api_version2 , monkeypatch
63
+ ):
64
+ """Results from inferred versions do not share the same cache key as results
65
+ from specified versions."""
66
+ xp = HashableArrayModuleFactory ()
67
+ xp .__array_api_version__ = "2021.12"
68
+ assert isinstance (array_api ._args_to_xps , WeakValueDictionary ) # sanity check
69
+ monkeypatch .setattr (array_api , "_args_to_xps" , WeakValueDictionary ())
70
+ assert len (array_api ._args_to_xps ) == 0 # sanity check
71
+ xps1 = array_api .make_strategies_namespace (xp , api_version = api_version1 )
72
+ assert xps1 .api_version == "2021.12" # sanity check
73
+ assert len (array_api ._args_to_xps ) == 1
74
+ xps2 = array_api .make_strategies_namespace (xp , api_version = api_version2 )
75
+ assert xps2 .api_version == "2021.12" # sanity check
76
+ assert len (array_api ._args_to_xps ) == 2
77
+ assert xps2 is not xps1
78
+
79
+
47
80
def test_complex_dtypes_raises_on_2021_12 ():
48
81
"""Accessing complex_dtypes() for 2021.12 strategy namespace raises helpful
49
82
error, but accessing on future versions returns expected strategy."""
0 commit comments