Skip to content

Commit 6f61f1c

Browse files
committed
Test steps stats_dtypes
Related to #5883
1 parent 09d9d73 commit 6f61f1c

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

pymc/tests/test_step.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,15 @@ def check_stat(self, check, idata, name):
8686
s = stat(group[var].sel(chain=0), axis=0)
8787
close_to(s, value, bound, name)
8888

89+
def check_stat_dtype(self, step, idata):
90+
# TODO: This check does not confirm the announced dtypes are correct as the
91+
# sampling machinery will convert them automatically.
92+
for stats_dtypes in getattr(step, "stats_dtypes", []):
93+
for stat, dtype in stats_dtypes.items():
94+
if stat == "tune":
95+
continue
96+
assert idata.sample_stats[stat].dtype == np.dtype(dtype)
97+
8998
@pytest.mark.parametrize(
9099
"step_fn, draws",
91100
[
@@ -139,6 +148,7 @@ def test_step_continuous(self, step_fn, draws):
139148
random_seed=1,
140149
)
141150
self.check_stat(check, idata, step.__class__.__name__)
151+
self.check_stat_dtype(idata, step)
142152

143153
def test_step_discrete(self):
144154
start, model, (mu, C) = mv_simple_discrete()
@@ -156,6 +166,7 @@ def test_step_discrete(self):
156166
random_seed=1,
157167
)
158168
self.check_stat(check, idata, step.__class__.__name__)
169+
self.check_stat_dtype(idata, step)
159170

160171
@pytest.mark.parametrize("proposal", ["uniform", "proportional"])
161172
def test_step_categorical(self, proposal):
@@ -174,6 +185,7 @@ def test_step_categorical(self, proposal):
174185
random_seed=1,
175186
)
176187
self.check_stat(check, idata, step.__class__.__name__)
188+
self.check_stat_dtype(idata, step)
177189

178190

179191
class TestCompoundStep:

0 commit comments

Comments
 (0)