Skip to content

Commit d9aa0da

Browse files
committed
Automatically round proposal values for discrete variables in SMC
1 parent 1818943 commit d9aa0da

File tree

2 files changed

+76
-37
lines changed

2 files changed

+76
-37
lines changed

pymc3/smc/smc.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import warnings
16-
1715
from collections import OrderedDict
1816

1917
import aesara.tensor as at
2018
import numpy as np
2119

22-
from aesara import config
2320
from aesara import function as aesara_function
21+
from aesara.graph.basic import clone_replace
2422
from scipy.special import logsumexp
2523
from scipy.stats import multivariate_normal
2624

@@ -34,6 +32,7 @@
3432
from pymc3.blocking import DictToArrayBijection
3533
from pymc3.model import Point, modelcontext
3634
from pymc3.sampling import sample_prior_predictive
35+
from pymc3.vartypes import discrete_types
3736

3837

3938
class SMC:
@@ -273,9 +272,15 @@ def posterior_to_trace(self):
273272
for i in range(lenght_pos):
274273
value = []
275274
size = 0
276-
for var in varnames:
277-
shape, new_size = self.var_info[var]
278-
value.append(self.posterior[i][size : size + new_size].reshape(shape))
275+
for varname in varnames:
276+
shape, new_size = self.var_info[varname]
277+
var_samples = self.posterior[i][size : size + new_size]
278+
# Round discrete variable samples. The rounded values were the ones
279+
# actually used in the logp evaluations (see logp_forw)
280+
var = self.model[varname]
281+
if var.dtype in discrete_types:
282+
var_samples = np.round(var_samples).astype(var.dtype)
283+
value.append(var_samples.reshape(shape))
279284
size += new_size
280285
strace.record(point={k: v for k, v in zip(varnames, value)})
281286
return strace
@@ -294,20 +299,32 @@ def logp_forw(point, out_vars, vars, shared):
294299
containing :class:`aesara.tensor.Tensor` for depended shared data
295300
"""
296301

302+
# Convert expected input of discrete variables to (rounded) floats
303+
if any(var.dtype in discrete_types for var in vars):
304+
replace_int_to_float = {}
305+
replace_float_to_round = {}
306+
new_vars = []
307+
for var in vars:
308+
if var.dtype in discrete_types:
309+
float_var = at.TensorType("floatX", var.broadcastable)(var.name)
310+
replace_int_to_float[var] = float_var
311+
new_vars.append(float_var)
312+
313+
round_float_var = at.round(float_var)
314+
round_float_var.name = var.name
315+
replace_float_to_round[float_var] = round_float_var
316+
else:
317+
new_vars.append(var)
318+
319+
replace_int_to_float.update(shared)
320+
replace_float_to_round.update(shared)
321+
out_vars = clone_replace(out_vars, replace_int_to_float, strict=False)
322+
out_vars = clone_replace(out_vars, replace_float_to_round)
323+
vars = new_vars
324+
297325
out_list, inarray0 = join_nonshared_inputs(point, out_vars, vars, shared)
298-
# TODO: Figure out how to safely accept float32 (floatX) input when there are
299-
# discrete variables of int64 dtype in `vars`.
300-
# See https://github.com/pymc-devs/pymc3/pull/4769#issuecomment-861494080
301-
if config.floatX == "float32" and any(var.dtype == "int64" for var in vars):
302-
warnings.warn(
303-
"SMC sampling may run slower due to the presence of discrete variables "
304-
"together with aesara.config.floatX == `float32`",
305-
UserWarning,
306-
)
307-
f = aesara_function([inarray0], out_list[0], allow_input_downcast=True)
308-
else:
309-
f = aesara_function([inarray0], out_list[0])
310-
f.trust_input = True
326+
f = aesara_function([inarray0], out_list[0])
327+
f.trust_input = True
311328
return f
312329

313330

pymc3/tests/test_smc.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
import aesara.tensor as at
1717
import numpy as np
1818
import pytest
19+
import scipy.stats as st
1920

2021
from arviz.data.inference_data import InferenceData
2122

2223
import pymc3 as pm
2324

25+
from pymc3.aesaraf import floatX
2426
from pymc3.backends.base import MultiTrace
27+
from pymc3.smc.smc import SMC
2528
from pymc3.tests.helpers import SeededTest
2629

2730

@@ -64,10 +67,6 @@ def two_gaussians(x):
6467
x = pm.Normal("x", 0, 1)
6568
y = pm.Normal("y", x, 1, observed=0)
6669

67-
with pm.Model() as self.slow_model:
68-
x = pm.Normal("x", 0, 1)
69-
y = pm.Normal("y", x, 1, observed=100)
70-
7170
def test_sample(self):
7271
with self.SMC_test:
7372
mtrace = pm.sample_smc(draws=self.samples, return_inferencedata=False)
@@ -76,12 +75,43 @@ def test_sample(self):
7675
mu1d = np.abs(x).mean(axis=0)
7776
np.testing.assert_allclose(self.muref, mu1d, rtol=0.0, atol=0.03)
7877

79-
def test_discrete_continuous(self):
80-
with pm.Model() as model:
81-
a = pm.Poisson("a", 5)
82-
b = pm.HalfNormal("b", 10)
83-
y = pm.Normal("y", a, b, observed=[1, 2, 3, 4])
84-
trace = pm.sample_smc(draws=10)
78+
def test_discrete_rounding_proposal(self):
79+
"""
80+
Test that discrete variable values are automatically rounded
81+
in SMC logp functions
82+
"""
83+
84+
with pm.Model() as m:
85+
z = pm.Bernoulli("z", p=0.7)
86+
like = pm.Potential("like", z * 1.0)
87+
88+
smc = SMC(model=m)
89+
smc.initialize_population()
90+
smc.setup_kernel()
91+
smc.initialize_logp()
92+
93+
assert smc.prior_logp_func(floatX(np.array([-0.51]))) == -np.inf
94+
assert np.isclose(smc.prior_logp_func(floatX(np.array([-0.49]))), np.log(0.3))
95+
assert np.isclose(smc.prior_logp_func(floatX(np.array([0.49]))), np.log(0.3))
96+
assert np.isclose(smc.prior_logp_func(floatX(np.array([0.51]))), np.log(0.7))
97+
assert smc.prior_logp_func(floatX(np.array([1.51]))) == -np.inf
98+
99+
def test_unobserved_discrete(self):
100+
n = 10
101+
rng = self.get_random_state()
102+
103+
z_true = np.zeros(n, dtype=int)
104+
z_true[int(n / 2) :] = 1
105+
y = st.norm(np.array([-1, 1])[z_true], 0.25).rvs(random_state=rng)
106+
107+
with pm.Model() as m:
108+
z = pm.Bernoulli("z", p=0.5, size=n)
109+
mu = pm.math.switch(z, 1.0, -1.0)
110+
like = pm.Normal("like", mu=mu, sigma=0.25, observed=y)
111+
112+
trace = pm.sample_smc(chains=1, return_inferencedata=False)
113+
114+
assert np.all(np.median(trace["z"], axis=0) == z_true)
85115

86116
def test_ml(self):
87117
data = np.repeat([1, 0], [50, 50])
@@ -109,14 +139,6 @@ def test_start(self):
109139
}
110140
trace = pm.sample_smc(500, chains=1, start=start)
111141

112-
def test_slowdown_warning(self):
113-
with aesara.config.change_flags(floatX="float32"):
114-
with pytest.warns(UserWarning, match="SMC sampling may run slower due to"):
115-
with pm.Model() as model:
116-
a = pm.Poisson("a", 5)
117-
y = pm.Normal("y", a, 5, observed=[1, 2, 3, 4])
118-
trace = pm.sample_smc(draws=100, chains=2, cores=1)
119-
120142
@pytest.mark.parametrize("chains", (1, 2))
121143
def test_return_datatype(self, chains):
122144
draws = 10

0 commit comments

Comments
 (0)