23
23
24
24
from collections import defaultdict
25
25
from copy import copy , deepcopy
26
- from typing import Dict , Iterable , List , Optional , Sequence , Set , Union , cast
26
+ from typing import Dict , Iterable , List , Optional , Sequence , Set , Tuple , Union , cast
27
27
28
28
import aesara .gradient as tg
29
29
import cloudpickle
@@ -432,25 +432,11 @@ def sample(
432
432
"Cannot sample from the model, since the model does not contain any free variables."
433
433
)
434
434
435
- start = deepcopy (initvals )
436
- model_initial_point = model .initial_point
437
- if start is None :
438
- model .check_start_vals (model_initial_point )
439
- else :
440
- if isinstance (start , dict ):
441
- model .update_start_vals (start , model .initial_point )
442
- else :
443
- for chain_start_vals in start :
444
- model .update_start_vals (chain_start_vals , model .initial_point )
445
- model .check_start_vals (start )
446
-
447
435
if cores is None :
448
436
cores = min (4 , _cpu_count ())
449
437
450
438
if chains is None :
451
439
chains = max (2 , cores )
452
- if isinstance (start , dict ):
453
- start = [start ] * chains
454
440
if random_seed == - 1 :
455
441
random_seed = None
456
442
if chains == 1 and isinstance (random_seed , int ):
@@ -476,10 +462,6 @@ def sample(
476
462
stacklevel = 2 ,
477
463
)
478
464
479
- if start is not None :
480
- for start_vals in start :
481
- _check_start_shape (model , start_vals )
482
-
483
465
# small trace warning
484
466
if draws == 0 :
485
467
msg = "Tuning was enabled throughout the whole trace."
@@ -490,11 +472,12 @@ def sample(
490
472
491
473
draws += tune
492
474
475
+ initial_points = None
493
476
if step is None and init is not None and all_continuous (model .value_vars , model ):
494
477
try :
495
478
# By default, try to use NUTS
496
479
_log .info ("Auto-assigning NUTS sampler..." )
497
- start_ , step = init_nuts (
480
+ initial_points , step = init_nuts (
498
481
init = init ,
499
482
chains = chains ,
500
483
n_init = n_init ,
@@ -503,31 +486,40 @@ def sample(
503
486
progressbar = progressbar ,
504
487
jitter_max_retries = jitter_max_retries ,
505
488
tune = tune ,
489
+ initvals = initvals ,
506
490
** kwargs ,
507
491
)
508
- if start is None :
509
- start = start_
510
- model .check_start_vals (start )
511
492
except (AttributeError , NotImplementedError , tg .NullTypeGradError ):
512
493
# gradient computation failed
513
- _log .info ("Initializing NUTS failed. " " Falling back to elementwise auto-assignment." )
494
+ _log .info ("Initializing NUTS failed. Falling back to elementwise auto-assignment." )
514
495
_log .debug ("Exception in init nuts" , exec_info = True )
515
496
step = assign_step_methods (model , step , step_kwargs = kwargs )
516
- start = model_initial_point
517
497
else :
518
- start = model_initial_point
519
498
step = assign_step_methods (model , step , step_kwargs = kwargs )
520
499
521
500
if isinstance (step , list ):
522
501
step = CompoundStep (step )
523
502
524
- if isinstance (start , dict ):
525
- start = [start ] * chains
503
+ if initial_points is None :
504
+ initvals = initvals or {}
505
+ if isinstance (initvals , dict ):
506
+ initvals = [initvals ] * chains
507
+ initial_points = []
508
+ mip = model .initial_point
509
+ for ivals in initvals :
510
+ ivals = deepcopy (ivals )
511
+ model .update_start_vals (ivals , mip )
512
+ initial_points .append (ivals )
513
+
514
+ # One final check that shapes and logps at the starting points are okay.
515
+ for ip in initial_points :
516
+ model .check_start_vals (ip )
517
+ _check_start_shape (model , ip )
526
518
527
519
sample_args = {
528
520
"draws" : draws ,
529
521
"step" : step ,
530
- "start" : start ,
522
+ "start" : initial_points ,
531
523
"trace" : trace ,
532
524
"chain" : chain_idx ,
533
525
"chains" : chains ,
@@ -579,7 +571,7 @@ def sample(
579
571
)
580
572
_log .info (f"Population sampling ({ chains } chains)" )
581
573
582
- initial_point_model_size = sum (start [0 ][n .name ].size for n in model .value_vars )
574
+ initial_point_model_size = sum (initial_points [0 ][n .name ].size for n in model .value_vars )
583
575
584
576
if has_demcmc and chains < 3 :
585
577
raise ValueError (
@@ -664,31 +656,41 @@ def sample(
664
656
return trace
665
657
666
658
667
- def _check_start_shape (model , start ):
668
- if not isinstance (start , dict ):
669
- raise TypeError ("start argument must be a dict or an array-like of dicts" )
670
-
671
- # Filter "non-input" variables
672
- initial_point = model .initial_point
673
- start = {k : v for k , v in start .items () if k in initial_point }
659
+ def _check_start_shape (model , start : PointType ):
660
+ """Checks that the prior evaluations and initial points have identical shapes.
674
661
662
+ Parameters
663
+ ----------
664
+ model : pm.Model
665
+ The current model on context.
666
+ start : dict
667
+ The complete dictionary mapping (transformed) variable names to numeric initial values.
668
+ """
675
669
e = ""
676
670
for var in model .basic_RVs :
677
- var_shape = model .fastfn (var .shape )(start )
678
- if var .name in start .keys ():
679
- start_var_shape = np .shape (start [var .name ])
680
- if start_var_shape :
681
- if not np .array_equal (var_shape , start_var_shape ):
682
- e += "\n Expected shape {} for var '{}', got: {}" .format (
683
- tuple (var_shape ), var .name , start_var_shape
684
- )
685
- # if start var has no shape
671
+ try :
672
+ var_shape = model .fastfn (var .shape )(start )
673
+ if var .name in start .keys ():
674
+ start_var_shape = np .shape (start [var .name ])
675
+ if start_var_shape :
676
+ if not np .array_equal (var_shape , start_var_shape ):
677
+ e += "\n Expected shape {} for var '{}', got: {}" .format (
678
+ tuple (var_shape ), var .name , start_var_shape
679
+ )
680
+ # if start var has no shape
681
+ else :
682
+ # if model var has a specified shape
683
+ if var_shape .size > 0 :
684
+ e += "\n Expected shape {} for var " "'{}', got scalar {}" .format (
685
+ tuple (var_shape ), var .name , start [var .name ]
686
+ )
687
+ except NotImplementedError as ex :
688
+ if ex .args [0 ].startswith ("Cannot sample" ):
689
+ _log .warning (
690
+ f"Unable to check start shape of { var } because the RV does not implement random sampling."
691
+ )
686
692
else :
687
- # if model var has a specified shape
688
- if var_shape .size > 0 :
689
- e += "\n Expected shape {} for var " "'{}', got scalar {}" .format (
690
- tuple (var_shape ), var .name , start [var .name ]
691
- )
693
+ raise
692
694
693
695
if e != "" :
694
696
raise ValueError (f"Bad shape for start argument:{ e } " )
@@ -943,7 +945,7 @@ def iter_sample(
943
945
def _iter_sample (
944
946
draws ,
945
947
step ,
946
- start : Optional [ PointType ] ,
948
+ start : PointType ,
947
949
trace : Optional [Union [BaseTrace , List [str ]]] = None ,
948
950
chain = 0 ,
949
951
tune = None ,
@@ -961,6 +963,7 @@ def _iter_sample(
961
963
Step function
962
964
start : dict
963
965
Starting point in parameter space (or partial point).
966
+ Must contain numeric (transformed) initial values for all (transformed) free variables.
964
967
trace : backend or list
965
968
This should be a backend instance, or a list of variables to track.
966
969
If None or a list of variables, the NDArray backend is used.
@@ -993,10 +996,7 @@ def _iter_sample(
993
996
except TypeError :
994
997
pass
995
998
996
- if start is None :
997
- start = {}
998
- model .update_start_vals (start , model .initial_point )
999
- point = Point (start , model = model , filter_model_vars = True )
999
+ point = start
1000
1000
1001
1001
if step .generates_stats and strace .supports_sampler_stats :
1002
1002
strace .setup (draws , chain , step .stats_dtypes )
@@ -1257,9 +1257,6 @@ def _prepare_iter_population(
1257
1257
1258
1258
# 1. prepare a BaseTrace for each chain
1259
1259
traces = [_choose_backend (None , model = model ) for chain in chains ]
1260
- for c , strace in enumerate (traces ):
1261
- # initialize the trace size and variable transforms
1262
- model .update_start_vals (start [c ], model .initial_point )
1263
1260
1264
1261
# 2. create a population (points) that tracks each chain
1265
1262
# it is updated as the chains are advanced
@@ -1422,6 +1419,7 @@ def _mp_sample(
1422
1419
Random seeds for each chain.
1423
1420
start : list
1424
1421
Starting points for each chain.
1422
+ Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
1425
1423
progressbar : bool
1426
1424
Whether or not to display a progress bar in the command line.
1427
1425
trace : BaseTrace, list, or None
@@ -1452,10 +1450,6 @@ def _mp_sample(
1452
1450
else :
1453
1451
strace = _choose_backend (None , model = model )
1454
1452
1455
- # for user supplied start value, fill-in missing value if the supplied
1456
- # dict does not contain all parameters
1457
- model .update_start_vals (start [idx - chain ], model .initial_point )
1458
-
1459
1453
if step .generates_stats and strace .supports_sampler_stats :
1460
1454
strace .setup (draws + tune , idx , step .stats_dtypes )
1461
1455
else :
@@ -2053,8 +2047,10 @@ def init_nuts(
2053
2047
progressbar = True ,
2054
2048
jitter_max_retries = 10 ,
2055
2049
tune = None ,
2050
+ * ,
2051
+ initvals : Optional [Union [PointType , Sequence [Optional [PointType ]]]] = None ,
2056
2052
** kwargs ,
2057
- ):
2053
+ ) -> Tuple [ Sequence [ PointType ], NUTS ] :
2058
2054
"""Set up the mass matrix initialization for NUTS.
2059
2055
2060
2056
NUTS convergence and sampling speed is extremely dependent on the
@@ -2089,6 +2085,9 @@ def init_nuts(
2089
2085
2090
2086
chains : int
2091
2087
Number of jobs to start.
2088
+ initvals : optional, dict or list of dicts
2089
+ Dict or list of dicts with initial values to use instead of the defaults from `Model.initial_values`.
2090
+ The keys should be names of transformed random variables.
2092
2091
n_init : int
2093
2092
Number of iterations of initializer. Only works for 'ADVI' init methods.
2094
2093
model : Model (optional if in ``with`` context)
@@ -2103,8 +2102,8 @@ def init_nuts(
2103
2102
2104
2103
Returns
2105
2104
-------
2106
- start : ``pymc.model.Point``
2107
- Starting point for sampler
2105
+ initial_points : list
2106
+ Starting points for each chain.
2108
2107
nuts_sampler : ``pymc.step_methods.NUTS``
2109
2108
Instantiated and initialized NUTS sampler object
2110
2109
"""
@@ -2135,6 +2134,8 @@ def init_nuts(
2135
2134
pm .callbacks .CheckParametersConvergence (tolerance = 1e-2 , diff = "relative" ),
2136
2135
]
2137
2136
2137
+ # TODO: Consider `initvals` for selecting the starting point.
2138
+
2138
2139
apoint = DictToArrayBijection .map (model .initial_point )
2139
2140
2140
2141
if init == "adapt_diag" :
@@ -2238,4 +2239,25 @@ def init_nuts(
2238
2239
2239
2240
step = pm .NUTS (potential = potential , model = model , ** kwargs )
2240
2241
2241
- return start , step
2242
+ # The "start" dict determined from initialization methods does not always respect the support of variables.
2243
+ # The next block combines it with the user-provided initvals such that initvals take priority.
2244
+ if initvals is None or isinstance (initvals , dict ):
2245
+ initvals = [initvals or {}] * chains
2246
+ if isinstance (start , dict ):
2247
+ start = [start ] * chains
2248
+ mip = model .initial_point
2249
+ initial_points = []
2250
+ for st , iv in zip (start , initvals ):
2251
+ from_init = deepcopy (st )
2252
+ model .update_start_vals (from_init , mip )
2253
+
2254
+ from_user = deepcopy (iv )
2255
+ model .update_start_vals (from_user , mip )
2256
+
2257
+ initial_points .append (
2258
+ {
2259
+ ** from_init ,
2260
+ ** from_user , # prioritize user-provided
2261
+ }
2262
+ )
2263
+ return initial_points , step
0 commit comments