diff --git a/docs/source/api/shape_utils.rst b/docs/source/api/shape_utils.rst index 7c6916b060..2908cb9c79 100644 --- a/docs/source/api/shape_utils.rst +++ b/docs/source/api/shape_utils.rst @@ -14,10 +14,6 @@ This module introduces functions that are made aware of the requested `size_tupl :toctree: generated/ to_tuple - shapes_broadcasting broadcast_dist_samples_shape - get_broadcastable_dist_samples - broadcast_distribution_samples - broadcast_dist_samples_to rv_size_is_none change_dist_size diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index de051b7d3f..7f873a2e4a 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -59,7 +59,7 @@ ) from pymc.distributions.shape_utils import ( _change_dist_size, - broadcast_dist_samples_to, + broadcast_dist_samples_shape, change_dist_size, get_support_shape, rv_size_is_none, @@ -1651,7 +1651,9 @@ def rng_fn(cls, rng, mu, rowchol, colchol, size=None): output_shape = size + dist_shape # Broadcasting all parameters - (mu,) = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size) + shapes = [mu.shape, output_shape] + broadcastable_shape = broadcast_dist_samples_shape(shapes, size=size) + mu = np.broadcast_to(mu, shape=broadcastable_shape) rowchol = np.broadcast_to(rowchol, shape=size + rowchol.shape[-2:]) colchol = np.broadcast_to(colchol, shape=size + colchol.shape[-2:]) diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index a4ee80ca12..2987ec444c 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -38,12 +38,8 @@ from pymc.pytensorf import convert_observed_data __all__ = [ - "to_tuple", - "shapes_broadcasting", "broadcast_dist_samples_shape", - "get_broadcastable_dist_samples", - "broadcast_distribution_samples", - "broadcast_dist_samples_to", + "to_tuple", "rv_size_is_none", "change_dist_size", ] @@ -91,47 +87,6 @@ def _check_shape_type(shape): return tuple(out) -def shapes_broadcasting(*args, raise_exception=False): - """Return the shape resulting from broadcasting multiple shapes. - Represents numpy's broadcasting rules. - - Parameters - ---------- - *args: array-like of int - Tuples or arrays or lists representing the shapes of arrays to be - broadcast. - raise_exception: bool (optional) - Controls whether to raise an exception or simply return `None` if - the broadcasting fails. - - Returns - ------- - Resulting shape. If broadcasting is not possible and `raise_exception` is - False, then `None` is returned. If `raise_exception` is `True`, a - `ValueError` is raised. - """ - x = list(_check_shape_type(args[0])) if args else () - for arg in args[1:]: - y = list(_check_shape_type(arg)) - if len(x) < len(y): - x, y = y, x - if len(y) > 0: - x[-len(y) :] = [ - j if i == 1 else i if j == 1 else i if i == j else 0 - for i, j in zip(x[-len(y) :], y) - ] - if not all(x): - if raise_exception: - raise ValueError( - "Supplied shapes {} do not broadcast together".format( - ", ".join([f"{a}" for a in args]) - ) - ) - else: - return None - return tuple(x) - - def broadcast_dist_samples_shape(shapes, size=None): """Apply shape broadcasting to shape tuples but assuming that the shapes correspond to draws from random variables, with the `size` tuple possibly @@ -152,7 +107,6 @@ def broadcast_dist_samples_shape(shapes, size=None): Examples -------- .. code-block:: python - size = 100 shape0 = (size,) shape1 = (size, 5) @@ -160,9 +114,7 @@ def broadcast_dist_samples_shape(shapes, size=None): out = broadcast_dist_samples_shape([shape0, shape1, shape2], size=size) assert out == (size, 4, 5) - .. code-block:: python - size = 100 shape0 = (size,) shape1 = (5,) @@ -170,9 +122,7 @@ def broadcast_dist_samples_shape(shapes, size=None): out = broadcast_dist_samples_shape([shape0, shape1, shape2], size=size) assert out == (size, 4, 5) - .. code-block:: python - size = 100 shape0 = (1,) shape1 = (5,) @@ -182,7 +132,7 @@ def broadcast_dist_samples_shape(shapes, size=None): assert out == (4, 5) """ if size is None: - broadcasted_shape = shapes_broadcasting(*shapes) + broadcasted_shape = np.broadcast_shapes(*shapes) if broadcasted_shape is None: raise ValueError( "Cannot broadcast provided shapes {} given size: {}".format( @@ -195,7 +145,7 @@ def broadcast_dist_samples_shape(shapes, size=None): # samples shapes without the size prepend sp_shapes = [s[len(_size) :] if _size == s[: min([len(_size), len(s)])] else s for s in shapes] try: - broadcast_shape = shapes_broadcasting(*sp_shapes, raise_exception=True) + broadcast_shape = np.broadcast_shapes(*sp_shapes) except ValueError: raise ValueError( "Cannot broadcast provided shapes {} given size: {}".format( @@ -215,212 +165,7 @@ def broadcast_dist_samples_shape(shapes, size=None): else: p_shape = shape broadcastable_shapes.append(p_shape) - return shapes_broadcasting(*broadcastable_shapes, raise_exception=True) - - -def get_broadcastable_dist_samples( - samples, size=None, must_bcast_with=None, return_out_shape=False -): - """Get a view of the samples drawn from distributions which adds new axes - in between the `size` prepend and the distribution's `shape`. These views - should be able to broadcast the samples from the distrubtions taking into - account the `size` (i.e. the number of samples) of the draw, which is - prepended to the sample's `shape`. Optionally, one can supply an extra - `must_bcast_with` to try to force samples to be able to broadcast with a - given shape. A `ValueError` is raised if it is not possible to broadcast - the provided samples. - - Parameters - ---------- - samples: Iterable of ndarrays holding the sampled values - size: None, int or tuple (optional) - size of the sample set requested. - must_bcast_with: None, int or tuple (optional) - Tuple shape to which the samples must be able to broadcast - return_out_shape: bool (optional) - If `True`, this function also returns the output's shape and not only - samples views. - - Returns - ------- - broadcastable_samples: List of the broadcasted sample arrays - broadcast_shape: If `return_out_shape` is `True`, the resulting broadcast - shape is returned. - - Examples - -------- - .. code-block:: python - - must_bcast_with = (3, 1, 5) - size = 100 - sample0 = np.random.randn(size) - sample1 = np.random.randn(size, 5) - sample2 = np.random.randn(size, 4, 5) - out = broadcast_dist_samples_to( - [sample0, sample1, sample2], - size=size, - must_bcast_with=must_bcast_with, - ) - assert out[0].shape == (size, 1, 1, 1) - assert out[1].shape == (size, 1, 1, 5) - assert out[2].shape == (size, 1, 4, 5) - assert np.all(sample0[:, None, None, None] == out[0]) - assert np.all(sample1[:, None, None] == out[1]) - assert np.all(sample2[:, None] == out[2]) - - .. code-block:: python - - size = 100 - must_bcast_with = (3, 1, 5) - sample0 = np.random.randn(size) - sample1 = np.random.randn(5) - sample2 = np.random.randn(4, 5) - out = broadcast_dist_samples_to( - [sample0, sample1, sample2], - size=size, - must_bcast_with=must_bcast_with, - ) - assert out[0].shape == (size, 1, 1, 1) - assert out[1].shape == (5,) - assert out[2].shape == (4, 5) - assert np.all(sample0[:, None, None, None] == out[0]) - assert np.all(sample1 == out[1]) - assert np.all(sample2 == out[2]) - """ - samples = [np.asarray(p) for p in samples] - _size = to_tuple(size) - must_bcast_with = to_tuple(must_bcast_with) - # Raw samples shapes - p_shapes = [p.shape for p in samples] + [_check_shape_type(must_bcast_with)] - out_shape = broadcast_dist_samples_shape(p_shapes, size=size) - # samples shapes without the size prepend - sp_shapes = [ - s[len(_size) :] if _size == s[: min([len(_size), len(s)])] else s for s in p_shapes - ] - broadcast_shape = shapes_broadcasting(*sp_shapes, raise_exception=True) - broadcastable_samples = [] - for param, p_shape, sp_shape in zip(samples, p_shapes, sp_shapes): - if _size == p_shape[: min([len(_size), len(p_shape)])]: - # If size prepends the shape, then we have to add broadcasting axis - # in the middle - slicer_head = [slice(None)] * len(_size) - slicer_tail = [np.newaxis] * (len(broadcast_shape) - len(sp_shape)) + [ - slice(None) - ] * len(sp_shape) - else: - # If size does not prepend the shape, then we have leave the - # parameter as is - slicer_head = [] - slicer_tail = [slice(None)] * len(sp_shape) - broadcastable_samples.append(param[tuple(slicer_head + slicer_tail)]) - if return_out_shape: - return broadcastable_samples, out_shape - else: - return broadcastable_samples - - -def broadcast_distribution_samples(samples, size=None): - """Broadcast samples drawn from distributions taking into account the - size (i.e. the number of samples) of the draw, which is prepended to - the sample's shape. - - Parameters - ---------- - samples: Iterable of ndarrays holding the sampled values - size: None, int or tuple (optional) - size of the sample set requested. - - Returns - ------- - List of broadcasted sample arrays - - Examples - -------- - .. code-block:: python - - size = 100 - sample0 = np.random.randn(size) - sample1 = np.random.randn(size, 5) - sample2 = np.random.randn(size, 4, 5) - out = broadcast_distribution_samples([sample0, sample1, sample2], - size=size) - assert all((o.shape == (size, 4, 5) for o in out)) - assert np.all(sample0[:, None, None] == out[0]) - assert np.all(sample1[:, None, :] == out[1]) - assert np.all(sample2 == out[2]) - - .. code-block:: python - - size = 100 - sample0 = np.random.randn(size) - sample1 = np.random.randn(5) - sample2 = np.random.randn(4, 5) - out = broadcast_distribution_samples([sample0, sample1, sample2], - size=size) - assert all((o.shape == (size, 4, 5) for o in out)) - assert np.all(sample0[:, None, None] == out[0]) - assert np.all(sample1 == out[1]) - assert np.all(sample2 == out[2]) - """ - return np.broadcast_arrays(*get_broadcastable_dist_samples(samples, size=size)) - - -def broadcast_dist_samples_to(to_shape, samples, size=None): - """Broadcast samples drawn from distributions to a given shape, taking into - account the size (i.e. the number of samples) of the draw, which is - prepended to the sample's shape. - - Parameters - ---------- - to_shape: Tuple shape onto which the samples must be able to broadcast - samples: Iterable of ndarrays holding the sampled values - size: None, int or tuple (optional) - size of the sample set requested. - - Returns - ------- - List of the broadcasted sample arrays - - Examples - -------- - .. code-block:: python - - to_shape = (3, 1, 5) - size = 100 - sample0 = np.random.randn(size) - sample1 = np.random.randn(size, 5) - sample2 = np.random.randn(size, 4, 5) - out = broadcast_dist_samples_to( - to_shape, - [sample0, sample1, sample2], - size=size - ) - assert np.all((o.shape == (size, 3, 4, 5) for o in out)) - assert np.all(sample0[:, None, None, None] == out[0]) - assert np.all(sample1[:, None, None] == out[1]) - assert np.all(sample2[:, None] == out[2]) - - .. code-block:: python - - size = 100 - to_shape = (3, 1, 5) - sample0 = np.random.randn(size) - sample1 = np.random.randn(5) - sample2 = np.random.randn(4, 5) - out = broadcast_dist_samples_to( - to_shape, - [sample0, sample1, sample2], - size=size - ) - assert np.all((o.shape == (size, 3, 4, 5) for o in out)) - assert np.all(sample0[:, None, None, None] == out[0]) - assert np.all(sample1 == out[1]) - assert np.all(sample2 == out[2]) - """ - samples, to_shape = get_broadcastable_dist_samples( - samples, size=size, must_bcast_with=to_shape, return_out_shape=True - ) - return [np.broadcast_to(o, to_shape) for o in samples] + return np.broadcast_shapes(*broadcastable_shapes) # User-provided can be lazily specified as scalars diff --git a/tests/distributions/test_shape_utils.py b/tests/distributions/test_shape_utils.py index 82e8a36b26..2c2598145b 100644 --- a/tests/distributions/test_shape_utils.py +++ b/tests/distributions/test_shape_utils.py @@ -30,17 +30,13 @@ from pymc import ShapeError from pymc.distributions.shape_utils import ( broadcast_dist_samples_shape, - broadcast_dist_samples_to, - broadcast_distribution_samples, change_dist_size, convert_dims, convert_shape, convert_size, - get_broadcastable_dist_samples, get_support_shape, get_support_shape_1d, rv_size_is_none, - shapes_broadcasting, to_tuple, ) from pymc.model import Model @@ -90,67 +86,18 @@ def fixture_exception_handling(request): return request.param -@pytest.fixture() -def samples_to_broadcast(fixture_sizes, fixture_shapes): - samples = [np.empty(s) for s in fixture_shapes] - try: - broadcast_shape = broadcast_dist_samples_shape(fixture_shapes, size=fixture_sizes) - except ValueError: - broadcast_shape = None - return fixture_sizes, samples, broadcast_shape - - -@pytest.fixture(params=test_to_shapes, ids=str) -def samples_to_broadcast_to(request, samples_to_broadcast): - to_shape = request.param - size, samples, broadcast_shape = samples_to_broadcast - if broadcast_shape is not None: - try: - broadcast_shape = broadcast_dist_samples_shape( - [broadcast_shape, to_tuple(to_shape)], size=size - ) - except ValueError: - broadcast_shape = None - return to_shape, size, samples, broadcast_shape - - class TestShapesBroadcasting: - @pytest.mark.parametrize( - "bad_input", - [None, [None], "asd", 3.6, {1: 2}, {3}, [8, [8]], "3", ["3"], np.array([[2]])], - ids=str, - ) - def test_type_check_raises(self, bad_input): - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", ".*ragged nested sequences.*", np.VisibleDeprecationWarning - ) - with pytest.raises(TypeError): - shapes_broadcasting(bad_input, tuple(), raise_exception=True) - with pytest.raises(TypeError): - shapes_broadcasting(bad_input, tuple(), raise_exception=False) - - def test_type_check_success(self): - inputs = [3, 3.0, tuple(), [3], (3,), np.array(3), np.array([3])] - out = shapes_broadcasting(*inputs) - assert out == (3,) - - def test_broadcasting(self, fixture_shapes, fixture_exception_handling): + def test_broadcasting(self, fixture_shapes): shapes = fixture_shapes - raise_exception = fixture_exception_handling try: expected_out = np.broadcast(*(np.empty(s) for s in shapes)).shape except ValueError: expected_out = None if expected_out is None: - if raise_exception: - with pytest.raises(ValueError): - shapes_broadcasting(*shapes, raise_exception=raise_exception) - else: - out = shapes_broadcasting(*shapes, raise_exception=raise_exception) - assert out is None + with pytest.raises(ValueError): + np.broadcast_shapes(*shapes) else: - out = shapes_broadcasting(*shapes, raise_exception=raise_exception) + out = np.broadcast_shapes(*shapes) assert out == expected_out def test_broadcast_dist_samples_shape(self, fixture_sizes, fixture_shapes): @@ -176,48 +123,6 @@ def test_broadcast_dist_samples_shape(self, fixture_sizes, fixture_shapes): assert out == expected_out -class TestSamplesBroadcasting: - def test_broadcast_distribution_samples(self, samples_to_broadcast): - size, samples, broadcast_shape = samples_to_broadcast - if broadcast_shape is not None: - outs = broadcast_distribution_samples(samples, size=size) - assert all(o.shape == broadcast_shape for o in outs) - else: - with pytest.raises(ValueError): - broadcast_distribution_samples(samples, size=size) - - def test_get_broadcastable_dist_samples(self, samples_to_broadcast): - size, samples, broadcast_shape = samples_to_broadcast - if broadcast_shape is not None: - size_ = to_tuple(size) - outs, out_shape = get_broadcastable_dist_samples( - samples, size=size, return_out_shape=True - ) - assert out_shape == broadcast_shape - for i, o in zip(samples, outs): - ishape = i.shape - if ishape[: min([len(size_), len(ishape)])] == size_: - expected_shape = ( - size_ + (1,) * (len(broadcast_shape) - len(ishape)) + ishape[len(size_) :] - ) - else: - expected_shape = ishape - assert o.shape == expected_shape - assert shapes_broadcasting(*(o.shape for o in outs)) == broadcast_shape - else: - with pytest.raises(ValueError): - get_broadcastable_dist_samples(samples, size=size) - - def test_broadcast_dist_samples_to(self, samples_to_broadcast_to): - to_shape, size, samples, broadcast_shape = samples_to_broadcast_to - if broadcast_shape is not None: - outs = broadcast_dist_samples_to(to_shape, samples, size=size) - assert all(o.shape == broadcast_shape for o in outs) - else: - with pytest.raises(ValueError): - broadcast_dist_samples_to(to_shape, samples, size=size) - - class TestSizeShapeDimsObserved: @pytest.mark.parametrize("param_shape", [(), (2,)]) @pytest.mark.parametrize("batch_shape", [(), (3,)])