Skip to content

Commit 99f17c7

Browse files
Introduce stats_dtypes_shape attribute
Closes #6503
1 parent 26f884b commit 99f17c7

File tree

3 files changed

+148
-4
lines changed

3 files changed

+148
-4
lines changed

pymc/blocking.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,18 @@
2020
from __future__ import annotations
2121

2222
from functools import partial
23-
from typing import Any, Callable, Dict, Generic, List, NamedTuple, TypeVar
23+
from typing import (
24+
Any,
25+
Callable,
26+
Dict,
27+
Generic,
28+
List,
29+
NamedTuple,
30+
Optional,
31+
Sequence,
32+
TypeVar,
33+
Union,
34+
)
2435

2536
import numpy as np
2637

@@ -33,6 +44,8 @@
3344
PointType: TypeAlias = Dict[str, np.ndarray]
3445
StatsDict: TypeAlias = Dict[str, Any]
3546
StatsType: TypeAlias = List[StatsDict]
47+
StatDtype: TypeAlias = Union[type, np.dtype]
48+
StatShape: TypeAlias = Optional[Sequence[Optional[int]]]
3649

3750

3851
# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for

pymc/step_methods/compound.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,17 @@
1818
@author: johnsalvatier
1919
"""
2020

21+
import warnings
22+
2123
from abc import ABC, abstractmethod
2224
from enum import IntEnum, unique
23-
from typing import Any, Dict, List, Mapping, Sequence, Tuple, Union
25+
from typing import Any, Dict, Iterable, List, Mapping, Sequence, Tuple, Union
2426

2527
import numpy as np
2628

2729
from pytensor.graph.basic import Variable
2830

29-
from pymc.blocking import PointType, StatsDict, StatsType
31+
from pymc.blocking import PointType, StatDtype, StatsDict, StatShape, StatsType
3032
from pymc.model import modelcontext
3133

3234
__all__ = ("Competence", "CompoundStep")
@@ -48,9 +50,61 @@ class Competence(IntEnum):
4850
IDEAL = 3
4951

5052

53+
def infer_warn_stats_info(
54+
stats_dtypes: List[Dict[str, StatDtype]],
55+
sds: Dict[str, Tuple[StatDtype, StatShape]],
56+
stepname: str,
57+
) -> Tuple[List[Dict[str, StatDtype]], Dict[str, Tuple[StatDtype, StatShape]]]:
58+
"""Helper function to get `stats_dtypes` and `stats_dtypes_shapes` from either of them."""
59+
# Avoid side-effects on the original lists/dicts
60+
stats_dtypes = [d.copy() for d in stats_dtypes]
61+
sds = sds.copy()
62+
# Disallow specification of both attributes
63+
if stats_dtypes and sds:
64+
raise TypeError(
65+
"Only one of `stats_dtypes_shapes` or `stats_dtypes` must be specified."
66+
f" `{stepname}.stats_dtypes` should be removed."
67+
)
68+
69+
# Infer one from the other
70+
if not sds and stats_dtypes:
71+
warnings.warn(
72+
f"`{stepname}.stats_dtypes` is deprecated."
73+
" Please update it to specify `stats_dtypes_shapes` instead.",
74+
DeprecationWarning,
75+
)
76+
if len(stats_dtypes) > 1:
77+
raise TypeError(
78+
f"`{stepname}.stats_dtypes` must be a list containing at most one dict."
79+
)
80+
for sd in stats_dtypes:
81+
for sname, dtype in sd.items():
82+
sds[sname] = (dtype, None)
83+
elif sds:
84+
stats_dtypes.append({sname: dtype for sname, (dtype, _) in sds.items()})
85+
return stats_dtypes, sds
86+
87+
5188
class BlockedStep(ABC):
5289
stats_dtypes: List[Dict[str, type]] = []
90+
"""A list containing <=1 dictionary that maps stat names to dtypes.
91+
92+
This attribute is deprecated.
93+
Use `stats_dtypes_shapes` instead.
94+
"""
95+
96+
stats_dtypes_shapes: Dict[str, Tuple[StatDtype, StatShape]] = {}
97+
"""Maps stat names to dtypes and shapes.
98+
99+
Shapes are interpreted in the following ways:
100+
- `[]` is a scalar.
101+
- `[3,]` is a length-3 vector.
102+
- `[4, None]` is a matrix with 4 rows and a dynamic number of columns.
103+
- `None` is a sparse stat (i.e. not always present) or a NumPy array with varying `ndim`.
104+
"""
105+
53106
vars: List[Variable] = []
107+
"""Variables that the step method is assigned to."""
54108

55109
def __new__(cls, *args, **kwargs):
56110
blocked = kwargs.get("blocked")
@@ -77,12 +131,21 @@ def __new__(cls, *args, **kwargs):
77131
if len(vars) == 0:
78132
raise ValueError("No free random variables to sample.")
79133

134+
# Auto-fill stats metadata attributes from whichever was given.
135+
stats_dtypes, stats_dtypes_shapes = infer_warn_stats_info(
136+
cls.stats_dtypes,
137+
cls.stats_dtypes_shapes,
138+
cls.__name__,
139+
)
140+
80141
if not blocked and len(vars) > 1:
81142
# In this case we create a separate sampler for each var
82143
# and append them to a CompoundStep
83144
steps = []
84145
for var in vars:
85146
step = super().__new__(cls)
147+
step.stats_dtypes = stats_dtypes
148+
step.stats_dtypes_shapes = stats_dtypes_shapes
86149
# If we don't return the instance we have to manually
87150
# call __init__
88151
step.__init__([var], *args, **kwargs)
@@ -93,6 +156,8 @@ def __new__(cls, *args, **kwargs):
93156
return CompoundStep(steps)
94157
else:
95158
step = super().__new__(cls)
159+
step.stats_dtypes = stats_dtypes
160+
step.stats_dtypes_shapes = stats_dtypes_shapes
96161
# Hack for creating the class correctly when unpickling.
97162
step.__newargs = (vars,) + args, kwargs
98163
return step
@@ -126,6 +191,20 @@ def stop_tuning(self):
126191
self.tune = False
127192

128193

194+
def get_stats_dtypes_shapes_from_steps(
195+
steps: Iterable[BlockedStep],
196+
) -> Dict[str, Tuple[StatDtype, StatShape]]:
197+
"""Combines stats dtype shape dictionaries from multiple step methods.
198+
199+
In the resulting stats dict, each sampler stat is prefixed by `sampler_#__`.
200+
"""
201+
result = {}
202+
for s, step in enumerate(steps):
203+
for sname, (dtype, shape) in step.stats_dtypes_shapes.items():
204+
result[f"sampler_{s}__{sname}"] = (dtype, shape)
205+
return result
206+
207+
129208
class CompoundStep:
130209
"""Step method composed of a list of several other step
131210
methods applied in sequence."""
@@ -135,6 +214,7 @@ def __init__(self, methods):
135214
self.stats_dtypes = []
136215
for method in self.methods:
137216
self.stats_dtypes.extend(method.stats_dtypes)
217+
self.stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps(methods)
138218
self.name = (
139219
f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]"
140220
)

pymc/tests/step_methods/test_compound.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525
Metropolis,
2626
Slice,
2727
)
28-
from pymc.step_methods.compound import StatsBijection, flatten_steps
28+
from pymc.step_methods.compound import (
29+
StatsBijection,
30+
flatten_steps,
31+
get_stats_dtypes_shapes_from_steps,
32+
infer_warn_stats_info,
33+
)
2934
from pymc.tests.helpers import StepMethodTester, fast_unstable_sampling_mode
3035
from pymc.tests.models import simple_2model_continuous
3136

@@ -94,6 +99,52 @@ def test_compound_step(self):
9499
assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(step.vars)
95100

96101

102+
class TestStatsMetadata:
103+
def test_infer_warn_stats_info(self):
104+
"""
105+
Until `BlockedStep.stats_dtypes` is removed, the new `stats_dtypes_shapes`
106+
attributed is inferred from `stats_dtypes`, or vice versa.
107+
"""
108+
# Infer new
109+
with pytest.warns(DeprecationWarning, match="to specify"):
110+
old, new = infer_warn_stats_info([{"a": int, "b": object}], {}, "bla")
111+
assert isinstance(old, list)
112+
assert len(old) == 1
113+
assert old[0] == {"a": int, "b": object}
114+
assert isinstance(new, dict)
115+
assert new["a"] == (int, None)
116+
assert new["b"] == (object, None)
117+
118+
# Infer old
119+
old, new = infer_warn_stats_info([], {"a": (int, []), "b": (float, [2])}, "bla")
120+
assert isinstance(old, list)
121+
assert len(old) == 1
122+
assert old[0] == {"a": int, "b": float}
123+
assert isinstance(new, dict)
124+
assert new["a"] == (int, [])
125+
assert new["b"] == (float, [2])
126+
127+
# Disallow specifying both (single source of truth problem)
128+
with pytest.raises(TypeError, match="Only one of"):
129+
infer_warn_stats_info([{"a": float}], {"b": (int, [])}, "bla")
130+
131+
def test_stats_from_steps(self):
132+
with pm.Model():
133+
s1 = pm.NUTS(pm.Normal("n"))
134+
s2 = pm.Metropolis(pm.Bernoulli("b", 0.5))
135+
cs = pm.CompoundStep([s1, s2])
136+
# Make sure that sampler initialization does not modify the
137+
# class-level default values of the attributes.
138+
assert pm.NUTS.stats_dtypes_shapes == {}
139+
assert pm.Metropolis.stats_dtypes_shapes == {}
140+
141+
sds = get_stats_dtypes_shapes_from_steps([s1, s2])
142+
assert "sampler_0__step_size" in sds
143+
assert "sampler_1__accepted" in sds
144+
assert len(cs.stats_dtypes) == 2
145+
assert cs.stats_dtypes_shapes == sds
146+
147+
97148
class TestStatsBijection:
98149
def test_flatten_steps(self):
99150
with pm.Model():

0 commit comments

Comments
 (0)