Skip to content

Commit 0044bf1

Browse files
committed
Remove deprecated function rvs_to_value_vars
1 parent c5115ee commit 0044bf1

File tree

2 files changed

+24
-105
lines changed

2 files changed

+24
-105
lines changed

pymc/pytensorf.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -271,66 +271,6 @@ def expand_replace(var):
271271
return graphs, replacements
272272

273273

274-
def rvs_to_value_vars(
275-
graphs: Iterable[Variable],
276-
apply_transforms: bool = True,
277-
**kwargs,
278-
) -> List[Variable]:
279-
"""Clone and replace random variables in graphs with their value variables.
280-
281-
This will *not* recompute test values in the resulting graphs.
282-
283-
Parameters
284-
----------
285-
graphs
286-
The graphs in which to perform the replacements.
287-
apply_transforms
288-
If ``True``, apply each value variable's transform.
289-
"""
290-
warnings.warn(
291-
"rvs_to_value_vars is deprecated. Use model.replace_rvs_by_values instead",
292-
FutureWarning,
293-
)
294-
295-
def populate_replacements(
296-
random_var: TensorVariable, replacements: Dict[TensorVariable, TensorVariable]
297-
) -> List[TensorVariable]:
298-
# Populate replacements dict with {rv: value} pairs indicating which graph
299-
# RVs should be replaced by what value variables.
300-
301-
value_var = getattr(
302-
random_var.tag, "observations", getattr(random_var.tag, "value_var", None)
303-
)
304-
305-
# No value variable to replace RV with
306-
if value_var is None:
307-
return []
308-
309-
transform = getattr(value_var.tag, "transform", None)
310-
if transform is not None and apply_transforms:
311-
# We want to replace uses of the RV by the back-transformation of its value
312-
value_var = transform.backward(value_var, *random_var.owner.inputs)
313-
314-
replacements[random_var] = value_var
315-
316-
# Also walk the graph of the value variable to make any additional replacements
317-
# if that is not a simple input variable
318-
return [value_var]
319-
320-
# Clone original graphs
321-
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
322-
equiv = clone_get_equiv(inputs, graphs, False, False, {})
323-
graphs = [equiv[n] for n in graphs]
324-
325-
graphs, _ = _replace_vars_in_graphs(
326-
graphs,
327-
replacement_fn=populate_replacements,
328-
**kwargs,
329-
)
330-
331-
return graphs
332-
333-
334274
def replace_rvs_by_values(
335275
graphs: Sequence[TensorVariable],
336276
*,

tests/test_pytensorf.py

Lines changed: 24 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
replace_rng_nodes,
4949
replace_rvs_by_values,
5050
reseed_rngs,
51-
rvs_to_value_vars,
5251
walk_model,
5352
)
5453
from pymc.testing import assert_no_rvs
@@ -671,8 +670,7 @@ def test_constant_fold_raises():
671670
class TestReplaceRVsByValues:
672671
@pytest.mark.parametrize("symbolic_rv", (False, True))
673672
@pytest.mark.parametrize("apply_transforms", (True, False))
674-
@pytest.mark.parametrize("test_deprecated_fn", (True, False))
675-
def test_basic(self, symbolic_rv, apply_transforms, test_deprecated_fn):
673+
def test_basic(self, symbolic_rv, apply_transforms):
676674
# Interval transform between last two arguments
677675
interval = (
678676
Interval(bounds_fn=lambda *args: (args[-2], args[-1])) if apply_transforms else None
@@ -696,15 +694,11 @@ def test_basic(self, symbolic_rv, apply_transforms, test_deprecated_fn):
696694
b_value_var = m.rvs_to_values[b]
697695
c_value_var = m.rvs_to_values[c]
698696

699-
if test_deprecated_fn:
700-
with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"):
701-
(res,) = rvs_to_value_vars((d,), apply_transforms=apply_transforms)
702-
else:
703-
(res,) = replace_rvs_by_values(
704-
(d,),
705-
rvs_to_values=m.rvs_to_values,
706-
rvs_to_transforms=m.rvs_to_transforms,
707-
)
697+
(res,) = replace_rvs_by_values(
698+
(d,),
699+
rvs_to_values=m.rvs_to_values,
700+
rvs_to_transforms=m.rvs_to_transforms,
701+
)
708702

709703
assert res.owner.op == pt.add
710704
log_output = res.owner.inputs[0]
@@ -740,8 +734,7 @@ def test_basic(self, symbolic_rv, apply_transforms, test_deprecated_fn):
740734
else:
741735
assert a_value_var not in res_ancestors
742736

743-
@pytest.mark.parametrize("test_deprecated_fn", (True, False))
744-
def test_unvalued_rv(self, test_deprecated_fn):
737+
def test_unvalued_rv(self):
745738
with pm.Model() as m:
746739
x = pm.Normal("x")
747740
y = pm.Normal.dist(x)
@@ -751,15 +744,11 @@ def test_unvalued_rv(self, test_deprecated_fn):
751744
x_value = m.rvs_to_values[x]
752745
z_value = m.rvs_to_values[z]
753746

754-
if test_deprecated_fn:
755-
with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"):
756-
(res,) = rvs_to_value_vars((out,))
757-
else:
758-
(res,) = replace_rvs_by_values(
759-
(out,),
760-
rvs_to_values=m.rvs_to_values,
761-
rvs_to_transforms=m.rvs_to_transforms,
762-
)
747+
(res,) = replace_rvs_by_values(
748+
(out,),
749+
rvs_to_values=m.rvs_to_values,
750+
rvs_to_transforms=m.rvs_to_transforms,
751+
)
763752

764753
assert res.owner.op == pt.add
765754
assert res.owner.inputs[0] is z_value
@@ -769,8 +758,7 @@ def test_unvalued_rv(self, test_deprecated_fn):
769758
assert res_y.owner.op == pt.random.normal
770759
assert res_y.owner.inputs[3] is x_value
771760

772-
@pytest.mark.parametrize("test_deprecated_fn", (True, False))
773-
def test_no_change_inplace(self, test_deprecated_fn):
761+
def test_no_change_inplace(self):
774762
# Test that calling rvs_to_value_vars in models with nested transformations
775763
# does not change the original rvs in place. See issue #5172
776764
with pm.Model() as m:
@@ -784,22 +772,17 @@ def test_no_change_inplace(self, test_deprecated_fn):
784772
before = pytensor.clone_replace(m.free_RVs)
785773

786774
# This call would change the model free_RVs in place in #5172
787-
if test_deprecated_fn:
788-
with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"):
789-
rvs_to_value_vars(m.potentials)
790-
else:
791-
replace_rvs_by_values(
792-
m.potentials,
793-
rvs_to_values=m.rvs_to_values,
794-
rvs_to_transforms=m.rvs_to_transforms,
795-
)
775+
replace_rvs_by_values(
776+
m.potentials,
777+
rvs_to_values=m.rvs_to_values,
778+
rvs_to_transforms=m.rvs_to_transforms,
779+
)
796780

797781
after = pytensor.clone_replace(m.free_RVs)
798782
assert equal_computations(before, after)
799783

800-
@pytest.mark.parametrize("test_deprecated_fn", (True, False))
801784
@pytest.mark.parametrize("reversed", (False, True))
802-
def test_interdependent_transformed_rvs(self, reversed, test_deprecated_fn):
785+
def test_interdependent_transformed_rvs(self, reversed):
803786
# Test that nested transformed variables, whose transformed values depend on other
804787
# RVs are properly replaced
805788
with pm.Model() as m:
@@ -815,15 +798,11 @@ def test_interdependent_transformed_rvs(self, reversed, test_deprecated_fn):
815798
if reversed:
816799
rvs = rvs[::-1]
817800

818-
if test_deprecated_fn:
819-
with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"):
820-
transform_values = rvs_to_value_vars(rvs)
821-
else:
822-
transform_values = replace_rvs_by_values(
823-
rvs,
824-
rvs_to_values=m.rvs_to_values,
825-
rvs_to_transforms=m.rvs_to_transforms,
826-
)
801+
transform_values = replace_rvs_by_values(
802+
rvs,
803+
rvs_to_values=m.rvs_to_values,
804+
rvs_to_transforms=m.rvs_to_transforms,
805+
)
827806

828807
for transform_value in transform_values:
829808
assert_no_rvs(transform_value)

0 commit comments

Comments
 (0)