Skip to content

Commit 5007d97

Browse files
Automatically add SpecifyShape Op when full-length shape is given
1 parent 3ce51ab commit 5007d97

File tree

5 files changed

+65
-3
lines changed

5 files changed

+65
-3
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
### New Features
1010
- The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models.
1111
- The dimensionality of model variables can now be parametrized through either of `shape`, `dims` or `size` (see [#4625](https://github.com/pymc-devs/pymc3/pull/4625)):
12-
- With `shape` the length of dimensions must be given numerically or as scalar Aesara `Variables`. Using `shape` restricts the model variable to the exact length and re-sizing is no longer possible.
12+
- With `shape` the length of dimensions must be given numerically or as scalar Aesara `Variables`. A `SpecifyShape` `Op` is added automatically unless `Ellipsis` is used. Using `shape` restricts the model variable to the exact length and re-sizing is no longer possible.
1313
- `dims` keeps model variables re-sizeable (for example through `pm.Data`) and leads to well defined coordinates in `InferenceData` objects.
1414
- The `size` kwarg creates new dimensions in addition to what is implied by RV parameters.
1515
- An `Ellipsis` (`...`) in the last position of `shape` or `dims` can be used as short-hand notation for implied dimensions.

pymc3/aesaraf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream
4646
from aesara.tensor.elemwise import Elemwise
4747
from aesara.tensor.random.op import RandomVariable
48+
from aesara.tensor.shape import SpecifyShape
4849
from aesara.tensor.sharedvar import SharedVariable
4950
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
5051
from aesara.tensor.var import TensorVariable
@@ -146,6 +147,8 @@ def change_rv_size(
146147
Expand the existing size by `new_size`.
147148
148149
"""
150+
if isinstance(rv_var.owner.op, SpecifyShape):
151+
rv_var = rv_var.owner.inputs[0]
149152
rv_node = rv_var.owner
150153
rng, size, dtype, *dist_params = rv_node.inputs
151154
name = rv_var.name

pymc3/distributions/distribution.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from aesara.graph.basic import Variable
3030
from aesara.tensor.random.op import RandomVariable
31+
from aesara.tensor.shape import SpecifyShape, specify_shape
3132

3233
from pymc3.aesaraf import change_rv_size, pandas_to_array
3334
from pymc3.distributions import _logcdf, _logp
@@ -253,6 +254,13 @@ def __new__(
253254
rv_out = cls.dist(*args, rng=rng, testval=None, **kwargs)
254255
n_implied = rv_out.ndim
255256

257+
# The `.dist()` can wrap automatically with a SpecifyShape Op which brings informative
258+
# error messages earlier in model construction.
259+
# Here, however, the underyling RV must be used - a new SpecifyShape Op can be added at the end.
260+
assert_shape = None
261+
if isinstance(rv_out.owner.op, SpecifyShape):
262+
rv_out, assert_shape = rv_out.owner.inputs
263+
256264
# `dims` are only available with this API, because `.dist()` can be used
257265
# without a modelcontext and dims are not tracked at the Aesara level.
258266
if dims is not None:
@@ -292,7 +300,15 @@ def __new__(
292300
# Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
293301
rv_out.tag.test_value = testval
294302

295-
return model.register_rv(rv_out, name, observed, total_size, dims=dims, transform=transform)
303+
rv_registered = model.register_rv(
304+
rv_out, name, observed, total_size, dims=dims, transform=transform
305+
)
306+
307+
# Wrapping in specify_shape now does not break transforms:
308+
if assert_shape is not None:
309+
rv_registered = specify_shape(rv_registered, assert_shape)
310+
311+
return rv_registered
296312

297313
@classmethod
298314
def dist(
@@ -314,6 +330,9 @@ def dist(
314330
315331
Ellipsis (...) may be used in the last position of the tuple,
316332
and automatically expand to the shape implied by RV inputs.
333+
334+
Without Ellipsis, a `SpecifyShape` Op is automatically applied,
335+
constraining this model variable to exactly the specified shape.
317336
size : int, tuple, Variable, optional
318337
A scalar or tuple for replicating the RV in addition
319338
to its implied shape/dimensionality.
@@ -330,6 +349,7 @@ def dist(
330349
raise NotImplementedError("The use of a `.dist(dims=...)` API is not yet supported.")
331350

332351
shape, _, size = _validate_shape_dims_size(shape=shape, size=size)
352+
assert_shape = None
333353

334354
# Create the RV without specifying size or testval.
335355
# The size will be expanded later (if necessary) and only then the testval fits.
@@ -338,13 +358,16 @@ def dist(
338358
if shape is None and size is None:
339359
size = ()
340360
elif shape is not None:
361+
# SpecifyShape is automatically applied for symbolic and non-Ellipsis shapes
341362
if isinstance(shape, Variable):
363+
assert_shape = shape
342364
size = ()
343365
else:
344366
if Ellipsis in shape:
345367
size = tuple(shape[:-1])
346368
else:
347369
size = tuple(shape[: len(shape) - rv_native.ndim])
370+
assert_shape = shape
348371
# no-op conditions:
349372
# `elif size is not None` (User already specified how to expand the RV)
350373
# `else` (Unreachable)
@@ -354,6 +377,9 @@ def dist(
354377
else:
355378
rv_out = rv_native
356379

380+
if assert_shape is not None:
381+
rv_out = specify_shape(rv_out, shape=assert_shape)
382+
357383
if testval is not None:
358384
rv_out.tag.test_value = testval
359385

pymc3/tests/test_logp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_logpt_incsubtensor(indices, shape):
8686
sigma = 0.001
8787
rng = aesara.shared(np.random.RandomState(232), borrow=True)
8888

89-
a = Normal.dist(mu, sigma, shape=shape, rng=rng)
89+
a = Normal.dist(mu, sigma, rng=rng)
9090
a.name = "a"
9191

9292
a_idx = at.set_subtensor(a[indices], data)

pymc3/tests/test_shape_handling.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,39 @@ def test_dist_api_works(self):
350350
assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3)
351351
assert pm.Normal.dist(mu=mu, size=(4,)).eval().shape == (4, 3)
352352

353+
def test_auto_assert_shape(self):
354+
with pytest.raises(AssertionError, match="will never match"):
355+
pm.Normal.dist(mu=[1, 2], shape=[])
356+
357+
mu = at.vector(name="mu_input")
358+
rv = pm.Normal.dist(mu=mu, shape=[3, 4])
359+
f = aesara.function([mu], rv, mode=aesara.Mode("py"))
360+
assert f([1, 2, 3, 4]).shape == (3, 4)
361+
362+
with pytest.raises(AssertionError, match=r"Got shape \(3, 2\), expected \(3, 4\)."):
363+
f([1, 2])
364+
365+
# The `shape` can be symbolic!
366+
s = at.vector(dtype="int32")
367+
rv = pm.Uniform.dist(2, [4, 5], shape=s)
368+
f = aesara.function([s], rv, mode=aesara.Mode("py"))
369+
f(
370+
[
371+
2,
372+
]
373+
)
374+
with pytest.raises(
375+
AssertionError,
376+
match=r"Got 1 dimensions \(shape \(2,\)\), expected 2 dimensions with shape \(3, 4\).",
377+
):
378+
f([3, 4])
379+
with pytest.raises(
380+
AssertionError,
381+
match=r"Got 1 dimensions \(shape \(2,\)\), expected 0 dimensions with shape \(\).",
382+
):
383+
f([])
384+
pass
385+
353386
def test_lazy_flavors(self):
354387

355388
_validate_shape_dims_size(shape=5)

0 commit comments

Comments
 (0)