File tree Expand file tree Collapse file tree 3 files changed +16
-13
lines changed Expand file tree Collapse file tree 3 files changed +16
-13
lines changed Original file line number Diff line number Diff line change 12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from pymc .step_methods .compound import CompoundStep
15
+ from pymc .step_methods .compound import BlockedStep , CompoundStep
16
16
from pymc .step_methods .hmc import NUTS , HamiltonianMC
17
17
from pymc .step_methods .metropolis import (
18
18
BinaryGibbsMetropolis ,
30
30
)
31
31
from pymc .step_methods .slicer import Slice
32
32
33
- STEP_METHODS = (
33
+ # Other step methods can be added by appending to this list
34
+ STEP_METHODS : list [type [BlockedStep ]] = [
34
35
NUTS ,
35
36
HamiltonianMC ,
36
37
Metropolis ,
37
38
BinaryMetropolis ,
38
39
BinaryGibbsMetropolis ,
39
40
Slice ,
40
41
CategoricalGibbsMetropolis ,
41
- )
42
+ ]
Original file line number Diff line number Diff line change @@ -762,12 +762,18 @@ def kill_grad(x):
762
762
steps = assign_step_methods (model , [])
763
763
assert isinstance (steps , Slice )
764
764
765
- def test_modify_step_methods (self ):
765
+ @pytest .fixture
766
+ def step_methods (self ):
767
+ """Make sure we reset the STEP_METHODS after the test is done."""
768
+ methods_copy = pm .STEP_METHODS .copy ()
769
+ yield pm .STEP_METHODS
770
+ pm .STEP_METHODS .clear ()
771
+ for method in methods_copy :
772
+ pm .STEP_METHODS .append (method )
773
+
774
+ def test_modify_step_methods (self , step_methods ):
766
775
"""Test step methods can be changed"""
767
- # remove nuts from step_methods
768
- step_methods = list (pm .STEP_METHODS )
769
776
step_methods .remove (NUTS )
770
- pm .STEP_METHODS = step_methods
771
777
772
778
with pm .Model () as model :
773
779
pm .Normal ("x" , 0 , 1 )
@@ -776,7 +782,7 @@ def test_modify_step_methods(self):
776
782
assert not isinstance (steps , NUTS )
777
783
778
784
# add back nuts
779
- pm . STEP_METHODS = [ * step_methods , NUTS ]
785
+ step_methods . append ( NUTS )
780
786
781
787
with pm .Model () as model :
782
788
pm .Normal ("x" , 0 , 1 )
Original file line number Diff line number Diff line change 26
26
Slice ,
27
27
)
28
28
from pymc .step_methods .compound import (
29
- BlockedStep ,
30
29
StatsBijection ,
31
30
flatten_steps ,
32
31
get_stats_dtypes_shapes_from_steps ,
38
37
39
38
40
39
def test_all_stepmethods_emit_tune_stat ():
41
- attrs = [getattr (pm .step_methods , n ) for n in dir (pm .step_methods )]
42
- step_types = [
43
- attr for attr in attrs if isinstance (attr , type ) and issubclass (attr , BlockedStep )
44
- ]
40
+ step_types = pm .step_methods .STEP_METHODS
45
41
assert len (step_types ) > 5
46
42
for cls in step_types :
47
43
assert "tune" in cls .stats_dtypes_shapes
You can’t perform that action at this time.
0 commit comments