Skip to content

Commit 4a381f1

Browse files
committed
Use temporary subclasses to test Interpolated distribution
1 parent c520245 commit 4a381f1

File tree

2 files changed

+39
-27
lines changed

2 files changed

+39
-27
lines changed

pymc3/tests/test_distributions.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -793,14 +793,29 @@ def test_multidimensional_beta_construction(self):
793793
with Model():
794794
Beta('beta', alpha=1., beta=1., shape=(10, 20))
795795

796-
@pytest.mark.skip('Interpolated logp testing is not implemented yet')
797796
def test_interpolated(self):
798797
for mu in R.vals:
799798
for sd in Rplus.vals:
800-
with Model() as model:
801-
x_points = np.linspace(mu - 5 * sd, mu + 5 * sd, 100)
802-
pdf_points = sp.norm.pdf(x_points, loc=mu, scale=sd)
803-
dist = Interpolated('dist', x_points=x_points, pdf_points=pdf_points)
804-
# TODO evalute logp at x_points somehow
805-
# pdf_output_points = <evaluated logp>
806-
# assert_almost_equal(pdf_output_points, pdf_points, decimal=3)
799+
#pylint: disable=cell-var-from-loop
800+
xmin = mu - 5 * sd
801+
xmax = mu + 5 * sd
802+
803+
class TestedInterpolated (Interpolated):
804+
805+
def __init__(self, **kwargs):
806+
x_points = np.linspace(xmin, xmax, 100000)
807+
pdf_points = sp.norm.pdf(x_points, loc=mu, scale=sd)
808+
super(TestedInterpolated, self).__init__(
809+
x_points=x_points,
810+
pdf_points=pdf_points,
811+
**kwargs
812+
)
813+
814+
def ref_pdf(value):
815+
return np.where(
816+
np.logical_and(value >= xmin, value <= xmax),
817+
sp.norm.logpdf(value, mu, sd),
818+
-np.inf * np.ones(value.shape)
819+
)
820+
821+
self.pymc3_matches_scipy(TestedInterpolated, R, {}, ref_pdf)

pymc3/tests/test_distributions_random.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -578,27 +578,24 @@ def ref_rand(size, mu, beta):
578578
pymc3_random(pm.Gumbel, {'mu': R, 'beta': Rplus}, ref_rand=ref_rand)
579579

580580
def test_interpolated(self):
581-
alpha = 0.05
582-
size = 10000
583-
tries = 10
584-
585-
# Interpolated model doesn't support variables as inputs, so it is necessary
586-
# to use custom code instead of pymc3_random for it
587581
for mu in R.vals:
588582
for sd in Rplus.vals:
589-
with pm.Model():
590-
x_points = np.linspace(mu - 5 * sd, mu + 5 * sd, 100)
591-
pdf_points = st.norm.pdf(x_points, loc=mu, scale=sd)
592-
dist = pm.Interpolated('dist', x_points=x_points, pdf_points=pdf_points)
593-
594-
p = alpha
595-
f = tries
596-
while p <= alpha and f > 0:
597-
s0 = dist.random(size=size)
598-
s1 = st.norm.rvs(size=size, loc=mu, scale=sd)
599-
_, p = st.ks_2samp(s0, s1)
600-
f -= 1
601-
assert p > alpha, 'Failed KS test for mu = %s, sd = %s' % (mu, sd)
583+
#pylint: disable=cell-var-from-loop
584+
def ref_rand(size):
585+
return st.norm.rvs(loc=mu, scale=sd, size=size)
586+
587+
class TestedInterpolated (pm.Interpolated):
588+
589+
def __init__(self, **kwargs):
590+
x_points = np.linspace(mu - 5 * sd, mu + 5 * sd, 100)
591+
pdf_points = st.norm.pdf(x_points, loc=mu, scale=sd)
592+
super(TestedInterpolated, self).__init__(
593+
x_points=x_points,
594+
pdf_points=pdf_points,
595+
**kwargs
596+
)
597+
598+
pymc3_random(TestedInterpolated, {}, ref_rand=ref_rand)
602599

603600
@pytest.mark.skip('Wishart random sampling not implemented.\n'
604601
'See https://github.com/pymc-devs/pymc3/issues/538')

0 commit comments

Comments
 (0)