Skip to content

Commit cb02b26

Browse files
committed
Merge functionality of pytensorf and logprob/utils
Also fixes circular imports
1 parent 55cc5aa commit cb02b26

File tree

15 files changed

+385
-686
lines changed

15 files changed

+385
-686
lines changed

pymc/gp/cov.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
"Kron",
4949
]
5050

51+
from pymc.pytensorf import constant_fold
52+
5153
TensorLike = Union[np.ndarray, TensorVariable]
5254
IntSequence = Union[np.ndarray, Sequence[int]]
5355

@@ -183,9 +185,6 @@ def n_dims(self) -> int:
183185
def _slice(self, X, Xs=None):
184186
xdims = X.shape[-1]
185187
if isinstance(xdims, Variable):
186-
# Circular dependency
187-
from pymc.pytensorf import constant_fold
188-
189188
[xdims] = constant_fold([xdims])
190189
if self.input_dim != xdims:
191190
warnings.warn(

pymc/gp/util.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
import pytensor.tensor as pt
1919

2020
from pytensor.compile import SharedVariable
21+
from pytensor.graph import ancestors
2122
from pytensor.tensor.variable import TensorConstant
2223
from scipy.cluster.vq import kmeans
2324

2425
# Avoid circular dependency when importing modelcontext
2526
from pymc.distributions.distribution import Distribution
2627
from pymc.model import modelcontext
27-
from pymc.pytensorf import compile_pymc, walk_model
28+
from pymc.pytensorf import compile_pymc
2829

2930
_ = Distribution # keep both pylint and black happy
3031

@@ -48,7 +49,7 @@ def replace_with_values(vars_needed, replacements=None, model=None):
4849
model = modelcontext(model)
4950

5051
inputs, input_names = [], []
51-
for rv in walk_model(vars_needed):
52+
for rv in ancestors(vars_needed):
5253
if rv in model.named_vars.values() and not isinstance(rv, SharedVariable):
5354
inputs.append(rv)
5455
input_names.append(rv.name)

pymc/logprob/basic.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@
6565
from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph
6666
from pymc.logprob.transform_value import TransformValuesRewrite
6767
from pymc.logprob.transforms import Transform
68-
from pymc.logprob.utils import find_rvs_in_graph, rvs_to_value_vars
68+
from pymc.logprob.utils import rvs_in_graph
69+
from pymc.pytensorf import replace_vars_in_graphs
6970

7071
TensorLike: TypeAlias = Union[Variable, float, np.ndarray]
7172

@@ -76,7 +77,7 @@ def _find_unallowed_rvs_in_graph(graph):
7677

7778
return {
7879
rv
79-
for rv in find_rvs_in_graph(graph)
80+
for rv in rvs_in_graph(graph)
8081
if not isinstance(rv.owner.op, (SimulatorRV, MinibatchIndexRV))
8182
}
8283

@@ -530,11 +531,9 @@ def conditional_logp(
530531
continue
531532

532533
# Replace `RandomVariable`s in the inputs with value variables.
533-
# Also, store the results in the `replacements` map for the nodes
534-
# that follow.
535-
remapped_vars, _ = rvs_to_value_vars(
536-
q_values + list(node.inputs),
537-
initial_replacements=replacements,
534+
remapped_vars = replace_vars_in_graphs(
535+
graphs=q_values + list(node.inputs),
536+
replacements=replacements,
538537
)
539538
q_values = remapped_vars[: len(q_values)]
540539
q_rv_inputs = remapped_vars[len(q_values) :]
@@ -562,8 +561,7 @@ def conditional_logp(
562561

563562
logprob_vars[q_value_var] = q_logprob_var
564563

565-
# Recompute test values for the changes introduced by the
566-
# replacements above.
564+
# Recompute test values for the changes introduced by the replacements above.
567565
if config.compute_test_value != "off":
568566
for node in io_toposort(graph_inputs(q_logprob_vars), q_logprob_vars):
569567
compute_test_value(node)

pymc/logprob/checks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
4646
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
47+
from pymc.logprob.utils import replace_rvs_by_values
4748

4849

4950
class MeasurableSpecifyShape(SpecifyShape):
@@ -107,8 +108,6 @@ class MeasurableCheckAndRaise(CheckAndRaise):
107108

108109
@_logprob.register(MeasurableCheckAndRaise)
109110
def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs):
110-
from pymc.pytensorf import replace_rvs_by_values
111-
112111
(value,) = values
113112
# transfer assertion from rv to value
114113
assertions = replace_rvs_by_values(assertions, rvs_to_values={inner_rv: value})

pymc/logprob/mixture.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@
7878
measurable_ir_rewrites_db,
7979
subtensor_ops,
8080
)
81-
from pymc.logprob.utils import check_potential_measurability
81+
from pymc.logprob.utils import check_potential_measurability, replace_rvs_by_values
82+
from pymc.pytensorf import constant_fold
8283

8384

8485
def is_newaxis(x):
@@ -255,9 +256,6 @@ def get_stack_mixture_vars(
255256
mixture_rvs = joined_rvs.owner.inputs
256257

257258
elif isinstance(joined_rvs.owner.op, Join):
258-
# TODO: Find better solution to avoid this circular dependency
259-
from pymc.pytensorf import constant_fold
260-
261259
join_axis = joined_rvs.owner.inputs[0]
262260
# TODO: Support symbolic join axes. This will raise ValueError if it's not a constant
263261
(join_axis,) = constant_fold((join_axis,), raise_not_constant=False)
@@ -351,9 +349,6 @@ def logprob_MixtureRV(
351349
comp_rvs = [comp[None] for comp in comp_rvs]
352350
original_shape = (len(comp_rvs),)
353351
else:
354-
# TODO: Find better solution to avoid this circular dependency
355-
from pymc.pytensorf import constant_fold
356-
357352
join_axis_val = constant_fold((join_axis,))[0].item()
358353
original_shape = shape_tuple(comp_rvs[0])
359354

@@ -544,7 +539,6 @@ def find_measurable_ifelse_mixture(fgraph, node):
544539
@_logprob.register(MeasurableIfElse)
545540
def logprob_ifelse(op, values, if_var, *base_rvs, **kwargs):
546541
"""Compute the log-likelihood graph for an `IfElse`."""
547-
from pymc.pytensorf import replace_rvs_by_values
548542

549543
assert len(values) * 2 == len(base_rvs)
550544

pymc/logprob/scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
logprob_rewrites_db,
6363
measurable_ir_rewrites_db,
6464
)
65-
from pymc.pytensorf import replace_rvs_by_values
65+
from pymc.logprob.utils import replace_rvs_by_values
6666

6767

6868
class MeasurableScan(Scan):

pymc/logprob/tensor.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@
5555
assume_measured_ir_outputs,
5656
measurable_ir_rewrites_db,
5757
)
58-
from pymc.logprob.utils import check_potential_measurability
58+
from pymc.logprob.utils import check_potential_measurability, replace_rvs_by_values
59+
from pymc.pytensorf import constant_fold
5960

6061

6162
@node_rewriter([Alloc])
@@ -131,7 +132,6 @@ class MeasurableMakeVector(MakeVector):
131132
def logprob_make_vector(op, values, *base_rvs, **kwargs):
132133
"""Compute the log-likelihood graph for a `MeasurableMakeVector`."""
133134
# TODO: Sort out this circular dependency issue
134-
from pymc.pytensorf import replace_rvs_by_values
135135

136136
(value,) = values
137137

@@ -158,9 +158,6 @@ class MeasurableJoin(Join):
158158
@_logprob.register(MeasurableJoin)
159159
def logprob_join(op, values, axis, *base_rvs, **kwargs):
160160
"""Compute the log-likelihood graph for a `Join`."""
161-
# TODO: Find better way to avoid circular dependency
162-
from pymc.pytensorf import constant_fold, replace_rvs_by_values
163-
164161
(value,) = values
165162

166163
base_rv_shapes = [base_var.shape[axis] for base_var in base_rvs]

0 commit comments

Comments
 (0)