Skip to content

Use context managers instead of unrolled scopes #429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
7 changes: 5 additions & 2 deletions jsonschema/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import pkgutil
import re

from jsonschema.compat import str_types, MutableMapping, urlsplit
from jsonschema.compat import str_types, MutableMapping
from uritools import urisplit, uriunsplit


class URIDict(MutableMapping):
Expand All @@ -13,7 +14,9 @@ class URIDict(MutableMapping):
"""

def normalize(self, uri):
return urlsplit(uri).geturl()
result = urisplit(uri)
return uriunsplit((result.scheme, result.authority,
result.path, None, None))

def __init__(self, *args, **kwargs):
self.store = dict()
Expand Down
17 changes: 3 additions & 14 deletions jsonschema/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,20 +301,9 @@ def enum(validator, enums, instance, schema):


def ref(validator, ref, instance, schema):
resolve = getattr(validator.resolver, "resolve", None)
if resolve is None:
with validator.resolver.resolving(ref) as resolved:
for error in validator.descend(instance, resolved):
yield error
else:
scope, resolved = validator.resolver.resolve(ref)
validator.resolver.push_scope(scope)

try:
for error in validator.descend(instance, resolved):
yield error
finally:
validator.resolver.pop_scope()
with validator.resolver.resolving(ref) as resolved:
for error in validator.descend(instance, resolved):
yield error


def type_draft3(validator, types, instance, schema):
Expand Down
23 changes: 2 additions & 21 deletions jsonschema/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from functools import lru_cache
from io import StringIO
from urllib.parse import (
unquote, urljoin, urlunsplit, SplitResult, urlsplit as _urlsplit
unquote, urljoin, urlunsplit, SplitResult
)
from urllib.request import urlopen
str_types = str,
Expand All @@ -24,7 +24,7 @@
from itertools import izip as zip # noqa
from StringIO import StringIO
from urlparse import (
urljoin, urlunsplit, SplitResult, urlsplit as _urlsplit # noqa
urljoin, urlunsplit, SplitResult # noqa
)
from urllib import unquote # noqa
from urllib2 import urlopen # noqa
Expand All @@ -34,23 +34,4 @@

from functools32 import lru_cache


# On python < 3.3 fragments are not handled properly with unknown schemes
def urlsplit(url):
scheme, netloc, path, query, fragment = _urlsplit(url)
if "#" in path:
path, fragment = path.split("#", 1)
return SplitResult(scheme, netloc, path, query, fragment)


def urldefrag(url):
if "#" in url:
s, n, p, q, frag = urlsplit(url)
defrag = urlunsplit((s, n, p, q, ''))
else:
defrag = url
frag = ''
return defrag, frag


# flake8: noqa
51 changes: 45 additions & 6 deletions jsonschema/tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,12 +870,11 @@ def test_it_delegates_to_a_ref_resolver(self):
resolver = validators.RefResolver("", {})
schema = {"$ref": mock.Mock()}

with mock.patch.object(resolver, "resolve") as resolve:
resolve.return_value = "url", {"type": "integer"}
with mock.patch.object(resolver, "resolving") as resolving:
resolving.return_value.__enter__.return_value = {"type": "integer"}
with self.assertRaises(ValidationError):
self.validator_class(schema, resolver=resolver).validate(None)

resolve.assert_called_once_with(schema["$ref"])
resolving.assert_called_once_with(schema["$ref"])

def test_it_delegates_to_a_legacy_ref_resolver(self):
"""
Expand All @@ -885,6 +884,10 @@ def test_it_delegates_to_a_legacy_ref_resolver(self):
"""

class LegacyRefResolver(object):
@contextmanager
def in_scope(self, scope):
yield

@contextmanager
def resolving(this, ref):
self.assertEqual(ref, "the ref")
Expand Down Expand Up @@ -1178,10 +1181,10 @@ def test_it_retrieves_unstored_refs_via_requests(self):
schema = {"baz": 12}

with MockImport("requests", mock.Mock()) as requests:
requests.get.return_value.json.return_value = schema
requests.Session.get.return_value.json.return_value = schema
with self.resolver.resolving(ref) as resolved:
self.assertEqual(resolved, 12)
requests.get.assert_called_once_with("http://bar")
requests.Session().get.assert_called_once_with("http://bar")

def test_it_retrieves_unstored_refs_via_urlopen(self):
ref = "http://bar#baz"
Expand All @@ -1195,6 +1198,42 @@ def test_it_retrieves_unstored_refs_via_urlopen(self):
self.assertEqual(resolved, 12)
urlopen.assert_called_once_with("http://bar")

def test_it_retrieves_unstored_file_refs_via_urlopen(self):
ref = "file://bar.json#baz"
schema = {"baz": 12}

with MockImport("requests", None):
with mock.patch("jsonschema.validators.urlopen") as urlopen:
urlopen.return_value.read.return_value = (
json.dumps(schema).encode("utf8"))
with self.resolver.resolving(ref) as resolved:
self.assertEqual(resolved, 12)
urlopen.assert_called_once_with("file://bar.json")

def test_it_retrieves_unstored_file_refs_via_requests_file(self):
ref = "file://bar.json#baz"
schema = {"baz": 12}

with MockImport("requests", mock.Mock()) as requests:
with MockImport("requests_file", mock.Mock()):
requests.Session.get.return_value.json.return_value = schema
with self.resolver.resolving(ref) as resolved:
self.assertEqual(resolved, 12)
requests.Session().get.assert_called_once_with("file://bar.json")

def test_it_retrieves_unstored_refs_via_urlopen_no_requests_file(self):
ref = "file://bar.json#baz"
schema = {"baz": 12}

with MockImport("requests", mock.Mock()) as requests:
with mock.patch("jsonschema.validators.urlopen") as urlopen:
urlopen.return_value.read.return_value = (
json.dumps(schema).encode("utf8"))
with self.resolver.resolving(ref) as resolved:
self.assertEqual(resolved, 12)
requests.Session().get.assert_not_called()
urlopen.assert_called_once_with("file://bar.json")

def test_it_can_construct_a_base_uri_from_a_schema(self):
schema = {"id": "foo"}
resolver = validators.RefResolver.from_schema(
Expand Down
109 changes: 59 additions & 50 deletions jsonschema/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import json
import numbers

from uritools import urisplit, uridefrag, urijoin
from six import add_metaclass

from jsonschema import _utils, _validators, _types
from jsonschema.compat import (
Sequence, urljoin, urlsplit, urldefrag, unquote, urlopen,
Sequence, unquote, urlopen,
str_types, int_types, iteritems, lru_cache,
)
from jsonschema.exceptions import (
Expand Down Expand Up @@ -254,9 +255,7 @@ def iter_errors(self, instance, _schema=None):
return

scope = id_of(_schema)
if scope:
self.resolver.push_scope(scope)
try:
with self.resolver.in_scope(scope):
ref = _schema.get(u"$ref")
if ref is not None:
validators = [(u"$ref", ref)]
Expand All @@ -280,9 +279,6 @@ def iter_errors(self, instance, _schema=None):
if k != u"$ref":
error.schema_path.appendleft(k)
yield error
finally:
if scope:
self.resolver.pop_scope()

def descend(self, instance, schema, path=None, schema_path=None):
for error in self.iter_errors(instance, schema):
Expand Down Expand Up @@ -522,15 +518,15 @@ class RefResolver(object):
A mapping from URI schemes to functions that should be used
to retrieve them

urljoin_cache (functools.lru_cache):
urijoin_cache (functools.lru_cache):

A cache that will be used for caching the results of joining
the resolution scope to subscopes.

remote_cache (functools.lru_cache):

A cache that will be used for caching the results of
resolved remote URLs.
resolved remote URIs.

Attributes:

Expand All @@ -547,17 +543,19 @@ def __init__(
store=(),
cache_remote=True,
handlers=(),
urljoin_cache=None,
urijoin_cache=None,
remote_cache=None,
):
if urljoin_cache is None:
urljoin_cache = lru_cache(1024)(urljoin)
if urijoin_cache is None:
urijoin_cache = lru_cache(1024)(urijoin)
if remote_cache is None:
remote_cache = lru_cache(1024)(self.resolve_from_url)
remote_cache = lru_cache(1024)(self.resolve_from_uri)

self.referrer = referrer
self.cache_remote = cache_remote
self.handlers = dict(handlers)
self.handlers = {'http': self.http_handler, 'https': self.http_handler,
'file': self.http_handler}
self.handlers.update(handlers)

self._scopes_stack = [base_uri]
self.store = _utils.URIDict(
Expand All @@ -567,7 +565,7 @@ def __init__(
self.store.update(store)
self.store[base_uri] = referrer

self._urljoin_cache = urljoin_cache
self._urijoin_cache = urijoin_cache
self._remote_cache = remote_cache

@classmethod
Expand Down Expand Up @@ -595,9 +593,38 @@ def from_schema(

return cls(base_uri=id_of(schema), referrer=schema, *args, **kwargs)

def http_handler(self, uri):
try:
import requests
except ImportError:
pass
else:
if hasattr(requests.Response, "json"):
session = requests.Session()

requests_supports_scheme = True
if urisplit(uri).scheme == "file":
try:
import requests_file
except ImportError:
requests_supports_scheme = False
else:
session.mount("file://", requests_file.FileAdapter())

if requests_supports_scheme:
# Requests has support for detecting the correct encoding of
# json over http
if callable(requests.Response.json):
return session.get(uri).json()
else:
return session.get(uri).json

# Otherwise, pass off to urllib and assume utf-8
return json.loads(urlopen(uri).read().decode("utf-8"))

def push_scope(self, scope):
self._scopes_stack.append(
self._urljoin_cache(self.resolution_scope, scope),
self._urijoin_cache(self.resolution_scope, scope),
)

def pop_scope(self):
Expand All @@ -616,16 +643,18 @@ def resolution_scope(self):

@property
def base_uri(self):
uri, _ = urldefrag(self.resolution_scope)
uri, _ = uridefrag(self.resolution_scope)
return uri

@contextlib.contextmanager
def in_scope(self, scope):
self.push_scope(scope)
if scope:
self.push_scope(scope)
try:
yield
finally:
self.pop_scope()
if scope:
self.pop_scope()

@contextlib.contextmanager
def resolving(self, ref):
Expand All @@ -641,28 +670,25 @@ def resolving(self, ref):

"""

