Skip to content

Make pytensorf.constant_fold unconditional #7568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.25.1,<2.26
- pytensor>=2.26.2,<2.27
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.25.1,<2.26
- pytensor>=2.26.2,<2.27
- python-graphviz
- rich>=13.7.1
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies:
- numpyro>=0.8.0
- pandas>=0.24.0
- pip
- pytensor>=2.25.1,<2.26
- pytensor>=2.26.2,<2.27
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.25.1,<2.26
- pytensor>=2.26.2,<2.27
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.25.1,<2.26
- pytensor>=2.26.2,<2.27
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.25.1,<2.26
- pytensor>=2.26.2,<2.27
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
6 changes: 3 additions & 3 deletions pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,9 @@ def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None:

base_var = node.inputs[0]

measurable_dimshuffle = MeasurableDimShuffle(node.op.input_broadcastable, node.op.new_order)(
base_var
)
measurable_dimshuffle = MeasurableDimShuffle(
input_ndim=node.op.input_ndim, new_order=node.op.new_order
)(base_var)
assert isinstance(measurable_dimshuffle, TensorVariable)

return [measurable_dimshuffle]
Expand Down
11 changes: 8 additions & 3 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
Expand Down Expand Up @@ -1057,7 +1058,7 @@ def compile_pymc(

def constant_fold(
xs: Sequence[TensorVariable], raise_not_constant: bool = True
) -> tuple[np.ndarray, ...]:
) -> tuple[np.ndarray | Variable, ...]:
"""Use constant folding to get constant values of a graph.
Parameters
Expand All @@ -1072,8 +1073,12 @@ def constant_fold(
"""
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], copy_inputs=False, clone=True)

# By default, rewrite_graph includes canonicalize which includes constant-folding as the final rewrite
folded_xs = rewrite_graph(fg).outputs
# The default rewrite_graph includes a constand_folding that is not always applied.
# We use an unconditional constant_folding as the last pass to ensure a thorough constant folding.
rewrite_graph(fg)
topo_unconditional_constant_folding.apply(fg)

folded_xs = fg.outputs

if raise_not_constant and not all(isinstance(folded_x, Constant) for folded_x in folded_xs):
raise NotConstantValueError
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ numpydoc
pandas>=0.24.0
polyagamma
pre-commit>=2.8.0
pytensor>=2.25.1,<2.26
pytensor>=2.26.2,<2.27
pytest-cov>=2.5
pytest>=3.0
rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ cachetools>=4.2.1
cloudpickle
numpy>=1.15.0
pandas>=0.24.0
pytensor>=2.25.1,<2.26
pytensor>=2.26.1,<2.27
rich>=13.7.1
scipy>=1.4.1
threadpoolctl>=3.1.0,<4.0.0
Expand Down
5 changes: 5 additions & 0 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,11 @@ def test_inputs_preserved(self):
(out_shape,) = constant_fold((out.shape[0],), raise_not_constant=False)
assert out_shape is a

def test_constant_fold_alloc(self):
# By default, Alloc outputs cannot be constant folded
x = pt.alloc(pt.arange(5), 2, 5)
np.testing.assert_allclose(constant_fold([x])[0], np.broadcast_to(np.arange(5), (2, 5)))


def test_replace_vars_in_graphs():
inp = shared(0.0, name="inp")
Expand Down
Loading