Skip to content

Commit eff1cf2

Browse files
ricardoV94lucianopaz
authored andcommitted
Add constant_fold helper
1 parent ce7b81a commit eff1cf2

File tree

6 files changed

+68
-21
lines changed

6 files changed

+68
-21
lines changed

docs/source/api/aesaraf.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Aesara utils
1616
floatX
1717
intX
1818
smartfloatX
19+
constant_fold
1920
CallableTensor
2021
join_nonshared_inputs
2122
make_shared_replacements

pymc/aesaraf.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from aesara import scalar
3535
from aesara.compile.mode import Mode, get_mode
3636
from aesara.gradient import grad
37-
from aesara.graph import node_rewriter
37+
from aesara.graph import node_rewriter, rewrite_graph
3838
from aesara.graph.basic import (
3939
Apply,
4040
Constant,
@@ -55,10 +55,13 @@
5555
RandomGeneratorSharedVariable,
5656
RandomStateSharedVariable,
5757
)
58+
from aesara.tensor.rewriting.basic import topo_constant_folding
59+
from aesara.tensor.rewriting.shape import ShapeFeature
5860
from aesara.tensor.sharedvar import SharedVariable
5961
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
6062
from aesara.tensor.var import TensorConstant, TensorVariable
6163

64+
from pymc.exceptions import NotConstantValueError
6265
from pymc.vartypes import continuous_types, isgenerator, typefilter
6366

6467
PotentialShapeType = Union[int, np.ndarray, Sequence[Union[int, Variable]], TensorVariable]
@@ -82,6 +85,7 @@
8285
"at_rng",
8386
"convert_observed_data",
8487
"compile_pymc",
88+
"constant_fold",
8589
]
8690

8791

@@ -971,3 +975,30 @@ def compile_pymc(
971975
**kwargs,
972976
)
973977
return aesara_function
978+
979+
980+
def constant_fold(
981+
xs: Sequence[TensorVariable], raise_not_constant: bool = True
982+
) -> Tuple[np.ndarray, ...]:
983+
"""Use constant folding to get constant values of a graph.
984+
985+
Parameters
986+
----------
987+
xs: Sequence of TensorVariable
988+
The variables that are to be constant folded
989+
raise_not_constant: bool, default True
990+
Raises NotConstantValueError if any of the variables cannot be constant folded.
991+
This should only be disabled with care, as the graphs are cloned before
992+
attempting constant folding, and any old non-shared inputs will not work with
993+
the returned outputs
994+
"""
995+
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], clone=True)
996+
997+
folded_xs = rewrite_graph(fg, custom_rewrite=topo_constant_folding).outputs
998+
999+
if raise_not_constant and not all(isinstance(folded_x, Constant) for folded_x in folded_xs):
1000+
raise NotConstantValueError
1001+
1002+
return tuple(
1003+
folded_x.data if isinstance(folded_x, Constant) else folded_x for folded_x in folded_xs
1004+
)

pymc/distributions/logprob.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@
2727
from aeppl.tensor import MeasurableJoin
2828
from aeppl.transforms import TransformValuesRewrite
2929
from aesara import tensor as at
30-
from aesara.graph import FunctionGraph, rewrite_graph
3130
from aesara.graph.basic import graph_inputs, io_toposort
3231
from aesara.tensor.random.op import RandomVariable
33-
from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding
3432
from aesara.tensor.subtensor import (
3533
AdvancedIncSubtensor,
3634
AdvancedIncSubtensor1,
@@ -41,7 +39,7 @@
4139
)
4240
from aesara.tensor.var import TensorVariable
4341

44-
from pymc.aesaraf import floatX
42+
from pymc.aesaraf import constant_fold, floatX
4543

4644

