@@ -86,6 +86,15 @@ def check_stat(self, check, idata, name):
86
86
s = stat (group [var ].sel (chain = 0 ), axis = 0 )
87
87
close_to (s , value , bound , name )
88
88
89
+ def check_stat_dtype (self , step , idata ):
90
+ # TODO: This check does not confirm the announced dtypes are correct as the
91
+ # sampling machinery will convert them automatically.
92
+ for stats_dtypes in getattr (step , "stats_dtypes" , []):
93
+ for stat , dtype in stats_dtypes .items ():
94
+ if stat == "tune" :
95
+ continue
96
+ assert idata .sample_stats [stat ].dtype == np .dtype (dtype )
97
+
89
98
@pytest .mark .parametrize (
90
99
"step_fn, draws" ,
91
100
[
@@ -139,6 +148,7 @@ def test_step_continuous(self, step_fn, draws):
139
148
random_seed = 1 ,
140
149
)
141
150
self .check_stat (check , idata , step .__class__ .__name__ )
151
+ self .check_stat_dtype (idata , step )
142
152
143
153
def test_step_discrete (self ):
144
154
start , model , (mu , C ) = mv_simple_discrete ()
@@ -156,6 +166,7 @@ def test_step_discrete(self):
156
166
random_seed = 1 ,
157
167
)
158
168
self .check_stat (check , idata , step .__class__ .__name__ )
169
+ self .check_stat_dtype (idata , step )
159
170
160
171
@pytest .mark .parametrize ("proposal" , ["uniform" , "proportional" ])
161
172
def test_step_categorical (self , proposal ):
@@ -174,6 +185,7 @@ def test_step_categorical(self, proposal):
174
185
random_seed = 1 ,
175
186
)
176
187
self .check_stat (check , idata , step .__class__ .__name__ )
188
+ self .check_stat_dtype (idata , step )
177
189
178
190
179
191
class TestCompoundStep :
0 commit comments