Skip to content

Commit 43c5a8e

Browse files
committed
Remove duplicate call to collect_default_updates
1 parent a197b19 commit 43c5a8e

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

pymc/distributions/distribution.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,8 @@ def rv_op(
930930
):
931931
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
932932
dummy_params = [dummy_size_param, *dummy_dist_params]
933+
# RNGs are not passed as explicit inputs (because we usually don't know how many are needed)
934+
# We retrieve them here. This will also raise if the user forgot to specify some update in a Scan Op
933935
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
934936

935937
rv_type = type(
@@ -1001,12 +1003,7 @@ def change_custom_dist_size(op, rv, new_size, expand):
10011003

10021004
return new_rv
10031005

1004-
# RNGs are not passed as explicit inputs (because we usually don't know how many are needed)
1005-
# We retrieve them here
1006-
updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
1007-
rngs = updates_dict.keys()
1008-
rngs_updates = updates_dict.values()
1009-
1006+
rngs, rngs_updates = zip(*dummy_updates_dict.items())
10101007
inputs = [*dummy_params, *rngs]
10111008
outputs = [dummy_rv, *rngs_updates]
10121009
signature = cls._infer_final_signature(

0 commit comments

Comments
 (0)