Skip to content

Commit 1a54972

Browse files
committed
Add 'custom_formats' argument to generated code
Previously `custom_formats` was assumed to be global, which works fine for evaluation, but results in errors when generating code. The approach used to fix this problem was to add a second argument to the `validate` functions when custom formats are provided. For execution, the function is wrapped with a `partial` that hides the extra argument (and therefore avoids changes in the API or documentation).
1 parent 1e21491 commit 1a54972

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

Diff for: fastjsonschema/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
API
7676
***
7777
"""
78+
from functools import partial, update_wrapper
7879

7980
from .draft04 import CodeGeneratorDraft04
8081
from .draft06 import CodeGeneratorDraft06
@@ -177,7 +178,10 @@ def compile(definition, handlers={}, formats={}, use_default=True):
177178
global_state = code_generator.global_state
178179
# Do not pass local state so it can recursively call itself.
179180
exec(code_generator.func_code, global_state)
180-
return global_state[resolver.get_scope_name()]
181+
func = global_state[resolver.get_scope_name()]
182+
if formats:
183+
return update_wrapper(partial(func, custom_formats=formats), func)
184+
return func
181185

182186

183187
# pylint: disable=dangerous-default-value

Diff for: fastjsonschema/generator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class CodeGenerator:
3232
def __init__(self, definition, resolver=None):
3333
self._code = []
3434
self._compile_regexps = {}
35+
self._custom_formats = {}
3536

3637
# Any extra library should be here to be imported only once.
3738
# Lines are imports to be printed in the file and objects
@@ -136,7 +137,8 @@ def generate_validation_function(self, uri, name):
136137
self._validation_functions_done.add(uri)
137138
self.l('')
138139
with self._resolver.resolving(uri) as definition:
139-
with self.l('def {}(data):', name):
140+
args = "data, custom_formats={}" if self._custom_formats else "data"
141+
with self.l('def {}({}):', name, args):
140142
self.generate_func_code_block(definition, 'data', 'data', clear_variables=True)
141143
self.l('return data')
142144

Diff for: tests/test_compile_to_code.py

+12
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
import shutil
44

5+
from fastjsonschema import JsonSchemaValueException
56
from fastjsonschema import compile_to_code, compile as compile_spec
67

78
@pytest.yield_fixture(autouse=True)
@@ -84,3 +85,14 @@ def test_compile_complex_one_of_all_of():
8485
}
8586
]
8687
})
88+
89+
90+
def test_compile_to_code_custom_format():
91+
formats = {'identifier': str.isidentifier}
92+
code = compile_to_code({'type': 'string', 'format': 'identifier'}, formats=formats)
93+
with open('temp/schema_3.py', 'w') as f:
94+
f.write(code)
95+
from temp.schema_3 import validate
96+
assert validate("identifier", formats) == "identifier"
97+
with pytest.raises(JsonSchemaValueException):
98+
validate("not-identifier", formats)

0 commit comments

Comments
 (0)