Skip to content

OpFromGraph difficult to subclass due to not allowing shared variables to be positioned explicitly as an input #473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ricardoV94 opened this issue Oct 10, 2023 · 3 comments · Fixed by #676

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 10, 2023

Description

The changes introduced in aesara-devs/aesara#902 make it hard to safely subclass OpFromGraph with initialization kwargs.

Here is an example:

import pytensor
import pytensor.tensor as pt
from pytensor.compile.builders import OpFromGraph

class OpFromGraphSub(OpFromGraph):
    
    def __init__(self, *args, kwarg, **kwargs):
        self.kwarg = kwarg
        super().__init__(*args, **kwargs)


x = pytensor.shared(2)
y = pytensor.shared(3)

out = OpFromGraphSub([], [x + 1], kwarg="test")()
out.owner.op(y) 
# TypeError: OpFromGraphSub.__init__() missing 1 required keyword-only argument: 'kwarg'

One would have to subclass made_node for this to work correctly, but that's highly non-trivial.

@ricardoV94 ricardoV94 changed the title Do not recreate Op in OpFromGraph.make_node when shared inputs changes Do not recreate Op in OpFromGraph.make_node when shared inputs changes Oct 10, 2023
@ricardoV94
Copy link
Member Author

PS: Shared variables are a mistake

@ricardoV94 ricardoV94 changed the title Do not recreate Op in OpFromGraph.make_node when shared inputs changes OpFromGraph difficult to subclass with initialization variables due to make_node Oct 10, 2023
@ricardoV94
Copy link
Member Author

ricardoV94 commented Dec 15, 2023

The problem here is that the initialization of OpFromGraph prohibits explicit shared variables as inputs, but internally it still builds an Apply that has them as explicit inputs (much like Scan does, although scan allows shared variables as explicit inputs).

The change in #473 allows one to recreate an existing OpFromGraph node where the shared variable is changed, sometimes by non-shared variables (desirable), but that requires that the Op itself be recreated because OpFromGraph forces shared variables to be the rightmost input (undesirable) of the respective Apply nodes.

if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
# The shared variables are not equal to the original shared
# variables, so we construct a new `Op` that uses the new shared
# variables instead.
replace = dict(
zip(self.inner_inputs[num_expected_inps:], new_shared_inputs)
)
# If the new shared variables are inconsistent with the inner-graph,
# such errors should arise in this step
new_inner_outputs = clone_replace(
self.inner_outputs, replace=replace, copy_inputs_over=True
)
# It's possible that the new shared variable inputs aren't actually
# shared variables. When they aren't we need to add them as new
# inputs.
unshared_inputs = [
inp for inp in new_shared_inputs if not isinstance(inp, SharedVariable)
]
new_inner_inputs = self.inner_inputs[:num_expected_inps] + unshared_inputs
new_op = type(self)(
inputs=new_inner_inputs,
outputs=new_inner_outputs,
inline=self.is_inline,
lop_overrides=self.lop_overrides,
grad_overrides=self.grad_overrides,
rop_overrides=self.rop_overrides,
connection_pattern=self._connection_pattern,
name=self.name,
**self.kwargs,
)
new_inputs = (
list(non_shared_inputs) + unshared_inputs + new_op.shared_inputs
)

Imagine we had an OpFromGraph whose node inputs where [non_shared_1, shared_1, shared_2], and we want to replace shared_2 by a non shared variable. The new signature would have to artificially be [non_shared_1, non_shared_2, shared_1] which requires building the internal fgraph.

@ricardoV94 ricardoV94 changed the title OpFromGraph difficult to subclass with initialization variables due to make_node OpFromGraph difficult to subclass due to constraint that shared variables can't be positioned explicitly as an input Dec 15, 2023
@ricardoV94 ricardoV94 changed the title OpFromGraph difficult to subclass due to constraint that shared variables can't be positioned explicitly as an input OpFromGraph difficult to subclass due to now allowing shared variables to be positioned explicitly as an input Dec 15, 2023
@ricardoV94
Copy link
Member Author

ricardoV94 commented Dec 15, 2023

We can simplify the restriction without breaking backwards compatibility: Allow shared variables to be any position as an input and to be passed explicitly, but still collect them automatically when the users doesn't remember to pass them.

Like Scan, we could introduce a strict=True but I don't think we need that now that we have more clever ways of finding what other variables are needed as input with truncated_graph_inputs

This means we don't need to handle SharedVariables at all inside make_node and therefore don't need to recreate the Op when these change

@ricardoV94 ricardoV94 changed the title OpFromGraph difficult to subclass due to now allowing shared variables to be positioned explicitly as an input OpFromGraph difficult to subclass due to not allowing shared variables to be positioned explicitly as an input Mar 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant