|
20 | 20 | import time
|
21 | 21 | import warnings
|
22 | 22 |
|
23 |
| -from collections import defaultdict |
24 | 23 | from typing import (
|
25 | 24 | Any,
|
26 | 25 | Dict,
|
27 | 26 | Iterator,
|
28 | 27 | List,
|
29 | 28 | Literal,
|
| 29 | + Mapping, |
30 | 30 | Optional,
|
31 | 31 | Sequence,
|
| 32 | + Set, |
32 | 33 | Tuple,
|
| 34 | + Type, |
33 | 35 | Union,
|
34 | 36 | overload,
|
35 | 37 | )
|
|
39 | 41 |
|
40 | 42 | from arviz import InferenceData
|
41 | 43 | from fastprogress.fastprogress import progress_bar
|
| 44 | +from pytensor.graph.basic import Variable |
42 | 45 | from typing_extensions import Protocol, TypeAlias
|
43 | 46 |
|
44 | 47 | import pymc as pm
|
@@ -90,7 +93,10 @@ def __call__(self, trace: IBaseTrace, draw: Draw):
|
90 | 93 |
|
91 | 94 |
|
92 | 95 | def instantiate_steppers(
|
93 |
| - model, steps: List[Step], selected_steps, step_kwargs=None |
| 96 | + model: Model, |
| 97 | + steps: List[Step], |
| 98 | + selected_steps: Mapping[Type[BlockedStep], List[Any]], |
| 99 | + step_kwargs: Optional[Dict[str, Dict]] = None, |
94 | 100 | ) -> Union[Step, List[Step]]:
|
95 | 101 | """Instantiate steppers assigned to the model variables.
|
96 | 102 |
|
@@ -122,22 +128,36 @@ def instantiate_steppers(
|
122 | 128 | used_keys = set()
|
123 | 129 | for step_class, vars in selected_steps.items():
|
124 | 130 | if vars:
|
125 |
| - args = step_kwargs.get(step_class.name, {}) |
126 |
| - used_keys.add(step_class.name) |
| 131 | + name = getattr(step_class, "name") |
| 132 | + args = step_kwargs.get(name, {}) |
| 133 | + used_keys.add(name) |
127 | 134 | step = step_class(vars=vars, model=model, **args)
|
128 | 135 | steps.append(step)
|
129 | 136 |
|
130 | 137 | unused_args = set(step_kwargs).difference(used_keys)
|
131 | 138 | if unused_args:
|
132 |
| - raise ValueError("Unused step method arguments: %s" % unused_args) |
| 139 | + s = "s" if len(unused_args) > 1 else "" |
| 140 | + example_arg = sorted(unused_args)[0] |
| 141 | + example_step = (list(selected_steps.keys()) or pm.STEP_METHODS)[0] |
| 142 | + example_step_name = getattr(example_step, "name") |
| 143 | + raise ValueError( |
| 144 | + f"Invalid key{s} found in step_kwargs: {unused_args}. " |
| 145 | + "Keys must be step names and values valid kwargs for that stepper. " |
| 146 | + f'Did you mean {{"{example_step_name}": {{"{example_arg}": ...}}}}?' |
| 147 | + ) |
133 | 148 |
|
134 | 149 | if len(steps) == 1:
|
135 | 150 | return steps[0]
|
136 | 151 |
|
137 | 152 | return steps
|
138 | 153 |
|
139 | 154 |
|
140 |
| -def assign_step_methods(model, step=None, methods=None, step_kwargs=None): |
| 155 | +def assign_step_methods( |
| 156 | + model: Model, |
| 157 | + step: Optional[Union[Step, Sequence[Step]]] = None, |
| 158 | + methods: Optional[Sequence[Type[BlockedStep]]] = None, |
| 159 | + step_kwargs: Optional[Dict[str, Any]] = None, |
| 160 | +) -> Union[Step, List[Step]]: |
141 | 161 | """Assign model variables to appropriate step methods.
|
142 | 162 |
|
143 | 163 | Passing a specified model will auto-assign its constituent stochastic
|
@@ -167,49 +187,48 @@ def assign_step_methods(model, step=None, methods=None, step_kwargs=None):
|
167 | 187 | methods : list
|
168 | 188 | List of step methods associated with the model's variables.
|
169 | 189 | """
|
170 |
| - steps = [] |
171 |
| - assigned_vars = set() |
172 |
| - |
173 |
| - if methods is None: |
174 |
| - methods = pm.STEP_METHODS |
| 190 | + steps: List[Step] = [] |
| 191 | + assigned_vars: Set[Variable] = set() |
175 | 192 |
|
176 | 193 | if step is not None:
|
177 |
| - try: |
178 |
| - steps += list(step) |
179 |
| - except TypeError: |
| 194 | + if isinstance(step, (BlockedStep, CompoundStep)): |
180 | 195 | steps.append(step)
|
| 196 | + else: |
| 197 | + steps.extend(step) |
181 | 198 | for step in steps:
|
182 | 199 | for var in step.vars:
|
183 | 200 | if var not in model.value_vars:
|
184 | 201 | raise ValueError(
|
185 |
| - f"{var} assigned to {step} sampler is not a value variable in the model. You can use `util.get_value_vars_from_user_vars` to parse user provided variables." |
| 202 | + f"{var} assigned to {step} sampler is not a value variable in the model. " |
| 203 | + "You can use `util.get_value_vars_from_user_vars` to parse user provided variables." |
186 | 204 | )
|
187 | 205 | assigned_vars = assigned_vars.union(set(step.vars))
|
188 | 206 |
|
189 | 207 | # Use competence classmethods to select step methods for remaining
|
190 | 208 | # variables
|
191 |
| - selected_steps = defaultdict(list) |
| 209 | + methods_list: List[Type[BlockedStep]] = list(methods or pm.STEP_METHODS) |
| 210 | + selected_steps: Dict[Type[BlockedStep], List] = {} |
192 | 211 | model_logp = model.logp()
|
193 | 212 |
|
194 | 213 | for var in model.value_vars:
|
195 | 214 | if var not in assigned_vars:
|
196 | 215 | # determine if a gradient can be computed
|
197 |
| - has_gradient = var.dtype not in discrete_types |
| 216 | + has_gradient = getattr(var, "dtype") not in discrete_types |
198 | 217 | if has_gradient:
|
199 | 218 | try:
|
200 |
| - tg.grad(model_logp, var) |
| 219 | + tg.grad(model_logp, var) # type: ignore |
201 | 220 | except (NotImplementedError, tg.NullTypeGradError):
|
202 | 221 | has_gradient = False
|
203 | 222 |
|
204 | 223 | # select the best method
|
205 | 224 | rv_var = model.values_to_rvs[var]
|
206 | 225 | selected = max(
|
207 |
| - methods, |
208 |
| - key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence( |
| 226 | + methods_list, |
| 227 | + key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence( # type: ignore |
209 | 228 | var, has_gradient
|
210 | 229 | ),
|
211 | 230 | )
|
212 |
| - selected_steps[selected].append(var) |
| 231 | + selected_steps.setdefault(selected, []).append(var) |
213 | 232 |
|
214 | 233 | return instantiate_steppers(model, steps, selected_steps, step_kwargs)
|
215 | 234 |
|
|
0 commit comments