Skip to content

Commit 4049c28

Browse files
committed
feat: Add blas_cores argument to pm.sample
1 parent b59a9eb commit 4049c28

File tree

5 files changed

+119
-47
lines changed

5 files changed

+119
-47
lines changed

pymc/sampling/mcmc.py

Lines changed: 84 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414

1515
"""Functions for MCMC sampling."""
1616

17+
import contextlib
1718
import logging
1819
import pickle
1920
import sys
2021
import time
2122
import warnings
2223

23-
from collections.abc import Iterator, Mapping, Sequence
24+
from collections.abc import Callable, Iterator, Mapping, Sequence
2425
from typing import (
2526
Any,
2627
Literal,
@@ -37,6 +38,7 @@
3738
from rich.console import Console
3839
from rich.progress import Progress
3940
from rich.theme import Theme
41+
from threadpoolctl import threadpool_limits
4042
from typing_extensions import Protocol
4143

4244
import pymc as pm
@@ -396,6 +398,7 @@ def sample(
396398
nuts_sampler_kwargs: dict[str, Any] | None = None,
397399
callback=None,
398400
mp_ctx=None,
401+
blas_cores: int | None | Literal["auto"] = "auto",
399402
**kwargs,
400403
) -> InferenceData: ...
401404

@@ -427,6 +430,7 @@ def sample(
427430
callback=None,
428431
mp_ctx=None,
429432
model: Model | None = None,
433+
blas_cores: int | None | Literal["auto"] = "auto",
430434
**kwargs,
431435
) -> MultiTrace: ...
432436

@@ -456,6 +460,7 @@ def sample(
456460
nuts_sampler_kwargs: dict[str, Any] | None = None,
457461
callback=None,
458462
mp_ctx=None,
463+
blas_cores: int | None | Literal["auto"] = "auto",
459464
model: Model | None = None,
460465
**kwargs,
461466
) -> InferenceData | MultiTrace:
@@ -499,6 +504,13 @@ def sample(
499504
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
500505
This requires the chosen sampler to be installed.
501506
All samplers, except "pymc", require the full model to be continuous.
507+
blas_cores: int or "auto" or None, default = "auto"
508+
The total number of threads blas and openmp functions should use during sampling. If set to None,
509+
this will keep the default behavior of whatever blas implementation is used at runtime.
510+
Setting it to "auto" will set it so that the total number of active blas threads is the
511+
same as the `cores` argument. If set to an integer, the sampler will try to use that total
512+
number of blas threads. If `blas_cores` is not divisible by `cores`, it might get rounded
513+
down.
502514
initvals : optional, dict, array of dict
503515
Dict or list of dicts with initial value strategies to use instead of the defaults from
504516
`Model.initial_values`. The keys should be names of transformed random variables.
@@ -644,6 +656,37 @@ def sample(
644656
if chains is None:
645657
chains = max(2, cores)
646658

659+
if blas_cores == "auto":
660+
blas_cores = cores
661+
662+
cores = min(cores, chains)
663+
664+
if cores < 1:
665+
raise ValueError("`cores` must be larger or equal to one")
666+
667+
if chains < 1:
668+
raise ValueError("`chains` must be larger or equal to one")
669+
670+
if blas_cores is not None and blas_cores < 1:
671+
raise ValueError("`blas_cores` must be larger or equal to one")
672+
673+
num_blas_cores_per_chain: int | None
674+
joined_blas_limiter: Callable[[], Any]
675+
676+
if blas_cores is None:
677+
joined_blas_limiter = contextlib.nullcontext
678+
num_blas_cores_per_chain = None
679+
elif isinstance(blas_cores, int):
680+
681+
def joined_blas_limiter():
682+
return threadpool_limits(limits=blas_cores)
683+
684+
num_blas_cores_per_chain = blas_cores // cores
685+
else:
686+
raise ValueError(
687+
f"Invalid argument `blas_cores`, must be int, 'auto' or None: {blas_cores}"
688+
)
689+
647690
if random_seed == -1:
648691
random_seed = None
649692
random_seed_list = _get_seeds_per_chain(random_seed, chains)
@@ -685,21 +728,22 @@ def sample(
685728
raise ValueError(
686729
"Model can not be sampled with NUTS alone. Your model is probably not continuous."
687730
)
688-
return _sample_external_nuts(
689-
sampler=nuts_sampler,
690-
draws=draws,
691-
tune=tune,
692-
chains=chains,
693-
target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
694-
random_seed=random_seed,
695-
initvals=initvals,
696-
model=model,
697-
var_names=var_names,
698-
progressbar=progressbar,
699-
idata_kwargs=idata_kwargs,
700-
nuts_sampler_kwargs=nuts_sampler_kwargs,
701-
**kwargs,
702-
)
731+
with joined_blas_limiter():
732+
return _sample_external_nuts(
733+
sampler=nuts_sampler,
734+
draws=draws,
735+
tune=tune,
736+
chains=chains,
737+
target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
738+
random_seed=random_seed,
739+
initvals=initvals,
740+
model=model,
741+
var_names=var_names,
742+
progressbar=progressbar,
743+
idata_kwargs=idata_kwargs,
744+
nuts_sampler_kwargs=nuts_sampler_kwargs,
745+
**kwargs,
746+
)
703747

704748
if isinstance(step, list):
705749
step = CompoundStep(step)
@@ -708,18 +752,19 @@ def sample(
708752
nuts_kwargs = kwargs.pop("nuts")
709753
[kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
710754
_log.info("Auto-assigning NUTS sampler...")
711-
initial_points, step = init_nuts(
712-
init=init,
713-
chains=chains,
714-
n_init=n_init,
715-
model=model,
716-
random_seed=random_seed_list,
717-
progressbar=progressbar,
718-
jitter_max_retries=jitter_max_retries,
719-
tune=tune,
720-
initvals=initvals,
721-
**kwargs,
722-
)
755+
with joined_blas_limiter():
756+
initial_points, step = init_nuts(
757+
init=init,
758+
chains=chains,
759+
n_init=n_init,
760+
model=model,
761+
random_seed=random_seed_list,
762+
progressbar=progressbar,
763+
jitter_max_retries=jitter_max_retries,
764+
tune=tune,
765+
initvals=initvals,
766+
**kwargs,
767+
)
723768

724769
if initial_points is None:
725770
# Time to draw/evaluate numeric start points for each chain.
@@ -756,7 +801,8 @@ def sample(
756801
)
757802

758803
sample_args = {
759-
"draws": draws + tune, # FIXME: Why is tune added to draws?
804+
# draws is now the total number of draws, including tuning
805+
"draws": draws + tune,
760806
"step": step,
761807
"start": initial_points,
762808
"traces": traces,
@@ -772,6 +818,7 @@ def sample(
772818
}
773819
parallel_args = {
774820
"mp_ctx": mp_ctx,
821+
"blas_cores": num_blas_cores_per_chain,
775822
}
776823

777824
sample_args.update(kwargs)
@@ -817,11 +864,15 @@ def sample(
817864
if has_population_samplers:
818865
_log.info(f"Population sampling ({chains} chains)")
819866
_print_step_hierarchy(step)
820-
_sample_population(initial_points=initial_points, parallelize=cores > 1, **sample_args)
867+
with joined_blas_limiter():
868+
_sample_population(
869+
initial_points=initial_points, parallelize=cores > 1, **sample_args
870+
)
821871
else:
822872
_log.info(f"Sequential sampling ({chains} chains in 1 job)")
823873
_print_step_hierarchy(step)
824-
_sample_many(**sample_args)
874+
with joined_blas_limiter():
875+
_sample_many(**sample_args)
825876

826877
t_sampling = time.time() - t_start
827878

@@ -1139,6 +1190,7 @@ def _mp_sample(
11391190
traces: Sequence[IBaseTrace],
11401191
model: Model | None = None,
11411192
callback: SamplingIteratorCallback | None = None,
1193+
blas_cores: int | None = None,
11421194
mp_ctx=None,
11431195
**kwargs,
11441196
) -> None:
@@ -1190,6 +1242,7 @@ def _mp_sample(
11901242
step_method=step,
11911243
progressbar=progressbar,
11921244
progressbar_theme=progressbar_theme,
1245+
blas_cores=blas_cores,
11931246
mp_ctx=mp_ctx,
11941247
)
11951248
try:

pymc/sampling/parallel.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from rich.console import Console
3030
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn
3131
from rich.theme import Theme
32+
from threadpoolctl import threadpool_limits
3233

3334
from pymc.blocking import DictToArrayBijection
3435
from pymc.exceptions import SamplingError
@@ -93,6 +94,7 @@ def __init__(
9394
draws: int,
9495
tune: int,
9596
seed,
97+
blas_cores,
9698
):
9799
self._msg_pipe = msg_pipe
98100
self._step_method = step_method
@@ -102,6 +104,7 @@ def __init__(
102104
self._at_seed = seed + 1
103105
self._draws = draws
104106
self._tune = tune
107+
self._blas_cores = blas_cores
105108

106109
def _unpickle_step_method(self):
107110
unpickle_error = (
@@ -116,22 +119,23 @@ def _unpickle_step_method(self):
116119
raise ValueError(unpickle_error)
117120

118121
def run(self):
119-
try:
120-
# We do not create this in __init__, as pickling this
121-
# would destroy the shared memory.
122-
self._unpickle_step_method()
123-
self._point = self._make_numpy_refs()
124-
self._start_loop()
125-
except KeyboardInterrupt:
126-
pass
127-
except BaseException as e:
128-
e = ExceptionWithTraceback(e, e.__traceback__)
129-
# Send is not blocking so we have to force a wait for the abort
130-
# message
131-
self._msg_pipe.send(("error", e))
132-
self._wait_for_abortion()
133-
finally:
134-
self._msg_pipe.close()
122+
with threadpool_limits(limits=self._blas_cores):
123+
try:
124+
# We do not create this in __init__, as pickling this
125+
# would destroy the shared memory.
126+
self._unpickle_step_method()
127+
self._point = self._make_numpy_refs()
128+
self._start_loop()
129+
except KeyboardInterrupt:
130+
pass
131+
except BaseException as e:
132+
e = ExceptionWithTraceback(e, e.__traceback__)
133+
# Send is not blocking so we have to force a wait for the abort
134+
# message
135+
self._msg_pipe.send(("error", e))
136+
self._wait_for_abortion()
137+
finally:
138+
self._msg_pipe.close()
135139

136140
def _wait_for_abortion(self):
137141
while True:
@@ -208,6 +212,7 @@ def __init__(
208212
chain: int,
209213
seed,
210214
start: dict[str, np.ndarray],
215+
blas_cores,
211216
mp_ctx,
212217
):
213218
self.chain = chain
@@ -256,6 +261,7 @@ def __init__(
256261
draws,
257262
tune,
258263
seed,
264+
blas_cores,
259265
),
260266
)
261267
self._process.start()
@@ -378,6 +384,7 @@ def __init__(
378384
step_method,
379385
progressbar: bool = True,
380386
progressbar_theme: Theme | None = default_progress_theme,
387+
blas_cores: int | None = None,
381388
mp_ctx=None,
382389
):
383390
if any(len(arg) != chains for arg in [seeds, start_points]):
@@ -411,6 +418,7 @@ def __init__(
411418
chain,
412419
seed,
413420
start,
421+
blas_cores,
414422
mp_ctx,
415423
)
416424
for chain, seed, start in zip(range(chains), seeds, start_points)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ pandas>=0.24.0
66
pytensor>=2.20,<2.21
77
rich>=13.7.1
88
scipy>=1.4.1
9+
threadpoolctl>=3.1.0,<4.0.0
910
typing-extensions>=3.7.4

tests/sampling/test_mcmc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,14 @@ def test_empty_model():
507507
error.match("any free variables")
508508

509509

510+
def test_blas_cores():
511+
with pm.Model():
512+
pm.Normal("a")
513+
pm.sample(blas_cores="auto", tune=10, cores=2, draws=10)
514+
pm.sample(blas_cores=None, tune=10, cores=2, draws=10)
515+
pm.sample(blas_cores=2, tune=10, cores=2, draws=10)
516+
517+
510518
def test_partial_trace_with_trace_unsupported():
511519
with pm.Model() as model:
512520
a = pm.Normal("a", mu=0, sigma=1)

tests/sampling/test_parallel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def test_explicit_sample(mp_start_method):
161161
mp_ctx=ctx,
162162
start={"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))},
163163
step_method_pickled=step_method_pickled,
164+
blas_cores=None,
164165
)
165166
proc.start()
166167
while True:
@@ -193,6 +194,7 @@ def test_iterator():
193194
start_points=[start] * 3,
194195
step_method=step,
195196
progressbar=False,
197+
blas_cores=None,
196198
)
197199
with sampler:
198200
for draw in sampler:

0 commit comments

Comments
 (0)