Skip to content

Commit 85cfc99

Browse files
committed
Refactor get_support_shape
* Do not call `at.stack` until the end * Allow to pass single shape_offset
1 parent b3fb802 commit 85cfc99

File tree

2 files changed

+53
-35
lines changed

2 files changed

+53
-35
lines changed

pymc/distributions/shape_utils.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import warnings
2121

2222
from functools import singledispatch
23-
from typing import Any, Optional, Sequence, Tuple, Union
23+
from typing import Any, Optional, Sequence, Tuple, Union, cast
2424

2525
import numpy as np
2626

@@ -671,7 +671,7 @@ def get_support_shape(
671671
observed: Optional[Any] = None,
672672
support_shape_offset: Sequence[int] = None,
673673
ndim_supp: int = 1,
674-
):
674+
) -> Optional[TensorVariable]:
675675
"""Extract the support shapes from shape / dims / observed information
676676
677677
Parameters
@@ -702,46 +702,61 @@ def get_support_shape(
702702
raise NotImplementedError("ndim_supp must be bigger than 0")
703703
if support_shape_offset is None:
704704
support_shape_offset = [0] * ndim_supp
705-
inferred_support_shape = None
705+
elif isinstance(support_shape_offset, int):
706+
support_shape_offset = [support_shape_offset] * ndim_supp
707+
inferred_support_shape: Optional[Sequence[Union[int, np.ndarray, Variable]]] = None
706708

707709
if shape is not None:
708710
shape = to_tuple(shape)
709711
assert isinstance(shape, tuple)
710-
inferred_support_shape = at.stack(
711-
[shape[i] - support_shape_offset[i] for i in np.arange(-ndim_supp, 0)]
712-
)
712+
if len(shape) < ndim_supp:
713+
raise ValueError(
714+
f"Number of shape dimensions is too small for ndim_supp of {ndim_supp}"
715+
)
716+
inferred_support_shape = [
717+
shape[i] - support_shape_offset[i] for i in np.arange(-ndim_supp, 0)
718+
]
713719

714720
if inferred_support_shape is None and dims is not None:
715721
dims = convert_dims(dims)
716722
assert isinstance(dims, tuple)
723+
if len(dims) < ndim_supp:
724+
raise ValueError(f"Number of dims is too small for ndim_supp of {ndim_supp}")
717725
model = modelcontext(None)
718-
inferred_support_shape = at.stack(
719-
[
720-
model.dim_lengths[dims[i]] - support_shape_offset[i] # type: ignore
721-
for i in np.arange(-ndim_supp, 0)
722-
]
723-
)
726+
inferred_support_shape = [
727+
model.dim_lengths[dims[i]] - support_shape_offset[i] # type: ignore
728+
for i in np.arange(-ndim_supp, 0)
729+
]
724730

725731
if inferred_support_shape is None and observed is not None:
726732
observed = convert_observed_data(observed)
727-
inferred_support_shape = at.stack(
728-
[observed.shape[i] - support_shape_offset[i] for i in np.arange(-ndim_supp, 0)]
729-
)
733+
if observed.ndim < ndim_supp:
734+
raise ValueError(
735+
f"Number of observed dimensions is too small for ndim_supp of {ndim_supp}"
736+
)
737+
inferred_support_shape = [
738+
observed.shape[i] - support_shape_offset[i] for i in np.arange(-ndim_supp, 0)
739+
]
730740

731-
if inferred_support_shape is None:
741+
# We did not learn anything
742+
if inferred_support_shape is None and support_shape is None:
743+
return None
744+
# Only source of information was the originally provided support_shape
745+
elif inferred_support_shape is None:
732746
inferred_support_shape = support_shape
733-
# If there are two sources of information for the support shapes, assert they are consistent:
747+
# There were two sources of support_shape, make sure they are consistent
734748
elif support_shape is not None:
735-
inferred_support_shape = at.stack(
736-
[
749+
inferred_support_shape = [
750+
cast(
751+
Variable,
737752
Assert(msg="support_shape does not match respective shape dimension")(
738753
inferred, at.eq(inferred, explicit)
739-
)
740-
for inferred, explicit in zip(inferred_support_shape, support_shape)
741-
]
742-
)
754+
),
755+
)
756+
for inferred, explicit in zip(inferred_support_shape, support_shape)
757+
]
743758

744-
return inferred_support_shape
759+
return at.stack(inferred_support_shape)
745760

746761

747762
def get_support_shape_1d(
@@ -751,21 +766,18 @@ def get_support_shape_1d(
751766
dims: Optional[Dims] = None,
752767
observed: Optional[Any] = None,
753768
support_shape_offset: int = 0,
754-
):
769+
) -> Optional[TensorVariable]:
755770
"""Helper function for cases when you just care about one dimension."""
756-
if support_shape is not None:
757-
support_shape_tuple = (support_shape,)
758-
else:
759-
support_shape_tuple = None
760-
761771
support_shape_tuple = get_support_shape(
762-
support_shape_tuple,
772+
support_shape=(support_shape,) if support_shape is not None else None,
763773
shape=shape,
764774
dims=dims,
765775
observed=observed,
766776
support_shape_offset=(support_shape_offset,),
767777
)
768-
if support_shape_tuple is not None:
769-
(support_shape,) = support_shape_tuple
770778

771-
return support_shape
779+
if support_shape_tuple is not None:
780+
(support_shape_,) = support_shape_tuple
781+
return support_shape_
782+
else:
783+
return None

pymc/tests/distributions/test_multivariate.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1466,7 +1466,13 @@ def test_zsn_shape(self, zerosum_axes):
14661466
@pytest.mark.parametrize(
14671467
"error, match, shape, support_shape, zerosum_axes",
14681468
[
1469-
(IndexError, "index out of range", (3, 4, 5), None, 4),
1469+
(
1470+
ValueError,
1471+
"Number of shape dimensions is too small for ndim_supp of 4",
1472+
(3, 4, 5),
1473+
None,
1474+
4,
1475+
),
14701476
(AssertionError, "does not match", (3, 4), (3,), None), # support_shape should be 4
14711477
(
14721478
AssertionError,

0 commit comments

Comments
 (0)