Skip to content

Commit 397cc98

Browse files
committed
Add in argument to validate/evaluate for whether to sanitize or not; also improves blacklist to only match proper dunders and not '__' and also better matching against unicode
1 parent bfb900f commit 397cc98

File tree

2 files changed

+58
-26
lines changed

2 files changed

+58
-26
lines changed

numexpr/necompiler.py

+34-18
Original file line numberDiff line numberDiff line change
@@ -261,17 +261,22 @@ def __str__(self):
261261
return 'Immediate(%d)' % (self.node.value,)
262262

263263

264-
_forbidden_re = re.compile('[\;[\:]|__|\.[abcdefghjklmnopqstuvwxyzA-Z_]')
265-
def stringToExpression(s, types, context):
264+
_flow_pat = r'[\;\[\:]'
265+
_dunder_pat = r'__[\w]+__'
266+
_attr_pat = r'\.\b(?!(real|imag|\d+)\b)'
267+
_blacklist_re = re.compile(f'{_flow_pat}|{_dunder_pat}|{_attr_pat}')
268+
269+
def stringToExpression(s, types, context, sanitize: bool):
266270
"""Given a string, convert it to a tree of ExpressionNode's.
267271
"""
268272
# sanitize the string for obvious attack vectors that NumExpr cannot
269273
# parse into its homebrew AST. This is to protect the call to `eval` below.
270274
# We forbid `;`, `:`. `[` and `__`, and attribute access via '.'.
271275
# We cannot ban `.real` or `.imag` however...
272-
no_whitespace = re.sub(r'\s+', '', s)
273-
if _forbidden_re.search(no_whitespace) is not None:
274-
raise ValueError(f'Expression {s} has forbidden control characters.')
276+
if sanitize:
277+
no_whitespace = re.sub(r'\s+', '', s)
278+
if _blacklist_re.search(no_whitespace) is not None:
279+
raise ValueError(f'Expression {s} has forbidden control characters.')
275280

276281
old_ctx = expressions._context.get_current_context()
277282
try:
@@ -558,15 +563,15 @@ def getContext(kwargs, _frame_depth=1):
558563
return context
559564

560565

561-
def precompile(ex, signature=(), context={}):
566+
def precompile(ex, signature=(), context={}, sanitize: bool=True):
562567
"""
563568
Compile the expression to an intermediate form.
564569
"""
565570
types = dict(signature)
566571
input_order = [name for (name, type_) in signature]
567572

568573
if isinstance(ex, str):
569-
ex = stringToExpression(ex, types, context)
574+
ex = stringToExpression(ex, types, context, sanitize)
570575

571576
# the AST is like the expression, but the node objects don't have
572577
# any odd interpretations
@@ -612,7 +617,7 @@ def precompile(ex, signature=(), context={}):
612617
return threeAddrProgram, signature, tempsig, constants, input_names
613618

614619

