|
38 | 38 | from pymc.pytensorf import convert_observed_data
|
39 | 39 |
|
40 | 40 | __all__ = [
|
| 41 | + "broadcast_dist_samples_shape", |
41 | 42 | "to_tuple",
|
42 | 43 | "rv_size_is_none",
|
43 | 44 | "change_dist_size",
|
@@ -86,45 +87,85 @@ def _check_shape_type(shape):
|
86 | 87 | return tuple(out)
|
87 | 88 |
|
88 | 89 |
|
89 |
| -def shapes_broadcasting(*args, raise_exception=False): |
90 |
| - """Return the shape resulting from broadcasting multiple shapes. |
91 |
| - Represents numpy's broadcasting rules. |
| 90 | +def broadcast_dist_samples_shape(shapes, size=None): |
| 91 | + """Apply shape broadcasting to shape tuples but assuming that the shapes |
| 92 | + correspond to draws from random variables, with the `size` tuple possibly |
| 93 | + prepended to it. The `size` prepend is ignored to consider if the supplied |
| 94 | + `shapes` can broadcast or not. It is prepended to the resulting broadcasted |
| 95 | + `shapes`, if any of the shape tuples had the `size` prepend. |
92 | 96 |
|
93 | 97 | Parameters
|
94 | 98 | ----------
|
95 |
| - *args: array-like of int |
96 |
| - Tuples or arrays or lists representing the shapes of arrays to be |
97 |
| - broadcast. |
98 |
| - raise_exception: bool (optional) |
99 |
| - Controls whether to raise an exception or simply return `None` if |
100 |
| - the broadcasting fails. |
| 99 | + shapes: Iterable of tuples holding the distribution samples shapes |
| 100 | + size: None, int or tuple (optional) |
| 101 | + size of the sample set requested. |
101 | 102 |
|
102 | 103 | Returns
|
103 | 104 | -------
|
104 |
| - Resulting shape. If broadcasting is not possible and `raise_exception` is |
105 |
| - False, then `None` is returned. If `raise_exception` is `True`, a |
106 |
| - `ValueError` is raised. |
| 105 | + tuple of the resulting shape |
| 106 | +
|
| 107 | + Examples |
| 108 | + -------- |
| 109 | + .. code-block:: python |
| 110 | + size = 100 |
| 111 | + shape0 = (size,) |
| 112 | + shape1 = (size, 5) |
| 113 | + shape2 = (size, 4, 5) |
| 114 | + out = broadcast_dist_samples_shape([shape0, shape1, shape2], |
| 115 | + size=size) |
| 116 | + assert out == (size, 4, 5) |
| 117 | + .. code-block:: python |
| 118 | + size = 100 |
| 119 | + shape0 = (size,) |
| 120 | + shape1 = (5,) |
| 121 | + shape2 = (4, 5) |
| 122 | + out = broadcast_dist_samples_shape([shape0, shape1, shape2], |
| 123 | + size=size) |
| 124 | + assert out == (size, 4, 5) |
| 125 | + .. code-block:: python |
| 126 | + size = 100 |
| 127 | + shape0 = (1,) |
| 128 | + shape1 = (5,) |
| 129 | + shape2 = (4, 5) |
| 130 | + out = broadcast_dist_samples_shape([shape0, shape1, shape2], |
| 131 | + size=size) |
| 132 | + assert out == (4, 5) |
107 | 133 | """
|
108 |
| - x = list(_check_shape_type(args[0])) if args else () |
109 |
| - for arg in args[1:]: |
110 |
| - y = list(_check_shape_type(arg)) |
111 |
| - if len(x) < len(y): |
112 |
| - x, y = y, x |
113 |
| - if len(y) > 0: |
114 |
| - x[-len(y) :] = [ |
115 |
| - j if i == 1 else i if j == 1 else i if i == j else 0 |
116 |
| - for i, j in zip(x[-len(y) :], y) |
117 |
| - ] |
118 |
| - if not all(x): |
119 |
| - if raise_exception: |
120 |
| - raise ValueError( |
121 |
| - "Supplied shapes {} do not broadcast together".format( |
122 |
| - ", ".join([f"{a}" for a in args]) |
123 |
| - ) |
| 134 | + if size is None: |
| 135 | + broadcasted_shape = np.broadcast_shapes(*shapes) |
| 136 | + if broadcasted_shape is None: |
| 137 | + raise ValueError( |
| 138 | + "Cannot broadcast provided shapes {} given size: {}".format( |
| 139 | + ", ".join([f"{s}" for s in shapes]), size |
124 | 140 | )
|
125 |
| - else: |
126 |
| - return None |
127 |
| - return tuple(x) |
| 141 | + ) |
| 142 | + return broadcasted_shape |
| 143 | + shapes = [_check_shape_type(s) for s in shapes] |
| 144 | + _size = to_tuple(size) |
| 145 | + # samples shapes without the size prepend |
| 146 | + sp_shapes = [s[len(_size) :] if _size == s[: min([len(_size), len(s)])] else s for s in shapes] |
| 147 | + try: |
| 148 | + broadcast_shape = np.broadcast_shapes(*sp_shapes) |
| 149 | + except ValueError: |
| 150 | + raise ValueError( |
| 151 | + "Cannot broadcast provided shapes {} given size: {}".format( |
| 152 | + ", ".join([f"{s}" for s in shapes]), size |
| 153 | + ) |
| 154 | + ) |
| 155 | + broadcastable_shapes = [] |
| 156 | + for shape, sp_shape in zip(shapes, sp_shapes): |
| 157 | + if _size == shape[: len(_size)]: |
| 158 | + # If size prepends the shape, then we have to add broadcasting axis |
| 159 | + # in the middle |
| 160 | + p_shape = ( |
| 161 | + shape[: len(_size)] |
| 162 | + + (1,) * (len(broadcast_shape) - len(sp_shape)) |
| 163 | + + shape[len(_size) :] |
| 164 | + ) |
| 165 | + else: |
| 166 | + p_shape = shape |
| 167 | + broadcastable_shapes.append(p_shape) |
| 168 | + return np.broadcast_shapes(*broadcastable_shapes) |
128 | 169 |
|
129 | 170 |
|
130 | 171 | # User-provided can be lazily specified as scalars
|
|
0 commit comments