Skip to content

Commit 6156cbb

Browse files
Rename start kwarg to initvals
1 parent e45e36b commit 6156cbb

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

pymc/sampling.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def sample(
253253
step=None,
254254
init="auto",
255255
n_init=200_000,
256-
start: Optional[Union[PointType, Sequence[Optional[PointType]]]] = None,
256+
initvals: Optional[Union[PointType, Sequence[Optional[PointType]]]] = None,
257257
trace: Optional[Union[BaseTrace, List[str]]] = None,
258258
chain_idx=0,
259259
chains=None,
@@ -291,11 +291,10 @@ def sample(
291291
users.
292292
n_init : int
293293
Number of iterations of initializer. Only works for 'ADVI' init methods.
294-
start : dict, or array of dict
295-
Starting point in parameter space (or partial point)
296-
Defaults to ``trace.point(-1))`` if there is a trace provided and model.initial_point if not
297-
(defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can
298-
overwrite the default.
294+
initvals : optional, dict, array of dict
295+
Dict or list of dicts with initial values to used instead of the defaults from `Model.initial_values`.
296+
The keys should be names of transformed random variables.
297+
Initialization methods for NUTS (see ``init`` keyword) can overwrite the default.
299298
trace : backend or list
300299
This should be a backend instance, or a list of variables to track.
301300
If None or a list of variables, the NDArray backend is used.
@@ -417,13 +416,23 @@ def sample(
417416
mean sd hdi_3% hdi_97%
418417
p 0.609 0.047 0.528 0.699
419418
"""
419+
if "start" in kwargs:
420+
if initvals is not None:
421+
raise ValueError("Passing both `start` and `initvals` is not supported.")
422+
warnings.warn(
423+
"The `start` kwarg was renamed to `initvals`. Please check the docstring.",
424+
DeprecationWarning,
425+
stacklevel=2,
426+
)
427+
initvals = kwargs.pop("start")
428+
420429
model = modelcontext(model)
421430
if not model.free_RVs:
422431
raise SamplingError(
423432
"Cannot sample from the model, since the model does not contain any free variables."
424433
)
425434

426-
start = deepcopy(start)
435+
start = deepcopy(initvals)
427436
model_initial_point = model.initial_point
428437
if start is None:
429438
model.check_start_vals(model_initial_point)

0 commit comments

Comments
 (0)