Skip to content

Commit 343869a

Browse files
committed
Do not create temporary SymbolicDistribution just to retrieve number of RNGs needed
Reordered methods for consistency
1 parent 21c2e6c commit 343869a

File tree

3 files changed

+38
-38
lines changed

3 files changed

+38
-38
lines changed

pymc/distributions/censored.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ def dist(cls, dist, lower, upper, **kwargs):
7676
check_dist_not_registered(dist)
7777
return super().dist([dist, lower, upper], **kwargs)
7878

79+
@classmethod
80+
def num_rngs(cls, *args, **kwargs):
81+
return 1
82+
83+
@classmethod
84+
def ndim_supp(cls, *dist_params):
85+
return 0
86+
7987
@classmethod
8088
def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
8189

@@ -96,24 +104,12 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
96104
rv_out.tag.upper = upper
97105

98106
if rngs is not None:
99-
rv_out = cls.change_rngs(rv_out, rngs)
107+
rv_out = cls._change_rngs(rv_out, rngs)
100108

101109
return rv_out
102110

103111
@classmethod
104-
def ndim_supp(cls, *dist_params):
105-
return 0
106-
107-
@classmethod
108-
def change_size(cls, rv, new_size, expand=False):
109-
dist = rv.tag.dist
110-
lower = rv.tag.lower
111-
upper = rv.tag.upper
112-
new_dist = change_rv_size(dist, new_size, expand=expand)
113-
return cls.rv_op(new_dist, lower, upper)
114-
115-
@classmethod
116-
def change_rngs(cls, rv, new_rngs):
112+
def _change_rngs(cls, rv, new_rngs):
117113
(new_rng,) = new_rngs
118114
dist_node = rv.tag.dist.owner
119115
lower = rv.tag.lower
@@ -123,8 +119,12 @@ def change_rngs(cls, rv, new_rngs):
123119
return cls.rv_op(new_dist, lower, upper)
124120

125121
@classmethod
126-
def graph_rvs(cls, rv):
127-
return (rv.tag.dist,)
122+
def change_size(cls, rv, new_size, expand=False):
123+
dist = rv.tag.dist
124+
lower = rv.tag.lower
125+
upper = rv.tag.upper
126+
new_dist = change_rv_size(dist, new_size, expand=expand)
127+
return cls.rv_op(new_dist, lower, upper)
128128

129129

130130
@_moment.register(Clip)

pymc/distributions/distribution.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -396,18 +396,19 @@ def __new__(
396396
to a canonical parametrization. It should call `super().dist()`, passing a
397397
list with the default parameters as the first and only non keyword argument,
398398
followed by other keyword arguments like size and rngs, and return the result
399-
cls.rv_op
400-
Returns a TensorVariable that represents the symbolic distribution
401-
parametrized by a default set of parameters and a size and rngs arguments
399+
cls.num_rngs
400+
Returns the number of rngs given the same arguments passed by the user when
401+
calling the distribution
402402
cls.ndim_supp
403-
Returns the support of the symbolic distribution, given the default
403+
Returns the support of the symbolic distribution, given the default set of
404404
parameters. This may not always be constant, for instance if the symbolic
405405
distribution can be defined based on an arbitrary base distribution.
406+
cls.rv_op
407+
Returns a TensorVariable that represents the symbolic distribution
408+
parametrized by a default set of parameters and a size and rngs arguments
406409
cls.change_size
407410
Returns an equivalent symbolic distribution with a different size. This is
408411
analogous to `pymc.aesaraf.change_rv_size` for `RandomVariable`s.
409-
cls.graph_rvs
410-
Returns base RVs in a symbolic distribution.
411412
412413
Parameters
413414
----------
@@ -465,9 +466,9 @@ def __new__(
465466
raise TypeError(f"Name needs to be a string but got: {name}")
466467

467468
if rngs is None:
468-
# Create a temporary rv to obtain number of rngs needed
469-
temp_graph = cls.dist(*args, rngs=None, **kwargs)
470-
rngs = [model.next_rng() for _ in cls.graph_rvs(temp_graph)]
469+
# Instead of passing individual RNG variables we could pass a RandomStream
470+
# and let the classes create as many RNGs as they need
471+
rngs = [model.next_rng() for _ in range(cls.num_rngs(*args, **kwargs))]
471472
elif not isinstance(rngs, (list, tuple)):
472473
rngs = [rngs]
473474

pymc/distributions/mixture.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,18 @@ def dist(cls, w, comp_dists, **kwargs):
205205
w = at.as_tensor_variable(w)
206206
return super().dist([w, *comp_dists], **kwargs)
207207

208+
@classmethod
209+
def num_rngs(cls, w, comp_dists, **kwargs):
210+
if not isinstance(comp_dists, (tuple, list)):
211+
# comp_dists is a single component
212+
comp_dists = [comp_dists]
213+
return len(comp_dists) + 1
214+
215+
@classmethod
216+
def ndim_supp(cls, weights, *components):
217+
# We already checked that all components have the same support dimensionality
218+
return components[0].owner.op.ndim_supp
219+
208220
@classmethod
209221
def rv_op(cls, weights, *components, size=None, rngs=None):
210222
# Update rngs if provided
@@ -329,11 +341,6 @@ def _resize_components(cls, size, *components):
329341

330342
return [change_rv_size(component, size) for component in components]
331343

332-
@classmethod
333-
def ndim_supp(cls, weights, *components):
334-
# We already checked that all components have the same support dimensionality
335-
return components[0].owner.op.ndim_supp
336-
337344
@classmethod
338345
def change_size(cls, rv, new_size, expand=False):
339346
weights = rv.tag.weights
@@ -355,14 +362,6 @@ def change_size(cls, rv, new_size, expand=False):
355362

356363
return cls.rv_op(weights, *components, rngs=rngs, size=None)
357364

358-
@classmethod
359-
def graph_rvs(cls, rv):
360-
# We return rv, which is itself a pseudo RandomVariable, that contains a
361-
# mix_indexes_ RV in its inner graph. We want super().dist() to generate
362-
# (components + 1) rngs for us, and it will do so based on how many elements
363-
# we return here
364-
return (*rv.tag.components, rv)
365-
366365

367366
@_get_measurable_outputs.register(MarginalMixtureRV)
368367
def _get_measurable_outputs_MarginalMixtureRV(op, node):

0 commit comments

Comments
 (0)