diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 929cd3887f..1e3d161e59 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -92,38 +92,29 @@ def construct_nominal_fgraph( dict[Variable, Variable], ]: """Construct an inner-`FunctionGraph` with ordered nominal inputs.""" - dummy_inputs = [] - for n, inp in enumerate(inputs): - if ( - not isinstance(inp, Variable) - or isinstance(inp, Constant) - or isinstance(inp, SharedVariable) - ): - raise TypeError( - f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}" - ) - - dummy_inputs.append(inp.type()) + implicit_shared_inputs = [] - dummy_shared_inputs = [] - shared_inputs = [] + dummy_inputs = [inp.type() for inp in inputs] + dummy_implicit_shared_inputs = [] for var in graph_inputs(outputs, inputs): + if var in inputs: + continue if isinstance(var, SharedVariable): - # To correctly support shared variables the inner-graph should - # not see them; otherwise, there will be problems with - # gradients. - # That's why we collect the shared variables and replace them - # with dummies. - shared_inputs.append(var) - dummy_shared_inputs.append(var.type()) - elif var not in inputs and not isinstance(var, Constant): - raise MissingInputError(f"OpFromGraph is missing an input: {var}") - - replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs)) + # We allow shared inputs to be added automatically to the graph + implicit_shared_inputs.append(var) + dummy_implicit_shared_inputs.append(var.type()) + elif not isinstance(var, Constant): + raise MissingInputError(f"NominalGraph is missing an input: {var}") + + replacements = dict( + zip( + inputs + implicit_shared_inputs, dummy_inputs + dummy_implicit_shared_inputs + ) + ) new = rebuild_collect_shared( cast(Sequence[Variable], outputs), - inputs=inputs + shared_inputs, + inputs=inputs + implicit_shared_inputs, replace=replacements, copy_inputs_over=False, ) @@ -133,7 +124,7 @@ def construct_nominal_fgraph( (clone_d, update_d, update_expr, new_shared_inputs), ) = new - assert len(local_inputs) == len(inputs) + len(shared_inputs) + assert len(local_inputs) == len(inputs) + len(implicit_shared_inputs) assert len(local_outputs) == len(outputs) assert not update_d assert not update_expr @@ -155,7 +146,7 @@ def construct_nominal_fgraph( fgraph.clients.pop(inp, None) fgraph.add_input(nom_inp) - return fgraph, shared_inputs, update_d, update_expr + return fgraph, implicit_shared_inputs, update_d, update_expr class OpFromGraph(Op, HasInnerGraph): @@ -177,8 +168,6 @@ class OpFromGraph(Op, HasInnerGraph): - grad() make it support DisconnectedType and the new interface - add support for NullType and DisconnectedType when R_op supports them - check how it works with updates. - - add test with constant as input or inside the inner graph. - - Add support for the GPU? Probably just need an opt to remove transfer - Add support to pickle this Op. - Add support/test with random generator - Add optimization to removing unused inputs/outputs @@ -310,11 +299,13 @@ def __init__( self, inputs: list[Variable], outputs: list[Variable], + *, inline: bool = False, lop_overrides: str = "default", grad_overrides: str = "default", rop_overrides: str = "default", connection_pattern: Optional[list[list[bool]]] = None, + strict: bool = False, name: Optional[str] = None, **kwargs, ): @@ -399,6 +390,10 @@ def __init__( must be equal to number of outputs. connection_pattern If not ``None``, this will be used as the connection_pattern for this :class:`Op`. + strict: bool, default False + If true, it raises when any variables needed to compute the inner graph + are not provided as explici inputs. This can only happen for graphs with + shared variables. name A name for debugging purposes. kwargs @@ -424,6 +419,12 @@ def __init__( inputs, outputs ) + if strict and self.shared_inputs: + raise ValueError( + "All variables needed to compute inner-graph must be provided as inputs under strict=True. " + f"The inner-graph implicitly depends on the following shared variables {self.shared_inputs}" + ) + self.kwargs = kwargs self.input_types = [inp.type for inp in inputs] self.output_types = [out.type for out in outputs] diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index abc85fe524..8b15ad4dda 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -15,7 +15,7 @@ from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.utils import MissingInputError from pytensor.printing import debugprint -from pytensor.tensor.basic import as_tensor +from pytensor.tensor.basic import constant from pytensor.tensor.math import dot, exp, sigmoid from pytensor.tensor.math import round as pt_round from pytensor.tensor.math import sum as pt_sum @@ -43,12 +43,6 @@ def test_valid_input(self): with pytest.raises(TypeError): OpFromGraph([1], [1]) - with pytest.raises(TypeError): - OpFromGraph([x, as_tensor(1)], [x]) - - with pytest.raises(TypeError): - OpFromGraph([shared(1)], [1]) - with pytest.raises(NotImplementedError): OpFromGraph([x], [x], updates={}) @@ -559,6 +553,31 @@ def test_outputs_consistency(self): # The original `op.fgraph` outputs should stay the same, though assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x]) + def test_explicit_input_from_constant(self): + x = pt.dscalar("x") + y = constant(1.0, name="y") + test_ofg = OpFromGraph([x, y], [x + y]) + + out = test_ofg(x, y) + assert out.eval({x: 5}) == 6 + + def test_explicit_input_from_shared(self): + x = pt.dscalar("x") + y = shared(1.0, name="y") + + with pytest.raises( + ValueError, + match=r"The inner-graph implicitly depends on the following shared variables \[y\]", + ): + OpFromGraph([x], [x + y], strict=True) + + test_ofg = OpFromGraph([x, y], [x + y], strict=True) + + out = test_ofg(x, y) + assert out.eval({x: 5}) == 6 + y.set_value(2.0) + assert out.eval({x: 6}) + @config.change_flags(floatX="float64") def test_debugprint():