48
48
replace_rng_nodes ,
49
49
replace_rvs_by_values ,
50
50
reseed_rngs ,
51
- rvs_to_value_vars ,
52
51
walk_model ,
53
52
)
54
53
from pymc .testing import assert_no_rvs
@@ -671,8 +670,7 @@ def test_constant_fold_raises():
671
670
class TestReplaceRVsByValues :
672
671
@pytest .mark .parametrize ("symbolic_rv" , (False , True ))
673
672
@pytest .mark .parametrize ("apply_transforms" , (True , False ))
674
- @pytest .mark .parametrize ("test_deprecated_fn" , (True , False ))
675
- def test_basic (self , symbolic_rv , apply_transforms , test_deprecated_fn ):
673
+ def test_basic (self , symbolic_rv , apply_transforms ):
676
674
# Interval transform between last two arguments
677
675
interval = (
678
676
Interval (bounds_fn = lambda * args : (args [- 2 ], args [- 1 ])) if apply_transforms else None
@@ -696,15 +694,11 @@ def test_basic(self, symbolic_rv, apply_transforms, test_deprecated_fn):
696
694
b_value_var = m .rvs_to_values [b ]
697
695
c_value_var = m .rvs_to_values [c ]
698
696
699
- if test_deprecated_fn :
700
- with pytest .warns (FutureWarning , match = "Use model.replace_rvs_by_values instead" ):
701
- (res ,) = rvs_to_value_vars ((d ,), apply_transforms = apply_transforms )
702
- else :
703
- (res ,) = replace_rvs_by_values (
704
- (d ,),
705
- rvs_to_values = m .rvs_to_values ,
706
- rvs_to_transforms = m .rvs_to_transforms ,
707
- )
697
+ (res ,) = replace_rvs_by_values (
698
+ (d ,),
699
+ rvs_to_values = m .rvs_to_values ,
700
+ rvs_to_transforms = m .rvs_to_transforms ,
701
+ )
708
702
709
703
assert res .owner .op == pt .add
710
704
log_output = res .owner .inputs [0 ]
@@ -740,8 +734,7 @@ def test_basic(self, symbolic_rv, apply_transforms, test_deprecated_fn):
740
734
else :
741
735
assert a_value_var not in res_ancestors
742
736
743
- @pytest .mark .parametrize ("test_deprecated_fn" , (True , False ))
744
- def test_unvalued_rv (self , test_deprecated_fn ):
737
+ def test_unvalued_rv (self ):
745
738
with pm .Model () as m :
746
739
x = pm .Normal ("x" )
747
740
y = pm .Normal .dist (x )
@@ -751,15 +744,11 @@ def test_unvalued_rv(self, test_deprecated_fn):
751
744
x_value = m .rvs_to_values [x ]
752
745
z_value = m .rvs_to_values [z ]
753
746
754
- if test_deprecated_fn :
755
- with pytest .warns (FutureWarning , match = "Use model.replace_rvs_by_values instead" ):
756
- (res ,) = rvs_to_value_vars ((out ,))
757
- else :
758
- (res ,) = replace_rvs_by_values (
759
- (out ,),
760
- rvs_to_values = m .rvs_to_values ,
761
- rvs_to_transforms = m .rvs_to_transforms ,
762
- )
747
+ (res ,) = replace_rvs_by_values (
748
+ (out ,),
749
+ rvs_to_values = m .rvs_to_values ,
750
+ rvs_to_transforms = m .rvs_to_transforms ,
751
+ )
763
752
764
753
assert res .owner .op == pt .add
765
754
assert res .owner .inputs [0 ] is z_value
@@ -769,8 +758,7 @@ def test_unvalued_rv(self, test_deprecated_fn):
769
758
assert res_y .owner .op == pt .random .normal
770
759
assert res_y .owner .inputs [3 ] is x_value
771
760
772
- @pytest .mark .parametrize ("test_deprecated_fn" , (True , False ))
773
- def test_no_change_inplace (self , test_deprecated_fn ):
761
+ def test_no_change_inplace (self ):
774
762
# Test that calling rvs_to_value_vars in models with nested transformations
775
763
# does not change the original rvs in place. See issue #5172
776
764
with pm .Model () as m :
@@ -784,22 +772,17 @@ def test_no_change_inplace(self, test_deprecated_fn):
784
772
before = pytensor .clone_replace (m .free_RVs )
785
773
786
774
# This call would change the model free_RVs in place in #5172
787
- if test_deprecated_fn :
788
- with pytest .warns (FutureWarning , match = "Use model.replace_rvs_by_values instead" ):
789
- rvs_to_value_vars (m .potentials )
790
- else :
791
- replace_rvs_by_values (
792
- m .potentials ,
793
- rvs_to_values = m .rvs_to_values ,
794
- rvs_to_transforms = m .rvs_to_transforms ,
795
- )
775
+ replace_rvs_by_values (
776
+ m .potentials ,
777
+ rvs_to_values = m .rvs_to_values ,
778
+ rvs_to_transforms = m .rvs_to_transforms ,
779
+ )
796
780
797
781
after = pytensor .clone_replace (m .free_RVs )
798
782
assert equal_computations (before , after )
799
783
800
- @pytest .mark .parametrize ("test_deprecated_fn" , (True , False ))
801
784
@pytest .mark .parametrize ("reversed" , (False , True ))
802
- def test_interdependent_transformed_rvs (self , reversed , test_deprecated_fn ):
785
+ def test_interdependent_transformed_rvs (self , reversed ):
803
786
# Test that nested transformed variables, whose transformed values depend on other
804
787
# RVs are properly replaced
805
788
with pm .Model () as m :
@@ -815,15 +798,11 @@ def test_interdependent_transformed_rvs(self, reversed, test_deprecated_fn):
815
798
if reversed :
816
799
rvs = rvs [::- 1 ]
817
800
818
- if test_deprecated_fn :
819
- with pytest .warns (FutureWarning , match = "Use model.replace_rvs_by_values instead" ):
820
- transform_values = rvs_to_value_vars (rvs )
821
- else :
822
- transform_values = replace_rvs_by_values (
823
- rvs ,
824
- rvs_to_values = m .rvs_to_values ,
825
- rvs_to_transforms = m .rvs_to_transforms ,
826
- )
801
+ transform_values = replace_rvs_by_values (
802
+ rvs ,
803
+ rvs_to_values = m .rvs_to_values ,
804
+ rvs_to_transforms = m .rvs_to_transforms ,
805
+ )
827
806
828
807
for transform_value in transform_values :
829
808
assert_no_rvs (transform_value )
0 commit comments