diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index d706a0e442..6f48673dc9 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -61,6 +61,7 @@ from pymc.printing import str_for_dist from pymc.pytensorf import ( collect_default_updates, + collect_default_updates_inner_fgraph, constant_fold, convert_observed_data, floatX, @@ -298,16 +299,17 @@ def __init__( raise ValueError("ndim_supp or gufunc_signature must be provided") kwargs.setdefault("inline", True) + kwargs.setdefault("strict", True) super().__init__(*args, **kwargs) - def update(self, node: Node): + def update(self, node: Node) -> dict[Variable, Variable]: """Symbolic update expression for input random state variables Returns a dictionary with the symbolic expressions required for correct updating of random state input variables repeated function evaluations. This is used by `pytensorf.compile_pymc`. """ - return {} + return collect_default_updates_inner_fgraph(node) def batch_ndim(self, node: Node) -> int: """Number of dimensions of the distribution's batch shape.""" @@ -701,24 +703,10 @@ class CustomSymbolicDistRV(SymbolicRandomVariable): symbolic random methods. """ - default_output = -1 + default_output = 0 _print_name = ("CustomSymbolicDist", "\\operatorname{CustomSymbolicDist}") - def update(self, node: Node): - op = node.op - inner_updates = collect_default_updates( - inputs=op.inner_inputs, outputs=op.inner_outputs, must_be_shared=False - ) - - # Map inner updates to outer inputs/outputs - updates = {} - for rng, update in inner_updates.items(): - inp_idx = op.inner_inputs.index(rng) - out_idx = op.inner_outputs.index(update) - updates[node.inputs[inp_idx]] = node.outputs[out_idx] - return updates - @_support_point.register(CustomSymbolicDistRV) def dist_support_point(op, rv, *args): @@ -818,14 +806,17 @@ def rv_op( if logp is not None: @_logprob.register(rv_type) - def custom_dist_logp(op, values, size, *params, **kwargs): - return logp(values[0], *params[: len(dist_params)]) + def custom_dist_logp(op, values, size, *inputs, **kwargs): + [value] = values + rv_params = inputs[: len(dist_params)] + return logp(value, *rv_params) if logcdf is not None: @_logcdf.register(rv_type) - def custom_dist_logcdf(op, value, size, *params, **kwargs): - return logcdf(value, *params[: len(dist_params)]) + def custom_dist_logcdf(op, value, size, *inputs, **kwargs): + rv_params = inputs[: len(dist_params)] + return logcdf(value, *rv_params) if support_point is not None: @@ -858,22 +849,29 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand): dummy_dist_params = [dist_param.type() for dist_param in old_dist_params] dummy_rv = dist(*dummy_dist_params, dummy_size_param) dummy_params = [dummy_size_param, *dummy_dist_params] - dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) + updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) + rngs = updates_dict.keys() + rngs_updates = updates_dict.values() new_rv_op = rv_type( - inputs=dummy_params, - outputs=[*dummy_updates_dict.values(), dummy_rv], + inputs=[*dummy_params, *rngs], + outputs=[dummy_rv, *rngs_updates], signature=signature, ) - new_rv = new_rv_op(new_size, *dist_params) + new_rv = new_rv_op(new_size, *dist_params, *rngs) return new_rv + # RNGs are not passed as explicit inputs (because we usually don't know how many are needed) + # We retrieve them here + updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) + rngs = updates_dict.keys() + rngs_updates = updates_dict.values() rv_op = rv_type( - inputs=dummy_params, - outputs=[*dummy_updates_dict.values(), dummy_rv], + inputs=[*dummy_params, *rngs], + outputs=[dummy_rv, *rngs_updates], signature=signature, ) - return rv_op(size, *dist_params) + return rv_op(size, *dist_params, *rngs) @staticmethod def _infer_final_signature(signature: str, n_inputs, n_updates) -> str: diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 1412d3e446..4103fa556f 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -436,7 +436,6 @@ def __init__(self, *args, ar_order, constant_term, **kwargs): def update(self, node: Node): """Return the update mapping for the noise RV.""" - # Since noise is a shared variable it shows up as the last node input return {node.inputs[-1]: node.outputs[0]} @@ -658,13 +657,13 @@ def step(*args): ar_ = pt.concatenate([init_, innov_.T], axis=-1) ar_op = AutoRegressiveRV( - inputs=[rhos_, sigma_, init_, steps_], + inputs=[rhos_, sigma_, init_, steps_, noise_rng], outputs=[noise_next_rng, ar_], ar_order=ar_order, constant_term=constant_term, ) - ar = ar_op(rhos, sigma, init_dist, steps) + ar = ar_op(rhos, sigma, init_dist, steps, noise_rng) return ar @@ -731,7 +730,6 @@ class GARCH11RV(SymbolicRandomVariable): def update(self, node: Node): """Return the update mapping for the noise RV.""" - # Since noise is a shared variable it shows up as the last node input return {node.inputs[-1]: node.outputs[0]} @@ -797,7 +795,6 @@ def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None) # In this case the size of the init_dist depends on the parameters shape batch_size = pt.broadcast_shape(omega, alpha_1, beta_1, initial_vol) init_dist = change_dist_size(init_dist, batch_size) - # initial_vol = initial_vol * pt.ones(batch_size) # Create OpFromGraph representing random draws from GARCH11 process # Variables with underscore suffix are dummy inputs into the OpFromGraph @@ -819,7 +816,7 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng): (y_t, _), innov_updates_ = pytensor.scan( fn=step, - outputs_info=[init_, initial_vol_ * pt.ones(batch_size)], + outputs_info=[init_, pt.broadcast_to(initial_vol_.astype("floatX"), init_.shape)], non_sequences=[omega_, alpha_1_, beta_1_, noise_rng], n_steps=steps_, strict=True, @@ -831,11 +828,11 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng): ) garch11_op = GARCH11RV( - inputs=[omega_, alpha_1_, beta_1_, initial_vol_, init_, steps_], + inputs=[omega_, alpha_1_, beta_1_, initial_vol_, init_, steps_, noise_rng], outputs=[noise_next_rng, garch11_], ) - garch11 = garch11_op(omega, alpha_1, beta_1, initial_vol, init_dist, steps) + garch11 = garch11_op(omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng) return garch11 @@ -891,14 +888,13 @@ class EulerMaruyamaRV(SymbolicRandomVariable): ndim_supp = 1 _print_name = ("EulerMaruyama", "\\operatorname{EulerMaruyama}") - def __init__(self, *args, dt, sde_fn, **kwargs): + def __init__(self, *args, dt: float, sde_fn: Callable, **kwargs): self.dt = dt self.sde_fn = sde_fn super().__init__(*args, **kwargs) def update(self, node: Node): """Return the update mapping for the noise RV.""" - # Since noise is a shared variable it shows up as the last node input return {node.inputs[-1]: node.outputs[0]} @@ -1010,14 +1006,14 @@ def step(*prev_args): ) eulermaruyama_op = EulerMaruyamaRV( - inputs=[init_, steps_, *sde_pars_], + inputs=[init_, steps_, *sde_pars_, noise_rng], outputs=[noise_next_rng, sde_out_], dt=dt, sde_fn=sde_fn, signature=f"(),(s),{','.join('()' for _ in sde_pars_)}->(),(t)", ) - eulermaruyama = eulermaruyama_op(init_dist, steps, *sde_pars) + eulermaruyama = eulermaruyama_op(init_dist, steps, *sde_pars, noise_rng) return eulermaruyama diff --git a/pymc/distributions/truncated.py b/pymc/distributions/truncated.py index 2a1618348a..263e76f2e5 100644 --- a/pymc/distributions/truncated.py +++ b/pymc/distributions/truncated.py @@ -17,7 +17,7 @@ import pytensor import pytensor.tensor as pt -from pytensor import scan +from pytensor import config, graph_replace, scan from pytensor.graph import Op from pytensor.graph.basic import Node from pytensor.raise_op import CheckAndRaise @@ -25,10 +25,12 @@ from pytensor.tensor import TensorConstant, TensorVariable from pytensor.tensor.random.basic import NormalRV from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.type import RandomType from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import ( + CustomSymbolicDistRV, Distribution, SymbolicRandomVariable, _support_point, @@ -38,8 +40,9 @@ from pymc.distributions.transforms import _default_transform from pymc.exceptions import TruncationError from pymc.logprob.abstract import _logcdf, _logprob -from pymc.logprob.basic import icdf, logcdf +from pymc.logprob.basic import icdf, logcdf, logp from pymc.math import logdiffexp +from pymc.pytensorf import collect_default_updates from pymc.util import check_dist_not_registered @@ -49,11 +52,17 @@ class TruncatedRV(SymbolicRandomVariable): that represents a truncated univariate random variable. """ - default_output = 1 - base_rv_op = None - max_n_steps = None - - def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs): + default_output: int = 0 + base_rv_op: Op + max_n_steps: int + + def __init__( + self, + *args, + base_rv_op: Op, + max_n_steps: int, + **kwargs, + ): self.base_rv_op = base_rv_op self.max_n_steps = max_n_steps self._print_name = ( @@ -63,9 +72,13 @@ def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs): super().__init__(*args, **kwargs) def update(self, node: Node): - """Return the update mapping for the noise RV.""" - # Since RNG is a shared variable it shows up as the last node input - return {node.inputs[-1]: node.outputs[0]} + """Return the update mapping for the internal RNGs. + + TruncatedRVs are created in a way that the rng updates follow the same order as the input RNGs. + """ + rngs = [inp for inp in node.inputs if isinstance(inp.type, RandomType)] + next_rngs = [out for out in node.outputs if isinstance(out.type, RandomType)] + return dict(zip(rngs, next_rngs)) @singledispatch @@ -142,10 +155,14 @@ class Truncated(Distribution): @classmethod def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs): - if not (isinstance(dist, TensorVariable) and isinstance(dist.owner.op, RandomVariable)): + if not ( + isinstance(dist, TensorVariable) + and isinstance(dist.owner.op, RandomVariable | CustomSymbolicDistRV) + ): if isinstance(dist.owner.op, SymbolicRandomVariable): raise NotImplementedError( - f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}" + f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}.\n" + f"You can try wrapping the distribution inside a CustomDist instead." ) raise ValueError( f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}" @@ -175,37 +192,40 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None): if size is None: size = pt.broadcast_shape(dist, lower, upper) dist = change_dist_size(dist, new_size=size) + rv_inputs = [ + inp + if not isinstance(inp.type, RandomType) + else pytensor.shared(np.random.default_rng()) + for inp in dist.owner.inputs + ] + graph_inputs = [*rv_inputs, lower, upper] # Variables with `_` suffix identify dummy inputs for the OpFromGraph - graph_inputs = [*dist.owner.inputs[1:], lower, upper] - graph_inputs_ = [inp.type() for inp in graph_inputs] + graph_inputs_ = [ + inp.type() if not isinstance(inp.type, RandomType) else inp for inp in graph_inputs + ] *rv_inputs_, lower_, upper_ = graph_inputs_ - # We will use a Shared RNG variable because Scan demands it, even though it - # would not be necessary for the OpFromGraph inverse cdf. - rng = pytensor.shared(np.random.default_rng()) - rv_ = dist.owner.op.make_node(rng, *rv_inputs_).default_output() + rv_ = dist.owner.op.make_node(*rv_inputs_).default_output() # Try to use inverted cdf sampling + # truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper)))) try: - # For left truncated discrete RVs, we need to include the whole lower bound. - # This may result in draws below the truncation range, if any uniform == 0 - lower_value = lower_ - 1 if dist.owner.op.dtype.startswith("int") else lower_ - cdf_lower_ = pt.exp(logcdf(rv_, lower_value)) - cdf_upper_ = pt.exp(logcdf(rv_, upper_)) - # It's okay to reuse the same rng here, because the rng in rv_ will not be - # used by either the logcdf of icdf functions - uniform_ = pt.random.uniform( - cdf_lower_, - cdf_upper_, - rng=rng, - size=rv_inputs_[0], - ) - truncated_rv_ = icdf(rv_, uniform_) + logcdf_lower_, logcdf_upper_ = Truncated._create_logcdf_exprs(rv_, rv_, lower_, upper_) + # We use the first RNG from the base RV, so we don't have to introduce a new one + # This is not problematic because the RNG won't be used in the RV logcdf graph + uniform_rng_ = next(inp_ for inp_ in rv_inputs_ if isinstance(inp_.type, RandomType)) + uniform_next_rng_, uniform_ = pt.random.uniform( + pt.exp(logcdf_lower_), + pt.exp(logcdf_upper_), + rng=uniform_rng_, + size=rv_.shape, + ).owner.outputs + truncated_rv_ = icdf(rv_, uniform_, warn_rvs=False) return TruncatedRV( base_rv_op=dist.owner.op, inputs=graph_inputs_, - outputs=[uniform_.owner.outputs[0], truncated_rv_], + outputs=[truncated_rv_, uniform_next_rng_], ndim_supp=0, max_n_steps=max_n_steps, )(*graph_inputs) @@ -213,8 +233,13 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None): pass # Fallback to rejection sampling - def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): - next_rng, new_truncated_rv = dist.owner.op.make_node(rng, *rv_inputs).outputs + # truncated_rv = zeros(rv.shape) + # reject_draws = ones(rv.shape, dtype=bool) + # while any(reject_draws): + # truncated_rv[reject_draws] = draw(rv)[reject_draws] + # reject_draws = (truncated_rv < lower) | (truncated_rv > upper) + def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs): + new_truncated_rv = dist.owner.op.make_node(*rv_inputs_).default_output() # Avoid scalar boolean indexing if truncated_rv.type.ndim == 0: truncated_rv = new_truncated_rv @@ -227,7 +252,7 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): return ( (truncated_rv, reject_draws), - [(rng, next_rng)], + collect_default_updates(new_truncated_rv), until(~pt.any(reject_draws)), ) @@ -237,7 +262,7 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): pt.zeros_like(rv_), pt.ones_like(rv_, dtype=bool), ], - non_sequences=[lower_, upper_, rng, *rv_inputs_], + non_sequences=[lower_, upper_, *rv_inputs_], n_steps=max_n_steps, strict=True, ) @@ -247,23 +272,49 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): truncated_rv_ = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")( truncated_rv_, convergence_ ) + # Sort updates of each RNG so that they show in the same order as the input RNGs + + def sort_updates(update): + rng, next_rng = update + return graph_inputs.index(rng) + + next_rngs = [next_rng for rng, next_rng in sorted(updates.items(), key=sort_updates)] return TruncatedRV( base_rv_op=dist.owner.op, inputs=graph_inputs_, - outputs=[next(iter(updates.values())), truncated_rv_], + outputs=[truncated_rv_, *next_rngs], ndim_supp=0, max_n_steps=max_n_steps, )(*graph_inputs) + @staticmethod + def _create_logcdf_exprs( + base_rv: TensorVariable, + value: TensorVariable, + lower: TensorVariable, + upper: TensorVariable, + ) -> tuple[TensorVariable, TensorVariable]: + """Create lower and upper logcdf expressions for base_rv. + + Uses `value` as a template for broadcasting. + """ + # For left truncated discrete RVs, we need to include the whole lower bound. + lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower + lower_value = pt.full_like(value, lower_value, dtype=config.floatX) + upper_value = pt.full_like(value, upper, dtype=config.floatX) + lower_logcdf = logcdf(base_rv, lower_value, warn_rvs=False) + upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value}) + return lower_logcdf, upper_logcdf + @_change_dist_size.register(TruncatedRV) -def change_truncated_size(op, dist, new_size, expand): - *rv_inputs, lower, upper, rng = dist.owner.inputs - # Recreate the original untruncated RV - untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output() +def change_truncated_size(op: TruncatedRV, truncated_rv, new_size, expand): + *rv_inputs, lower, upper = truncated_rv.owner.inputs + untruncated_rv = op.base_rv_op.make_node(*rv_inputs).default_output() + if expand: - new_size = to_tuple(new_size) + tuple(dist.shape) + new_size = to_tuple(new_size) + tuple(truncated_rv.shape) return Truncated.rv_op( untruncated_rv, @@ -275,11 +326,11 @@ def change_truncated_size(op, dist, new_size, expand): @_support_point.register(TruncatedRV) -def truncated_support_point(op, rv, *inputs): - *rv_inputs, lower, upper, rng = inputs +def truncated_support_point(op: TruncatedRV, truncated_rv, *inputs): + *rv_inputs, lower, upper = inputs # recreate untruncated rv and respective support_point - untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output() + untruncated_rv = op.base_rv_op.make_node(*rv_inputs).default_output() untruncated_support_point = support_point(untruncated_rv) fallback_support_point = pt.switch( @@ -300,31 +351,25 @@ def truncated_support_point(op, rv, *inputs): @_default_transform.register(TruncatedRV) -def truncated_default_transform(op, rv): +def truncated_default_transform(op, truncated_rv): # Don't transform discrete truncated distributions - if op.base_rv_op.dtype.startswith("int"): + if truncated_rv.type.dtype.startswith("int"): return None - # Lower and Upper are the arguments -3 and -2 - return bounded_cont_transform(op, rv, bound_args_indices=(-3, -2)) + # Lower and Upper are the arguments -2 and -1 + return bounded_cont_transform(op, truncated_rv, bound_args_indices=(-2, -1)) @_logprob.register(TruncatedRV) def truncated_logprob(op, values, *inputs, **kwargs): (value,) = values - - *rv_inputs, lower, upper, rng = inputs - rv_inputs = [rng, *rv_inputs] + *rv_inputs, lower, upper = inputs base_rv_op = op.base_rv_op - logp = _logprob(base_rv_op, (value,), *rv_inputs, **kwargs) - # For left truncated RVs, we don't want to include the lower bound in the - # normalization term - lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower - lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs) - upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs) - + base_rv = base_rv_op.make_node(*rv_inputs).default_output() + base_logp = logp(base_rv, value) + lower_logcdf, upper_logcdf = Truncated._create_logcdf_exprs(base_rv, value, lower, upper) if base_rv_op.name: - logp.name = f"{base_rv_op}_logprob" + base_logp.name = f"{base_rv_op}_logprob" lower_logcdf.name = f"{base_rv_op}_lower_logcdf" upper_logcdf.name = f"{base_rv_op}_upper_logcdf" @@ -339,37 +384,31 @@ def truncated_logprob(op, values, *inputs, **kwargs): elif is_upper_bounded: lognorm = upper_logcdf - logp = logp - lognorm + truncated_logp = base_logp - lognorm if is_lower_bounded: - logp = pt.switch(value < lower, -np.inf, logp) + truncated_logp = pt.switch(value < lower, -np.inf, truncated_logp) if is_upper_bounded: - logp = pt.switch(value <= upper, logp, -np.inf) + truncated_logp = pt.switch(value <= upper, truncated_logp, -np.inf) if is_lower_bounded and is_upper_bounded: - logp = check_parameters( - logp, + truncated_logp = check_parameters( + truncated_logp, pt.le(lower, upper), msg="lower_bound <= upper_bound", ) - return logp + return truncated_logp @_logcdf.register(TruncatedRV) -def truncated_logcdf(op, value, *inputs, **kwargs): - *rv_inputs, lower, upper, rng = inputs - rv_inputs = [rng, *rv_inputs] - - base_rv_op = op.base_rv_op - logcdf = _logcdf(base_rv_op, value, *rv_inputs, **kwargs) +def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs): + *rv_inputs, lower, upper = inputs - # For left truncated discrete RVs, we don't want to include the lower bound in the - # normalization term - lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower - lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs) - upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs) + base_rv = op.base_rv_op.make_node(*rv_inputs).default_output() + base_logcdf = logcdf(base_rv, value) + lower_logcdf, upper_logcdf = Truncated._create_logcdf_exprs(base_rv, value, lower, upper) is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value))) is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value))) @@ -382,7 +421,7 @@ def truncated_logcdf(op, value, *inputs, **kwargs): elif is_upper_bounded: lognorm = upper_logcdf - logcdf_numerator = logdiffexp(logcdf, lower_logcdf) if is_lower_bounded else logcdf + logcdf_numerator = logdiffexp(base_logcdf, lower_logcdf) if is_lower_bounded else base_logcdf logcdf_trunc = logcdf_numerator - lognorm if is_lower_bounded: diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index a68cd64d8a..f357e56348 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -23,11 +23,13 @@ from pytensor import scalar from pytensor.compile import Function, Mode, get_mode +from pytensor.compile.builders import OpFromGraph from pytensor.gradient import grad from pytensor.graph import Type, rewrite_graph from pytensor.graph.basic import ( Apply, Constant, + Node, Variable, clone_get_equiv, graph_inputs, @@ -781,8 +783,25 @@ def reseed_rngs( rng.set_value(new_rng, borrow=True) +def collect_default_updates_inner_fgraph(node: Node) -> dict[Variable, Variable]: + """Collect default updates from node with inner fgraph.""" + op = node.op + inner_updates = collect_default_updates( + inputs=op.inner_inputs, outputs=op.inner_outputs, must_be_shared=False + ) + + # Map inner updates to outer inputs/outputs + updates = {} + for rng, update in inner_updates.items(): + inp_idx = op.inner_inputs.index(rng) + out_idx = op.inner_outputs.index(update) + updates[node.inputs[inp_idx]] = node.outputs[out_idx] + + return updates + + def collect_default_updates( - outputs: Sequence[Variable], + outputs: Variable | Sequence[Variable], *, inputs: Sequence[Variable] | None = None, must_be_shared: bool = True, @@ -874,9 +893,16 @@ def find_default_update(clients, rng: Variable) -> None | Variable: f"No update found for at least one RNG used in Scan Op {client.op}.\n" "You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically." ) + elif isinstance(client.op, OpFromGraph): + try: + next_rng = collect_default_updates_inner_fgraph(client)[rng] + except (ValueError, KeyError): + raise ValueError( + f"No update found for at least one RNG used in OpFromGraph Op {client.op}.\n" + "You can use `pytensorf.collect_default_updates` and include those updates as outputs." + ) else: - # We don't know how this RNG should be updated (e.g., OpFromGraph). - # The user should provide an update manually + # We don't know how this RNG should be updated. The user should provide an update manually return None # Recurse until we find final update for RNG diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index bb43063be9..2607a3278a 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -41,18 +41,19 @@ CustomDist, CustomDistRV, CustomSymbolicDistRV, + DiracDelta, PartialObservedRV, SymbolicRandomVariable, _support_point, create_partial_observed_rv, support_point, ) -from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple +from pymc.distributions.shape_utils import change_dist_size, to_tuple from pymc.distributions.transforms import log from pymc.exceptions import BlockModelAccessError from pymc.logprob.basic import conditional_logp, logcdf, logp from pymc.model import Deterministic, Model -from pymc.pytensorf import collect_default_updates +from pymc.pytensorf import collect_default_updates, compile_pymc from pymc.sampling import draw, sample from pymc.testing import ( BaseTestDistributionRandom, @@ -584,9 +585,7 @@ def custom_dist(p, sigma, size): def test_custom_methods(self): def custom_dist(mu, size): - if rv_size_is_none(size): - return mu - return pt.full(size, mu) + return DiracDelta.dist(mu, size=size) def custom_support_point(rv, size, mu): return pt.full_like(rv, mu + 1) @@ -778,7 +777,8 @@ def test_inline(self): class TestSymbolicRV(SymbolicRandomVariable): pass - x = TestSymbolicRV([], [Flat.dist()], ndim_supp=0)() + rng = pytensor.shared(np.random.default_rng()) + x = TestSymbolicRV([rng], [Flat.dist(rng=rng)], ndim_supp=0)(rng) # By default, the SymbolicRandomVariable will not be inlined. Because we did not # dispatch a custom logprob function it will raise next @@ -788,9 +788,70 @@ class TestSymbolicRV(SymbolicRandomVariable): class TestInlinedSymbolicRV(SymbolicRandomVariable): inline_logprob = True - x_inline = TestInlinedSymbolicRV([], [Flat.dist()], ndim_supp=0)() + x_inline = TestInlinedSymbolicRV([rng], [Flat.dist(rng=rng)], ndim_supp=0)(rng) assert np.isclose(logp(x_inline, 0).eval(), 0) + def test_default_update(self): + """Test SymbolicRandomVariable Op default to updates from inner graph.""" + + class SymbolicRVDefaultUpdates(SymbolicRandomVariable): + pass + + class SymbolicRVCustomUpdates(SymbolicRandomVariable): + def update(self, node): + return {} + + rng = pytensor.shared(np.random.default_rng()) + dummy_rng = rng.type() + dummy_next_rng, dummy_x = pt.random.normal(rng=dummy_rng).owner.outputs + + # Check that default updates work + next_rng, x = SymbolicRVDefaultUpdates( + inputs=[dummy_rng], + outputs=[dummy_next_rng, dummy_x], + ndim_supp=0, + )(rng) + fn = compile_pymc(inputs=[], outputs=x, random_seed=431) + assert fn() != fn() + + # Check that custom updates are respected, by using one that's broken + next_rng, x = SymbolicRVCustomUpdates( + inputs=[dummy_rng], + outputs=[dummy_next_rng, dummy_x], + ndim_supp=0, + )(rng) + with pytest.raises( + ValueError, + match="No update found for at least one RNG used in SymbolicRandomVariable Op SymbolicRVCustomUpdates", + ): + compile_pymc(inputs=[], outputs=x, random_seed=431) + + def test_recreate_with_different_rng_inputs(self): + """Test that we can recreate a SymbolicRandomVariable with new RNG inputs. + + Related to https://github.com/pymc-devs/pytensor/issues/473 + """ + rng = pytensor.shared(np.random.default_rng()) + + dummy_rng = rng.type() + dummy_next_rng, dummy_x = pt.random.normal(rng=dummy_rng).owner.outputs + + op = SymbolicRandomVariable( + [dummy_rng], + [dummy_next_rng, dummy_x], + ndim_supp=0, + ) + + next_rng, x = op(rng) + assert op.update(x.owner) == {rng: next_rng} + + new_rng = pytensor.shared(np.random.default_rng()) + inputs = x.owner.inputs.copy() + inputs[0] = new_rng + # This would fail with the default OpFromGraph.__call__() + new_next_rng, new_x = x.owner.op(*inputs) + assert op.update(new_x.owner) == {new_rng: new_next_rng} + def test_tag_future_warning_dist(): # Test no unexpected warnings diff --git a/tests/distributions/test_mixture.py b/tests/distributions/test_mixture.py index 7ce6084d8d..a7293c0ed3 100644 --- a/tests/distributions/test_mixture.py +++ b/tests/distributions/test_mixture.py @@ -1588,8 +1588,8 @@ def test_hurdle_negativebinomial_graph(self): _, nonzero_dist = self.check_hurdle_mixture_graph(dist) assert isinstance(nonzero_dist.owner.op.base_rv_op, NegativeBinomial) - assert nonzero_dist.owner.inputs[2].data == n - assert nonzero_dist.owner.inputs[3].data == p + assert nonzero_dist.owner.inputs[-4].data == n + assert nonzero_dist.owner.inputs[-3].data == p def test_hurdle_gamma_graph(self): psi, alpha, beta = 0.25, 3, 4 @@ -1599,8 +1599,8 @@ def test_hurdle_gamma_graph(self): # Under the hood it uses the shape-scale parametrization of the Gamma distribution. # So the second value is the reciprocal of the rate (i.e. 1 / beta) assert isinstance(nonzero_dist.owner.op.base_rv_op, Gamma) - assert nonzero_dist.owner.inputs[2].data == alpha - assert nonzero_dist.owner.inputs[3].eval() == 1 / beta + assert nonzero_dist.owner.inputs[-4].data == alpha + assert nonzero_dist.owner.inputs[-3].eval() == 1 / beta def test_hurdle_lognormal_graph(self): psi, mu, sigma = 0.1, 2, 2.5 @@ -1608,8 +1608,8 @@ def test_hurdle_lognormal_graph(self): _, nonzero_dist = self.check_hurdle_mixture_graph(dist) assert isinstance(nonzero_dist.owner.op.base_rv_op, LogNormal) - assert nonzero_dist.owner.inputs[2].data == mu - assert nonzero_dist.owner.inputs[3].data == sigma + assert nonzero_dist.owner.inputs[-4].data == mu + assert nonzero_dist.owner.inputs[-3].data == sigma @pytest.mark.parametrize( "dist, psi, non_psi_args", diff --git a/tests/distributions/test_truncated.py b/tests/distributions/test_truncated.py index ccd34d173f..a451051210 100644 --- a/tests/distributions/test_truncated.py +++ b/tests/distributions/test_truncated.py @@ -18,13 +18,19 @@ import scipy from pytensor.tensor.random.basic import GeometricRV, NormalRV +from pytensor.tensor.random.type import RandomType -from pymc import Censored, Model, draw, find_MAP -from pymc.distributions.continuous import ( +from pymc import Model, draw, find_MAP +from pymc.distributions import ( + Censored, + ChiSquared, + CustomDist, Exponential, Gamma, + HalfNormal, + LogNormal, + Mixture, TruncatedNormal, - TruncatedNormalRV, ) from pymc.distributions.shape_utils import change_dist_size from pymc.distributions.transforms import _default_transform @@ -59,6 +65,24 @@ class RejectionGeometricRV(GeometricRV): rejection_geometric = RejectionGeometricRV() +def icdf_normal_customdist(loc, scale, name=None, size=None): + def dist(loc, scale, size): + return loc + icdf_normal(size=size) * scale + + x = CustomDist.dist(loc, scale, dist=dist, size=size) + x.name = name + return x + + +def rejection_normal_customdist(loc, scale, name=None, size=None): + def dist(loc, scale, size): + return loc + rejection_normal(size=size) * scale + + x = CustomDist.dist(loc, scale, dist=dist, size=size) + x.name = name + return x + + @_truncated.register(IcdfNormalRV) @_truncated.register(RejectionNormalRV) @_truncated.register(IcdfGeometricRV) @@ -94,7 +118,7 @@ def test_truncation_specialized_op(shape_info): else: raise ValueError(f"Not a valid shape_info parametrization: {shape_info}") - assert isinstance(xt.owner.op, TruncatedNormalRV) + assert isinstance(xt.owner.op, TruncatedNormal.rv_type) assert xt.shape.eval() == (100,) # Test RNG is not reused @@ -107,10 +131,14 @@ def test_truncation_specialized_op(shape_info): @pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)]) @pytest.mark.parametrize("op_type", ["icdf", "rejection"]) @pytest.mark.parametrize("scalar", [True, False]) -def test_truncation_continuous_random(op_type, lower, upper, scalar): +@pytest.mark.parametrize("custom_dist", [False, True]) +def test_truncation_continuous_random(op_type, lower, upper, scalar, custom_dist): loc = 0.15 scale = 10 - normal_op = icdf_normal if op_type == "icdf" else rejection_normal + if custom_dist: + normal_op = icdf_normal_customdist if op_type == "icdf" else rejection_normal_customdist + else: + normal_op = icdf_normal if op_type == "icdf" else rejection_normal x = normal_op(loc, scale, name="x", size=() if scalar else (100,)) xt = Truncated.dist(x, lower=lower, upper=upper) @@ -145,10 +173,14 @@ def test_truncation_continuous_random(op_type, lower, upper, scalar): @pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)]) @pytest.mark.parametrize("op_type", ["icdf", "rejection"]) -def test_truncation_continuous_logp(op_type, lower, upper): +@pytest.mark.parametrize("custom_dist", [False, True]) +def test_truncation_continuous_logp(op_type, lower, upper, custom_dist): loc = 0.15 scale = 10 - op = icdf_normal if op_type == "icdf" else rejection_normal + if custom_dist: + op = icdf_normal_customdist if op_type == "icdf" else rejection_normal_customdist + else: + op = icdf_normal if op_type == "icdf" else rejection_normal x = op(loc, scale, name="x") xt = Truncated.dist(x, lower=lower, upper=upper) @@ -173,10 +205,14 @@ def test_truncation_continuous_logp(op_type, lower, upper): @pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)]) @pytest.mark.parametrize("op_type", ["icdf", "rejection"]) -def test_truncation_continuous_logcdf(op_type, lower, upper): +@pytest.mark.parametrize("custom_dist", [False, True]) +def test_truncation_continuous_logcdf(op_type, lower, upper, custom_dist): loc = 0.15 scale = 10 - op = icdf_normal if op_type == "icdf" else rejection_normal + if custom_dist: + op = icdf_normal_customdist if op_type == "icdf" else rejection_normal_customdist + else: + op = icdf_normal if op_type == "icdf" else rejection_normal x = op(loc, scale, name="x") xt = Truncated.dist(x, lower=lower, upper=upper) @@ -481,3 +517,59 @@ def test_vectorized_bounds(): xs_logp, xs_sym_logp, ) + + +def test_truncated_multiple_rngs(): + def mix_dist_fn(size): + return Mixture.dist( + w=[0.3, 0.7], comp_dists=[HalfNormal.dist(), LogNormal.dist()], shape=size + ) + + upper = 0.1 + x = CustomDist.dist(dist=mix_dist_fn) + x_trunc = Truncated.dist(x, lower=0, upper=upper, shape=(5,)) + + # Mixture doesn't have an icdf method, so TruncatedRV uses a RejectionSampling representation + # Check that RNGs updates are correct + # TODO: Find out way of testing updates were not mixed + rngs = [inp for inp in x_trunc.owner.inputs if isinstance(inp.type, RandomType)] + next_rngs = [out for out in x_trunc.owner.outputs if isinstance(out.type, RandomType)] + assert len(set(rngs)) == len(set(next_rngs)) == 3 + + draws1 = draw(x_trunc, random_seed=1) + draws2 = draw(x_trunc, random_seed=1) + draws3 = draw(x_trunc, random_seed=2) + assert np.unique(draws1).size == 5 + assert np.unique(draws3).size == 5 + assert np.all(draws1 == draws2) + assert np.all(draws1 != draws3) + + test_x = np.array([-1, 0, 1, 2, 3]) + mix_rv = mix_dist_fn((5,)) + expected_logp = logp(mix_rv, test_x) - logcdf(mix_rv, upper) + expected_logp = pt.where(test_x <= upper, expected_logp, -np.inf) + np.testing.assert_allclose( + logp(x_trunc, test_x).eval(), + expected_logp.eval(), + ) + + +def test_truncated_maxwell_dist(): + def maxwell_dist(scale, size): + return pt.sqrt(ChiSquared.dist(nu=3, size=size)) * scale + + scale = 5.0 + upper = 2.0 + x = CustomDist.dist(scale, dist=maxwell_dist) + trunc_x = Truncated.dist(x, lower=None, upper=upper, size=(5,)) + assert np.all(draw(trunc_x, draws=20) < 2) + + test_value = np.array([-0.5, 0.0, 0.5, 1.5, 2.5]) + expected_logp = scipy.stats.maxwell.logpdf( + test_value, scale=scale + ) - scipy.stats.maxwell.logcdf(upper, scale=scale) + expected_logp[(test_value <= 0) | (test_value > upper)] = -np.inf + np.testing.assert_allclose( + logp(trunc_x, test_value).eval(), + expected_logp, + ) diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index 9dcaaf94c3..bb294668eb 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -408,28 +408,6 @@ def test_compile_pymc_updates_inputs(self): # Each RV adds a shared output for its rng assert len(fn_fgraph.outputs) == 1 + rvs_in_graph - def test_compile_pymc_symbolic_rv_update(self): - """Test that SymbolicRandomVariable Op update methods are used by compile_pymc""" - - class NonSymbolicRV(OpFromGraph): - def update(self, node): - return {node.inputs[0]: node.outputs[0]} - - rng = pytensor.shared(np.random.default_rng()) - dummy_rng = rng.type() - dummy_next_rng, dummy_x = NonSymbolicRV( - [dummy_rng], pt.random.normal(rng=dummy_rng).owner.outputs - )(rng) - - # Check that there are no updates at first - fn = compile_pymc(inputs=[], outputs=dummy_x) - assert fn() == fn() - - # And they are enabled once the Op is registered as a SymbolicRV - SymbolicRandomVariable.register(NonSymbolicRV) - fn = compile_pymc(inputs=[], outputs=dummy_x, random_seed=431) - assert fn() != fn() - def test_compile_pymc_symbolic_rv_missing_update(self): """Test that error is raised if SymbolicRandomVariable Op does not provide rule for updating RNG""" @@ -588,6 +566,22 @@ def step_wo_update(x, rng): fn = compile_pymc([], ys, random_seed=1) assert not (set(fn()) & set(fn())) + def test_op_from_graph_updates(self): + rng = pytensor.shared(np.random.default_rng()) + next_rng_, x_ = pt.random.normal(size=(10,), rng=rng).owner.outputs + + x = OpFromGraph([], [x_])() + with pytest.raises( + ValueError, + match="No update found for at least one RNG used in OpFromGraph Op", + ): + collect_default_updates([x]) + + next_rng, x = OpFromGraph([], [next_rng_, x_])() + assert collect_default_updates([x]) == {rng: next_rng} + fn = compile_pymc([], x, random_seed=1) + assert not (set(fn()) & set(fn())) + def test_replace_rng_nodes(): rng = pytensor.shared(np.random.default_rng())