Skip to content

Commit a32c5e7

Browse files
jahallJoseph Hall
andauthored
More informative error message for unused step sampler arguments (#6738)
Co-authored-by: Joseph Hall <[email protected]>
1 parent 261862d commit a32c5e7

File tree

3 files changed

+43
-26
lines changed

3 files changed

+43
-26
lines changed

pymc/sampling/mcmc.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,18 @@
2020
import time
2121
import warnings
2222

23-
from collections import defaultdict
2423
from typing import (
2524
Any,
2625
Dict,
2726
Iterator,
2827
List,
2928
Literal,
29+
Mapping,
3030
Optional,
3131
Sequence,
32+
Set,
3233
Tuple,
34+
Type,
3335
Union,
3436
overload,
3537
)
@@ -39,6 +41,7 @@
3941

4042
from arviz import InferenceData
4143
from fastprogress.fastprogress import progress_bar
44+
from pytensor.graph.basic import Variable
4245
from typing_extensions import Protocol, TypeAlias
4346

4447
import pymc as pm
@@ -90,7 +93,10 @@ def __call__(self, trace: IBaseTrace, draw: Draw):
9093

9194

9295
def instantiate_steppers(
93-
model, steps: List[Step], selected_steps, step_kwargs=None
96+
model: Model,
97+
steps: List[Step],
98+
selected_steps: Mapping[Type[BlockedStep], List[Any]],
99+
step_kwargs: Optional[Dict[str, Dict]] = None,
94100
) -> Union[Step, List[Step]]:
95101
"""Instantiate steppers assigned to the model variables.
96102
@@ -122,22 +128,36 @@ def instantiate_steppers(
122128
used_keys = set()
123129
for step_class, vars in selected_steps.items():
124130
if vars:
125-
args = step_kwargs.get(step_class.name, {})
126-
used_keys.add(step_class.name)
131+
name = getattr(step_class, "name")
132+
args = step_kwargs.get(name, {})
133+
used_keys.add(name)
127134
step = step_class(vars=vars, model=model, **args)
128135
steps.append(step)
129136

130137
unused_args = set(step_kwargs).difference(used_keys)
131138
if unused_args:
132-
raise ValueError("Unused step method arguments: %s" % unused_args)
139+
s = "s" if len(unused_args) > 1 else ""
140+
example_arg = sorted(unused_args)[0]
141+
example_step = (list(selected_steps.keys()) or pm.STEP_METHODS)[0]
142+
example_step_name = getattr(example_step, "name")
143+
raise ValueError(
144+
f"Invalid key{s} found in step_kwargs: {unused_args}. "
145+
"Keys must be step names and values valid kwargs for that stepper. "
146+
f'Did you mean {{"{example_step_name}": {{"{example_arg}": ...}}}}?'
147+
)
133148

134149
if len(steps) == 1:
135150
return steps[0]
136151

137152
return steps
138153

139154

140-
def assign_step_methods(model, step=None, methods=None, step_kwargs=None):
155+
def assign_step_methods(
156+
model: Model,
157+
step: Optional[Union[Step, Sequence[Step]]] = None,
158+
methods: Optional[Sequence[Type[BlockedStep]]] = None,
159+
step_kwargs: Optional[Dict[str, Any]] = None,
160+
) -> Union[Step, List[Step]]:
141161
"""Assign model variables to appropriate step methods.
142162
143163
Passing a specified model will auto-assign its constituent stochastic
@@ -167,49 +187,48 @@ def assign_step_methods(model, step=None, methods=None, step_kwargs=None):
167187
methods : list
168188
List of step methods associated with the model's variables.
169189
"""
170-
steps = []
171-
assigned_vars = set()
172-
173-
if methods is None:
174-
methods = pm.STEP_METHODS
190+
steps: List[Step] = []
191+
assigned_vars: Set[Variable] = set()
175192

176193
if step is not None:
177-
try:
178-
steps += list(step)
179-
except TypeError:
194+
if isinstance(step, (BlockedStep, CompoundStep)):
180195
steps.append(step)
196+
else:
197+
steps.extend(step)
181198
for step in steps:
182199
for var in step.vars:
183200
if var not in model.value_vars:
184201
raise ValueError(
185-
f"{var} assigned to {step} sampler is not a value variable in the model. You can use `util.get_value_vars_from_user_vars` to parse user provided variables."
202+
f"{var} assigned to {step} sampler is not a value variable in the model. "
203+
"You can use `util.get_value_vars_from_user_vars` to parse user provided variables."
186204
)
187205
assigned_vars = assigned_vars.union(set(step.vars))
188206

189207
# Use competence classmethods to select step methods for remaining
190208
# variables
191-
selected_steps = defaultdict(list)
209+
methods_list: List[Type[BlockedStep]] = list(methods or pm.STEP_METHODS)
210+
selected_steps: Dict[Type[BlockedStep], List] = {}
192211
model_logp = model.logp()
193212

194213
for var in model.value_vars:
195214
if var not in assigned_vars:
196215
# determine if a gradient can be computed
197-
has_gradient = var.dtype not in discrete_types
216+
has_gradient = getattr(var, "dtype") not in discrete_types
198217
if has_gradient:
199218
try:
200-
tg.grad(model_logp, var)
219+
tg.grad(model_logp, var) # type: ignore
201220
except (NotImplementedError, tg.NullTypeGradError):
202221
has_gradient = False
203222

204223
# select the best method
205224
rv_var = model.values_to_rvs[var]
206225
selected = max(
207-
methods,
208-
key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence(
226+
methods_list,
227+
key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence( # type: ignore
209228
var, has_gradient
210229
),
211230
)
212-
selected_steps[selected].append(var)
231+
selected_steps.setdefault(selected, []).append(var)
213232

214233
return instantiate_steppers(model, steps, selected_steps, step_kwargs)
215234

pymc/step_methods/compound.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def reset_tuning(self):
246246
method.reset_tuning()
247247

248248
@property
249-
def vars(self):
249+
def vars(self) -> List[Variable]:
250250
return [var for method in self.methods for var in method.vars]
251251

252252

tests/sampling/test_mcmc.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,11 @@ def test_sample_init(self):
180180

181181
def test_sample_args(self):
182182
with self.model:
183-
with pytest.raises(ValueError) as excinfo:
183+
with pytest.raises(ValueError, match=r"'foo'"):
184184
pm.sample(50, tune=0, chains=1, step=pm.Metropolis(), foo=1)
185-
assert "'foo'" in str(excinfo.value)
186185

187-
with pytest.raises(ValueError) as excinfo:
186+
with pytest.raises(ValueError, match=r"'foo'") as excinfo:
188187
pm.sample(50, tune=0, chains=1, step=pm.Metropolis(), foo={})
189-
assert "foo" in str(excinfo.value)
190188

191189
def test_parallel_start(self):
192190
with self.model:

0 commit comments

Comments
 (0)