Skip to content

Commit bc1d1cb

Browse files
committed
Allow size broadcasting from dims and observed
1 parent e703221 commit bc1d1cb

File tree

8 files changed

+103
-118
lines changed

8 files changed

+103
-118
lines changed

pymc/distributions/distribution.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,18 @@
3333
from aesara.tensor.var import TensorVariable
3434
from typing_extensions import TypeAlias
3535

36-
from pymc.aesaraf import change_rv_size
36+
from pymc.aesaraf import change_rv_size, convert_observed_data
3737
from pymc.distributions.shape_utils import (
3838
Dims,
3939
Shape,
4040
Size,
41+
StrongDims,
4142
StrongShape,
4243
convert_dims,
4344
convert_shape,
4445
convert_size,
4546
find_size,
46-
resize_from_dims,
47-
resize_from_observed,
47+
shape_from_dims,
4848
)
4949
from pymc.printing import str_for_dist, str_for_symbolic_dist
5050
from pymc.util import UNSET
@@ -152,29 +152,28 @@ def fn(*args, **kwargs):
152152
def _make_rv_and_resize_shape(
153153
*,
154154
cls,
155-
dims: Optional[Dims],
155+
dims: Optional[StrongDims],
156156
model,
157157
observed,
158158
args,
159159
**kwargs,
160-
) -> Tuple[Variable, Optional[Dims], Optional[Union[np.ndarray, Variable]], StrongShape]:
161-
"""Creates the RV and processes dims or observed to determine a resize shape."""
162-
# Create the RV without dims information, because that's not something tracked at the Aesara level.
163-
# If necessary we'll later replicate to a different size implied by already known dims.
164-
rv_out = cls.dist(*args, **kwargs)
165-
ndim_actual = rv_out.ndim
160+
) -> Tuple[Variable, StrongShape]:
161+
"""Creates the RV, possibly using dims or observed to determine a resize shape (if needed)."""
166162
resize_shape = None
163+
size_or_shape = kwargs.get("size") or kwargs.get("shape")
164+
165+
# Create the RV without dims or observed information
166+
rv_out = cls.dist(*args, **kwargs)
167167

168-
# # `dims` are only available with this API, because `.dist()` can be used
169-
# # without a modelcontext and dims are not tracked at the Aesara level.
170-
dims = convert_dims(dims)
171-
dims_can_resize = kwargs.get("shape", None) is None and kwargs.get("size", None) is None
172-
if dims is not None:
173-
if dims_can_resize:
174-
resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
175-
elif observed is not None:
176-
resize_shape, observed = resize_from_observed(observed, ndim_actual)
177-
return rv_out, dims, observed, resize_shape
168+
# Preference is given to size or shape, if not provided we use dims and observed
169+
# to resize the variable
170+
if not size_or_shape:
171+
if dims is not None:
172+
resize_shape = shape_from_dims(dims, tuple(rv_out.shape), model)
173+
elif observed is not None:
174+
resize_shape = tuple(observed.shape)
175+
176+
return rv_out, resize_shape
178177

179178

