Skip to content

Commit 6b43938

Browse files
committed
Reimplement several RandomVariables as SymbolicRandomVariables
This allows sampling from multiple backends without having to dispatch for each one
1 parent b0e863e commit 6b43938

File tree

3 files changed

+131
-119
lines changed

3 files changed

+131
-119
lines changed

pymc/distributions/continuous.py

Lines changed: 90 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@
5252
vonmises,
5353
)
5454
from pytensor.tensor.random.op import RandomVariable
55+
from pytensor.tensor.random.utils import normalize_size_param
5556
from pytensor.tensor.variable import TensorConstant
5657

5758
from pymc.logprob.abstract import _logprob_helper
5859
from pymc.logprob.basic import icdf
60+
from pymc.pytensorf import normalize_rng_param
5961

6062
try:
6163
from polyagamma import polyagamma_cdf, polyagamma_pdf, random_polyagamma
@@ -73,7 +75,6 @@ def polyagamma_cdf(*args, **kwargs):
7375

7476
from scipy import stats
7577
from scipy.interpolate import InterpolatedUnivariateSpline
76-
from scipy.special import expit
7778

7879
from pymc.distributions import transforms
7980
from pymc.distributions.dist_math import (
@@ -90,8 +91,8 @@ def polyagamma_cdf(*args, **kwargs):
9091
normal_lcdf,
9192
zvalue,
9293
)
93-
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous
94-
from pymc.distributions.shape_utils import rv_size_is_none
94+
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous, SymbolicRandomVariable
95+
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
9596
from pymc.distributions.transforms import _default_transform
9697
from pymc.math import invlogit, logdiffexp, logit
9798

@@ -1236,20 +1237,28 @@ def icdf(value, alpha, beta):
12361237
)
12371238

12381239

1239-
class KumaraswamyRV(RandomVariable):
1240+
class KumaraswamyRV(SymbolicRandomVariable):
12401241
name = "kumaraswamy"
1241-
ndim_supp = 0
1242-
ndims_params = [0, 0]
1243-
dtype = "floatX"
1242+
signature = "[rng],[size],(),()->[rng],()"
12441243
_print_name = ("Kumaraswamy", "\\operatorname{Kumaraswamy}")
12451244

12461245
@classmethod
1247-
def rng_fn(cls, rng, a, b, size) -> np.ndarray:
1248-
u = rng.uniform(size=size)
1249-
return np.asarray((1 - (1 - u) ** (1 / b)) ** (1 / a))
1246+
def rv_op(cls, a, b, *, size=None, rng=None):
1247+
a = pt.as_tensor(a)
1248+
b = pt.as_tensor(b)
1249+
rng = normalize_rng_param(rng)
1250+
size = normalize_size_param(size)
12501251

1252+
if rv_size_is_none(size):
1253+
size = implicit_size_from_params(a, b, ndims_params=cls.ndims_params)
12511254

1252-
kumaraswamy = KumaraswamyRV()
1255+
next_rng, u = uniform(size=size, rng=rng).owner.outputs
1256+
draws = (1 - (1 - u) ** (1 / b)) ** (1 / a)
1257+
1258+
return cls(
1259+
inputs=[rng, size, a, b],
1260+
outputs=[next_rng, draws],
1261+
)(rng, size, a, b)
12531262

12541263

