diff --git a/jsonschema/_utils.py b/jsonschema/_utils.py index 5b245a34a..4c114c99e 100644 --- a/jsonschema/_utils.py +++ b/jsonschema/_utils.py @@ -3,7 +3,8 @@ import pkgutil import re -from jsonschema.compat import str_types, MutableMapping, urlsplit +from jsonschema.compat import str_types, MutableMapping +from uritools import urisplit, uriunsplit class URIDict(MutableMapping): @@ -13,7 +14,9 @@ class URIDict(MutableMapping): """ def normalize(self, uri): - return urlsplit(uri).geturl() + result = urisplit(uri) + return uriunsplit((result.scheme, result.authority, + result.path, None, None)) def __init__(self, *args, **kwargs): self.store = dict() diff --git a/jsonschema/_validators.py b/jsonschema/_validators.py index 0d5b2b009..e222da69a 100644 --- a/jsonschema/_validators.py +++ b/jsonschema/_validators.py @@ -301,20 +301,9 @@ def enum(validator, enums, instance, schema): def ref(validator, ref, instance, schema): - resolve = getattr(validator.resolver, "resolve", None) - if resolve is None: - with validator.resolver.resolving(ref) as resolved: - for error in validator.descend(instance, resolved): - yield error - else: - scope, resolved = validator.resolver.resolve(ref) - validator.resolver.push_scope(scope) - - try: - for error in validator.descend(instance, resolved): - yield error - finally: - validator.resolver.pop_scope() + with validator.resolver.resolving(ref) as resolved: + for error in validator.descend(instance, resolved): + yield error def type_draft3(validator, types, instance, schema): diff --git a/jsonschema/compat.py b/jsonschema/compat.py index ff91fe620..6bf6b7d90 100644 --- a/jsonschema/compat.py +++ b/jsonschema/compat.py @@ -14,7 +14,7 @@ from functools import lru_cache from io import StringIO from urllib.parse import ( - unquote, urljoin, urlunsplit, SplitResult, urlsplit as _urlsplit + unquote, urljoin, urlunsplit, SplitResult ) from urllib.request import urlopen str_types = str, @@ -24,7 +24,7 @@ from itertools import izip as zip # noqa from StringIO import StringIO from urlparse import ( - urljoin, urlunsplit, SplitResult, urlsplit as _urlsplit # noqa + urljoin, urlunsplit, SplitResult # noqa ) from urllib import unquote # noqa from urllib2 import urlopen # noqa @@ -34,23 +34,4 @@ from functools32 import lru_cache - -# On python < 3.3 fragments are not handled properly with unknown schemes -def urlsplit(url): - scheme, netloc, path, query, fragment = _urlsplit(url) - if "#" in path: - path, fragment = path.split("#", 1) - return SplitResult(scheme, netloc, path, query, fragment) - - -def urldefrag(url): - if "#" in url: - s, n, p, q, frag = urlsplit(url) - defrag = urlunsplit((s, n, p, q, '')) - else: - defrag = url - frag = '' - return defrag, frag - - # flake8: noqa diff --git a/jsonschema/tests/test_validators.py b/jsonschema/tests/test_validators.py index f2f5a1f8b..a1e3a7bcf 100644 --- a/jsonschema/tests/test_validators.py +++ b/jsonschema/tests/test_validators.py @@ -870,12 +870,11 @@ def test_it_delegates_to_a_ref_resolver(self): resolver = validators.RefResolver("", {}) schema = {"$ref": mock.Mock()} - with mock.patch.object(resolver, "resolve") as resolve: - resolve.return_value = "url", {"type": "integer"} + with mock.patch.object(resolver, "resolving") as resolving: + resolving.return_value.__enter__.return_value = {"type": "integer"} with self.assertRaises(ValidationError): self.validator_class(schema, resolver=resolver).validate(None) - - resolve.assert_called_once_with(schema["$ref"]) + resolving.assert_called_once_with(schema["$ref"]) def test_it_delegates_to_a_legacy_ref_resolver(self): """ @@ -885,6 +884,10 @@ def test_it_delegates_to_a_legacy_ref_resolver(self): """ class LegacyRefResolver(object): + @contextmanager + def in_scope(self, scope): + yield + @contextmanager def resolving(this, ref): self.assertEqual(ref, "the ref") @@ -1178,10 +1181,10 @@ def test_it_retrieves_unstored_refs_via_requests(self): schema = {"baz": 12} with MockImport("requests", mock.Mock()) as requests: - requests.get.return_value.json.return_value = schema + requests.Session.get.return_value.json.return_value = schema with self.resolver.resolving(ref) as resolved: self.assertEqual(resolved, 12) - requests.get.assert_called_once_with("http://bar") + requests.Session().get.assert_called_once_with("http://bar") def test_it_retrieves_unstored_refs_via_urlopen(self): ref = "http://bar#baz" @@ -1195,6 +1198,42 @@ def test_it_retrieves_unstored_refs_via_urlopen(self): self.assertEqual(resolved, 12) urlopen.assert_called_once_with("http://bar") + def test_it_retrieves_unstored_file_refs_via_urlopen(self): + ref = "file://bar.json#baz" + schema = {"baz": 12} + + with MockImport("requests", None): + with mock.patch("jsonschema.validators.urlopen") as urlopen: + urlopen.return_value.read.return_value = ( + json.dumps(schema).encode("utf8")) + with self.resolver.resolving(ref) as resolved: + self.assertEqual(resolved, 12) + urlopen.assert_called_once_with("file://bar.json") + + def test_it_retrieves_unstored_file_refs_via_requests_file(self): + ref = "file://bar.json#baz" + schema = {"baz": 12} + + with MockImport("requests", mock.Mock()) as requests: + with MockImport("requests_file", mock.Mock()): + requests.Session.get.return_value.json.return_value = schema + with self.resolver.resolving(ref) as resolved: + self.assertEqual(resolved, 12) + requests.Session().get.assert_called_once_with("file://bar.json") + + def test_it_retrieves_unstored_refs_via_urlopen_no_requests_file(self): + ref = "file://bar.json#baz" + schema = {"baz": 12} + + with MockImport("requests", mock.Mock()) as requests: + with mock.patch("jsonschema.validators.urlopen") as urlopen: + urlopen.return_value.read.return_value = ( + json.dumps(schema).encode("utf8")) + with self.resolver.resolving(ref) as resolved: + self.assertEqual(resolved, 12) + requests.Session().get.assert_not_called() + urlopen.assert_called_once_with("file://bar.json") + def test_it_can_construct_a_base_uri_from_a_schema(self): schema = {"id": "foo"} resolver = validators.RefResolver.from_schema( diff --git a/jsonschema/validators.py b/jsonschema/validators.py index a47c3aefa..57b73271c 100644 --- a/jsonschema/validators.py +++ b/jsonschema/validators.py @@ -5,11 +5,12 @@ import json import numbers +from uritools import urisplit, uridefrag, urijoin from six import add_metaclass from jsonschema import _utils, _validators, _types from jsonschema.compat import ( - Sequence, urljoin, urlsplit, urldefrag, unquote, urlopen, + Sequence, unquote, urlopen, str_types, int_types, iteritems, lru_cache, ) from jsonschema.exceptions import ( @@ -254,9 +255,7 @@ def iter_errors(self, instance, _schema=None): return scope = id_of(_schema) - if scope: - self.resolver.push_scope(scope) - try: + with self.resolver.in_scope(scope): ref = _schema.get(u"$ref") if ref is not None: validators = [(u"$ref", ref)] @@ -280,9 +279,6 @@ def iter_errors(self, instance, _schema=None): if k != u"$ref": error.schema_path.appendleft(k) yield error - finally: - if scope: - self.resolver.pop_scope() def descend(self, instance, schema, path=None, schema_path=None): for error in self.iter_errors(instance, schema): @@ -522,7 +518,7 @@ class RefResolver(object): A mapping from URI schemes to functions that should be used to retrieve them - urljoin_cache (functools.lru_cache): + urijoin_cache (functools.lru_cache): A cache that will be used for caching the results of joining the resolution scope to subscopes. @@ -530,7 +526,7 @@ class RefResolver(object): remote_cache (functools.lru_cache): A cache that will be used for caching the results of - resolved remote URLs. + resolved remote URIs. Attributes: @@ -547,17 +543,19 @@ def __init__( store=(), cache_remote=True, handlers=(), - urljoin_cache=None, + urijoin_cache=None, remote_cache=None, ): - if urljoin_cache is None: - urljoin_cache = lru_cache(1024)(urljoin) + if urijoin_cache is None: + urijoin_cache = lru_cache(1024)(urijoin) if remote_cache is None: - remote_cache = lru_cache(1024)(self.resolve_from_url) + remote_cache = lru_cache(1024)(self.resolve_from_uri) self.referrer = referrer self.cache_remote = cache_remote - self.handlers = dict(handlers) + self.handlers = {'http': self.http_handler, 'https': self.http_handler, + 'file': self.http_handler} + self.handlers.update(handlers) self._scopes_stack = [base_uri] self.store = _utils.URIDict( @@ -567,7 +565,7 @@ def __init__( self.store.update(store) self.store[base_uri] = referrer - self._urljoin_cache = urljoin_cache + self._urijoin_cache = urijoin_cache self._remote_cache = remote_cache @classmethod @@ -595,9 +593,38 @@ def from_schema( return cls(base_uri=id_of(schema), referrer=schema, *args, **kwargs) + def http_handler(self, uri): + try: + import requests + except ImportError: + pass + else: + if hasattr(requests.Response, "json"): + session = requests.Session() + + requests_supports_scheme = True + if urisplit(uri).scheme == "file": + try: + import requests_file + except ImportError: + requests_supports_scheme = False + else: + session.mount("file://", requests_file.FileAdapter()) + + if requests_supports_scheme: + # Requests has support for detecting the correct encoding of + # json over http + if callable(requests.Response.json): + return session.get(uri).json() + else: + return session.get(uri).json + + # Otherwise, pass off to urllib and assume utf-8 + return json.loads(urlopen(uri).read().decode("utf-8")) + def push_scope(self, scope): self._scopes_stack.append( - self._urljoin_cache(self.resolution_scope, scope), + self._urijoin_cache(self.resolution_scope, scope), ) def pop_scope(self): @@ -616,16 +643,18 @@ def resolution_scope(self): @property def base_uri(self): - uri, _ = urldefrag(self.resolution_scope) + uri, _ = uridefrag(self.resolution_scope) return uri @contextlib.contextmanager def in_scope(self, scope): - self.push_scope(scope) + if scope: + self.push_scope(scope) try: yield finally: - self.pop_scope() + if scope: + self.pop_scope() @contextlib.contextmanager def resolving(self, ref): @@ -641,28 +670,25 @@ def resolving(self, ref): """ - url, resolved = self.resolve(ref) - self.push_scope(url) - try: + uri, resolved = self.resolve(ref) + with self.in_scope(uri): yield resolved - finally: - self.pop_scope() def resolve(self, ref): - url = self._urljoin_cache(self.resolution_scope, ref) - return url, self._remote_cache(url) + uri = self._urijoin_cache(self.resolution_scope, ref) + return uri, self._remote_cache(uri) - def resolve_from_url(self, url): - url, fragment = urldefrag(url) + def resolve_from_uri(self, uri): + uri, fragment = uridefrag(uri) try: - document = self.store[url] + document = self.store[uri] except KeyError: try: - document = self.resolve_remote(url) + document = self.resolve_remote(uri) except Exception as exc: raise RefResolutionError(exc) - return self.resolve_fragment(document, fragment) + return self.resolve_fragment(document, fragment or '') def resolve_fragment(self, document, fragment): """ @@ -731,29 +757,12 @@ def resolve_remote(self, uri): .. _requests: http://pypi.python.org/pypi/requests/ """ - try: - import requests - except ImportError: - requests = None - - scheme = urlsplit(uri).scheme + scheme = urisplit(uri).scheme if scheme in self.handlers: result = self.handlers[scheme](uri) - elif ( - scheme in [u"http", u"https"] and - requests and - getattr(requests.Response, "json", None) is not None - ): - # Requests has support for detecting the correct encoding of - # json over http - if callable(requests.Response.json): - result = requests.get(uri).json() - else: - result = requests.get(uri).json else: - # Otherwise, pass off to urllib and assume utf-8 - result = json.loads(urlopen(uri).read().decode("utf-8")) + raise ValueError(scheme) if self.cache_remote: self.store[uri] = result diff --git a/setup.py b/setup.py index 37cadea62..3127d3409 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ "pyrsistent>=0.14.0", "six>=1.11.0", "functools32;python_version<'3'", + "uritools>=2.2.0", ], extras_require={ "format": [