diff --git a/jsonschema/_utils.py b/jsonschema/_utils.py index 5b245a34a..d9e5d311c 100644 --- a/jsonschema/_utils.py +++ b/jsonschema/_utils.py @@ -3,7 +3,7 @@ import pkgutil import re -from jsonschema.compat import str_types, MutableMapping, urlsplit +from jsonschema.compat import str_types, MutableMapping class URIDict(MutableMapping): @@ -13,7 +13,9 @@ class URIDict(MutableMapping): """ def normalize(self, uri): - return urlsplit(uri).geturl() + normalized = uri.normalize() + assert not normalized.fragment, "URI had unexpected non-empty fragment" + return normalized.copy_with(fragment=None) def __init__(self, *args, **kwargs): self.store = dict() diff --git a/jsonschema/compat.py b/jsonschema/compat.py index ff91fe620..87f8a63c2 100644 --- a/jsonschema/compat.py +++ b/jsonschema/compat.py @@ -13,19 +13,17 @@ zip = zip from functools import lru_cache from io import StringIO - from urllib.parse import ( - unquote, urljoin, urlunsplit, SplitResult, urlsplit as _urlsplit - ) + from urllib.parse import unquote from urllib.request import urlopen str_types = str, int_types = int, iteritems = operator.methodcaller("items") + + import rfc3986 + rfc3986.URIReference.__hash__ = tuple.__hash__ else: from itertools import izip as zip # noqa from StringIO import StringIO - from urlparse import ( - urljoin, urlunsplit, SplitResult, urlsplit as _urlsplit # noqa - ) from urllib import unquote # noqa from urllib2 import urlopen # noqa str_types = basestring @@ -35,22 +33,4 @@ from functools32 import lru_cache -# On python < 3.3 fragments are not handled properly with unknown schemes -def urlsplit(url): - scheme, netloc, path, query, fragment = _urlsplit(url) - if "#" in path: - path, fragment = path.split("#", 1) - return SplitResult(scheme, netloc, path, query, fragment) - - -def urldefrag(url): - if "#" in url: - s, n, p, q, frag = urlsplit(url) - defrag = urlunsplit((s, n, p, q, '')) - else: - defrag = url - frag = '' - return defrag, frag - - # flake8: noqa diff --git a/jsonschema/tests/test_validators.py b/jsonschema/tests/test_validators.py index 7e773776b..824efa02c 100644 --- a/jsonschema/tests/test_validators.py +++ b/jsonschema/tests/test_validators.py @@ -6,6 +6,7 @@ import sys import unittest +from rfc3986 import uri_reference from twisted.trial.unittest import SynchronousTestCase from jsonschema import ( @@ -78,16 +79,18 @@ def test_if_a_version_is_not_provided_it_is_not_registered(self): self.assertFalse(validates.called) def test_validates_registers_meta_schema_id(self): - meta_schema_key = "meta schema id" - my_meta_schema = {u"id": meta_schema_key} + my_meta_schema = {u"id": "meta schema id"} + + def id_of(schema): + return uri_reference(schema.get("id", "")) validators.create( meta_schema=my_meta_schema, version="my version", - id_of=lambda s: s.get("id", ""), + id_of=id_of, ) - self.assertIn(meta_schema_key, validators.meta_schemas) + self.assertIn(id_of(my_meta_schema), validators.meta_schemas) def test_validates_registers_meta_schema_draft6_id(self): meta_schema_key = "meta schema $id" @@ -98,7 +101,7 @@ def test_validates_registers_meta_schema_draft6_id(self): version="my version", ) - self.assertIn(meta_schema_key, validators.meta_schemas) + self.assertIn(uri_reference(meta_schema_key), validators.meta_schemas) def test_extend(self): original_validators = dict(self.Validator.VALIDATORS) @@ -1054,7 +1057,7 @@ def test_custom_validator(self): Validator = validators.create( meta_schema={"id": "meta schema id"}, version="12", - id_of=lambda s: s.get("id", ""), + id_of=lambda s: uri_reference(s.get("id", "")), ) schema = {"$schema": "meta schema id"} self.assertIs( @@ -1195,7 +1198,9 @@ def test_it_retrieves_stored_refs(self): with self.resolver.resolving(self.stored_uri) as resolved: self.assertIs(resolved, self.stored_schema) - self.resolver.store["cached_ref"] = {"foo": 12} + cached_uri = uri_reference("cached_ref").resolve_with( + self.resolver.base_uri) + self.resolver.store[cached_uri] = {"foo": 12} with self.resolver.resolving("cached_ref#/foo") as resolved: self.assertEqual(resolved, 12) @@ -1225,27 +1230,29 @@ def fake_urlopen(url): self.assertEqual(resolved, 12) def test_it_can_construct_a_base_uri_from_a_schema(self): - schema = {"id": "foo"} + schema = {"id": "http://foo.json#"} resolver = validators.RefResolver.from_schema( schema, id_of=lambda schema: schema.get(u"id", u""), ) - self.assertEqual(resolver.base_uri, "foo") - self.assertEqual(resolver.resolution_scope, "foo") + self.assertEqual(resolver.base_uri, "http://foo.json") + self.assertEqual(resolver.resolution_scope, "http://foo.json") with resolver.resolving("") as resolved: self.assertEqual(resolved, schema) with resolver.resolving("#") as resolved: self.assertEqual(resolved, schema) - with resolver.resolving("foo") as resolved: + with resolver.resolving("http://foo.json") as resolved: self.assertEqual(resolved, schema) - with resolver.resolving("foo#") as resolved: + with resolver.resolving("http://foo.json#") as resolved: self.assertEqual(resolved, schema) def test_it_can_construct_a_base_uri_from_a_schema_without_id(self): schema = {} resolver = validators.RefResolver.from_schema(schema) - self.assertEqual(resolver.base_uri, "") - self.assertEqual(resolver.resolution_scope, "") + self.assertEqual(resolver.base_uri, + validators.RefResolver.DEFAULT_BASE_URI) + self.assertEqual(resolver.resolution_scope, + validators.RefResolver.DEFAULT_BASE_URI) with resolver.resolving("") as resolved: self.assertEqual(resolved, schema) with resolver.resolving("#") as resolved: diff --git a/jsonschema/validators.py b/jsonschema/validators.py index 28f2f6f8f..2375aeea0 100644 --- a/jsonschema/validators.py +++ b/jsonschema/validators.py @@ -6,10 +6,11 @@ import numbers from six import add_metaclass +from rfc3986 import uri_reference, URIReference from jsonschema import _utils, _validators, _types from jsonschema.compat import ( - Sequence, urljoin, urlsplit, urldefrag, unquote, urlopen, + Sequence, unquote, urlopen, str_types, int_types, iteritems, lru_cache, ) from jsonschema.exceptions import ( @@ -108,10 +109,33 @@ def DEFAULT_TYPES(self): return self._DEFAULT_TYPES +def _as_uri(uri_or_str): + """ + Return URIReference parse result of input string, + or pass through URIReference argument + """ + if isinstance(uri_or_str, str_types): + return uri_reference(uri_or_str) + return uri_or_str + + +def _join_uri(base, ref): + """Join absolute base URI with relative URI reference""" + return _as_uri(ref).resolve_with(base, strict=True) + + +def _load_uri_from_schema(schema, key): + """ + Return URIReference object from URI given by key in schema. + Return URIReference.fromstring('') if key not found + """ + return uri_reference(schema.get(key, "")) + + def _id_of(schema): if schema is True or schema is False: - return u"" - return schema.get(u"$id", u"") + return uri_reference("") + return _load_uri_from_schema(schema, "$id") def create( @@ -256,7 +280,7 @@ def iter_errors(self, instance, _schema=None): return scope = id_of(_schema) - if scope: + if scope.unsplit(): self.resolver.push_scope(scope) try: ref = _schema.get(u"$ref") @@ -283,7 +307,7 @@ def iter_errors(self, instance, _schema=None): error.schema_path.appendleft(k) yield error finally: - if scope: + if scope.unsplit(): self.resolver.pop_scope() def descend(self, instance, schema, path=None, schema_path=None): @@ -416,7 +440,7 @@ def extend(validator, validators=(), version=None, type_checker=None): }, type_checker=_types.draft3_type_checker, version="draft3", - id_of=lambda schema: schema.get(u"id", ""), + id_of=lambda schema: _load_uri_from_schema(schema, u"id"), ) Draft4Validator = create( @@ -451,7 +475,7 @@ def extend(validator, validators=(), version=None, type_checker=None): }, type_checker=_types.draft4_type_checker, version="draft4", - id_of=lambda schema: schema.get(u"id", ""), + id_of=lambda schema: _load_uri_from_schema(schema, u"id"), ) @@ -542,6 +566,10 @@ class RefResolver(object): """ + DEFAULT_BASE_URI = uri_reference( + "urn:uuid:00000000-0000-0000-0000-000000000000" + ) + def __init__( self, base_uri, @@ -553,10 +581,21 @@ def __init__( remote_cache=None, ): if urljoin_cache is None: - urljoin_cache = lru_cache(1024)(urljoin) + urljoin_cache = lru_cache(1024)(_join_uri) if remote_cache is None: remote_cache = lru_cache(1024)(self.resolve_from_url) + if isinstance(base_uri, str_types): + base_uri = uri_reference(base_uri) + + if not base_uri.unsplit(): + base_uri = self.DEFAULT_BASE_URI + + if not base_uri.is_absolute(): + if base_uri.fragment: + raise ValueError("Base URI must not have non-empty fragment") + base_uri = base_uri.copy_with(fragment=None) + self.referrer = referrer self.cache_remote = cache_remote self.handlers = dict(handlers) @@ -566,7 +605,8 @@ def __init__( (id, validator.META_SCHEMA) for id, validator in iteritems(meta_schemas) ) - self.store.update(store) + + self.store.update({_as_uri(k): v for k, v in dict(store).items()}) self.store[base_uri] = referrer self._urljoin_cache = urljoin_cache @@ -599,7 +639,7 @@ def from_schema( def push_scope(self, scope): self._scopes_stack.append( - self._urljoin_cache(self.resolution_scope, scope), + self._urljoin_cache(self.base_uri, scope) ) def pop_scope(self): @@ -618,8 +658,7 @@ def resolution_scope(self): @property def base_uri(self): - uri, _ = urldefrag(self.resolution_scope) - return uri + return self.resolution_scope.copy_with(fragment=None) @contextlib.contextmanager def in_scope(self, scope): @@ -651,11 +690,17 @@ def resolving(self, ref): self.pop_scope() def resolve(self, ref): - url = self._urljoin_cache(self.resolution_scope, ref) + assert self.base_uri + url = self._urljoin_cache(self.base_uri, ref) return url, self._remote_cache(url) def resolve_from_url(self, url): - url, fragment = urldefrag(url) + if url.fragment: + fragment = url.fragment + url = url.copy_with(fragment=None) + else: + fragment = '' + try: document = self.store[url] except KeyError: @@ -722,7 +767,7 @@ def resolve_remote(self, uri): Arguments: - uri (str): + uri (URIReference): The URI to resolve @@ -738,8 +783,7 @@ def resolve_remote(self, uri): except ImportError: requests = None - scheme = urlsplit(uri).scheme - + scheme = uri.scheme if scheme in self.handlers: result = self.handlers[scheme](uri) elif ( @@ -750,12 +794,12 @@ def resolve_remote(self, uri): # Requests has support for detecting the correct encoding of # json over http if callable(requests.Response.json): - result = requests.get(uri).json() + result = requests.get(uri.unsplit()).json() else: - result = requests.get(uri).json + result = requests.get(uri.unsplit()).json else: # Otherwise, pass off to urllib and assume utf-8 - with urlopen(uri) as url: + with urlopen(uri.unsplit()) as url: result = json.loads(url.read().decode("utf-8")) if self.cache_remote: @@ -846,4 +890,4 @@ def validator_for(schema, default=_LATEST_VERSION): """ if schema is True or schema is False: return default - return meta_schemas.get(schema.get(u"$schema", u""), default) + return meta_schemas.get(_load_uri_from_schema(schema, u"$schema"), default) diff --git a/setup.py b/setup.py index 37cadea62..aaa742af4 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ "pyrsistent>=0.14.0", "six>=1.11.0", "functools32;python_version<'3'", + "rfc3986>=1.1.0", ], extras_require={ "format": [