|
| 1 | +# Copyright 2025 The PyMC Developers |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import pytensor.tensor as pt |
| 16 | + |
| 17 | +from pymc.distributions.continuous import Continuous |
| 18 | +from pymc.distributions.distribution import SymbolicRandomVariable |
| 19 | +from pymc.distributions.shape_utils import ( |
| 20 | + rv_size_is_none, |
| 21 | +) |
| 22 | +from pymc.distributions.transforms import _default_transform |
| 23 | +from pymc.pytensorf import normalize_rng_param |
| 24 | +from pytensor.tensor import get_underlying_scalar_constant_value |
| 25 | +from pytensor.tensor.random.utils import ( |
| 26 | + normalize_size_param, |
| 27 | +) |
| 28 | + |
| 29 | +__all__ = ["NormalSingularValues"] |
| 30 | + |
| 31 | +from pymc.logprob.transforms import Transform |
| 32 | + |
| 33 | + |
| 34 | +# TODO: this is a lot of work to just get a list normally distributed variables |
| 35 | +class NormalSingularValuesRV(SymbolicRandomVariable): |
| 36 | + name = "normalsingularvalues" |
| 37 | + extended_signature = "[rng],[size],(),(m)->[rng],(m)" # TODO: check if this is correct |
| 38 | + _print_name = ("NormalSingularValuesRV", "\\operatorname{NormalSingularValuesRV}") |
| 39 | + |
| 40 | + def make_node(self, rng, size, n, m): |
| 41 | + n = pt.as_tensor_variable(n) |
| 42 | + m = pt.as_tensor_variable(m) |
| 43 | + if not all(n.type.broadcastable) or not all(m.type.broadcastable): |
| 44 | + raise ValueError("n and m must be scalars.") |
| 45 | + |
| 46 | + return super().make_node(rng, size, n, m) |
| 47 | + |
| 48 | + @classmethod |
| 49 | + def rv_op(cls, n: int, m: int, *, rng=None, size=None): |
| 50 | + # We flatten the size to make operations easier, and then rebuild it |
| 51 | + n = pt.as_tensor(n, ndim=0, dtype=int) |
| 52 | + m = pt.as_tensor(m, ndim=0, dtype=int) |
| 53 | + |
| 54 | + rng = normalize_rng_param(rng) |
| 55 | + size = normalize_size_param(size) |
| 56 | + |
| 57 | + # TODO: currently assume size = 1. Fix this once everything is working |
| 58 | + D = get_underlying_scalar_constant_value(n) |
| 59 | + Q = get_underlying_scalar_constant_value(m) |
| 60 | + |
| 61 | + # Perform a direct computation via SVD of a normal matrix |
| 62 | + sz = [] if rv_size_is_none(size) else size |
| 63 | + next_rng, z = pt.random.normal(0, 1, size=(*sz, D, Q), rng=rng).owner.outputs |
| 64 | + _, samples, _ = pt.linalg.svd(z) |
| 65 | + |
| 66 | + return cls( |
| 67 | + inputs=[rng, size, n, m], |
| 68 | + outputs=[next_rng, samples], |
| 69 | + )(rng, size, n, m) |
| 70 | + |
| 71 | + return samples |
| 72 | + |
| 73 | + |
| 74 | +# This is adapted from ordered transform. |
| 75 | +# Might make sense to just make that transform more generic by |
| 76 | +# allowing it to take parameters "positive" and "ascending" |
| 77 | +# and then just use that here. |
| 78 | +class PosRevOrdered(Transform): |
| 79 | + name = "posrevordered" |
| 80 | + |
| 81 | + def __init__(self, ndim_supp=None): |
| 82 | + pass |
| 83 | + |
| 84 | + def backward(self, value, *inputs): |
| 85 | + return pt.cumsum(pt.exp(value[..., ::-1]), axis=-1)[..., ::-1] |
| 86 | + |
| 87 | + def forward(self, value, *inputs): |
| 88 | + y = pt.zeros(value.shape) |
| 89 | + y = pt.set_subtensor(y[..., -1], pt.log(value[..., -1])) |
| 90 | + y = pt.set_subtensor(y[..., :-1], pt.log(value[..., :-1] - value[..., 1:])) |
| 91 | + return y |
| 92 | + |
| 93 | + def log_jac_det(self, value, *inputs): |
| 94 | + return pt.sum(value, axis=-1) |
| 95 | + |
| 96 | + |
| 97 | +class NormalSingularValues(Continuous): |
| 98 | + rv_type = NormalSingularValuesRV |
| 99 | + rv_op = NormalSingularValuesRV.rv_op |
| 100 | + |
| 101 | + @classmethod |
| 102 | + def dist(cls, n, m, **kwargs): |
| 103 | + n = pt.as_tensor_variable(n).astype(int) |
| 104 | + m = pt.as_tensor_variable(m).astype(int) |
| 105 | + return super().dist([n, m], **kwargs) |
| 106 | + |
| 107 | + def support_point(rv, *args): |
| 108 | + return pt.linspace(1, 0.5, rv.shape[-1]) |
| 109 | + |
| 110 | + def logp(sigma, n, m): |
| 111 | + # First term: prod[exp(-0.5*sigma**2)] |
| 112 | + log_p = -0.5 * pt.sum(sigma**2) |
| 113 | + |
| 114 | + # Second + Fourth term (ignoring constant factor) |
| 115 | + # prod(sigma**(D-Q-1)) + prod(2*sigma)) = prod(2*sigma**(D-Q)) |
| 116 | + log_p += (n - m) * pt.sum(pt.log(sigma)) |
| 117 | + |
| 118 | + # Third term: prod[prod[ |s1**2-s2**2| ]] |
| 119 | + # li = pt.triu_indices(m,k=1) |
| 120 | + # log_p += pt.log((sigma[:,None]**2 - sigma[None,:]**2)[li]).sum() |
| 121 | + log_p += ( |
| 122 | + pt.log(pt.eye(m) + pt.abs(sigma[:, None] ** 2 - sigma[None, :] ** 2) + 1e-6).sum() / 2.0 |
| 123 | + ) |
| 124 | + |
| 125 | + return log_p |
| 126 | + |
| 127 | + |
| 128 | +@_default_transform.register(NormalSingularValues) |
| 129 | +def lkjcorr_default_transform(op, rv): |
| 130 | + return PosRevOrdered() |
0 commit comments