Skip to content

Commit ff172bd

Browse files
committed
Allow to specify dims on the fly from observed
1 parent bc1d1cb commit ff172bd

File tree

2 files changed

+44
-25
lines changed

2 files changed

+44
-25
lines changed

pymc/distributions/distribution.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def fn(*args, **kwargs):
149149
return fn
150150

151151

152-
def _make_rv_and_resize_shape(
152+
def _make_rv_and_resize_shape_from_dims(
153153
*,
154154
cls,
155155
dims: Optional[StrongDims],
@@ -159,21 +159,23 @@ def _make_rv_and_resize_shape(
159159
**kwargs,
160160
) -> Tuple[Variable, StrongShape]:
161161
"""Creates the RV, possibly using dims or observed to determine a resize shape (if needed)."""
162-
resize_shape = None
162+
resize_shape_from_dims = None
163163
size_or_shape = kwargs.get("size") or kwargs.get("shape")
164164

165-
# Create the RV without dims or observed information
165+
# Preference is given to size or shape. If not specified, we rely on dims and
166+
# finally, observed, to determine the shape of the variable. Because dims can be
167+
# specified on the fly, we need a two-step process where we first create the RV
168+
# without dims information and then resize it.
169+
if not size_or_shape and observed is not None:
170+
kwargs["shape"] = tuple(observed.shape)
171+
172+
# Create the RV without dims information
166173
rv_out = cls.dist(*args, **kwargs)
167174

168-
# Preference is given to size or shape, if not provided we use dims and observed
169-
# to resize the variable
170-
if not size_or_shape:
171-
if dims is not None:
172-
resize_shape = shape_from_dims(dims, tuple(rv_out.shape), model)
173-
elif observed is not None:
174-
resize_shape = tuple(observed.shape)
175+
if not size_or_shape and dims is not None:
176+
resize_shape_from_dims = shape_from_dims(dims, tuple(rv_out.shape), model)
175177

176-
return rv_out, resize_shape
178+
return rv_out, resize_shape_from_dims
177179

178180

179181
class Distribution(metaclass=DistributionMeta):
@@ -257,16 +259,17 @@ def __new__(
257259
if observed is not None:
258260
observed = convert_observed_data(observed)
259261

260-
# Create the RV, possibly taking into consideration dims and observed to
261-
# determine its shape
262-
rv_out, resize_shape = _make_rv_and_resize_shape(
262+
# Create the RV, without taking `dims` into consideration
263+
rv_out, resize_shape_from_dims = _make_rv_and_resize_shape_from_dims(
263264
cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs
264265
)
265266

266-
# A shape was specified only through `dims`, or implied by `observed`.
267-
if resize_shape:
268-
resize_size = find_size(shape=resize_shape, size=None, ndim_supp=cls.rv_op.ndim_supp)
269-
rv_out = change_rv_size(rv=rv_out, new_size=resize_size, expand=False)
267+
# Resize variable based on `dims` information
268+
if resize_shape_from_dims:
269+
resize_size_from_dims = find_size(
270+
shape=resize_shape_from_dims, size=None, ndim_supp=cls.rv_op.ndim_supp
271+
)
272+
rv_out = change_rv_size(rv=rv_out, new_size=resize_size_from_dims, expand=False)
270273

271274
rv_out = model.register_rv(
272275
rv_out,
@@ -452,16 +455,17 @@ def __new__(
452455
if observed is not None:
453456
observed = convert_observed_data(observed)
454457

455-
# Create the RV, possibly taking into consideration dims and observed to
456-
# determine its shape
457-
rv_out, resize_shape = _make_rv_and_resize_shape(
458+
# Create the RV, without taking `dims` into consideration
459+
rv_out, resize_shape_from_dims = _make_rv_and_resize_shape_from_dims(
458460
cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs
459461
)
460462

461-
# A shape was specified only through `dims`, or implied by `observed`.
462-
if resize_shape:
463-
resize_size = find_size(shape=resize_shape, size=None, ndim_supp=rv_out.tag.ndim_supp)
464-
rv_out = cls.change_size(rv=rv_out, new_size=resize_size, expand=False)
463+
# Resize variable based on `dims` information
464+
if resize_shape_from_dims:
465+
resize_size_from_dims = find_size(
466+
shape=resize_shape_from_dims, size=None, ndim_supp=rv_out.tag.ndim_supp
467+
)
468+
rv_out = cls.change_size(rv=rv_out, new_size=resize_size_from_dims, expand=False)
465469

466470
rv_out = model.register_rv(
467471
rv_out,

pymc/tests/test_shape_handling.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,21 @@ def test_define_dims_on_the_fly(self):
320320
# The change should propagate all the way through
321321
assert effect.eval().shape == (4,)
322322

323+
def test_define_dims_on_the_fly_from_observed(self):
324+
with pm.Model() as pmodel:
325+
data = aesara.shared(np.zeros((4, 5)))
326+
x = pm.Normal("x", observed=data, dims=("patient", "trials"))
327+
assert pmodel.dim_lengths["patient"].eval() == 4
328+
assert pmodel.dim_lengths["trials"].eval() == 5
329+
330+
# Use dim to create a new RV
331+
x_noisy = pm.Normal("x_noisy", 0, dims=("patient", "trials"))
332+
assert x_noisy.eval().shape == (4, 5)
333+
334+
# Change data patient dims
335+
data.set_value(np.zeros((10, 6)))
336+
assert x_noisy.eval().shape == (10, 6)
337+
323338
def test_can_resize_data_defined_size(self):
324339
with pm.Model() as pmodel:
325340
x = pm.MutableData("x", [[1, 2, 3, 4]], dims=("first", "second"))

0 commit comments

Comments
 (0)