url, resolved = self.resolve(ref)
self.push_scope(url)
try:
uri, resolved = self.resolve(ref)
with self.in_scope(uri):
yield resolved
finally:
self.pop_scope()

def resolve(self, ref):
url = self._urljoin_cache(self.resolution_scope, ref)
return url, self._remote_cache(url)
uri = self._urijoin_cache(self.resolution_scope, ref)
return uri, self._remote_cache(uri)

def resolve_from_url(self, url):
url, fragment = urldefrag(url)
def resolve_from_uri(self, uri):
uri, fragment = uridefrag(uri)
try:
document = self.store[url]
document = self.store[uri]
except KeyError:
try:
document = self.resolve_remote(url)
document = self.resolve_remote(uri)
except Exception as exc:
raise RefResolutionError(exc)

return self.resolve_fragment(document, fragment)
return self.resolve_fragment(document, fragment or '')

def resolve_fragment(self, document, fragment):
"""
Expand Down Expand Up @@ -731,29 +757,12 @@ def resolve_remote(self, uri):
.. _requests: http://pypi.python.org/pypi/requests/

"""
try:
import requests
except ImportError:
requests = None

scheme = urlsplit(uri).scheme
scheme = urisplit(uri).scheme

if scheme in self.handlers:
result = self.handlers[scheme](uri)
elif (
scheme in [u"http", u"https"] and
requests and
getattr(requests.Response, "json", None) is not None
):
# Requests has support for detecting the correct encoding of
# json over http
if callable(requests.Response.json):
result = requests.get(uri).json()
else:
result = requests.get(uri).json
else:
# Otherwise, pass off to urllib and assume utf-8
result = json.loads(urlopen(uri).read().decode("utf-8"))
raise ValueError(scheme)

if self.cache_remote:
self.store[uri] = result
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"pyrsistent>=0.14.0",
"six>=1.11.0",
"functools32;python_version<'3'",
"uritools>=2.2.0",
],
extras_require={
"format": [
Expand Down