Skip to content

Commit 3f2a1da

Browse files
committed
Improve collect_default_updates
* It works with nested RNGs * It raises error if RNG used in SymbolicRandomVariable is not given an update * It raises warning if same RNG is used in multiple nodes
1 parent a75af50 commit 3f2a1da

File tree

3 files changed

+181
-51
lines changed

3 files changed

+181
-51
lines changed

pymc/distributions/distribution.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,9 @@ class CustomSymbolicDistRV(SymbolicRandomVariable):
593593

594594
def update(self, node: Node):
595595
op = node.op
596-
inner_updates = collect_default_updates(op.inner_inputs, op.inner_outputs)
596+
inner_updates = collect_default_updates(
597+
op.inner_inputs, op.inner_outputs, must_be_shared=False
598+
)
597599

598600
# Map inner updates to outer inputs/outputs
599601
updates = {}

pymc/pytensorf.py

Lines changed: 72 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
Variable,
4343
clone_get_equiv,
4444
graph_inputs,
45-
vars_between,
4645
walk,
4746
)
4847
from pytensor.graph.fg import FunctionGraph
@@ -51,6 +50,7 @@
5150
from pytensor.tensor.basic import _as_tensor_variable
5251
from pytensor.tensor.elemwise import Elemwise
5352
from pytensor.tensor.random.op import RandomVariable
53+
from pytensor.tensor.random.type import RandomType
5454
from pytensor.tensor.random.var import (
5555
RandomGeneratorSharedVariable,
5656
RandomStateSharedVariable,
@@ -1000,42 +1000,85 @@ def reseed_rngs(
10001000

10011001

10021002
def collect_default_updates(
1003-
inputs: Sequence[Variable], outputs: Sequence[Variable]
1003+
inputs: Sequence[Variable],
1004+
outputs: Sequence[Variable],
1005+
must_be_shared: bool = True,
10041006
) -> Dict[Variable, Variable]:
1005-
"""Collect default update expression of RVs between inputs and outputs"""
1007+
"""Collect default update expression for shared-variable RNGs used by RVs between inputs and outputs.
1008+
1009+
If `must_be_shared` is False, update expressions will also be returned for non-shared input RNGs.
1010+
This can be useful to obtain the symbolic update expressions from inner graphs.
1011+
"""
10061012

10071013
# Avoid circular import
10081014
from pymc.distributions.distribution import SymbolicRandomVariable
10091015

1016+
def find_default_update(clients, rng: Variable) -> Union[None, Variable]:
1017+
rng_clients = clients.get(rng, None)
1018+
1019+
# Root case, RNG is not used elsewhere
1020+
if not rng_clients:
1021+
return rng
1022+
1023+
if len(rng_clients) > 1:
1024+
warnings.warn(
1025+
f"RNG Variable {rng} has multiple clients. This is likely an inconsistent random graph.",
1026+
UserWarning,
1027+
)
1028+
return None
1029+
1030+
[client, _] = rng_clients[0]
1031+
1032+
# RNG is an output of the function, this is not a problem
1033+
if client == "output":
1034+
return rng
1035+
1036+
# RNG is used by another operator, which should output an update for the RNG
1037+
if isinstance(client.op, RandomVariable):
1038+
# RandomVariable first output is always the update of the input RNG
1039+
next_rng = client.outputs[0]
1040+
1041+
elif isinstance(client.op, SymbolicRandomVariable):
1042+
# SymbolicRandomVariable have an explicit method that returns an
1043+
# update mapping for their RNG(s)
1044+
next_rng = client.op.update(client).get(rng)
1045+
if next_rng is None:
1046+
raise ValueError(
1047+
f"No update mapping found for RNG used in SymbolicRandomVariable Op {client.op}"
1048+
)
1049+
else:
1050+
# We don't know how this RNG should be updated (e.g., Scan).
1051+
# The user should provide an update manually
1052+
return None
1053+
1054+
# Recurse until we find final update for RNG
1055+
return find_default_update(clients, next_rng)
1056+
1057+
outputs = makeiter(outputs)
1058+
fg = FunctionGraph(outputs=outputs, clone=False)
1059+
clients = fg.clients
1060+
10101061
rng_updates = {}
1011-
output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
1012-
for random_var in (
1013-
var
1014-
for var in vars_between(inputs, output_to_list)
1015-
if var.owner
1016-
and isinstance(var.owner.op, (RandomVariable, SymbolicRandomVariable))
1017-
and var not in inputs
1062+
# Iterate over input RNGs. Only consider shared RNGs if `must_be_shared==True`
1063+
for input_rng in (
1064+
inp
1065+
for inp in graph_inputs(outputs, blockers=inputs)
1066+
if (
1067+
(not must_be_shared or isinstance(inp, SharedVariable))
1068+
and isinstance(inp.type, RandomType)
1069+
)
10181070
):
1019-
# All nodes in `vars_between(inputs, outputs)` have owners.
1020-
# But mypy doesn't know, so we just assert it:
1021-
assert random_var.owner.op is not None
1022-
if isinstance(random_var.owner.op, RandomVariable):
1023-
rng = random_var.owner.inputs[0]
1024-
if getattr(rng, "default_update", None) is not None:
1025-
update_map = {rng: rng.default_update}
1026-
else:
1027-
update_map = {rng: random_var.owner.outputs[0]}
1071+
# Even if an explicit default update is provided, we call it to
1072+
# issue any warnings about invalid random graphs.
1073+
default_update = find_default_update(clients, input_rng)
1074+
1075+
# Respect default update if provided
1076+
if getattr(input_rng, "default_update", None):
1077+
rng_updates[input_rng] = input_rng.default_update
10281078
else:
1029-
update_map = random_var.owner.op.update(random_var.owner)
1030-
# Check that we are not setting different update expressions for the same variables
1031-
for rng, update in update_map.items():
1032-
if rng not in rng_updates:
1033-
rng_updates[rng] = update
1034-
# When a variable has multiple outputs, it will be called twice with the same
1035-
# update expression. We don't want to raise in that case, only if the update
1036-
# expression in different from the one already registered
1037-
elif rng_updates[rng] is not update:
1038-
raise ValueError(f"Multiple update expressions found for the variable {rng}")
1079+
if default_update is not None:
1080+
rng_updates[input_rng] = default_update
1081+
10391082
return rng_updates
10401083

10411084

tests/test_pytensorf.py

Lines changed: 106 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
15+
1416
from unittest import mock
1517

1618
import numpy as np
@@ -38,6 +40,7 @@
3840
from pymc.exceptions import NotConstantValueError
3941
from pymc.logprob.utils import ParameterValueError
4042
from pymc.pytensorf import (
43+
collect_default_updates,
4144
compile_pymc,
4245
constant_fold,
4346
convert_observed_data,
@@ -406,28 +409,63 @@ def test_compile_pymc_updates_inputs(self):
406409
# Each RV adds a shared output for its rng
407410
assert len(fn_fgraph.outputs) == 1 + rvs_in_graph
408411

409-
# Disable `reseed_rngs` so that we can test with simpler update rule
410-
@mock.patch("pymc.pytensorf.reseed_rngs")
411-
def test_compile_pymc_custom_update_op(self, _):
412-
"""Test that custom MeasurableVariable Op updates are used by compile_pymc"""
412+
def test_compile_pymc_symbolic_rv_update(self):
413+
"""Test that SymbolicRandomVariable Op update methods are used by compile_pymc"""
413414

414415
class NonSymbolicRV(OpFromGraph):
415416
def update(self, node):
416-
return {node.inputs[0]: node.inputs[0] + 1}
417+
return {node.inputs[0]: node.outputs[0]}
417418

418-
dummy_inputs = [pt.scalar(), pt.scalar()]
419-
dummy_outputs = [pt.add(*dummy_inputs)]
420-
dummy_x = NonSymbolicRV(dummy_inputs, dummy_outputs)(pytensor.shared(1.0), 1.0)
419+
rng = pytensor.shared(np.random.default_rng())
420+
dummy_rng = rng.type()
421+
dummy_next_rng, dummy_x = NonSymbolicRV(
422+
[dummy_rng], pt.random.normal(rng=dummy_rng).owner.outputs
423+
)(rng)
421424

422425
# Check that there are no updates at first
423426
fn = compile_pymc(inputs=[], outputs=dummy_x)
424-
assert fn() == fn() == 2.0
427+
assert fn() == fn()
425428

426429
# And they are enabled once the Op is registered as a SymbolicRV
427430
SymbolicRandomVariable.register(NonSymbolicRV)
428-
fn = compile_pymc(inputs=[], outputs=dummy_x)
429-
assert fn() == 2.0
430-
assert fn() == 3.0
431+
fn = compile_pymc(inputs=[], outputs=dummy_x, random_seed=431)
432+
assert fn() != fn()
433+
434+
def test_compile_pymc_symbolic_rv_missing_update(self):
435+
"""Test that error is raised if SymbolicRandomVariable Op does not
436+
provide rule for updating RNG"""
437+
438+
class SymbolicRV(OpFromGraph):
439+
def update(self, node):
440+
# Update is provided for rng1 but not rng2
441+
return {node.inputs[0]: node.outputs[0]}
442+
443+
SymbolicRandomVariable.register(SymbolicRV)
444+
445+
# No problems at first, as the one RNG is given the update rule
446+
rng1 = pytensor.shared(np.random.default_rng())
447+
dummy_rng1 = rng1.type()
448+
dummy_next_rng1, dummy_x1 = SymbolicRV(
449+
[dummy_rng1],
450+
pt.random.normal(rng=dummy_rng1).owner.outputs,
451+
)(rng1)
452+
fn = compile_pymc(inputs=[], outputs=dummy_x1, random_seed=433)
453+
assert fn() != fn()
454+
455+
# Now there's a problem as there is no update rule for rng2
456+
rng2 = pytensor.shared(np.random.default_rng())
457+
dummy_rng2 = rng2.type()
458+
dummy_next_rng1, dummy_x1, dummy_next_rng2, dummy_x2 = SymbolicRV(
459+
[dummy_rng1, dummy_rng2],
460+
[
461+
*pt.random.normal(rng=dummy_rng1).owner.outputs,
462+
*pt.random.normal(rng=dummy_rng2).owner.outputs,
463+
],
464+
)(rng1, rng2)
465+
with pytest.raises(
466+
ValueError, match="No update mapping found for RNG used in SymbolicRandomVariable"
467+
):
468+
compile_pymc(inputs=[], outputs=[dummy_x1, dummy_x2])
431469

432470
def test_random_seed(self):
433471
seedx = pytensor.shared(np.random.default_rng(1))
@@ -457,15 +495,62 @@ def test_random_seed(self):
457495
assert y3_eval == y2_eval
458496

459497
def test_multiple_updates_same_variable(self):
460-
rng = pytensor.shared(np.random.default_rng(), name="rng")
461-
x = pt.random.normal(rng=rng)
462-
y = pt.random.normal(rng=rng)
463-
464-
assert compile_pymc([], [x])
465-
assert compile_pymc([], [y])
466-
msg = "Multiple update expressions found for the variable rng"
467-
with pytest.raises(ValueError, match=msg):
468-
compile_pymc([], [x, y])
498+
# Raise if unexpected warning is issued
499+
with warnings.catch_warnings():
500+
warnings.simplefilter("error")
501+
502+
rng = pytensor.shared(np.random.default_rng(), name="rng")
503+
x = pt.random.normal(rng=rng)
504+
y = pt.random.normal(rng=rng)
505+
506+
# No warnings if only one variable is used
507+
assert compile_pymc([], [x])
508+
assert compile_pymc([], [y])
509+
510+
user_warn_msg = "RNG Variable rng has multiple clients"
511+
with pytest.warns(UserWarning, match=user_warn_msg):
512+
f = compile_pymc([], [x, y], random_seed=456)
513+
assert f() == f()
514+
515+
# The user can provide an explicit update, but we will still issue a warning
516+
with pytest.warns(UserWarning, match=user_warn_msg):
517+
f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456)
518+
assert f() != f()
519+
520+
# Same with default update
521+
rng.default_update = x.owner.outputs[0]
522+
with pytest.warns(UserWarning, match=user_warn_msg):
523+
f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456)
524+
assert f() != f()
525+
526+
def test_nested_updates(self):
527+
rng = pytensor.shared(np.random.default_rng())
528+
next_rng1, x = pt.random.normal(rng=rng).owner.outputs
529+
next_rng2, y = pt.random.normal(rng=next_rng1).owner.outputs
530+
next_rng3, z = pt.random.normal(rng=next_rng2).owner.outputs
531+
532+
collect_default_updates([], [x, y, z]) == {rng: next_rng3}
533+
534+
fn = compile_pymc([], [x, y, z], random_seed=514)
535+
assert not set(list(np.array(fn()))) & set(list(np.array(fn())))
536+
537+
# A local myopic rule (as PyMC used before, would not work properly)
538+
fn = pytensor.function([], [x, y, z], updates={rng: next_rng1})
539+
assert set(list(np.array(fn()))) & set(list(np.array(fn())))
540+
541+
542+
def test_collect_default_updates_must_be_shared():
543+
shared_rng = pytensor.shared(np.random.default_rng())
544+
nonshared_rng = shared_rng.type()
545+
546+
next_rng_of_shared, x = pt.random.normal(rng=shared_rng).owner.outputs
547+
next_rng_of_nonshared, y = pt.random.normal(rng=nonshared_rng).owner.outputs
548+
549+
res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y])
550+
assert res == {shared_rng: next_rng_of_shared}
551+
552+
res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y], must_be_shared=False)
553+
assert res == {shared_rng: next_rng_of_shared, nonshared_rng: next_rng_of_nonshared}
469554

470555

471556
def test_replace_rng_nodes():

0 commit comments

Comments
 (0)