From 7c009b9f63a278e3bd330c453835033d39aa29ba Mon Sep 17 00:00:00 2001 From: Margus Niitsoo Date: Mon, 21 Apr 2025 10:24:30 +0300 Subject: [PATCH] Added the core code to be discussed as draft --- .../multivariate/normal_singular_values.py | 130 ++++++++++++++++++ .../multivariate/orthogonal_matrix.py | 77 +++++++++++ .../distributions/multivariate/spherical.py | 89 ++++++++++++ 3 files changed, 296 insertions(+) create mode 100644 pymc_extras/distributions/multivariate/normal_singular_values.py create mode 100644 pymc_extras/distributions/multivariate/orthogonal_matrix.py create mode 100644 pymc_extras/distributions/multivariate/spherical.py diff --git a/pymc_extras/distributions/multivariate/normal_singular_values.py b/pymc_extras/distributions/multivariate/normal_singular_values.py new file mode 100644 index 00000000..3971ce58 --- /dev/null +++ b/pymc_extras/distributions/multivariate/normal_singular_values.py @@ -0,0 +1,130 @@ +# Copyright 2025 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytensor.tensor as pt + +from pymc.distributions.continuous import Continuous +from pymc.distributions.distribution import SymbolicRandomVariable +from pymc.distributions.shape_utils import ( + rv_size_is_none, +) +from pymc.distributions.transforms import _default_transform +from pymc.pytensorf import normalize_rng_param +from pytensor.tensor import get_underlying_scalar_constant_value +from pytensor.tensor.random.utils import ( + normalize_size_param, +) + +__all__ = ["NormalSingularValues"] + +from pymc.logprob.transforms import Transform + + +# TODO: this is a lot of work to just get a list normally distributed variables +class NormalSingularValuesRV(SymbolicRandomVariable): + name = "normalsingularvalues" + extended_signature = "[rng],[size],(),(m)->[rng],(m)" # TODO: check if this is correct + _print_name = ("NormalSingularValuesRV", "\\operatorname{NormalSingularValuesRV}") + + def make_node(self, rng, size, n, m): + n = pt.as_tensor_variable(n) + m = pt.as_tensor_variable(m) + if not all(n.type.broadcastable) or not all(m.type.broadcastable): + raise ValueError("n and m must be scalars.") + + return super().make_node(rng, size, n, m) + + @classmethod + def rv_op(cls, n: int, m: int, *, rng=None, size=None): + # We flatten the size to make operations easier, and then rebuild it + n = pt.as_tensor(n, ndim=0, dtype=int) + m = pt.as_tensor(m, ndim=0, dtype=int) + + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + + # TODO: currently assume size = 1. Fix this once everything is working + D = get_underlying_scalar_constant_value(n) + Q = get_underlying_scalar_constant_value(m) + + # Perform a direct computation via SVD of a normal matrix + sz = [] if rv_size_is_none(size) else size + next_rng, z = pt.random.normal(0, 1, size=(*sz, D, Q), rng=rng).owner.outputs + _, samples, _ = pt.linalg.svd(z) + + return cls( + inputs=[rng, size, n, m], + outputs=[next_rng, samples], + )(rng, size, n, m) + + return samples + + +# This is adapted from ordered transform. +# Might make sense to just make that transform more generic by +# allowing it to take parameters "positive" and "ascending" +# and then just use that here. +class PosRevOrdered(Transform): + name = "posrevordered" + + def __init__(self, ndim_supp=None): + pass + + def backward(self, value, *inputs): + return pt.cumsum(pt.exp(value[..., ::-1]), axis=-1)[..., ::-1] + + def forward(self, value, *inputs): + y = pt.zeros(value.shape) + y = pt.set_subtensor(y[..., -1], pt.log(value[..., -1])) + y = pt.set_subtensor(y[..., :-1], pt.log(value[..., :-1] - value[..., 1:])) + return y + + def log_jac_det(self, value, *inputs): + return pt.sum(value, axis=-1) + + +class NormalSingularValues(Continuous): + rv_type = NormalSingularValuesRV + rv_op = NormalSingularValuesRV.rv_op + + @classmethod + def dist(cls, n, m, **kwargs): + n = pt.as_tensor_variable(n).astype(int) + m = pt.as_tensor_variable(m).astype(int) + return super().dist([n, m], **kwargs) + + def support_point(rv, *args): + return pt.linspace(1, 0.5, rv.shape[-1]) + + def logp(sigma, n, m): + # First term: prod[exp(-0.5*sigma**2)] + log_p = -0.5 * pt.sum(sigma**2) + + # Second + Fourth term (ignoring constant factor) + # prod(sigma**(D-Q-1)) + prod(2*sigma)) = prod(2*sigma**(D-Q)) + log_p += (n - m) * pt.sum(pt.log(sigma)) + + # Third term: prod[prod[ |s1**2-s2**2| ]] + # li = pt.triu_indices(m,k=1) + # log_p += pt.log((sigma[:,None]**2 - sigma[None,:]**2)[li]).sum() + log_p += ( + pt.log(pt.eye(m) + pt.abs(sigma[:, None] ** 2 - sigma[None, :] ** 2) + 1e-6).sum() / 2.0 + ) + + return log_p + + +@_default_transform.register(NormalSingularValues) +def lkjcorr_default_transform(op, rv): + return PosRevOrdered() diff --git a/pymc_extras/distributions/multivariate/orthogonal_matrix.py b/pymc_extras/distributions/multivariate/orthogonal_matrix.py new file mode 100644 index 00000000..5964a9d4 --- /dev/null +++ b/pymc_extras/distributions/multivariate/orthogonal_matrix.py @@ -0,0 +1,77 @@ +# Copyright 2025 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytensor.tensor as pt + +from pytensor.tensor import TensorVariable + +from pymc_extras.distributions.multivariate.spherical import Spherical + +__all__ = ["SemiOrthogonalMatrix"] + + +class SemiOrthogonalMatrix: + def __new__(cls, name, D, Q, **kwargs): + dof = D * Q - Q * (Q - 1) // 2 # Total degrees of freedom + + vs, pos = pt.zeros(dof), 0 + for q in range(Q): + vq = Spherical(f"{name}_v{q}", D - q) + vs = pt.set_subtensor(vs[pos : pos + D - q], vq) + pos += D - q + + return cls.orth_from_vs(vs, D, Q) + + # Create a householder matrix from a vector + @classmethod + def _householder_matrix(cls, v: TensorVariable, D: int) -> TensorVariable: + Q = v.shape[0] + H = pt.eye(D) + sgn = 1.0 # Original paper recommends sign(v[0]) but that causes divergences + u = pt.inc_subtensor(v[0], sgn * pt.linalg.norm(v)) + H = pt.set_subtensor( + H[-Q:, -Q:], -sgn * (pt.eye(Q, Q) - 2 * u[:, None] * u[None, :] / (pt.dot(u, u) + 1e-6)) + ) + return H + + # Construct an orthogonal matrix from a vector of normally distributed values + # as a cumulative product of householder matrices + @classmethod + def orth_from_vs(cls, vs: TensorVariable, D: int, Q: int) -> TensorVariable: + """Construct an orthogonal matrix from a set of direction vectors v""" + H_p = pt.eye(D) + pos, q = 0, 0 + dof = D * Q - Q * (Q - 1) // 2 + while pos < dof: + v = vs[pos : pos + D - q] + H = cls._householder_matrix(v, D) + H_p = H @ H_p + pos += D - q + q += 1 + return H_p[:q, :] + + @classmethod + def vs_from_orth(cls, U: TensorVariable, D: int, Q: int) -> TensorVariable: + """Get the vs values that would lead to orthogonal matrix U. Inverse of orth_from_vs""" + vs = [] + vl = D * Q - Q * (Q - 1) // 2 + vs, pos = pt.zeros(vl), 0 + for q in range(Q): + v = U[q:, q] # Top row of the remaining submatrix + + vs = pt.set_subtensor(vs[pos : pos + D - q], v) + H = cls._householder_matrix(v, D) + U = H.dot(U) + pos += D - q + return vs diff --git a/pymc_extras/distributions/multivariate/spherical.py b/pymc_extras/distributions/multivariate/spherical.py new file mode 100644 index 00000000..24943eff --- /dev/null +++ b/pymc_extras/distributions/multivariate/spherical.py @@ -0,0 +1,89 @@ +# Copyright 2025 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pymc as pm +import pytensor.tensor as pt + +from pymc.distributions.continuous import Continuous +from pymc.distributions.distribution import SymbolicRandomVariable +from pymc.distributions.shape_utils import ( + rv_size_is_none, +) +from pymc.pytensorf import normalize_rng_param +from pytensor.tensor import get_underlying_scalar_constant_value +from pytensor.tensor.random.utils import ( + normalize_size_param, +) + +__all__ = ["Spherical"] + + +class SphericalRV(SymbolicRandomVariable): + name = "spherical" + extended_signature = "[rng],[size],(n)->[rng],(n)" # TODO: check if this is correct + _print_name = ("SphericalRV", "\\operatorname{SphericalRV}") + + def make_node(self, rng, size, n): + n = pt.as_tensor_variable(n) + return super().make_node(rng, size, n) + + @classmethod + def rv_op(cls, n, *, rng=None, size=None): + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + n = pt.as_tensor(n, ndim=0, dtype=int) + nv = get_underlying_scalar_constant_value(n) + + # Perform a direct computation via SVD of a normal matrix + sz = [] if rv_size_is_none(size) else size + + next_rng, z = pt.random.normal(0, 1, size=(*sz, nv), rng=rng).owner.outputs + samples = z / pt.sqrt(z * z.sum(axis=-1, keepdims=True) + 1e-6) + # TODO: scale by the .dist given + + return cls( + inputs=[rng, size, n], + outputs=[next_rng, samples], + )(rng, size, n) + + return samples + + +class Spherical(Continuous): + rv_type = SphericalRV + rv_op = SphericalRV.rv_op + + @classmethod + def dist(cls, n, **kwargs): + n = pt.as_tensor_variable(n).astype(int) + return super().dist([n], **kwargs) + + def support_point(rv, size, n, *args): + return pt.ones(rv.shape) / pt.sqrt(n) + + def logp(value, n): + # TODO: take dist as a parameter instead of hardcoding + dist = pm.Gamma.dist(50, 50) + + # Get the radius + r = pt.sqrt(pt.sum(value**2)) + + # Get the log prior of the radius + log_p = pm.logp(dist, r) + # log_p = pm.logp(pm.TruncatedNormal.dist(1,lower=0),r) + + # Add the log det jacobian for radius + log_p += (value.shape[-1] - 1) * pt.log(r) + + return log_p