18
18
@author: johnsalvatier
19
19
"""
20
20
21
+ import warnings
22
+
21
23
from abc import ABC , abstractmethod
22
24
from enum import IntEnum , unique
23
- from typing import Any , Dict , List , Mapping , Sequence , Tuple , Union
25
+ from typing import Any , Dict , Iterable , List , Mapping , Sequence , Tuple , Union
24
26
25
27
import numpy as np
26
28
27
29
from pytensor .graph .basic import Variable
28
30
29
- from pymc .blocking import PointType , StatsDict , StatsType
31
+ from pymc .blocking import PointType , StatDtype , StatsDict , StatShape , StatsType
30
32
from pymc .model import modelcontext
31
33
32
34
__all__ = ("Competence" , "CompoundStep" )
@@ -48,9 +50,61 @@ class Competence(IntEnum):
48
50
IDEAL = 3
49
51
50
52
53
+ def infer_warn_stats_info (
54
+ stats_dtypes : List [Dict [str , StatDtype ]],
55
+ sds : Dict [str , Tuple [StatDtype , StatShape ]],
56
+ stepname : str ,
57
+ ) -> Tuple [List [Dict [str , StatDtype ]], Dict [str , Tuple [StatDtype , StatShape ]]]:
58
+ """Helper function to get `stats_dtypes` and `stats_dtypes_shapes` from either of them."""
59
+ # Avoid side-effects on the original lists/dicts
60
+ stats_dtypes = [d .copy () for d in stats_dtypes ]
61
+ sds = sds .copy ()
62
+ # Disallow specification of both attributes
63
+ if stats_dtypes and sds :
64
+ raise TypeError (
65
+ "Only one of `stats_dtypes_shapes` or `stats_dtypes` must be specified."
66
+ f" `{ stepname } .stats_dtypes` should be removed."
67
+ )
68
+
69
+ # Infer one from the other
70
+ if not sds and stats_dtypes :
71
+ warnings .warn (
72
+ f"`{ stepname } .stats_dtypes` is deprecated."
73
+ " Please update it to specify `stats_dtypes_shapes` instead." ,
74
+ DeprecationWarning ,
75
+ )
76
+ if len (stats_dtypes ) > 1 :
77
+ raise TypeError (
78
+ f"`{ stepname } .stats_dtypes` must be a list containing at most one dict."
79
+ )
80
+ for sd in stats_dtypes :
81
+ for sname , dtype in sd .items ():
82
+ sds [sname ] = (dtype , None )
83
+ elif sds :
84
+ stats_dtypes .append ({sname : dtype for sname , (dtype , _ ) in sds .items ()})
85
+ return stats_dtypes , sds
86
+
87
+
51
88
class BlockedStep (ABC ):
52
89
stats_dtypes : List [Dict [str , type ]] = []
90
+ """A list containing <=1 dictionary that maps stat names to dtypes.
91
+
92
+ This attribute is deprecated.
93
+ Use `stats_dtypes_shapes` instead.
94
+ """
95
+
96
+ stats_dtypes_shapes : Dict [str , Tuple [StatDtype , StatShape ]] = {}
97
+ """Maps stat names to dtypes and shapes.
98
+
99
+ Shapes are interpreted in the following ways:
100
+ - `[]` is a scalar.
101
+ - `[3,]` is a length-3 vector.
102
+ - `[4, None]` is a matrix with 4 rows and a dynamic number of columns.
103
+ - `None` is a sparse stat (i.e. not always present) or a NumPy array with varying `ndim`.
104
+ """
105
+
53
106
vars : List [Variable ] = []
107
+ """Variables that the step method is assigned to."""
54
108
55
109
def __new__ (cls , * args , ** kwargs ):
56
110
blocked = kwargs .get ("blocked" )
@@ -77,12 +131,21 @@ def __new__(cls, *args, **kwargs):
77
131
if len (vars ) == 0 :
78
132
raise ValueError ("No free random variables to sample." )
79
133
134
+ # Auto-fill stats metadata attributes from whichever was given.
135
+ stats_dtypes , stats_dtypes_shapes = infer_warn_stats_info (
136
+ cls .stats_dtypes ,
137
+ cls .stats_dtypes_shapes ,
138
+ cls .__name__ ,
139
+ )
140
+
80
141
if not blocked and len (vars ) > 1 :
81
142
# In this case we create a separate sampler for each var
82
143
# and append them to a CompoundStep
83
144
steps = []
84
145
for var in vars :
85
146
step = super ().__new__ (cls )
147
+ step .stats_dtypes = stats_dtypes
148
+ step .stats_dtypes_shapes = stats_dtypes_shapes
86
149
# If we don't return the instance we have to manually
87
150
# call __init__
88
151
step .__init__ ([var ], * args , ** kwargs )
@@ -93,6 +156,8 @@ def __new__(cls, *args, **kwargs):
93
156
return CompoundStep (steps )
94
157
else :
95
158
step = super ().__new__ (cls )
159
+ step .stats_dtypes = stats_dtypes
160
+ step .stats_dtypes_shapes = stats_dtypes_shapes
96
161
# Hack for creating the class correctly when unpickling.
97
162
step .__newargs = (vars ,) + args , kwargs
98
163
return step
@@ -126,6 +191,20 @@ def stop_tuning(self):
126
191
self .tune = False
127
192
128
193
194
+ def get_stats_dtypes_shapes_from_steps (
195
+ steps : Iterable [BlockedStep ],
196
+ ) -> Dict [str , Tuple [StatDtype , StatShape ]]:
197
+ """Combines stats dtype shape dictionaries from multiple step methods.
198
+
199
+ In the resulting stats dict, each sampler stat is prefixed by `sampler_#__`.
200
+ """
201
+ result = {}
202
+ for s , step in enumerate (steps ):
203
+ for sname , (dtype , shape ) in step .stats_dtypes_shapes .items ():
204
+ result [f"sampler_{ s } __{ sname } " ] = (dtype , shape )
205
+ return result
206
+
207
+
129
208
class CompoundStep :
130
209
"""Step method composed of a list of several other step
131
210
methods applied in sequence."""
@@ -135,6 +214,7 @@ def __init__(self, methods):
135
214
self .stats_dtypes = []
136
215
for method in self .methods :
137
216
self .stats_dtypes .extend (method .stats_dtypes )
217
+ self .stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps (methods )
138
218
self .name = (
139
219
f"Compound[{ ', ' .join (getattr (m , 'name' , 'UNNAMED_STEP' ) for m in self .methods )} ]"
140
220
)
0 commit comments