4745
def _get_scaling(
@@ -338,12 +336,8 @@ def logprob_join_constant_shapes(op, values, axis, *base_vars, **kwargs):
338336

339337
base_var_shapes = [base_var.shape[axis] for base_var in base_vars]
340338

341-
shape_fg = FunctionGraph(
342-
outputs=base_var_shapes,
343-
features=[ShapeFeature()],
344-
clone=True,
345-
)
346-
base_var_shapes = rewrite_graph(shape_fg, custom_opt=topo_constant_folding).outputs
339+
# We don't need the graph to be constant, just to have RandomVariables removed
340+
base_var_shapes = constant_fold(base_var_shapes, raise_not_constant=False)
347341

348342
split_values = at.split(
349343
value,

pymc/distributions/timeseries.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,12 @@
2121

2222
from aeppl.abstract import _get_measurable_outputs
2323
from aeppl.logprob import _logprob
24-
from aesara.graph import FunctionGraph, rewrite_graph
2524
from aesara.graph.basic import Node, clone_replace
2625
from aesara.raise_op import Assert
2726
from aesara.tensor import TensorVariable
2827
from aesara.tensor.random.op import RandomVariable
29-
from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding
3028

31-
from pymc.aesaraf import convert_observed_data, floatX, intX
29+
from pymc.aesaraf import constant_fold, convert_observed_data, floatX, intX
3230
from pymc.distributions import distribution, multivariate
3331
from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
3432
from pymc.distributions.distribution import (
@@ -46,6 +44,7 @@
4644
convert_dims,
4745
to_tuple,
4846
)
47+
from pymc.exceptions import NotConstantValueError
4948
from pymc.model import modelcontext
5049
from pymc.util import check_dist_not_registered
5150

@@ -472,14 +471,9 @@ def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant:
472471
If inferred ar_order cannot be inferred from rhos or if it is less than 1
473472
"""
474473
if ar_order is None:
475-
shape_fg = FunctionGraph(
476-
outputs=[rhos.shape[-1]],
477-
features=[ShapeFeature()],
478-
clone=True,
479-
)
480-
(folded_shape,) = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs
481-
folded_shape = getattr(folded_shape, "data", None)
482-
if folded_shape is None:
474+
try:
475+
(folded_shape,) = constant_fold((rhos.shape[-1],))
476+
except NotConstantValueError:
483477
raise ValueError(
484478
"Could not infer ar_order from last dimension of rho. Pass it "
485479
"explictily or make sure rho have a static shape"

pymc/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,7 @@ def __init__(self, message, actual=None, expected=None):
7878

7979
class TruncationError(RuntimeError):
8080
"""Exception for errors generated from truncated graphs"""
81+
82+
83+
class NotConstantValueError(ValueError):
84+
pass

pymc/tests/test_aesaraf.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from pymc.aesaraf import (
3737
compile_pymc,
38+
constant_fold,
3839
convert_observed_data,
3940
extract_obs_data,
4041
replace_rng_nodes,
@@ -45,6 +46,7 @@
4546
from pymc.distributions.dist_math import check_parameters
4647
from pymc.distributions.distribution import SymbolicRandomVariable
4748
from pymc.distributions.transforms import Interval
49+
from pymc.exceptions import NotConstantValueError
4850
from pymc.vartypes import int_types
4951

5052

@@ -610,3 +612,24 @@ def test_reseed_rngs():
610612
assert rng.get_value()._bit_generator.state == bit_generator.state
611613
else:
612614
assert rng.get_value().bit_generator.state == bit_generator.state
615+
616+
617+
def test_constant_fold():
618+
x = at.random.normal(size=(5,))
619+
y = at.arange(x.size)
620+
621+
res = constant_fold((y, y.shape))
622+
assert np.array_equal(res[0], np.arange(5))
623+
assert tuple(res[1]) == (5,)
624+
625+
626+
def test_constant_fold_raises():
627+
size = aesara.shared(5)
628+
x = at.random.normal(size=(size,))
629+
y = at.arange(x.size)
630+
631+
with pytest.raises(NotConstantValueError):
632+
constant_fold((y, y.shape))
633+
634+
res = constant_fold((y, y.shape), raise_not_constant=False)
635+
assert tuple(res[1].eval()) == (5,)

0 commit comments

Comments
 (0)