12551264
class Kumaraswamy(UnitContinuous):
@@ -1296,13 +1305,11 @@ class Kumaraswamy(UnitContinuous):
12961305
b > 0.
12971306
"""
12981307

1299-
rv_op = kumaraswamy
1308+
rv_type = KumaraswamyRV
1309+
rv_op = KumaraswamyRV.rv_op
13001310

13011311
@classmethod
13021312
def dist(cls, a: DIST_PARAMETER_TYPES, b: DIST_PARAMETER_TYPES, *args, **kwargs):
1303-
a = pt.as_tensor_variable(a)
1304-
b = pt.as_tensor_variable(b)
1305-
13061313
return super().dist([a, b], *args, **kwargs)
13071314

13081315
def support_point(rv, size, a, b):
@@ -1533,24 +1540,32 @@ def icdf(value, mu, b):
15331540
return check_icdf_parameters(res, b > 0, msg="b > 0")
15341541

15351542

1536-
class AsymmetricLaplaceRV(RandomVariable):
1543+
class AsymmetricLaplaceRV(SymbolicRandomVariable):
15371544
name = "asymmetriclaplace"
1538-
ndim_supp = 0
1539-
ndims_params = [0, 0, 0]
1540-
dtype = "floatX"
1545+
signature = "[rng],[size],(),(),()->[rng],()"
15411546
_print_name = ("AsymmetricLaplace", "\\operatorname{AsymmetricLaplace}")
15421547

15431548
@classmethod
1544-
def rng_fn(cls, rng, b, kappa, mu, size=None) -> np.ndarray:
1545-
u = rng.uniform(size=size)
1549+
def rv_op(cls, b, kappa, mu, *, size=None, rng=None):
1550+
b = pt.as_tensor(b)
1551+
kappa = pt.as_tensor(kappa)
1552+
mu = pt.as_tensor(mu)
1553+
rng = normalize_rng_param(rng)
1554+
size = normalize_size_param(size)
1555+
1556+
if rv_size_is_none(size):
1557+
size = implicit_size_from_params(b, kappa, mu, ndims_params=cls.ndims_params)
1558+
1559+
next_rng, u = uniform(size=size, rng=rng).owner.outputs
15461560
switch = kappa**2 / (1 + kappa**2)
1547-
non_positive_x = mu + kappa * np.log(u * (1 / switch)) / b
1548-
positive_x = mu - np.log((1 - u) * (1 + kappa**2)) / (kappa * b)
1561+
non_positive_x = mu + kappa * pt.log(u * (1 / switch)) / b
1562+
positive_x = mu - pt.log((1 - u) * (1 + kappa**2)) / (kappa * b)
15491563
draws = non_positive_x * (u <= switch) + positive_x * (u > switch)
1550-
return np.asarray(draws)
1551-
15521564

1553-
asymmetriclaplace = AsymmetricLaplaceRV()
1565+
return cls(
1566+
inputs=[rng, size, b, kappa, mu],
1567+
outputs=[next_rng, draws],
1568+
)(rng, size, b, kappa, mu)
15541569

15551570

15561571
class AsymmetricLaplace(Continuous):
@@ -1599,15 +1614,12 @@ class AsymmetricLaplace(Continuous):
15991614
of interest.
16001615
"""
16011616

1602-
rv_op = asymmetriclaplace
1617+
rv_type = AsymmetricLaplaceRV
1618+
rv_op = AsymmetricLaplaceRV.rv_op
16031619

16041620
@classmethod
16051621
def dist(cls, kappa=None, mu=None, b=None, q=None, *args, **kwargs):
16061622
kappa = cls.get_kappa(kappa, q)
1607-
b = pt.as_tensor_variable(b)
1608-
kappa = pt.as_tensor_variable(kappa)
1609-
mu = pt.as_tensor_variable(mu)
1610-
16111623
return super().dist([b, kappa, mu], *args, **kwargs)
16121624

16131625
@classmethod
@@ -2475,7 +2487,6 @@ def dist(cls, nu, **kwargs):
24752487
return Gamma.dist(alpha=nu / 2, beta=1 / 2, **kwargs)
24762488

24772489

2478-
# TODO: Remove this once logp for multiplication is working!
24792490
class WeibullBetaRV(RandomVariable):
24802491
name = "weibull"
24812492
ndim_supp = 0
@@ -2597,19 +2608,22 @@ def icdf(value, alpha, beta):
25972608
)
25982609

25992610

2600-
class HalfStudentTRV(RandomVariable):
2611+
class HalfStudentTRV(SymbolicRandomVariable):
26012612
name = "halfstudentt"
2602-
ndim_supp = 0
2603-
ndims_params = [0, 0]
2604-
dtype = "floatX"
2613+
signature = "[rng],[size],(),()->[rng],()"
26052614
_print_name = ("HalfStudentT", "\\operatorname{HalfStudentT}")
26062615

26072616
@classmethod
2608-
def rng_fn(cls, rng, nu, sigma, size=None) -> np.ndarray:
2609-
return np.asarray(np.abs(stats.t.rvs(nu, scale=sigma, size=size, random_state=rng)))
2617+
def rv_op(cls, nu, sigma, *, size=None, rng=None) -> np.ndarray:
2618+
nu = pt.as_tensor(nu)
2619+
sigma = pt.as_tensor(sigma)
2620+
rng = normalize_rng_param(rng)
2621+
size = normalize_size_param(size)
26102622

2623+
next_rng, t_draws = t(df=nu, scale=sigma, size=size, rng=rng).owner.outputs
2624+
draws = pt.abs(t_draws)
26112625

2612-
halfstudentt = HalfStudentTRV()
2626+
return cls(inputs=[rng, size, nu, sigma], outputs=[next_rng, draws])(rng, size, nu, sigma)
26132627

26142628

