Skip to content

Commit a4ea9fc

Browse files
committed
compile_pymc handles default updates for duplicated nodes
Fixes bug in VI with multiple Minibatch variables, which occurred due to separate calls to model.logp (from model.datalogp and model.varlogp) that create distinct clones of the RandomIntegersRV underlying minibatch slicing. `compile_pymc` would not set any updates in this case
1 parent 8fe3833 commit a4ea9fc

File tree

3 files changed

+77
-29
lines changed

3 files changed

+77
-29
lines changed

pymc/pytensorf.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
Constant,
3333
Variable,
3434
clone_get_equiv,
35+
equal_computations,
3536
graph_inputs,
3637
walk,
3738
)
@@ -872,8 +873,23 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
872873
return rng
873874

874875
if len(rng_clients) > 1:
876+
# Multiple clients are techincally fine if they are used in identical operations
877+
# We check if the default_update of each client would be the same
878+
update, *other_updates = (
879+
find_default_update(
880+
# Pass version of clients that includes only one the RNG clients at a time
881+
clients | {rng: [rng_client]},
882+
rng,
883+
)
884+
for rng_client in rng_clients
885+
)
886+
if all(equal_computations([update], [other_update]) for other_update in other_updates):
887+
return update
888+
875889
warnings.warn(
876-
f"RNG Variable {rng} has multiple clients. This is likely an inconsistent random graph.",
890+
f"RNG Variable {rng} has multiple distinct clients {rng_clients}, "
891+
f"likely due to an inconsistent random graph. "
892+
f"No default update will be returned.",
877893
UserWarning,
878894
)
879895
return None

tests/test_pytensorf.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import warnings
1514

1615
import numpy as np
1716
import numpy.ma as ma
@@ -486,34 +485,45 @@ def test_random_seed(self):
486485
assert x3_eval == x2_eval
487486
assert y3_eval == y2_eval
488487

488+
@pytest.mark.filterwarnings("error") # This is part of the test
489489
def test_multiple_updates_same_variable(self):
490-
# Raise if unexpected warning is issued
491-
with warnings.catch_warnings():
492-
warnings.simplefilter("error")
493-
494-
rng = pytensor.shared(np.random.default_rng(), name="rng")
495-
x = pt.random.normal(rng=rng)
496-
y = pt.random.normal(rng=rng)
497-
498-
# No warnings if only one variable is used
499-
assert compile_pymc([], [x])
500-
assert compile_pymc([], [y])
501-
502-
user_warn_msg = "RNG Variable rng has multiple clients"
503-
with pytest.warns(UserWarning, match=user_warn_msg):
504-
f = compile_pymc([], [x, y], random_seed=456)
505-
assert f() == f()
506-
507-
# The user can provide an explicit update, but we will still issue a warning
508-
with pytest.warns(UserWarning, match=user_warn_msg):
509-
f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456)
510-
assert f() != f()
511-
512-
# Same with default update
513-
rng.default_update = x.owner.outputs[0]
514-
with pytest.warns(UserWarning, match=user_warn_msg):
515-
f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456)
516-
assert f() != f()
490+
rng = pytensor.shared(np.random.default_rng(), name="rng")
491+
x = pt.random.normal(0, rng=rng)
492+
y = pt.random.normal(1, rng=rng)
493+
494+
# No warnings if only one variable is used
495+
assert compile_pymc([], [x])
496+
assert compile_pymc([], [y])
497+
498+
user_warn_msg = "RNG Variable rng has multiple distinct clients"
499+
with pytest.warns(UserWarning, match=user_warn_msg):
500+
f = compile_pymc([], [x, y], random_seed=456)
501+
assert f() == f()
502+
503+
# The user can provide an explicit update, but we will still issue a warning
504+
with pytest.warns(UserWarning, match=user_warn_msg):
505+
f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456)
506+
assert f() != f()
507+
508+
# Same with default update
509+
rng.default_update = x.owner.outputs[0]
510+
with pytest.warns(UserWarning, match=user_warn_msg):
511+
f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456)
512+
assert f() != f()
513+
514+
@pytest.mark.filterwarnings("error") # This is part of the test
515+
def test_duplicated_client_nodes(self):
516+
"""Test compile_pymc can handle duplicated (mergeable) RV updates."""
517+
rng = pytensor.shared(np.random.default_rng(1))
518+
x = pt.random.normal(rng=rng)
519+
y = x.owner.clone().default_output()
520+
521+
fn = compile_pymc([], [x, y], random_seed=1)
522+
res_x1, res_y1 = fn()
523+
assert res_x1 == res_y1
524+
res_x2, res_y2 = fn()
525+
assert res_x2 == res_y2
526+
assert res_x1 != res_x2
517527

518528
def test_nested_updates(self):
519529
rng = pytensor.shared(np.random.default_rng())

tests/variational/test_inference.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,3 +472,25 @@ def test_fit_data_coords(hierarchical_model, hierarchical_model_data):
472472
hierarchical_model_data["group_coords"].keys()
473473
)
474474
assert data["mu"].shape == tuple()
475+
476+
477+
def test_multiple_minibatch_variables():
478+
"""Regression test for bug reported in
479+
https://discourse.pymc.io/t/verifying-that-minibatch-is-actually-randomly-sampling/14308
480+
"""
481+
true_weights = np.array([-5, 5] * 5)
482+
feature = np.repeat(np.eye(10), 10_000, axis=0)
483+
y = feature @ true_weights
484+
485+
with pm.Model() as model:
486+
minibatch_feature, minibatch_y = pm.Minibatch(feature, y, batch_size=1)
487+
weights = pm.Normal("weights", 0, 10, shape=10)
488+
pm.Normal(
489+
"y",
490+
mu=minibatch_feature @ weights,
491+
sigma=0.01,
492+
observed=minibatch_y,
493+
total_size=len(y),
494+
)
495+
mean_field = pm.fit(10_000, obj_optimizer=pm.adam(learning_rate=0.01), progressbar=False)
496+
np.testing.assert_allclose(mean_field.mean.get_value(), true_weights, rtol=1e-1)

0 commit comments

Comments
 (0)