diff --git a/jsonschema/tests/test_validators.py b/jsonschema/tests/test_validators.py index 02a584b4c..5e4004eb0 100644 --- a/jsonschema/tests/test_validators.py +++ b/jsonschema/tests/test_validators.py @@ -1622,6 +1622,36 @@ def handler(url): with resolver.resolving(ref): pass + def test_custom_cache_decorators(self): + response = [object(), object()] + + def handler(url): + return response.pop() + + def mock_cache_dec(f): + return f + + # We need cache_remote=False here, because there are two layers of + # caches: the resolver.store, which is disabled by cache_remote, and + # the default lru_cache on resolver.resolve_from_url, which we disable + # with mock_cache_dec. + + ref = "foo://bar" + resolver = validators.RefResolver( + "", {}, remote_cache_dec=mock_cache_dec, + urljoin_cache_dec=mock_cache_dec, + cache_remote=False, + handlers={"foo": handler}, + ) + with resolver.resolving(ref): + pass + with resolver.resolving(ref): + pass + + # Since there should be no caching, the handler must have been called + # twice, so the "response" list should be empty now. + self.assertEqual(len(response), 0) + def test_if_you_give_it_junk_you_get_a_resolution_error(self): error = ValueError("Oh no! What's this?") diff --git a/jsonschema/validators.py b/jsonschema/validators.py index 650b5d17b..051bbcac6 100644 --- a/jsonschema/validators.py +++ b/jsonschema/validators.py @@ -588,14 +588,14 @@ class RefResolver(object): A mapping from URI schemes to functions that should be used to retrieve them - urljoin_cache (functools.lru_cache): + urljoin_cache_dec (functools.lru_cache): - A cache that will be used for caching the results of joining - the resolution scope to subscopes. + A cache decorator that will be used for caching the results of + joining the resolution scope to subscopes. - remote_cache (functools.lru_cache): + remote_cache_dec (functools.lru_cache): - A cache that will be used for caching the results of + A cache decorator that will be used for caching the results of resolved remote URLs. Attributes: @@ -614,11 +614,18 @@ def __init__( handlers=(), urljoin_cache=None, remote_cache=None, + urljoin_cache_dec=None, + remote_cache_dec=None, ): + if urljoin_cache_dec is None: + urljoin_cache_dec = lru_cache(1024) + if remote_cache_dec is None: + remote_cache_dec = lru_cache(1024) + if urljoin_cache is None: - urljoin_cache = lru_cache(1024)(urljoin) + urljoin_cache = urljoin_cache_dec(urljoin) if remote_cache is None: - remote_cache = lru_cache(1024)(self.resolve_from_url) + remote_cache = remote_cache_dec(self.resolve_from_url) self.referrer = referrer self.cache_remote = cache_remote