Skip to content

Commit 4626712

Browse files
committed
fix extract_Q_estimate function
1 parent b903e57 commit 4626712

File tree

2 files changed

+8
-14
lines changed

2 files changed

+8
-14
lines changed

pymc/step_methods/mlda.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -946,28 +946,22 @@ def extract_Q_estimate(trace, levels):
946946
MLDA with variance reduction has been used for sampling.
947947
"""
948948

949-
Q_0_raw = trace.get_sampler_stats("Q_0")
950-
# total number of base level samples from all iterations
951-
total_base_level_samples = sum(it.shape[0] for it in Q_0_raw)
952-
Q_0 = np.concatenate(Q_0_raw).reshape((1, total_base_level_samples))
949+
Q_0_raw = trace.get_sampler_stats("Q_0").squeeze()
950+
Q_0 = np.concatenate(Q_0_raw)[None, ::]
953951
ess_Q_0 = az.ess(np.array(Q_0, np.float64))
954952
Q_0_var = Q_0.var() / ess_Q_0
955953

956954
Q_diff_means = []
957955
Q_diff_vars = []
958956
for l in range(1, levels):
959-
Q_diff_raw = trace.get_sampler_stats(f"Q_{l}_{l-1}")
960-
# total number of samples from all iterations
961-
total_level_samples = sum(it.shape[0] for it in Q_diff_raw)
962-
Q_diff = np.concatenate(Q_diff_raw).reshape((1, total_level_samples))
957+
Q_diff_raw = trace.get_sampler_stats(f"Q_{l}_{l-1}").squeeze()
958+
Q_diff = np.hstack(Q_diff_raw)[None, ::]
963959
ess_diff = az.ess(np.array(Q_diff, np.float64))
964-
965960
Q_diff_means.append(Q_diff.mean())
966961
Q_diff_vars.append(Q_diff.var() / ess_diff)
967962

968963
Q_mean = Q_0.mean() + sum(Q_diff_means)
969964
Q_se = np.sqrt(Q_0_var + sum(Q_diff_vars))
970-
971965
return Q_mean, Q_se
972966

973967

pymc/tests/test_step.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,13 +1750,13 @@ def perform(self, node, inputs, outputs):
17501750
Q_mean_vr, Q_se_vr = extract_Q_estimate(trace, 3)
17511751

17521752
# check that returned values are floats and finite.
1753-
assert isinstance(Q_mean_standard, float)
1753+
assert isinstance(Q_mean_standard, np.floating)
17541754
assert np.isfinite(Q_mean_standard)
1755-
assert isinstance(Q_mean_vr, float)
1755+
assert isinstance(Q_mean_vr, np.floating)
17561756
assert np.isfinite(Q_mean_vr)
1757-
assert isinstance(Q_se_standard, float)
1757+
assert isinstance(Q_se_standard, np.floating)
17581758
assert np.isfinite(Q_se_standard)
1759-
assert isinstance(Q_se_vr, float)
1759+
assert isinstance(Q_se_vr, np.floating)
17601760
assert np.isfinite(Q_se_vr)
17611761

17621762
# check consistency of QoI across levels.

0 commit comments

Comments
 (0)