Skip to content

Commit 11f08db

Browse files
committed
Group compile_pymc tests in own class
1 parent c434469 commit 11f08db

File tree

1 file changed

+85
-88
lines changed

1 file changed

+85
-88
lines changed

pymc/tests/test_aesaraf.py

Lines changed: 85 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -596,91 +596,88 @@ def test_rvs_to_value_vars_nested():
596596
assert equal_computations(before, after)
597597

598598

599-
def test_check_bounds_flag():
600-
"""Test that CheckParameterValue Ops are replaced or removed when using compile_pymc"""
601-
logp = at.ones(3)
602-
cond = np.array([1, 0, 1])
603-
bound = check_parameters(logp, cond)
604-
605-
with pm.Model() as m:
606-
pass
607-
608-
with pytest.raises(ParameterValueError):
609-
aesara.function([], bound)()
610-
611-
m.check_bounds = False
612-
with m:
613-
assert np.all(compile_pymc([], bound)() == 1)
614-
615-
m.check_bounds = True
616-
with m:
617-
assert np.all(compile_pymc([], bound)() == -np.inf)
618-
619-
620-
def test_compile_pymc_sets_rng_updates():
621-
rng = aesara.shared(np.random.default_rng(0))
622-
x = pm.Normal.dist(rng=rng)
623-
assert x.owner.inputs[0] is rng
624-
f = compile_pymc([], x)
625-
assert not np.isclose(f(), f())
626-
627-
# Check that update was not done inplace
628-
assert not hasattr(rng, "default_update")
629-
f = aesara.function([], x)
630-
assert f() == f()
631-
632-
633-
def test_compile_pymc_with_updates():
634-
x = aesara.shared(0)
635-
f = compile_pymc([], x, updates={x: x + 1})
636-
assert f() == 0
637-
assert f() == 1
638-
639-
640-
def test_compile_pymc_missing_default_explicit_updates():
641-
rng = aesara.shared(np.random.default_rng(0))
642-
x = pm.Normal.dist(rng=rng)
643-
644-
# By default, compile_pymc should update the rng of x
645-
f = compile_pymc([], x)
646-
assert f() != f()
647-
648-
# An explicit update should override the default_update, like aesara.function does
649-
# For testing purposes, we use an update that leaves the rng unchanged
650-
f = compile_pymc([], x, updates={rng: rng})
651-
assert f() == f()
652-
653-
# If we specify a custom default_update directly it should use that instead.
654-
rng.default_update = rng
655-
f = compile_pymc([], x)
656-
assert f() == f()
657-
658-
# And again, it should be overridden by an explicit update
659-
f = compile_pymc([], x, updates={rng: x.owner.outputs[0]})
660-
assert f() != f()
661-
662-
663-
def test_compile_pymc_updates_inputs():
664-
"""Test that compile_pymc does not include rngs updates of variables that are inputs
665-
or ancestors to inputs
666-
"""
667-
x = at.random.normal()
668-
y = at.random.normal(x)
669-
z = at.random.normal(y)
670-
671-
for inputs, rvs_in_graph in (
672-
([], 3),
673-
([x], 2),
674-
([y], 1),
675-
([z], 0),
676-
([x, y], 1),
677-
([x, y, z], 0),
678-
):
679-
fn = compile_pymc(inputs, z, on_unused_input="ignore")
680-
fn_fgraph = fn.maker.fgraph
681-
# Each RV adds a shared input for its rng
682-
assert len(fn_fgraph.inputs) == len(inputs) + rvs_in_graph
683-
# If the output is an input, the graph has a DeepCopyOp
684-
assert len(fn_fgraph.apply_nodes) == max(rvs_in_graph, 1)
685-
# Each RV adds a shared output for its rng
686-
assert len(fn_fgraph.outputs) == 1 + rvs_in_graph
599+
class TestCompilePyMC:
600+
def test_check_bounds_flag(self):
601+
"""Test that CheckParameterValue Ops are replaced or removed when using compile_pymc"""
602+
logp = at.ones(3)
603+
cond = np.array([1, 0, 1])
604+
bound = check_parameters(logp, cond)
605+
606+
with pm.Model() as m:
607+
pass
608+
609+
with pytest.raises(ParameterValueError):
610+
aesara.function([], bound)()
611+
612+
m.check_bounds = False
613+
with m:
614+
assert np.all(compile_pymc([], bound)() == 1)
615+
616+
m.check_bounds = True
617+
with m:
618+
assert np.all(compile_pymc([], bound)() == -np.inf)
619+
620+
def test_compile_pymc_sets_rng_updates(self):
621+
rng = aesara.shared(np.random.default_rng(0))
622+
x = pm.Normal.dist(rng=rng)
623+
assert x.owner.inputs[0] is rng
624+
f = compile_pymc([], x)
625+
assert not np.isclose(f(), f())
626+
627+
# Check that update was not done inplace
628+
assert not hasattr(rng, "default_update")
629+
f = aesara.function([], x)
630+
assert f() == f()
631+
632+
def test_compile_pymc_with_updates(self):
633+
x = aesara.shared(0)
634+
f = compile_pymc([], x, updates={x: x + 1})
635+
assert f() == 0
636+
assert f() == 1
637+
638+
def test_compile_pymc_missing_default_explicit_updates(self):
639+
rng = aesara.shared(np.random.default_rng(0))
640+
x = pm.Normal.dist(rng=rng)
641+
642+
# By default, compile_pymc should update the rng of x
643+
f = compile_pymc([], x)
644+
assert f() != f()
645+
646+
# An explicit update should override the default_update, like aesara.function does
647+
# For testing purposes, we use an update that leaves the rng unchanged
648+
f = compile_pymc([], x, updates={rng: rng})
649+
assert f() == f()
650+
651+
# If we specify a custom default_update directly it should use that instead.
652+
rng.default_update = rng
653+
f = compile_pymc([], x)
654+
assert f() == f()
655+
656+
# And again, it should be overridden by an explicit update
657+
f = compile_pymc([], x, updates={rng: x.owner.outputs[0]})
658+
assert f() != f()
659+
660+
def test_compile_pymc_updates_inputs(self):
661+
"""Test that compile_pymc does not include rngs updates of variables that are inputs
662+
or ancestors to inputs
663+
"""
664+
x = at.random.normal()
665+
y = at.random.normal(x)
666+
z = at.random.normal(y)
667+
668+
for inputs, rvs_in_graph in (
669+
([], 3),
670+
([x], 2),
671+
([y], 1),
672+
([z], 0),
673+
([x, y], 1),
674+
([x, y, z], 0),
675+
):
676+
fn = compile_pymc(inputs, z, on_unused_input="ignore")
677+
fn_fgraph = fn.maker.fgraph
678+
# Each RV adds a shared input for its rng
679+
assert len(fn_fgraph.inputs) == len(inputs) + rvs_in_graph
680+
# If the output is an input, the graph has a DeepCopyOp
681+
assert len(fn_fgraph.apply_nodes) == max(rvs_in_graph, 1)
682+
# Each RV adds a shared output for its rng
683+
assert len(fn_fgraph.outputs) == 1 + rvs_in_graph

0 commit comments

Comments
 (0)