Skip to content

Commit 7c009b9

Browse files
committed
Added the core code to be discussed as draft
1 parent 7fb87b4 commit 7c009b9

File tree

3 files changed

+296
-0
lines changed

3 files changed

+296
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 2025 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytensor.tensor as pt
16+
17+
from pymc.distributions.continuous import Continuous
18+
from pymc.distributions.distribution import SymbolicRandomVariable
19+
from pymc.distributions.shape_utils import (
20+
rv_size_is_none,
21+
)
22+
from pymc.distributions.transforms import _default_transform
23+
from pymc.pytensorf import normalize_rng_param
24+
from pytensor.tensor import get_underlying_scalar_constant_value
25+
from pytensor.tensor.random.utils import (
26+
normalize_size_param,
27+
)
28+
29+
__all__ = ["NormalSingularValues"]
30+
31+
from pymc.logprob.transforms import Transform
32+
33+
34+
# TODO: this is a lot of work to just get a list normally distributed variables
35+
class NormalSingularValuesRV(SymbolicRandomVariable):
36+
name = "normalsingularvalues"
37+
extended_signature = "[rng],[size],(),(m)->[rng],(m)" # TODO: check if this is correct
38+
_print_name = ("NormalSingularValuesRV", "\\operatorname{NormalSingularValuesRV}")
39+
40+
def make_node(self, rng, size, n, m):
41+
n = pt.as_tensor_variable(n)
42+
m = pt.as_tensor_variable(m)
43+
if not all(n.type.broadcastable) or not all(m.type.broadcastable):
44+
raise ValueError("n and m must be scalars.")
45+
46+
return super().make_node(rng, size, n, m)
47+
48+
@classmethod
49+
def rv_op(cls, n: int, m: int, *, rng=None, size=None):
50+
# We flatten the size to make operations easier, and then rebuild it
51+
n = pt.as_tensor(n, ndim=0, dtype=int)
52+
m = pt.as_tensor(m, ndim=0, dtype=int)
53+
54+
rng = normalize_rng_param(rng)
55+
size = normalize_size_param(size)
56+
57+
# TODO: currently assume size = 1. Fix this once everything is working
58+
D = get_underlying_scalar_constant_value(n)
59+
Q = get_underlying_scalar_constant_value(m)
60+
61+
# Perform a direct computation via SVD of a normal matrix
62+
sz = [] if rv_size_is_none(size) else size
63+
next_rng, z = pt.random.normal(0, 1, size=(*sz, D, Q), rng=rng).owner.outputs
64+
_, samples, _ = pt.linalg.svd(z)
65+
66+
return cls(
67+
inputs=[rng, size, n, m],
68+
outputs=[next_rng, samples],
69+
)(rng, size, n, m)
70+
71+
return samples
72+
73+
74+
# This is adapted from ordered transform.
75+
# Might make sense to just make that transform more generic by
76+
# allowing it to take parameters "positive" and "ascending"
77+
# and then just use that here.
78+
class PosRevOrdered(Transform):
79+
name = "posrevordered"
80+
81+
def __init__(self, ndim_supp=None):
82+
pass
83+
84+
def backward(self, value, *inputs):
85+
return pt.cumsum(pt.exp(value[..., ::-1]), axis=-1)[..., ::-1]
86+
87+
def forward(self, value, *inputs):
88+
y = pt.zeros(value.shape)
89+
y = pt.set_subtensor(y[..., -1], pt.log(value[..., -1]))
90+
y = pt.set_subtensor(y[..., :-1], pt.log(value[..., :-1] - value[..., 1:]))
91+
return y
92+
93+
def log_jac_det(self, value, *inputs):
94+
return pt.sum(value, axis=-1)
95+
96+
97+
class NormalSingularValues(Continuous):
98+
rv_type = NormalSingularValuesRV
99+
rv_op = NormalSingularValuesRV.rv_op
100+
101+
@classmethod
102+
def dist(cls, n, m, **kwargs):
103+
n = pt.as_tensor_variable(n).astype(int)
104+
m = pt.as_tensor_variable(m).astype(int)
105+
return super().dist([n, m], **kwargs)
106+
107+
def support_point(rv, *args):
108+
return pt.linspace(1, 0.5, rv.shape[-1])
109+
110+
def logp(sigma, n, m):
111+
# First term: prod[exp(-0.5*sigma**2)]
112+
log_p = -0.5 * pt.sum(sigma**2)
113+
114+
# Second + Fourth term (ignoring constant factor)
115+
# prod(sigma**(D-Q-1)) + prod(2*sigma)) = prod(2*sigma**(D-Q))
116+
log_p += (n - m) * pt.sum(pt.log(sigma))
117+
118+
# Third term: prod[prod[ |s1**2-s2**2| ]]
119+
# li = pt.triu_indices(m,k=1)
120+
# log_p += pt.log((sigma[:,None]**2 - sigma[None,:]**2)[li]).sum()
121+
log_p += (
122+
pt.log(pt.eye(m) + pt.abs(sigma[:, None] ** 2 - sigma[None, :] ** 2) + 1e-6).sum() / 2.0
123+
)
124+
125+
return log_p
126+
127+
128+
@_default_transform.register(NormalSingularValues)
129+
def lkjcorr_default_transform(op, rv):
130+
return PosRevOrdered()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2025 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytensor.tensor as pt
16+
17+
from pytensor.tensor import TensorVariable
18+
19+
from pymc_extras.distributions.multivariate.spherical import Spherical
20+
21+
__all__ = ["SemiOrthogonalMatrix"]
22+
23+
24+
class SemiOrthogonalMatrix:
25+
def __new__(cls, name, D, Q, **kwargs):
26+
dof = D * Q - Q * (Q - 1) // 2 # Total degrees of freedom
27+
28+
vs, pos = pt.zeros(dof), 0
29+
for q in range(Q):
30+
vq = Spherical(f"{name}_v{q}", D - q)
31+
vs = pt.set_subtensor(vs[pos : pos + D - q], vq)
32+
pos += D - q
33+
34+
return cls.orth_from_vs(vs, D, Q)
35+
36+
# Create a householder matrix from a vector
37+
@classmethod
38+
def _householder_matrix(cls, v: TensorVariable, D: int) -> TensorVariable:
39+
Q = v.shape[0]
40+
H = pt.eye(D)
41+
sgn = 1.0 # Original paper recommends sign(v[0]) but that causes divergences
42+
u = pt.inc_subtensor(v[0], sgn * pt.linalg.norm(v))
43+
H = pt.set_subtensor(
44+
H[-Q:, -Q:], -sgn * (pt.eye(Q, Q) - 2 * u[:, None] * u[None, :] / (pt.dot(u, u) + 1e-6))
45+
)
46+
return H
47+
48+
# Construct an orthogonal matrix from a vector of normally distributed values
49+
# as a cumulative product of householder matrices
50+
@classmethod
51+
def orth_from_vs(cls, vs: TensorVariable, D: int, Q: int) -> TensorVariable:
52+
"""Construct an orthogonal matrix from a set of direction vectors v"""
53+
H_p = pt.eye(D)
54+
pos, q = 0, 0
55+
dof = D * Q - Q * (Q - 1) // 2
56+
while pos < dof:
57+
v = vs[pos : pos + D - q]
58+
H = cls._householder_matrix(v, D)
59+
H_p = H @ H_p
60+
pos += D - q
61+
q += 1
62+
return H_p[:q, :]
63+
64+
@classmethod
65+
def vs_from_orth(cls, U: TensorVariable, D: int, Q: int) -> TensorVariable:
66+
"""Get the vs values that would lead to orthogonal matrix U. Inverse of orth_from_vs"""
67+
vs = []
68+
vl = D * Q - Q * (Q - 1) // 2
69+
vs, pos = pt.zeros(vl), 0
70+
for q in range(Q):
71+
v = U[q:, q] # Top row of the remaining submatrix
72+
73+
vs = pt.set_subtensor(vs[pos : pos + D - q], v)
74+
H = cls._householder_matrix(v, D)
75+
U = H.dot(U)
76+
pos += D - q
77+
return vs
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2025 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pymc as pm
16+
import pytensor.tensor as pt
17+
18+
from pymc.distributions.continuous import Continuous
19+
from pymc.distributions.distribution import SymbolicRandomVariable
20+
from pymc.distributions.shape_utils import (
21+
rv_size_is_none,
22+
)
23+
from pymc.pytensorf import normalize_rng_param
24+
from pytensor.tensor import get_underlying_scalar_constant_value
25+
from pytensor.tensor.random.utils import (
26+
normalize_size_param,
27+
)
28+
29+
__all__ = ["Spherical"]
30+
31+
32+
class SphericalRV(SymbolicRandomVariable):
33+
name = "spherical"
34+
extended_signature = "[rng],[size],(n)->[rng],(n)" # TODO: check if this is correct
35+
_print_name = ("SphericalRV", "\\operatorname{SphericalRV}")
36+
37+
def make_node(self, rng, size, n):
38+
n = pt.as_tensor_variable(n)
39+
return super().make_node(rng, size, n)
40+
41+
@classmethod
42+
def rv_op(cls, n, *, rng=None, size=None):
43+
rng = normalize_rng_param(rng)
44+
size = normalize_size_param(size)
45+
n = pt.as_tensor(n, ndim=0, dtype=int)
46+
nv = get_underlying_scalar_constant_value(n)
47+
48+
# Perform a direct computation via SVD of a normal matrix
49+
sz = [] if rv_size_is_none(size) else size
50+
51+
next_rng, z = pt.random.normal(0, 1, size=(*sz, nv), rng=rng).owner.outputs
52+
samples = z / pt.sqrt(z * z.sum(axis=-1, keepdims=True) + 1e-6)
53+
# TODO: scale by the .dist given
54+
55+
return cls(
56+
inputs=[rng, size, n],
57+
outputs=[next_rng, samples],
58+
)(rng, size, n)
59+
60+
return samples
61+
62+
63+
class Spherical(Continuous):
64+
rv_type = SphericalRV
65+
rv_op = SphericalRV.rv_op
66+
67+
@classmethod
68+
def dist(cls, n, **kwargs):
69+
n = pt.as_tensor_variable(n).astype(int)
70+
return super().dist([n], **kwargs)
71+
72+
def support_point(rv, size, n, *args):
73+
return pt.ones(rv.shape) / pt.sqrt(n)
74+
75+
def logp(value, n):
76+
# TODO: take dist as a parameter instead of hardcoding
77+
dist = pm.Gamma.dist(50, 50)
78+
79+
# Get the radius
80+
r = pt.sqrt(pt.sum(value**2))
81+
82+
# Get the log prior of the radius
83+
log_p = pm.logp(dist, r)
84+
# log_p = pm.logp(pm.TruncatedNormal.dist(1,lower=0),r)
85+
86+
# Add the log det jacobian for radius
87+
log_p += (value.shape[-1] - 1) * pt.log(r)
88+
89+
return log_p

0 commit comments

Comments
 (0)