Skip to content

Commit 29500af

Browse files
committed
Make rvs_to_values work with non-RandomVariables
No longer returns replacements dict from `rvs_to_values` as the keys will be the cloned `rvs` which are not useful to the caller.
1 parent d19f575 commit 29500af

File tree

8 files changed

+132
-114
lines changed

8 files changed

+132
-114
lines changed

pymc/aesaraf.py

Lines changed: 36 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import warnings
15-
1614
from typing import (
1715
Callable,
1816
Dict,
@@ -147,32 +145,6 @@ def dataframe_to_tensor_variable(df: pd.DataFrame, *args, **kwargs) -> TensorVar
147145
return at.as_tensor_variable(df.to_numpy(), *args, **kwargs)
148146

149147

150-
def extract_rv_and_value_vars(
151-
var: TensorVariable,
152-
) -> Tuple[TensorVariable, TensorVariable]:
153-
"""Return a random variable and it's observations or value variable, or ``None``.
154-
155-
Parameters
156-
==========
157-
var
158-
A variable corresponding to a ``RandomVariable``.
159-
160-
Returns
161-
=======
162-
The first value in the tuple is the ``RandomVariable``, and the second is the
163-
measure/log-likelihood value variable that corresponds with the latter.
164-
165-
"""
166-
if not var.owner:
167-
return None, None
168-
169-
if isinstance(var.owner.op, RandomVariable):
170-
rv_value = getattr(var.tag, "observations", getattr(var.tag, "value_var", None))
171-
return var, rv_value
172-
173-
return None, None
174-
175-
176148
def extract_obs_data(x: TensorVariable) -> np.ndarray:
177149
"""Extract data from observed symbolic variables.
178150
@@ -200,20 +172,15 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:
200172

201173
def walk_model(
202174
graphs: Iterable[TensorVariable],
203-
walk_past_rvs: bool = False,
204175
stop_at_vars: Optional[Set[TensorVariable]] = None,
205176
expand_fn: Callable[[TensorVariable], Iterable[TensorVariable]] = lambda var: [],
206177
) -> Generator[TensorVariable, None, None]:
207178
"""Walk model graphs and yield their nodes.
208179
209-
By default, these walks will not go past ``RandomVariable`` nodes.
210-
211180
Parameters
212181
==========
213182
graphs
214183
The graphs to walk.
215-
walk_past_rvs
216-
If ``True``, the walk will not terminate at ``RandomVariable``s.
217184
stop_at_vars
218185
A list of variables at which the walk will terminate.
219186
expand_fn
@@ -225,16 +192,12 @@ def walk_model(
225192
def expand(var):
226193
new_vars = expand_fn(var)
227194

228-
if (
229-
var.owner
230-
and (walk_past_rvs or not isinstance(var.owner.op, RandomVariable))
231-
and (var not in stop_at_vars)
232-
):
195+
if var.owner and var not in stop_at_vars:
233196
new_vars.extend(reversed(var.owner.inputs))
234197

235198
return new_vars
236199

237-
yield from walk(graphs, expand, False)
200+
yield from walk(graphs, expand, bfs=False)
238201

239202

240203
def replace_rvs_in_graphs(
@@ -263,7 +226,11 @@ def replace_rvs_in_graphs(
263226

264227
def expand_replace(var):
265228
new_nodes = []
266-
if var.owner and isinstance(var.owner.op, RandomVariable):
229+
if var.owner:
230+
# Call replacement_fn to update replacements dict inplace and, optionally,
231+
# specify new nodes that should also be walked for replacements. This
232+
# includes `value` variables that are not simple input variables, and may
233+
# contain other `random` variables in their graphs (e.g., IntervalTransform)
267234
new_nodes.extend(replacement_fn(var, replacements))
268235
return new_nodes
269236

@@ -290,10 +257,10 @@ def expand_replace(var):
290257

291258
def rvs_to_value_vars(
292259
graphs: Iterable[TensorVariable],
293-
apply_transforms: bool = False,
260+
apply_transforms: bool = True,
294261
initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None,
295262
**kwargs,
296-
) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]:
263+
) -> TensorVariable:
297264
"""Clone and replace random variables in graphs with their value variables.
298265
299266
This will *not* recompute test values in the resulting graphs.
@@ -309,38 +276,30 @@ def rvs_to_value_vars(
309276
310277
"""
311278

312-
# Avoid circular dependency
313-
from pymc.distributions import NoDistribution
314-
315-
def transform_replacements(var, replacements):
316-
rv_var, rv_value_var = extract_rv_and_value_vars(var)
317-
318-
if rv_value_var is None:
319-
# If RandomVariable does not have a value_var and corresponds to
320-
# a NoDistribution, we allow further replacements in upstream graph
321-
if isinstance(rv_var.owner.op, NoDistribution):
322-
return rv_var.owner.inputs
279+
def populate_replacements(
280+
random_var: TensorVariable, replacements: Dict[TensorVariable, TensorVariable]
281+
) -> List[TensorVariable]:
282+
# Populate replacements dict with {rv: value} pairs indicating which graph
283+
# RVs should be replaced by what value variables.
323284

324-
else:
325-
warnings.warn(
326-
f"No value variable found for {rv_var}; "
327-
"the random variable will not be replaced."
328-
)
329-
return []
285+
value_var = getattr(
286+
random_var.tag, "observations", getattr(random_var.tag, "value_var", None)
287+
)
330288

331-
transform = getattr(rv_value_var.tag, "transform", None)
289+
# No value variable to replace RV with
290+
if value_var is None:
291+
return []
332292

333-
if transform is None or not apply_transforms:
334-
replacements[var] = rv_value_var
335-
# In case the value variable is itself a graph, we walk it for
336-
# potential replacements
337-
return [rv_value_var]
293+
transform = getattr(value_var.tag, "transform", None)
294+
if transform is not None and apply_transforms:
295+
# We want to replace uses of the RV by the back-transformation of its value
296+
value_var = transform.backward(value_var, *random_var.owner.inputs)
338297

339-
trans_rv_value = transform.backward(rv_value_var, *rv_var.owner.inputs)
340-
replacements[var] = trans_rv_value
298+
replacements[random_var] = value_var
341299

342-
# Walk the transformed variable and make replacements
343-
return [trans_rv_value]
300+
# Also walk the graph of the value variable to make any additional replacements
301+
# if that is not a simple input variable
302+
return [value_var]
344303

345304
# Clone original graphs
346305
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
@@ -352,7 +311,14 @@ def transform_replacements(var, replacements):
352311
equiv.get(k, k): equiv.get(v, v) for k, v in initial_replacements.items()
353312
}
354313

355-
return replace_rvs_in_graphs(graphs, transform_replacements, initial_replacements, **kwargs)
314+
graphs, _ = replace_rvs_in_graphs(
315+
graphs,
316+
replacement_fn=populate_replacements,
317+
initial_replacements=initial_replacements,
318+
**kwargs,
319+
)
320+
321+
return graphs
356322

357323

358324
def inputvars(a):

pymc/gp/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def replace_with_values(vars_needed, replacements=None, model=None):
5757
model = modelcontext(model)
5858

5959
inputs, input_names = [], []
60-
for rv in walk_model(vars_needed, walk_past_rvs=True):
60+
for rv in walk_model(vars_needed):
6161
if rv in model.named_vars.values() and not isinstance(rv, SharedVariable):
6262
inputs.append(rv)
6363
input_names.append(rv.name)

pymc/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ def logp(
761761
# Replace random variables by their value variables in potential terms
762762
potential_logps = []
763763
if potentials:
764-
potential_logps, _ = rvs_to_value_vars(potentials, apply_transforms=True)
764+
potential_logps = rvs_to_value_vars(potentials)
765765

766766
logp_factors = [None] * len(varlist)
767767
for logp_order, logp in zip((rv_order + potential_order), (rv_logps + potential_logps)):
@@ -935,7 +935,7 @@ def potentiallogp(self) -> Variable:
935935
"""Aesara scalar of log-probability of the Potential terms"""
936936
# Convert random variables in Potential expression into their log-likelihood
937937
# inputs and apply their transforms, if any
938-
potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True)
938+
potentials = rvs_to_value_vars(self.potentials)
939939
if potentials:
940940
return at.sum([at.sum(factor) for factor in potentials])
941941
else:
@@ -976,10 +976,10 @@ def unobserved_value_vars(self):
976976
vars.append(value_var)
977977

978978
# Remove rvs from untransformed values graph
979-
untransformed_vars, _ = rvs_to_value_vars(untransformed_vars, apply_transforms=True)
979+
untransformed_vars = rvs_to_value_vars(untransformed_vars)
980980

981981
# Remove rvs from deterministics graph
982-
deterministics, _ = rvs_to_value_vars(self.deterministics, apply_transforms=True)
982+
deterministics = rvs_to_value_vars(self.deterministics)
983983

984984
return vars + untransformed_vars + deterministics
985985

pymc/step_methods/metropolis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):
583583

584584
if isinstance(distr, CategoricalRV):
585585
k_graph = rv_var.owner.inputs[3].shape[-1]
586-
(k_graph,), _ = rvs_to_value_vars((k_graph,), apply_transforms=True)
586+
(k_graph,) = rvs_to_value_vars((k_graph,))
587587
k = model.compile_fn(k_graph, inputs=model.value_vars, on_unused_input="ignore")(
588588
initial_point
589589
)

pymc/tests/distributions/test_logprob.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_joint_logp_basic():
129129
with pytest.warns(FutureWarning):
130130
b_logpt = joint_logpt(b, b_value_var, sum=False)
131131

132-
res_ancestors = list(walk_model(b_logp, walk_past_rvs=True))
132+
res_ancestors = list(walk_model(b_logp))
133133
res_rv_ancestors = [
134134
v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable)
135135
]

0 commit comments

Comments
 (0)