26152629
class HalfStudentT(PositiveContinuous):
@@ -2671,14 +2685,12 @@ class HalfStudentT(PositiveContinuous):
26712685
x = pm.HalfStudentT('x', lam=4, nu=10)
26722686
"""
26732687

2674-
rv_op = halfstudentt
2688+
rv_type = HalfStudentTRV
2689+
rv_op = HalfStudentTRV.rv_op
26752690

26762691
@classmethod
26772692
def dist(cls, nu, sigma=None, lam=None, *args, **kwargs):
2678-
nu = pt.as_tensor_variable(nu)
26792693
lam, sigma = get_tau_sigma(lam, sigma)
2680-
sigma = pt.as_tensor_variable(sigma)
2681-
26822694
return super().dist([nu, sigma], *args, **kwargs)
26832695

26842696
def support_point(rv, size, nu, sigma):
@@ -2710,19 +2722,29 @@ def logp(value, nu, sigma):
27102722
)
27112723

27122724

2713-
class ExGaussianRV(RandomVariable):
2725+
class ExGaussianRV(SymbolicRandomVariable):
27142726
name = "exgaussian"
2715-
ndim_supp = 0
2716-
ndims_params = [0, 0, 0]
2717-
dtype = "floatX"
2727+
signature = "[rng],[size],(),(),()->[rng],()"
27182728
_print_name = ("ExGaussian", "\\operatorname{ExGaussian}")
27192729

27202730
@classmethod
2721-
def rng_fn(cls, rng, mu, sigma, nu, size=None) -> np.ndarray:
2722-
return np.asarray(rng.normal(mu, sigma, size=size) + rng.exponential(scale=nu, size=size))
2731+
def rv_op(cls, mu, sigma, nu, *, size=None, rng=None):
2732+
mu = pt.as_tensor(mu)
2733+
sigma = pt.as_tensor(sigma)
2734+
nu = pt.as_tensor(nu)
2735+
rng = normalize_rng_param(rng)
2736+
size = normalize_size_param(size)
27232737

2738+
if rv_size_is_none(size):
2739+
size = implicit_size_from_params(mu, sigma, nu, ndims_params=cls.ndims_params)
27242740

2725-
exgaussian = ExGaussianRV()
2741+
next_rng, normal_draws = normal(loc=mu, scale=sigma, size=size, rng=rng).owner.outputs
2742+
final_rng, exponential_draws = exponential(scale=nu, size=size, rng=next_rng).owner.outputs
2743+
draws = normal_draws + exponential_draws
2744+
2745+
return cls(inputs=[rng, size, mu, sigma, nu], outputs=[final_rng, draws])(
2746+
rng, size, mu, sigma, nu
2747+
)
27262748

27272749

27282750
class ExGaussian(Continuous):
@@ -2792,14 +2814,11 @@ class ExGaussian(Continuous):
27922814
Vol. 4, No. 1, pp 35-45.
27932815
"""
27942816

2795-
rv_op = exgaussian
2817+
rv_type = ExGaussianRV
2818+
rv_op = ExGaussianRV.rv_op
27962819

27972820
@classmethod
27982821
def dist(cls, mu=0.0, sigma=None, nu=None, *args, **kwargs):
2799-
mu = pt.as_tensor_variable(mu)
2800-
sigma = pt.as_tensor_variable(sigma)
2801-
nu = pt.as_tensor_variable(nu)
2802-
28032822
return super().dist([mu, sigma, nu], *args, **kwargs)
28042823

28052824
def support_point(rv, size, mu, sigma, nu):
@@ -3477,19 +3496,25 @@ def icdf(value, mu, s):
34773496
)
34783497

34793498

3480-
class LogitNormalRV(RandomVariable):
3499+
class LogitNormalRV(SymbolicRandomVariable):
34813500
name = "logit_normal"
3482-
ndim_supp = 0
3483-
ndims_params = [0, 0]
3484-
dtype = "floatX"
3501+
signature = "[rng],[size],(),()->[rng],()"
34853502
_print_name = ("logitNormal", "\\operatorname{logitNormal}")
34863503

34873504
@classmethod
3488-
def rng_fn(cls, rng, mu, sigma, size=None) -> np.ndarray:
3489-
return np.asarray(expit(stats.norm.rvs(loc=mu, scale=sigma, size=size, random_state=rng)))
3505+
def rv_op(cls, mu, sigma, *, size=None, rng=None):
3506+
mu = pt.as_tensor(mu)
3507+
sigma = pt.as_tensor(sigma)
3508+
rng = normalize_rng_param(rng)
3509+
size = normalize_size_param(size)
34903510