615-
def NumExpr(ex, signature=(), **kwargs):
620+
def NumExpr(ex, signature=(), sanitize: bool=True, **kwargs):
616621
"""
617622
Compile an expression built using E.<variable> variables to a function.
618623
@@ -629,7 +634,7 @@ def NumExpr(ex, signature=(), **kwargs):
629634
# translated to either True or False).
630635
_frame_depth = 1
631636
context = getContext(kwargs, _frame_depth=_frame_depth)
632-
threeAddrProgram, inputsig, tempsig, constants, input_names = precompile(ex, signature, context)
637+
threeAddrProgram, inputsig, tempsig, constants, input_names = precompile(ex, signature, context, sanitize=sanitize)
633638
program = compileThreeAddrForm(threeAddrProgram)
634639
return interpreter.NumExpr(inputsig.encode('ascii'),
635640
tempsig.encode('ascii'),
@@ -710,8 +715,8 @@ def getType(a):
710715
raise ValueError("unknown type %s" % a.dtype.name)
711716

712717

713-
def getExprNames(text, context):
714-
ex = stringToExpression(text, {}, context)
718+
def getExprNames(text, context, sanitize: bool=True):
719+
ex = stringToExpression(text, {}, context, sanitize)
715720
ast = expressionToAST(ex)
716721
input_order = getInputOrder(ast, None)
717722
#try to figure out if vml operations are used by expression
@@ -779,6 +784,7 @@ def validate(ex: str,
779784
order: str = 'K',
780785
casting: str = 'safe',
781786
_frame_depth: int = 2,
787+
sanitize: bool = True,
782788
**kwargs) -> Optional[Exception]:
783789
"""
784790
Validate a NumExpr expression with the given `local_dict` or `locals()`.
@@ -826,16 +832,19 @@ def validate(ex: str,
826832
like float64 to float32, are allowed.
827833
* 'unsafe' means any data conversions may be done.
828834
835+
sanitize: bool
836+
Both `validate` and by extension `evaluate` call `eval(ex)`, which is
837+
potentially dangerous on unsanitized inputs. As such, NumExpr by default
838+
performs simple sanitization, banning the character ':;[', the
839+
dunder '__[\w+]__', and attribute access to all but '.real' and '.imag'.
840+
829841
_frame_depth: int
830842
The calling frame depth. Unless you are a NumExpr developer you should
831843
not set this value.
832844
833845
Note
834846
----
835-
Both `validate` and by extension `evaluate` call `eval(ex)`, which is
836-
potentially dangerous on unsanitized inputs. As such, NumExpr does some
837-
sanitization, banning the character ':;[', the dunder '__', and attribute
838-
access to all but '.r' for real and '.i' for imag access to complex numbers.
847+
839848
"""
840849
global _numexpr_last
841850

@@ -848,7 +857,7 @@ def validate(ex: str,
848857
context = getContext(kwargs)
849858
expr_key = (ex, tuple(sorted(context.items())))
850859
if expr_key not in _names_cache:
851-
_names_cache[expr_key] = getExprNames(ex, context)
860+
_names_cache[expr_key] = getExprNames(ex, context, sanitize=sanitize)
852861
names, ex_uses_vml = _names_cache[expr_key]
853862
arguments = getArguments(names, local_dict, global_dict, _frame_depth=_frame_depth)
854863

@@ -861,7 +870,7 @@ def validate(ex: str,
861870
try:
862871
compiled_ex = _numexpr_cache[numexpr_key]
863872
except KeyError:
864-
compiled_ex = _numexpr_cache[numexpr_key] = NumExpr(ex, signature, **context)
873+
compiled_ex = _numexpr_cache[numexpr_key] = NumExpr(ex, signature, sanitize=sanitize, **context)
865874
kwargs = {'out': out, 'order': order, 'casting': casting,
866875
'ex_uses_vml': ex_uses_vml}
867876
_numexpr_last = dict(ex=compiled_ex, argnames=names, kwargs=kwargs)
@@ -875,6 +884,7 @@ def evaluate(ex: str,
875884
out: numpy.ndarray = None,
876885
order: str = 'K',
877886
casting: str = 'safe',
887+
sanitize: bool = True,
878888
_frame_depth: int = 3,
879889
**kwargs) -> numpy.ndarray:
880890
"""
@@ -920,6 +930,12 @@ def evaluate(ex: str,
920930
like float64 to float32, are allowed.
921931
* 'unsafe' means any data conversions may be done.
922932
933+
sanitize: bool
934+
Both `validate` and by extension `evaluate` call `eval(ex)`, which is
935+
potentially dangerous on unsanitized inputs. As such, NumExpr by default
936+
performs simple sanitization, banning the character ':;[', the
937+
dunder '__[\w+]__', and attribute access to all but '.real' and '.imag'.
938+
923939
_frame_depth: int
924940
The calling frame depth. Unless you are a NumExpr developer you should
925941
not set this value.
@@ -936,7 +952,7 @@ def evaluate(ex: str,
936952
# `getArguments`
937953
e = validate(ex, local_dict=local_dict, global_dict=global_dict,
938954
out=out, order=order, casting=casting,
939-
_frame_depth=_frame_depth, **kwargs)
955+
_frame_depth=_frame_depth, sanitize=sanitize, **kwargs)
940956
if e is None:
941957
return re_evaluate(local_dict=local_dict, _frame_depth=_frame_depth)
942958
else:

numexpr/tests/test_numexpr.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ def test_forbidden_tokens(self):
528528

529529
# Forbid indexing
530530
try:
531-
evaluate('locals()[]')
531+
evaluate('locals()["evaluate"]')
532532
except ValueError:
533533
pass
534534
else:
@@ -542,24 +542,40 @@ def test_forbidden_tokens(self):
542542
else:
543543
self.fail()
544544

545-
# Attribute access
545+
# Attribute access with spaces
546546
try:
547-
evaluate('os.cpucount()')
547+
evaluate('os. cpu_count()')
548548
except ValueError:
549549
pass
550550
else:
551551
self.fail()
552552

553-
# But decimal point must pass
553+
# Attribute access with funny unicode characters that eval translates
554+
# into ASCII.
555+
try:
556+
evaluate("(3+1).ᵇit_length()")
557+
except ValueError:
558+
pass
559+
else:
560+
self.fail()
561+
562+
# Pass decimal points
554563
a = 3.0
555564
evaluate('a*2.')
556565
evaluate('2.+a')
557-
558566

567+
# pass .real and .imag
568+
c = 2.5 + 1.5j
569+
evaluate('c.real')
570+
evaluate('c.imag')
559571

560-
561-
562-
572+
def test_no_sanitize(self):
573+
try: # Errors on compile() after eval()
574+
evaluate('import os;', sanitize=False)
575+
except SyntaxError:
576+
pass
577+
else:
578+
self.fail()
563579

564580
def test_disassemble(self):
565581
assert_equal(disassemble(NumExpr(

0 commit comments

Comments
 (0)