Skip to content

Commit 5b82a40

Browse files
authored
Allow passing trust_input to function (pymc-devs#1206)
1 parent 69efc68 commit 5b82a40

File tree

5 files changed

+67
-4
lines changed

5 files changed

+67
-4
lines changed

pytensor/compile/debugmode.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1966,6 +1966,12 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
19661966
If the outputs argument for pytensor.function was a list, then
19671967
output_keys is None. If the outputs argument was a dict, then
19681968
output_keys is a sorted list of the keys from that dict.
1969+
trust_input : bool, default False
1970+
If True, no input validation checks are performed when the function is
1971+
called. This includes checking the number of inputs, their types and
1972+
that multiple inputs are not aliased to each other. Failure to meet any
1973+
of these conditions can lead to computational errors or to the
1974+
interpreter crashing.
19691975
19701976
Notes
19711977
-----
@@ -1993,6 +1999,7 @@ def __init__(
19931999
output_keys=None,
19942000
name=None,
19952001
no_fgraph_prep=False,
2002+
trust_input=False,
19962003
):
19972004
self.mode = mode
19982005
self.profile = profile
@@ -2146,6 +2153,7 @@ def __init__(
21462153
self.on_unused_input = on_unused_input # Used for the pickling/copy
21472154
self.output_keys = output_keys
21482155
self.name = name
2156+
self.trust_input = trust_input
21492157

21502158
self.required = [(i.value is None) for i in self.inputs]
21512159
self.refeed = [

pytensor/compile/function/__init__.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def function_dump(
3737
profile: bool | ProfileStats | None = None,
3838
on_unused_input: str | None = None,
3939
extra_tag_to_remove: str | None = None,
40+
trust_input: bool = False,
4041
):
4142
"""
4243
This is helpful to make a reproducible case for problems during PyTensor
@@ -82,6 +83,7 @@ def function_dump(
8283
"allow_input_downcast": allow_input_downcast,
8384
"profile": profile,
8485
"on_unused_input": on_unused_input,
86+
"trust_input": trust_input,
8587
}
8688
with Path(filename).open("wb") as f:
8789
pickler = pytensor.misc.pkl_utils.StripPickler(
@@ -107,6 +109,7 @@ def function(
107109
allow_input_downcast: bool | None = None,
108110
profile: bool | ProfileStats | None = None,
109111
on_unused_input: str | None = None,
112+
trust_input: bool = False,
110113
):
111114
"""
112115
Return a :class:`callable object <pytensor.compile.function.types.Function>`
@@ -164,6 +167,12 @@ def function(
164167
on_unused_input
165168
What to do if a variable in the 'inputs' list is not used in the graph.
166169
Possible values are 'raise', 'warn', 'ignore' and None.
170+
trust_input: bool, default False
171+
If True, no input validation checks are performed when the function is
172+
called. This includes checking the number of inputs, their types and
173+
that multiple inputs are not aliased to each other. Failure to meet any
174+
of these conditions can lead to computational errors or to the
175+
interpreter crashing.
167176
168177
Returns
169178
-------
@@ -310,7 +319,12 @@ def opt_log1p(node):
310319
"semantics, which disallow using updates and givens"
311320
)
312321
fn = orig_function(
313-
inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
322+
inputs,
323+
outputs,
324+
mode=mode,
325+
accept_inplace=accept_inplace,
326+
name=name,
327+
trust_input=trust_input,
314328
)
315329
else:
316330
# note: pfunc will also call orig_function -- orig_function is
@@ -329,5 +343,6 @@ def opt_log1p(node):
329343
on_unused_input=on_unused_input,
330344
profile=profile,
331345
output_keys=output_keys,
346+
trust_input=trust_input,
332347
)
333348
return fn

pytensor/compile/function/pfunc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def pfunc(
377377
on_unused_input=None,
378378
output_keys=None,
379379
fgraph: FunctionGraph | None = None,
380+
trust_input: bool = False,
380381
) -> Function:
381382
"""
382383
Function-constructor for graphs with shared variables.
@@ -425,6 +426,12 @@ def pfunc(
425426
fgraph
426427
An existing `FunctionGraph` from which to construct the returned
427428
`Function`. When this is non-``None``, nothing is cloned.
429+
trust_input : bool, default False
430+
If True, no input validation checks are performed when the function is
431+
called. This includes checking the number of inputs, their types and
432+
that multiple inputs are not aliased to each other. Failure to meet any
433+
of these conditions can lead to computational errors or to the
434+
interpreter crashing.
428435
429436
Returns
430437
-------
@@ -472,6 +479,7 @@ def pfunc(
472479
on_unused_input=on_unused_input,
473480
output_keys=output_keys,
474481
fgraph=fgraph,
482+
trust_input=trust_input,
475483
)
476484

477485

pytensor/compile/function/types.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def __init__(
373373
return_none: bool,
374374
output_keys,
375375
maker: "FunctionMaker",
376+
trust_input: bool = False,
376377
name: str | None = None,
377378
):
378379
"""
@@ -407,6 +408,12 @@ def __init__(
407408
TODO
408409
maker
409410
The `FunctionMaker` that created this instance.
411+
trust_input : bool, default False
412+
If True, no input validation checks are performed when the function is
413+
called. This includes checking the number of inputs, their types and
414+
that multiple inputs are not aliased to each other. Failure to meet any
415+
of these conditions can lead to computational errors or to the
416+
interpreter crashing.
410417
name
411418
A string name.
412419
"""
@@ -420,7 +427,7 @@ def __init__(
420427
self.return_none = return_none
421428
self.maker = maker
422429
self.profile = None # reassigned in FunctionMaker.create
423-
self.trust_input = False # If True, we don't check the input parameter
430+
self.trust_input = trust_input # If True, we don't check the input parameter
424431
self.name = name
425432
self.nodes_with_inner_function = []
426433
self.output_keys = output_keys
@@ -1341,7 +1348,12 @@ class FunctionMaker:
13411348
name : str
13421349
An optional name for this function. If used, the profile mode will
13431350
print the time spent in this function.
1344-
1351+
trust_input : bool, default False
1352+
If True, no input validation checks are performed when the function is
1353+
called. This includes checking the number of inputs, their types and
1354+
that multiple inputs are not aliased to each other. Failure to meet any
1355+
of these conditions can lead to computational errors or to the
1356+
interpreter crashing.
13451357
"""
13461358

13471359
@staticmethod
@@ -1507,6 +1519,7 @@ def __init__(
15071519
output_keys=None,
15081520
name=None,
15091521
no_fgraph_prep=False,
1522+
trust_input=False,
15101523
):
15111524
# Save the provided mode, not the instantiated mode.
15121525
# The instantiated mode don't pickle and if we unpickle an PyTensor
@@ -1609,6 +1622,7 @@ def __init__(
16091622
self.on_unused_input = on_unused_input # Used for the pickling/copy
16101623
self.output_keys = output_keys
16111624
self.name = name
1625+
self.trust_input = trust_input
16121626

16131627
self.required = [(i.value is None) for i in self.inputs]
16141628
self.refeed = [
@@ -1726,6 +1740,7 @@ def create(self, input_storage=None, storage_map=None):
17261740
self.return_none,
17271741
self.output_keys,
17281742
self,
1743+
trust_input=self.trust_input,
17291744
name=self.name,
17301745
)
17311746

@@ -1743,6 +1758,7 @@ def orig_function(
17431758
on_unused_input=None,
17441759
output_keys=None,
17451760
fgraph: FunctionGraph | None = None,
1761+
trust_input: bool = False,
17461762
) -> Function:
17471763
"""
17481764
Return a Function that will calculate the outputs from the inputs.
@@ -1773,7 +1789,12 @@ def orig_function(
17731789
fgraph
17741790
An existing `FunctionGraph` to use instead of constructing a new one
17751791
from cloned `outputs`.
1776-
1792+
trust_input : bool, default False
1793+
If True, no input validation checks are performed when the function is
1794+
called. This includes checking the number of inputs, their types and
1795+
that multiple inputs are not aliased to each other. Failure to meet any
1796+
of these conditions can lead to computational errors or to the
1797+
interpreter crashing.
17771798
"""
17781799

17791800
if profile:
@@ -1806,6 +1827,7 @@ def orig_function(
18061827
output_keys=output_keys,
18071828
name=name,
18081829
fgraph=fgraph,
1830+
trust_input=trust_input,
18091831
)
18101832
with config.change_flags(compute_test_value="off"):
18111833
fn = m.create(defaults)

tests/compile/function/test_function.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ def test_function_name():
5454
assert regex.match(func.name) is not None
5555

5656

57+
def test_trust_input():
58+
x = dvector()
59+
y = shared(1)
60+
z = x + y
61+
f = function([x], z)
62+
assert f.trust_input is False
63+
f = function([x], z, trust_input=True)
64+
assert f.trust_input is True
65+
66+
5767
class TestFunctionIn:
5868
def test_in_strict(self):
5969
a = dvector()

0 commit comments

Comments
 (0)