3511+
next_rng, normal_draws = normal(loc=mu, scale=sigma, size=size, rng=rng).owner.outputs
3512+
draws = pt.expit(normal_draws)
34913513

3492-
logit_normal = LogitNormalRV()
3514+
return cls(
3515+
inputs=[rng, size, mu, sigma],
3516+
outputs=[next_rng, draws],
3517+
)(rng, size, mu, sigma)
34933518

34943519

34953520
class LogitNormal(UnitContinuous):
@@ -3540,15 +3565,12 @@ class LogitNormal(UnitContinuous):
35403565
Defaults to 1.
35413566
"""
35423567

3543-
rv_op = logit_normal
3568+
rv_type = LogitNormalRV
3569+
rv_op = LogitNormalRV.rv_op
35443570

35453571
@classmethod
35463572
def dist(cls, mu=0, sigma=None, tau=None, **kwargs):
3547-
mu = pt.as_tensor_variable(mu)
3548-
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
3549-
sigma = pt.as_tensor_variable(sigma)
3550-
tau = pt.as_tensor_variable(tau)
3551-
3573+
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
35523574
return super().dist([mu, sigma], **kwargs)
35533575

35543576
def support_point(rv, size, mu, sigma):

pymc/distributions/discrete.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from pytensor.tensor import TensorConstant
2020
from pytensor.tensor.random.basic import (
21-
RandomVariable,
2221
ScipyRandomVariable,
2322
bernoulli,
2423
betabinom,
@@ -28,7 +27,9 @@
2827
hypergeometric,
2928
nbinom,
3029
poisson,
30+
uniform,
3131
)
32+
from pytensor.tensor.random.utils import normalize_size_param
3233
from scipy import stats
3334

3435
import pymc as pm
@@ -45,8 +46,8 @@
4546
normal_lccdf,
4647
normal_lcdf,
4748
)
48-
from pymc.distributions.distribution import Discrete
49-
from pymc.distributions.shape_utils import rv_size_is_none
49+
from pymc.distributions.distribution import Discrete, SymbolicRandomVariable
50+
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
5051
from pymc.logprob.basic import logcdf, logp
5152
from pymc.math import sigmoid
5253

@@ -65,6 +66,8 @@
6566
"OrderedProbit",
6667
]
6768

69+
from pymc.pytensorf import normalize_rng_param
70+
6871

6972
class Binomial(Discrete):
7073
R"""
@@ -387,20 +390,26 @@ def logcdf(value, p):
387390
)
388391

389392

390-
class DiscreteWeibullRV(RandomVariable):
393+
class DiscreteWeibullRV(SymbolicRandomVariable):
391394
name = "discrete_weibull"
392-
ndim_supp = 0
393-
ndims_params = [0, 0]
394-
dtype = "int64"
395+
signature = "[rng],[size],(),()->[rng],()"
395396
_print_name = ("dWeibull", "\\operatorname{dWeibull}")
396397

397398
@classmethod
398-
def rng_fn(cls, rng, q, beta, size):
399-
p = rng.uniform(size=size)
400-
return np.ceil(np.power(np.log(1 - p) / np.log(q), 1.0 / beta)) - 1
399+
def rv_op(cls, q, beta, *, size=None, rng=None):
400+
q = pt.as_tensor(q)
401+
beta = pt.as_tensor(beta)
402+
rng = normalize_rng_param(rng)
403+
size = normalize_size_param(size)
404+
405+
if rv_size_is_none(size):
406+
size = implicit_size_from_params(q, beta, ndims_params=cls.ndims_params)
401407

408+
next_rng, p = uniform(size=size, rng=rng).owner.outputs
409+
draws = pt.ceil(pt.power(pt.log(1 - p) / pt.log(q), 1.0 / beta)) - 1
410+
draws = draws.astype("int64")
402411

403-
discrete_weibull = DiscreteWeibullRV()
412+
return cls(inputs=[rng, size, q, beta], outputs=[next_rng, draws])(rng, size, q, beta)
404413

405414

406415
class DiscreteWeibull(Discrete):
@@ -452,12 +461,11 @@ def DiscreteWeibull(q, b, x):
452461
453462
"""
454463

455-
rv_op = discrete_weibull
464+
rv_type = DiscreteWeibullRV
465+
rv_op = DiscreteWeibullRV.rv_op
456466

457467
@classmethod
458468
def dist(cls, q, beta, *args, **kwargs):
459-
q = pt.as_tensor_variable(q)
460-
beta = pt.as_tensor_variable(beta)
461469
return super().dist([q, beta], **kwargs)
462470

463471
def support_point(rv, size, q, beta):

0 commit comments

Comments
 (0)