Skip to content

Commit 034b9a4

Browse files
committed
Make default STEP_METHODS a list that can be modified
1 parent 22e8f0b commit 034b9a4

File tree

3 files changed

+16
-13
lines changed

3 files changed

+16
-13
lines changed

pymc/step_methods/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from pymc.step_methods.compound import CompoundStep
15+
from pymc.step_methods.compound import BlockedStep, CompoundStep
1616
from pymc.step_methods.hmc import NUTS, HamiltonianMC
1717
from pymc.step_methods.metropolis import (
1818
BinaryGibbsMetropolis,
@@ -30,12 +30,13 @@
3030
)
3131
from pymc.step_methods.slicer import Slice
3232

33-
STEP_METHODS = (
33+
# Other step methods can be added by appending to this list
34+
STEP_METHODS: list[type[BlockedStep]] = [
3435
NUTS,
3536
HamiltonianMC,
3637
Metropolis,
3738
BinaryMetropolis,
3839
BinaryGibbsMetropolis,
3940
Slice,
4041
CategoricalGibbsMetropolis,
41-
)
42+
]

tests/sampling/test_mcmc.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -762,12 +762,18 @@ def kill_grad(x):
762762
steps = assign_step_methods(model, [])
763763
assert isinstance(steps, Slice)
764764

765-
def test_modify_step_methods(self):
765+
@pytest.fixture
766+
def step_methods(self):
767+
"""Make sure we reset the STEP_METHODS after the test is done."""
768+
methods_copy = pm.STEP_METHODS.copy()
769+
yield pm.STEP_METHODS
770+
pm.STEP_METHODS.clear()
771+
for method in methods_copy:
772+
pm.STEP_METHODS.append(method)
773+
774+
def test_modify_step_methods(self, step_methods):
766775
"""Test step methods can be changed"""
767-
# remove nuts from step_methods
768-
step_methods = list(pm.STEP_METHODS)
769776
step_methods.remove(NUTS)
770-
pm.STEP_METHODS = step_methods
771777

772778
with pm.Model() as model:
773779
pm.Normal("x", 0, 1)
@@ -776,7 +782,7 @@ def test_modify_step_methods(self):
776782
assert not isinstance(steps, NUTS)
777783

778784
# add back nuts
779-
pm.STEP_METHODS = [*step_methods, NUTS]
785+
step_methods.append(NUTS)
780786

781787
with pm.Model() as model:
782788
pm.Normal("x", 0, 1)

tests/step_methods/test_compound.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
Slice,
2727
)
2828
from pymc.step_methods.compound import (
29-
BlockedStep,
3029
StatsBijection,
3130
flatten_steps,
3231
get_stats_dtypes_shapes_from_steps,
@@ -38,10 +37,7 @@
3837

3938

4039
def test_all_stepmethods_emit_tune_stat():
41-
attrs = [getattr(pm.step_methods, n) for n in dir(pm.step_methods)]
42-
step_types = [
43-
attr for attr in attrs if isinstance(attr, type) and issubclass(attr, BlockedStep)
44-
]
40+
step_types = pm.step_methods.STEP_METHODS
4541
assert len(step_types) > 5
4642
for cls in step_types:
4743
assert "tune" in cls.stats_dtypes_shapes

0 commit comments

Comments
 (0)