Skip to content

Implement OpFromGraph __eq__ and __hash__ #1114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ricardoV94 opened this issue Dec 9, 2024 · 3 comments
Open

Implement OpFromGraph __eq__ and __hash__ #1114

ricardoV94 opened this issue Dec 9, 2024 · 3 comments
Labels
help wanted Extra attention is needed OpFromGraph

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 9, 2024

Description

This allows merging duplicated nodes as well as comparing graph equality.

import pytensor
import pytensor.tensor as pt
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import equal_computations

x = pt.scalar("x")
out1 = OpFromGraph([x], [x + 1])(x)
out2 = OpFromGraph([x], [x + 1])(x)

assert equal_computations([out1], [out2])

It should pass the assert. It fails because out1.owner.op == out2.owner.op evaluates to False. We can probably do something very similar to Scan:

pytensor/pytensor/scan/op.py

Lines 1254 to 1320 in 4b41e09

def __eq__(self, other):
if type(self) is not type(other):
return False
if self.info != other.info:
return False
if self.profile != other.profile:
return False
if self.truncate_gradient != other.truncate_gradient:
return False
if self.name != other.name:
return False
if self.allow_gc != other.allow_gc:
return False
# Compare inner graphs
# TODO: Use `self.inner_fgraph == other.inner_fgraph`
if len(self.inner_inputs) != len(other.inner_inputs):
return False
if len(self.inner_outputs) != len(other.inner_outputs):
return False
# strict=False because length already compared above
for self_in, other_in in zip(
self.inner_inputs, other.inner_inputs, strict=False
):
if self_in.type != other_in.type:
return False
return equal_computations(
self.inner_outputs,
other.inner_outputs,
self.inner_inputs,
other.inner_inputs,
)
def __str__(self):
inplace = "none"
if self.destroy_map:
# Check if all outputs are inplace
if sorted(self.destroy_map) == sorted(
range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot)
):
inplace = "all"
else:
inplace = str(list(self.destroy_map))
return (
f"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}"
)
def __hash__(self):
return hash(
(
type(self),
self._hash_inner_graph,
self.info,
self.profile,
self.truncate_gradient,
self.name,
self.allow_gc,
)
)

@ricardoV94 ricardoV94 added help wanted Extra attention is needed OpFromGraph labels Dec 9, 2024
@ricardoV94 ricardoV94 changed the title Implement OpFromGraph.__eq__ Implement OpFromGraph __eq__ and __hash__ Dec 9, 2024
@zaxtax
Copy link
Contributor

zaxtax commented Jan 3, 2025

Do we want to support Alpha equivalence where the following two are considered equal?

x = pt.scalar("x")
out1 = OpFromGraph([x], [x + 1])(x)
y = pt.scalar("y")
out2 = OpFromGraph([y], [y + 1])(y)

assert equal_computations([out1], [out2])

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 3, 2025

Do we want to support Alpha equivalence where the following two are considered equal?

x = pt.scalar("x")
out1 = OpFromGraph([x], [x + 1])(x)
y = pt.scalar("y")
out2 = OpFromGraph([y], [y + 1])(y)

assert equal_computations([out1], [out2])

Those Ops are equivalent, but not the nodes, since one takes x, and the other takes y. So out2.owner.op == out1.owner.op should be True but equal_computations not. However the following should pass the assert:

x = pt.scalar("x")
out1 = OpFromGraph([x], [x + 1])(x)
y = pt.scalar("y")
out2 = OpFromGraph([y], [y + 1])(x)

assert equal_computations([out1], [out2])

The dummy variables used to define the inner graph are irrelevant.

equal_computations will do the correct thing, since it checks for equivalence in the whole node which means same root inputs + equivalent intermediate nodes (which includes inputs and Op)

@zaxtax
Copy link
Contributor

zaxtax commented Jan 3, 2025 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed OpFromGraph
Projects
None yet
Development

No branches or pull requests

2 participants