Skip to content

Commit bcf5637

Browse files
lucianopaztwiecki
authored andcommitted
Added shape_utils doc (#3463)
* Added shape_utils to docs * Added shape_utils to api
1 parent 05e3c39 commit bcf5637

File tree

3 files changed

+45
-1
lines changed

3 files changed

+45
-1
lines changed

docs/source/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ API Reference
1919
api/math
2020
api/data
2121
api/model
22+
api/shape_utils

docs/source/api/shape_utils.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
***********
2+
shape_utils
3+
***********
4+
5+
This submodule contains various functions that apply numpy's broadcasting rules to shape tuples, and also to samples drawn from probability distributions.
6+
7+
The main challenge when broadcasting samples drawn from a generative model, is that each random variate has a core shape. When we draw many i.i.d samples from a given RV, for example if we ask for `size_tuple` i.i.d draws, the result usually is a `size_tuple + RV_core_shape`. In the generative model's hierarchy, the downstream RVs that are conditionally dependent on our above sampled values, will get an array with a shape that is incosistent with the core shape they expect to see for their parameters. This is a problem sometimes because it prevents regular broadcasting in complex hierachical models, and thus make prior and posterior predictive sampling difficult.
8+
9+
This module introduces functions that are made aware of the requested `size_tuple` of i.i.d samples, and does the broadcasting on the core shapes, transparently ignoring or moving the i.i.d `size_tuple` prepended axes around.
10+
11+
.. currentmodule:: pymc3.distributions.shape_utils
12+
13+
.. autosummary::
14+
15+
to_tuple
16+
shapes_broadcasting
17+
broadcast_dist_samples_shape
18+
get_broadcastable_dist_samples
19+
broadcast_distribution_samples
20+
broadcast_dist_samples_to
21+
22+
.. automodule:: pymc3.distributions.shape_utils
23+
:members:

pymc3/distributions/shape_utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,18 @@
1717

1818

1919
def to_tuple(shape):
20-
"""Convert ints, arrays, and Nones to tuples"""
20+
"""Convert ints, arrays, and Nones to tuples
21+
22+
Parameters
23+
----------
24+
shape : None, int or array-like
25+
Represents the shape to convert to tuple.
26+
27+
Returns
28+
-------
29+
If `shape` is None, returns an empty tuple. If it's an int, (shape,) is
30+
returned. If it is array-like, tuple(shape) is returned.
31+
"""
2132
if shape is None:
2233
return tuple()
2334
temp = np.atleast_1d(shape)
@@ -106,6 +117,7 @@ def broadcast_dist_samples_shape(shapes, size=None):
106117
Examples
107118
--------
108119
.. code-block:: python
120+
109121
size = 100
110122
shape0 = (size,)
111123
shape1 = (size, 5)
@@ -115,6 +127,7 @@ def broadcast_dist_samples_shape(shapes, size=None):
115127
assert out == (size, 4, 5)
116128
117129
.. code-block:: python
130+
118131
size = 100
119132
shape0 = (size,)
120133
shape1 = (5,)
@@ -124,6 +137,7 @@ def broadcast_dist_samples_shape(shapes, size=None):
124137
assert out == (size, 4, 5)
125138
126139
.. code-block:: python
140+
127141
size = 100
128142
shape0 = (1,)
129143
shape1 = (5,)
@@ -204,6 +218,7 @@ def get_broadcastable_dist_samples(
204218
Examples
205219
--------
206220
.. code-block:: python
221+
207222
must_bcast_with = (3, 1, 5)
208223
size = 100
209224
sample0 = np.random.randn(size)
@@ -222,6 +237,7 @@ def get_broadcastable_dist_samples(
222237
assert np.all(sample2[:, None] == out[2])
223238
224239
.. code-block:: python
240+
225241
size = 100
226242
must_bcast_with = (3, 1, 5)
227243
sample0 = np.random.randn(size)
@@ -290,6 +306,7 @@ def broadcast_distribution_samples(samples, size=None):
290306
Examples
291307
--------
292308
.. code-block:: python
309+
293310
size = 100
294311
sample0 = np.random.randn(size)
295312
sample1 = np.random.randn(size, 5)
@@ -302,6 +319,7 @@ def broadcast_distribution_samples(samples, size=None):
302319
assert np.all(sample2 == out[2])
303320
304321
.. code-block:: python
322+
305323
size = 100
306324
sample0 = np.random.randn(size)
307325
sample1 = np.random.randn(5)
@@ -335,6 +353,7 @@ def broadcast_dist_samples_to(to_shape, samples, size=None):
335353
Examples
336354
--------
337355
.. code-block:: python
356+
338357
to_shape = (3, 1, 5)
339358
size = 100
340359
sample0 = np.random.randn(size)
@@ -351,6 +370,7 @@ def broadcast_dist_samples_to(to_shape, samples, size=None):
351370
assert np.all(sample2[:, None] == out[2])
352371
353372
.. code-block:: python
373+
354374
size = 100
355375
to_shape = (3, 1, 5)
356376
sample0 = np.random.randn(size)

0 commit comments

Comments
 (0)