diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 95bd8c4c7d..2b61ab68ec 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -145,6 +145,7 @@ def convert_observed_data(data): def change_rv_size( rv: TensorVariable, new_size: PotentialShapeType, + new_size_dims: Optional[tuple] = (None,), expand: Optional[bool] = False, ) -> TensorVariable: """Change or expand the size of a `RandomVariable`. @@ -155,6 +156,8 @@ def change_rv_size( The old `RandomVariable` output. new_size The new size. + new_size_dims + dim names of the new size vector expand: Expand the existing size by `new_size`. @@ -166,6 +169,10 @@ def change_rv_size( elif new_size_ndim == 0: new_size = (new_size,) + # wrap None in tuple, if new_size_dims are None + if new_size_dims is None: + new_size_dims = (None,) + # Extract the RV node that is to be resized, together with its inputs, name and tag assert rv.owner.op is not None if isinstance(rv.owner.op, SpecifyShape): @@ -180,9 +187,13 @@ def change_rv_size( size = shape[: len(shape) - rv_node.op.ndim_supp] new_size = tuple(new_size) + tuple(size) + # create the name of the RV's resizing tensor + # TODO: add information where the dim is coming from (obseverd, prior, ...) + new_size_name = f"Broadcast to {new_size_dims[0]}_dim" + # Make sure the new size is a tensor. This dtype-aware conversion helps # to not unnecessarily pick up a `Cast` in some cases (see #4652). - new_size = at.as_tensor(new_size, ndim=1, dtype="int64") + new_size = at.as_tensor(new_size, ndim=1, dtype="int64", name=new_size_name) new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params) new_rv = new_rv_node.outputs[-1] diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index dd5968df7e..0eaa0624c7 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -266,7 +266,9 @@ def __new__( if resize_shape: # A batch size was specified through `dims`, or implied by `observed`. - rv_out = change_rv_size(rv=rv_out, new_size=resize_shape, expand=True) + rv_out = change_rv_size( + rv=rv_out, new_size=resize_shape, new_size_dims=dims, expand=True + ) rv_out = model.register_rv( rv_out,