14
14
15
15
"""Functions for MCMC sampling."""
16
16
17
+ import contextlib
17
18
import logging
18
19
import pickle
19
20
import sys
20
21
import time
21
22
import warnings
22
23
23
- from collections .abc import Iterator , Mapping , Sequence
24
+ from collections .abc import Callable , Iterator , Mapping , Sequence
24
25
from typing import (
25
26
Any ,
26
27
Literal ,
37
38
from rich .console import Console
38
39
from rich .progress import Progress
39
40
from rich .theme import Theme
41
+ from threadpoolctl import threadpool_limits
40
42
from typing_extensions import Protocol
41
43
42
44
import pymc as pm
@@ -396,6 +398,7 @@ def sample(
396
398
nuts_sampler_kwargs : dict [str , Any ] | None = None ,
397
399
callback = None ,
398
400
mp_ctx = None ,
401
+ blas_cores : int | None | Literal ["auto" ] = "auto" ,
399
402
** kwargs ,
400
403
) -> InferenceData : ...
401
404
@@ -427,6 +430,7 @@ def sample(
427
430
callback = None ,
428
431
mp_ctx = None ,
429
432
model : Model | None = None ,
433
+ blas_cores : int | None | Literal ["auto" ] = "auto" ,
430
434
** kwargs ,
431
435
) -> MultiTrace : ...
432
436
@@ -456,6 +460,7 @@ def sample(
456
460
nuts_sampler_kwargs : dict [str , Any ] | None = None ,
457
461
callback = None ,
458
462
mp_ctx = None ,
463
+ blas_cores : int | None | Literal ["auto" ] = "auto" ,
459
464
model : Model | None = None ,
460
465
** kwargs ,
461
466
) -> InferenceData | MultiTrace :
@@ -499,6 +504,13 @@ def sample(
499
504
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
500
505
This requires the chosen sampler to be installed.
501
506
All samplers, except "pymc", require the full model to be continuous.
507
+ blas_cores: int or "auto" or None, default = "auto"
508
+ The total number of threads blas and openmp functions should use during sampling. If set to None,
509
+ this will keep the default behavior of whatever blas implementation is used at runtime.
510
+ Setting it to "auto" will set it so that the total number of active blas threads is the
511
+ same as the `cores` argument. If set to an integer, the sampler will try to use that total
512
+ number of blas threads. If `blas_cores` is not divisible by `cores`, it might get rounded
513
+ down.
502
514
initvals : optional, dict, array of dict
503
515
Dict or list of dicts with initial value strategies to use instead of the defaults from
504
516
`Model.initial_values`. The keys should be names of transformed random variables.
@@ -644,6 +656,37 @@ def sample(
644
656
if chains is None :
645
657
chains = max (2 , cores )
646
658
659
+ if blas_cores == "auto" :
660
+ blas_cores = cores
661
+
662
+ cores = min (cores , chains )
663
+
664
+ if cores < 1 :
665
+ raise ValueError ("`cores` must be larger or equal to one" )
666
+
667
+ if chains < 1 :
668
+ raise ValueError ("`chains` must be larger or equal to one" )
669
+
670
+ if blas_cores is not None and blas_cores < 1 :
671
+ raise ValueError ("`blas_cores` must be larger or equal to one" )
672
+
673
+ num_blas_cores_per_chain : int | None
674
+ joined_blas_limiter : Callable [[], Any ]
675
+
676
+ if blas_cores is None :
677
+ joined_blas_limiter = contextlib .nullcontext
678
+ num_blas_cores_per_chain = None
679
+ elif isinstance (blas_cores , int ):
680
+
681
+ def joined_blas_limiter ():
682
+ return threadpool_limits (limits = blas_cores )
683
+
684
+ num_blas_cores_per_chain = blas_cores // cores
685
+ else :
686
+ raise ValueError (
687
+ f"Invalid argument `blas_cores`, must be int, 'auto' or None: { blas_cores } "
688
+ )
689
+
647
690
if random_seed == - 1 :
648
691
random_seed = None
649
692
random_seed_list = _get_seeds_per_chain (random_seed , chains )
@@ -685,21 +728,22 @@ def sample(
685
728
raise ValueError (
686
729
"Model can not be sampled with NUTS alone. Your model is probably not continuous."
687
730
)
688
- return _sample_external_nuts (
689
- sampler = nuts_sampler ,
690
- draws = draws ,
691
- tune = tune ,
692
- chains = chains ,
693
- target_accept = kwargs .pop ("nuts" , {}).get ("target_accept" , 0.8 ),
694
- random_seed = random_seed ,
695
- initvals = initvals ,
696
- model = model ,
697
- var_names = var_names ,
698
- progressbar = progressbar ,
699
- idata_kwargs = idata_kwargs ,
700
- nuts_sampler_kwargs = nuts_sampler_kwargs ,
701
- ** kwargs ,
702
- )
731
+ with joined_blas_limiter ():
732
+ return _sample_external_nuts (
733
+ sampler = nuts_sampler ,
734
+ draws = draws ,
735
+ tune = tune ,
736
+ chains = chains ,
737
+ target_accept = kwargs .pop ("nuts" , {}).get ("target_accept" , 0.8 ),
738
+ random_seed = random_seed ,
739
+ initvals = initvals ,
740
+ model = model ,
741
+ var_names = var_names ,
742
+ progressbar = progressbar ,
743
+ idata_kwargs = idata_kwargs ,
744
+ nuts_sampler_kwargs = nuts_sampler_kwargs ,
745
+ ** kwargs ,
746
+ )
703
747
704
748
if isinstance (step , list ):
705
749
step = CompoundStep (step )
@@ -708,18 +752,19 @@ def sample(
708
752
nuts_kwargs = kwargs .pop ("nuts" )
709
753
[kwargs .setdefault (k , v ) for k , v in nuts_kwargs .items ()]
710
754
_log .info ("Auto-assigning NUTS sampler..." )
711
- initial_points , step = init_nuts (
712
- init = init ,
713
- chains = chains ,
714
- n_init = n_init ,
715
- model = model ,
716
- random_seed = random_seed_list ,
717
- progressbar = progressbar ,
718
- jitter_max_retries = jitter_max_retries ,
719
- tune = tune ,
720
- initvals = initvals ,
721
- ** kwargs ,
722
- )
755
+ with joined_blas_limiter ():
756
+ initial_points , step = init_nuts (
757
+ init = init ,
758
+ chains = chains ,
759
+ n_init = n_init ,
760
+ model = model ,
761
+ random_seed = random_seed_list ,
762
+ progressbar = progressbar ,
763
+ jitter_max_retries = jitter_max_retries ,
764
+ tune = tune ,
765
+ initvals = initvals ,
766
+ ** kwargs ,
767
+ )
723
768
724
769
if initial_points is None :
725
770
# Time to draw/evaluate numeric start points for each chain.
@@ -756,7 +801,8 @@ def sample(
756
801
)
757
802
758
803
sample_args = {
759
- "draws" : draws + tune , # FIXME: Why is tune added to draws?
804
+ # draws is now the total number of draws, including tuning
805
+ "draws" : draws + tune ,
760
806
"step" : step ,
761
807
"start" : initial_points ,
762
808
"traces" : traces ,
@@ -772,6 +818,7 @@ def sample(
772
818
}
773
819
parallel_args = {
774
820
"mp_ctx" : mp_ctx ,
821
+ "blas_cores" : num_blas_cores_per_chain ,
775
822
}
776
823
777
824
sample_args .update (kwargs )
@@ -817,11 +864,15 @@ def sample(
817
864
if has_population_samplers :
818
865
_log .info (f"Population sampling ({ chains } chains)" )
819
866
_print_step_hierarchy (step )
820
- _sample_population (initial_points = initial_points , parallelize = cores > 1 , ** sample_args )
867
+ with joined_blas_limiter ():
868
+ _sample_population (
869
+ initial_points = initial_points , parallelize = cores > 1 , ** sample_args
870
+ )
821
871
else :
822
872
_log .info (f"Sequential sampling ({ chains } chains in 1 job)" )
823
873
_print_step_hierarchy (step )
824
- _sample_many (** sample_args )
874
+ with joined_blas_limiter ():
875
+ _sample_many (** sample_args )
825
876
826
877
t_sampling = time .time () - t_start
827
878
@@ -1139,6 +1190,7 @@ def _mp_sample(
1139
1190
traces : Sequence [IBaseTrace ],
1140
1191
model : Model | None = None ,
1141
1192
callback : SamplingIteratorCallback | None = None ,
1193
+ blas_cores : int | None = None ,
1142
1194
mp_ctx = None ,
1143
1195
** kwargs ,
1144
1196
) -> None :
@@ -1190,6 +1242,7 @@ def _mp_sample(
1190
1242
step_method = step ,
1191
1243
progressbar = progressbar ,
1192
1244
progressbar_theme = progressbar_theme ,
1245
+ blas_cores = blas_cores ,
1193
1246
mp_ctx = mp_ctx ,
1194
1247
)
1195
1248
try :
0 commit comments