Skip to content

Commit 1b32070

Browse files
authored
Create helper pm.draw() to take draws for a given variable (#5340)
1 parent f6b930b commit 1b32070

File tree

4 files changed

+127
-6
lines changed

4 files changed

+127
-6
lines changed

RELEASE-NOTES.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ All of the above apply to:
3131

3232
- ⚠ The library is now named, installed and imported as "pymc". For example: `pip install pymc`.
3333
- ⚠ Theano-PyMC has been replaced with Aesara, so all external references to `theano`, `tt`, and `pymc3.theanof` need to be replaced with `aesara`, `at`, and `pymc.aesaraf` (see [4471](https://github.com/pymc-devs/pymc/pull/4471)).
34-
- `pm.Distribution(...).logp(x)` is now `pm.logp(pm.Distribution(...), x)`
35-
- `pm.Distribution(...).logcdf(x)` is now `pm.logcdf(pm.Distribution(...), x)`
36-
- `pm.Distribution(...).random()` is now `pm.Distribution(...).eval()`
37-
- `pm.draw_values(...)` and `pm.generate_samples(...)` were removed. The tensors can now be evaluated with `.eval()`.
34+
- `pm.Distribution(...).logp(x)` is now `pm.logp(pm.Distribution(...), x)`.
35+
- `pm.Distribution(...).logcdf(x)` is now `pm.logcdf(pm.Distribution(...), x)`.
36+
- `pm.Distribution(...).random(size=x)` is now `pm.draw(pm.Distribution(...), draws=x)`.
37+
- `pm.draw_values(...)` and `pm.generate_samples(...)` were removed.
3838
- `pm.fast_sample_posterior_predictive` was removed.
3939
- `pm.sample_prior_predictive`, `pm.sample_posterior_predictive` and `pm.sample_posterior_predictive_w` now return an `InferenceData` object by default, instead of a dictionary (see [#5073](https://github.com/pymc-devs/pymc/pull/5073)).
4040
- `pm.sample_prior_predictive` no longer returns transformed variable values by default. Pass them by name in `var_names` if you want to obtain these draws (see [4769](https://github.com/pymc-devs/pymc/pull/4769)).

pymc/aesaraf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -967,7 +967,7 @@ def compile_pymc(inputs, outputs, mode=None, **kwargs):
967967

968968
# Set the default update of a NoDistribution RNG so that it is automatically
969969
# updated after every function call
970-
output_to_list = outputs if isinstance(outputs, list) else [outputs]
970+
output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
971971
for rv in (
972972
node
973973
for node in walk_model(output_to_list, walk_past_rvs=True)

pymc/sampling.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import xarray
4343

4444
from aesara.compile.mode import Mode
45-
from aesara.graph.basic import Constant
45+
from aesara.graph.basic import Constant, Variable
4646
from aesara.tensor.sharedvar import SharedVariable
4747
from arviz import InferenceData
4848
from fastprogress.fastprogress import progress_bar
@@ -96,6 +96,7 @@
9696
"sample_posterior_predictive_w",
9797
"init_nuts",
9898
"sample_prior_predictive",
99+
"draw",
99100
]
100101

101102
STEP_METHODS = (
@@ -2093,6 +2094,70 @@ def sample_prior_predictive(
20932094
return pm.to_inference_data(prior=prior, **ikwargs)
20942095

20952096

2097+
def draw(
2098+
vars: Union[Variable, Sequence[Variable]],
2099+
draws: int = 1,
2100+
mode: Optional[Union[str, Mode]] = None,
2101+
**kwargs,
2102+
) -> Union[np.ndarray, List[np.ndarray]]:
2103+
"""Draw samples for one variable or a list of variables
2104+
2105+
Parameters
2106+
----------
2107+
vars
2108+
A variable or a list of variables for which to draw samples.
2109+
draws : int
2110+
Number of samples needed to draw. Detaults to 500.
2111+
mode
2112+
The mode used by ``aesara.function`` to compile the graph.
2113+
**kwargs
2114+
Keyword arguments for :func:`pymc.aesara.compile_pymc`
2115+
2116+
Returns
2117+
-------
2118+
List[np.ndarray]
2119+
A list of numpy arrays.
2120+
2121+
Examples
2122+
--------
2123+
.. code-block:: python
2124+
2125+
import pymc as pm
2126+
2127+
# Draw samples for one variable
2128+
with pm.Model():
2129+
x = pm.Normal("x")
2130+
x_draws = pm.draw(x, draws=100)
2131+
print(x_draws.shape)
2132+
2133+
# Draw 1000 samples for several variables
2134+
with pm.Model():
2135+
x = pm.Normal("x")
2136+
y = pm.Normal("y", shape=10)
2137+
z = pm.Uniform("z", shape=5)
2138+
num_draws = 1000
2139+
# Draw samples of a list variables
2140+
draws = pm.draw([x, y, z], draws=num_draws)
2141+
assert draws[0].shape == (num_draws,)
2142+
assert draws[1].shape == (num_draws, 10)
2143+
assert draws[2].shape == (num_draws, 5)
2144+
"""
2145+
2146+
draw_fn = compile_pymc(inputs=[], outputs=vars, mode=mode, **kwargs)
2147+
2148+
if draws == 1:
2149+
return draw_fn()
2150+
2151+
# Single variable output
2152+
if not isinstance(vars, (list, tuple)):
2153+
drawn_values = (draw_fn() for _ in range(draws))
2154+
return np.stack(drawn_values)
2155+
2156+
# Multiple variable output
2157+
drawn_values = zip(*(draw_fn() for _ in range(draws)))
2158+
return [np.stack(v) for v in drawn_values]
2159+
2160+
20962161
def _init_jitter(
20972162
model: Model,
20982163
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]],

pymc/tests/test_sampling.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,3 +1212,59 @@ def test_sample_deterministic():
12121212
idata = pm.sample(chains=1, draws=50, compute_convergence_checks=False)
12131213

12141214
np.testing.assert_allclose(idata.posterior["y"], idata.posterior["x"] + 100)
1215+
1216+
1217+
class TestDraw(SeededTest):
1218+
def test_univariate(self):
1219+
with pm.Model():
1220+
x = pm.Normal("x")
1221+
1222+
x_draws = pm.draw(x)
1223+
assert x_draws.shape == ()
1224+
1225+
(x_draws,) = pm.draw([x])
1226+
assert x_draws.shape == ()
1227+
1228+
x_draws = pm.draw(x, draws=10)
1229+
assert x_draws.shape == (10,)
1230+
1231+
(x_draws,) = pm.draw([x], draws=10)
1232+
assert x_draws.shape == (10,)
1233+
1234+
def test_multivariate(self):
1235+
with pm.Model():
1236+
mln = pm.Multinomial("mln", n=5, p=np.array([0.25, 0.25, 0.25, 0.25]))
1237+
1238+
mln_draws = pm.draw(mln, draws=1)
1239+
assert mln_draws.shape == (4,)
1240+
1241+
(mln_draws,) = pm.draw([mln], draws=1)
1242+
assert mln_draws.shape == (4,)
1243+
1244+
mln_draws = pm.draw(mln, draws=10)
1245+
assert mln_draws.shape == (10, 4)
1246+
1247+
(mln_draws,) = pm.draw([mln], draws=10)
1248+
assert mln_draws.shape == (10, 4)
1249+
1250+
def test_multiple_variables(self):
1251+
with pm.Model():
1252+
x = pm.Normal("x")
1253+
y = pm.Normal("y", shape=10)
1254+
z = pm.Uniform("z", shape=5)
1255+
w = pm.Dirichlet("w", a=[1, 1, 1])
1256+
1257+
num_draws = 100
1258+
draws = pm.draw((x, y, z, w), draws=num_draws)
1259+
assert draws[0].shape == (num_draws,)
1260+
assert draws[1].shape == (num_draws, 10)
1261+
assert draws[2].shape == (num_draws, 5)
1262+
assert draws[3].shape == (num_draws, 3)
1263+
1264+
def test_draw_different_samples(self):
1265+
with pm.Model():
1266+
x = pm.Normal("x")
1267+
1268+
x_draws_1 = pm.draw(x, 100)
1269+
x_draws_2 = pm.draw(x, 100)
1270+
assert not np.all(np.isclose(x_draws_1, x_draws_2))

0 commit comments

Comments
 (0)