Skip to content

Commit ea34371

Browse files
committed
Change class checks to using isinstance
1 parent 3fb364b commit ea34371

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

pymc3/sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,7 @@ def _iter_sample(
903903
try:
904904
step.tune = bool(tune)
905905
if hasattr(step, 'reset_tuning'):
906-
if step.__class__.__name__ == 'CompoundStep':
906+
if isinstance(step, CompoundStep):
907907
if hasattr(step.methods[0], 'is_mlda_base'):
908908
if not step.methods[0].is_mlda_base:
909909
step.reset_tuning()

pymc3/step_methods/metropolis.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from ..distributions import draw_values
2323
from .arraystep import ArrayStepShared, PopulationArrayStepShared, ArrayStep, metrop_select, Competence
24+
from .compound import CompoundStep
2425
import pymc3 as pm
2526
from pymc3.theanof import floatX
2627

@@ -1145,10 +1146,10 @@ def astep(self, q0):
11451146

11461147
# Capture latest base chain scaling stats from next step method
11471148
self.base_scaling_stats = {}
1148-
if self.next_step_method.__class__.__name__ == "CompoundStep":
1149+
if isinstance(self.next_step_method, CompoundStep):
11491150
for method in self.next_step_method.methods:
11501151
self.base_scaling_stats["base_scaling_" + method.vars[0].name] = method.scaling
1151-
elif self.next_step_method.__class__.__name__ == "Metropolis":
1152+
elif isinstance(self.next_step_method, Metropolis):
11521153
self.base_scaling_stats["base_scaling"] = self.next_step_method.scaling
11531154
else:
11541155
# next method is MLDA

0 commit comments

Comments
 (0)