180179
class Distribution(metaclass=DistributionMeta):
@@ -254,15 +253,20 @@ def __new__(
254253
if not isinstance(name, string_types):
255254
raise TypeError(f"Name needs to be a string but got: {name}")
256255

257-
# Create the RV and process dims and observed to determine
258-
# a shape by which the created RV may need to be resized.
259-
rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
256+
dims = convert_dims(dims)
257+
if observed is not None:
258+
observed = convert_observed_data(observed)
259+
260+
# Create the RV, possibly taking into consideration dims and observed to
261+
# determine its shape
262+
rv_out, resize_shape = _make_rv_and_resize_shape(
260263
cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs
261264
)
262265

266+
# A shape was specified only through `dims`, or implied by `observed`.
263267
if resize_shape:
264-
# A batch size was specified through `dims`, or implied by `observed`.
265-
rv_out = change_rv_size(rv=rv_out, new_size=resize_shape, expand=True)
268+
resize_size = find_size(shape=resize_shape, size=None, ndim_supp=cls.rv_op.ndim_supp)
269+
rv_out = change_rv_size(rv=rv_out, new_size=resize_size, expand=False)
266270

267271
rv_out = model.register_rv(
268272
rv_out,
@@ -336,11 +340,7 @@ def dist(
336340
shape = convert_shape(shape)
337341
size = convert_size(size)
338342

339-
create_size, ndim_expected, ndim_batch, ndim_supp = find_size(
340-
shape=shape, size=size, ndim_supp=cls.rv_op.ndim_supp
341-
)
342-
# Create the RV with a `size` right away.
343-
# This is not necessarily the final result.
343+
create_size = find_size(shape=shape, size=size, ndim_supp=cls.rv_op.ndim_supp)
344344
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
345345

346346
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
@@ -448,19 +448,20 @@ def __new__(
448448
if not isinstance(name, string_types):
449449
raise TypeError(f"Name needs to be a string but got: {name}")
450450

451-
# Create the RV and process dims and observed to determine
452-
# a shape by which the created RV may need to be resized.
453-
rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
451+
dims = convert_dims(dims)
452+
if observed is not None:
453+
observed = convert_observed_data(observed)
454+
455+
# Create the RV, possibly taking into consideration dims and observed to
456+
# determine its shape
457+
rv_out, resize_shape = _make_rv_and_resize_shape(
454458
cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs
455459
)
456460

461+
# A shape was specified only through `dims`, or implied by `observed`.
457462
if resize_shape:
458-
# A batch size was specified through `dims`, or implied by `observed`.
459-
rv_out = cls.change_size(
460-
rv=rv_out,
461-
new_size=resize_shape,
462-
expand=True,
463-
)
463+
resize_size = find_size(shape=resize_shape, size=None, ndim_supp=rv_out.tag.ndim_supp)
464+
rv_out = cls.change_size(rv=rv_out, new_size=resize_size, expand=False)
464465

465466
rv_out = model.register_rv(
466467
rv_out,
@@ -529,18 +530,17 @@ def dist(
529530
shape = convert_shape(shape)
530531
size = convert_size(size)
531532

532-
create_size, ndim_expected, ndim_batch, ndim_supp = find_size(
533-
shape=shape, size=size, ndim_supp=cls.ndim_supp(*dist_params)
534-
)
535-
# Create the RV with a `size` right away.
536-
# This is not necessarily the final result.
537-
graph = cls.rv_op(*dist_params, size=create_size, **kwargs)
533+
ndim_supp = cls.ndim_supp(*dist_params)
534+
create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp)
535+
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
536+
# This is needed for resizing from dims in `__new__`
537+
rv_out.tag.ndim_supp = ndim_supp
538538

539539
# TODO: Create new attr error stating that these are not available for DerivedDistribution
540540
# rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
541541
# rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
542542
# rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
543-
return graph
543+
return rv_out
544544

545545

546546
@singledispatch

pymc/distributions/shape_utils.py

Lines changed: 21 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,14 @@
1818
samples from probability distributions for stochastic nodes in PyMC.
1919
"""
2020

21-
from typing import Optional, Sequence, Tuple, Union, cast
21+
from typing import Optional, Sequence, Tuple, Union
2222

2323
import numpy as np
2424

2525
from aesara.graph.basic import Variable
2626
from aesara.tensor.var import TensorVariable
2727
from typing_extensions import TypeAlias
2828

29-
from pymc.aesaraf import convert_observed_data
30-
3129
__all__ = [
3230
"to_tuple",
3331
"shapes_broadcasting",
@@ -471,74 +469,46 @@ def convert_size(size: Size) -> Optional[StrongSize]:
471469
return size
472470

473471

474-
def resize_from_dims(dims: Dims, ndim_implied: int, model) -> Tuple[StrongSize, StrongDims]:
475-
"""Determines a potential resize shape from a `dims` tuple.
472+
def shape_from_dims(
473+
dims: StrongDims, shape_implied: Sequence[TensorVariable], model
474+
) -> StrongShape:
475+
"""Determines shape from a `dims` tuple.
476476
477477
Parameters
478478
----------
479479
dims : array-like
480480
A vector of dimension names or None.
481-
ndim_implied : int
482-
Number of RV dimensions that were implied from its inputs alone.
481+
shape_implied : tensor_like of int
482+
Shape of RV implied from its inputs alone.
483483
model : pm.Model
484484
The current model on stack.
485485
486486
Returns
487487
-------
488-
resize_shape : array-like
489-
Shape of new dimensions that should be prepended.
490488
dims : tuple of (str or None)
491-
Names or None for all dimensions after resizing.
489+
Names or None for all RV dimensions.
492490
"""
493-
sdims = cast(StrongDims, dims)
494-
495-
ndim_resize = len(sdims) - ndim_implied
491+
ndim_resize = len(dims) - len(shape_implied)
496492

497-
# All resize dims must be known already (numerically or symbolically).
498-
unknowndim_resize_dims = set(sdims[:ndim_resize]) - set(model.dim_lengths)
493+
# Dims must be known already or be inferrable from implied dimensions of the RV
494+
unknowndim_resize_dims = set(dims[:ndim_resize]) - set(model.dim_lengths)
499495
if unknowndim_resize_dims:
500496
raise KeyError(
501497
f"Dimensions {unknowndim_resize_dims} are unknown to the model and cannot be used to specify a `size`."
502498
)
503499

504500
# The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
505-
resize_shape: Tuple[Variable, ...] = tuple(
506-
model.dim_lengths[dname] for dname in sdims[:ndim_resize]
501+
return tuple(
502+
model.dim_lengths[dname] if dname in model.dim_lengths else shape_implied[i]
503+
for i, dname in enumerate(dims)
507504
)
508-
return resize_shape, sdims
509-
510-
511-
def resize_from_observed(
512-
observed, ndim_implied: int
513-
) -> Tuple[StrongSize, Union[np.ndarray, Variable]]:
514-
"""Determines a potential resize shape from observations.
515-
516-
Parameters
517-
----------
518-
observed : scalar, array-like
519-
The value of the `observed` kwarg to the RV creation.
520-
ndim_implied : int
521-
Number of RV dimensions that were implied from its inputs alone.
522-
523-
Returns
524-
-------
525-
resize_shape : array-like
526-
Shape of new dimensions that should be prepended.
527-
observed : scalar, array-like
528-
Observations as numpy array or `Variable`.
529-
"""
530-
if not hasattr(observed, "shape"):
531-
observed = convert_observed_data(observed)
532-
ndim_resize = observed.ndim - ndim_implied
533-
resize_shape = tuple(observed.shape[d] for d in range(ndim_resize))
534-
return resize_shape, observed
535505

536506

537507
def find_size(
538508
shape: Optional[StrongShape],
539509
size: Optional[StrongSize],
540510
ndim_supp: int,
541-
) -> Tuple[Optional[StrongSize], Optional[int], Optional[int], int]:
511+
) -> Optional[StrongSize]:
542512
"""Determines the size keyword argument for creating a Distribution.
543513
544514
Parameters
@@ -553,30 +523,19 @@ def find_size(
553523
554524
Returns
555525
-------
556-
create_size : int, optional
557-
The size argument to be passed to the distribution
558-
ndim_expected : int, optional
559-
Number of dimensions expected after distribution was created
560-
ndim_batch : int, optional
561-
Number of batch dimensions
562-
ndim_supp : int
563-
Number of support dimensions
526+
size : tuble of int or TensorVariable, optional
527+
The size argument for creating the Distribution
564528
"""
565529

566-
ndim_expected: Optional[int] = None
567-
ndim_batch: Optional[int] = None
568-
create_size: Optional[StrongSize] = None
530+
if size is not None:
531+
return size
569532

570533
if shape is not None:
571534
ndim_expected = len(tuple(shape))
572535
ndim_batch = ndim_expected - ndim_supp
573-
create_size = tuple(shape)[:ndim_batch]
574-
elif size is not None:
575-
ndim_expected = ndim_supp + len(tuple(size))
576-
ndim_batch = ndim_expected - ndim_supp
577-
create_size = size
536+
return tuple(shape)[:ndim_batch]
578537

579-
return create_size, ndim_expected, ndim_batch, ndim_supp
538+
return None
580539

581540

582541
def rv_size_is_none(size: Variable) -> bool:

pymc/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
from aesara.compile.sharedvalue import SharedVariable
4444
from aesara.graph.basic import Constant, Variable, graph_inputs
4545
from aesara.graph.fg import FunctionGraph
46+
from aesara.scalar import Cast
47+
from aesara.tensor.elemwise import Elemwise
4648
from aesara.tensor.random.rewriting import local_subtensor_rv_lift
4749
from aesara.tensor.sharedvar import ScalarSharedVariable
4850
from aesara.tensor.var import TensorConstant, TensorVariable
@@ -1367,6 +1369,12 @@ def register_rv(
13671369
isinstance(data, Variable)
13681370
and not isinstance(data, (GenTensorVariable, Minibatch))
13691371
and data.owner is not None
1372+
# The only Aesara operation we allow on observed data is type casting
1373+
# Although we could allow for any graph that does not depend on other RVs
1374+
and not (
1375+
isinstance(data.owner.op, Elemwise)
1376+
and isinstance(data.owner.op.scalar_op, Cast)
1377+
)
13701378
):
13711379
raise TypeError(
13721380
"Variables that depend on other nodes cannot be used for observed data."

pymc/tests/test_data_container.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_sample(self):
4545
with pm.Model():
4646
x_shared = pm.MutableData("x_shared", x)
4747
b = pm.Normal("b", 0.0, 10.0)
48-
pm.Normal("obs", b * x_shared, np.sqrt(1e-2), observed=y)
48+
pm.Normal("obs", b * x_shared, np.sqrt(1e-2), observed=y, shape=x_shared.shape)
4949

5050
prior_trace0 = pm.sample_prior_predictive(1000)
5151
idata = pm.sample(1000, tune=1000, chains=1)

pymc/tests/test_distributions_random.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,7 +1905,6 @@ def test_density_dist_with_random(self, size):
19051905
mu,
19061906
random=lambda mu, rng=None, size=None: rng.normal(loc=mu, scale=1, size=size),
19071907
observed=np.random.randn(100, *size),
1908-
size=size,
19091908
)
19101909

19111910
assert obs.eval().shape == (100,) + size
@@ -1937,7 +1936,6 @@ def test_density_dist_with_random_multivariate(self, size):
19371936
mean=mu, cov=np.eye(len(mu)), size=size
19381937
),
19391938
observed=np.random.randn(100, *size, supp_shape),
1940-
size=size,
19411939
ndims_params=[1],
19421940
ndim_supp=1,
19431941
)

pymc/tests/test_sampling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,7 +1606,7 @@ def test_linear_model(self):
16061606
beta = pm.Normal("beta", 0, 0.1)
16071607
mu = pm.Deterministic("mu", alpha + beta * x)
16081608
sigma = pm.HalfNormal("sigma", 0.1)
1609-
obs = pm.Normal("obs", mu, sigma, observed=y)
1609+
obs = pm.Normal("obs", mu, sigma, observed=y, shape=x.shape)
16101610

16111611
f = compile_forward_sampling_function(
16121612
[obs],
@@ -1624,7 +1624,7 @@ def test_linear_model(self):
16241624
beta = pm.Normal("beta", 0, 0.1)
16251625
mu = pm.Deterministic("mu", alpha + beta * x)
16261626
sigma = pm.HalfNormal("sigma", 0.1)
1627-
obs = pm.Normal("obs", mu, sigma, observed=y)
1627+
obs = pm.Normal("obs", mu, sigma, observed=y, shape=x.shape)
16281628

16291629
f = compile_forward_sampling_function(
16301630
[obs],
@@ -1644,7 +1644,7 @@ def test_nested_observed_model(self):
16441644
beta = pm.Normal("beta", 0, 0.1, size=p.shape)
16451645
mu = pm.Deterministic("mu", beta[category])
16461646
sigma = pm.HalfNormal("sigma", 0.1)
1647-
pm.Normal("obs", mu, sigma, observed=y)
1647+
pm.Normal("obs", mu, sigma, observed=y, shape=mu.shape)
16481648

16491649
f = compile_forward_sampling_function(
16501650
outputs=model.observed_RVs,
@@ -1675,7 +1675,7 @@ def test_volatile_parameters(self):
16751675
mu = pm.Normal("mu", 0, 1)
16761676
nested_mu = pm.Normal("nested_mu", mu, 1, size=10)
16771677
sigma = pm.HalfNormal("sigma", 1)
1678-
pm.Normal("obs", nested_mu, sigma, observed=y)
1678+
pm.Normal("obs", nested_mu, sigma, observed=y, shape=nested_mu.shape)
16791679

16801680
f = compile_forward_sampling_function(
16811681
outputs=model.observed_RVs,

0 commit comments

Comments
 (0)