Skip to content

Commit 0757e57

Browse files
authored
Add type hints for rng_fn return value (#5296)
* Add type hints to rng_fn * Change all return types to numpy arrays
1 parent 69d3114 commit 0757e57

File tree

1 file changed

+32
-28
lines changed

1 file changed

+32
-28
lines changed

pymc/distributions/continuous.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -906,8 +906,8 @@ class WaldRV(RandomVariable):
906906
_print_name = ("Wald", "\\operatorname{Wald}")
907907

908908
@classmethod
909-
def rng_fn(cls, rng, mu, lam, alpha, size):
910-
return rng.wald(mu, lam, size=size) + alpha
909+
def rng_fn(cls, rng, mu, lam, alpha, size) -> np.ndarray:
910+
return np.asarray(rng.wald(mu, lam, size=size) + alpha)
911911

912912

913913
wald = WaldRV()
@@ -1137,8 +1137,8 @@ def logcdf(
11371137

11381138
class BetaClippedRV(BetaRV):
11391139
@classmethod
1140-
def rng_fn(cls, rng, alpha, beta, size):
1141-
return clipped_beta_rvs(alpha, beta, size=size, random_state=rng)
1140+
def rng_fn(cls, rng, alpha, beta, size) -> np.ndarray:
1141+
return np.asarray(clipped_beta_rvs(alpha, beta, size=size, random_state=rng))
11421142

11431143

11441144
beta = BetaClippedRV()
@@ -1288,9 +1288,9 @@ class KumaraswamyRV(RandomVariable):
12881288
_print_name = ("Kumaraswamy", "\\operatorname{Kumaraswamy}")
12891289

12901290
@classmethod
1291-
def rng_fn(cls, rng, a, b, size):
1291+
def rng_fn(cls, rng, a, b, size) -> np.ndarray:
12921292
u = rng.uniform(size=size)
1293-
return (1 - (1 - u) ** (1 / b)) ** (1 / a)
1293+
return np.asarray((1 - (1 - u) ** (1 / b)) ** (1 / a))
12941294

12951295

12961296
kumaraswamy = KumaraswamyRV()
@@ -1596,13 +1596,13 @@ class AsymmetricLaplaceRV(RandomVariable):
15961596
_print_name = ("AsymmetricLaplace", "\\operatorname{AsymmetricLaplace}")
15971597

15981598
@classmethod
1599-
def rng_fn(cls, rng, b, kappa, mu, size=None):
1599+
def rng_fn(cls, rng, b, kappa, mu, size=None) -> np.ndarray:
16001600
u = rng.uniform(size=size)
16011601
switch = kappa ** 2 / (1 + kappa ** 2)
16021602
non_positive_x = mu + kappa * np.log(u * (1 / switch)) / b
16031603
positive_x = mu - np.log((1 - u) * (1 + kappa ** 2)) / (kappa * b)
16041604
draws = non_positive_x * (u <= switch) + positive_x * (u > switch)
1605-
return draws
1605+
return np.asarray(draws)
16061606

16071607

16081608
asymmetriclaplace = AsymmetricLaplaceRV()
@@ -1811,8 +1811,8 @@ class StudentTRV(RandomVariable):
18111811
_print_name = ("StudentT", "\\operatorname{StudentT}")
18121812

18131813
@classmethod
1814-
def rng_fn(cls, rng, nu, mu, sigma, size=None):
1815-
return stats.t.rvs(nu, mu, sigma, size=size, random_state=rng)
1814+
def rng_fn(cls, rng, nu, mu, sigma, size=None) -> np.ndarray:
1815+
return np.asarray(stats.t.rvs(nu, mu, sigma, size=size, random_state=rng))
18161816

18171817

18181818
studentt = StudentTRV()
@@ -2538,8 +2538,8 @@ class WeibullBetaRV(WeibullRV):
25382538
ndims_params = [0, 0]
25392539

25402540
@classmethod
2541-
def rng_fn(cls, rng, alpha, beta, size):
2542-
return beta * rng.weibull(alpha, size=size)
2541+
def rng_fn(cls, rng, alpha, beta, size) -> np.ndarray:
2542+
return np.asarray(beta * rng.weibull(alpha, size=size))
25432543

25442544

25452545
weibull_beta = WeibullBetaRV()
@@ -2642,8 +2642,8 @@ class HalfStudentTRV(RandomVariable):
26422642
_print_name = ("HalfStudentT", "\\operatorname{HalfStudentT}")
26432643

26442644
@classmethod
2645-
def rng_fn(cls, rng, nu, sigma, size=None):
2646-
return np.abs(stats.t.rvs(nu, sigma, size=size, random_state=rng))
2645+
def rng_fn(cls, rng, nu, sigma, size=None) -> np.ndarray:
2646+
return np.asarray(np.abs(stats.t.rvs(nu, sigma, size=size, random_state=rng)))
26472647

26482648

26492649
halfstudentt = HalfStudentTRV()
@@ -2770,8 +2770,8 @@ class ExGaussianRV(RandomVariable):
27702770
_print_name = ("ExGaussian", "\\operatorname{ExGaussian}")
27712771

27722772
@classmethod
2773-
def rng_fn(cls, rng, mu, sigma, nu, size=None):
2774-
return rng.normal(mu, sigma, size=size) + rng.exponential(scale=nu, size=size)
2773+
def rng_fn(cls, rng, mu, sigma, nu, size=None) -> np.ndarray:
2774+
return np.asarray(rng.normal(mu, sigma, size=size) + rng.exponential(scale=nu, size=size))
27752775

27762776

27772777
exgaussian = ExGaussianRV()
@@ -3007,8 +3007,10 @@ class SkewNormalRV(RandomVariable):
30073007
_print_name = ("SkewNormal", "\\operatorname{SkewNormal}")
30083008

30093009
@classmethod
3010-
def rng_fn(cls, rng, mu, sigma, alpha, size=None):
3011-
return stats.skewnorm.rvs(a=alpha, loc=mu, scale=sigma, size=size, random_state=rng)
3010+
def rng_fn(cls, rng, mu, sigma, alpha, size=None) -> np.ndarray:
3011+
return np.asarray(
3012+
stats.skewnorm.rvs(a=alpha, loc=mu, scale=sigma, size=size, random_state=rng)
3013+
)
30123014

30133015

30143016
skewnormal = SkewNormalRV()
@@ -3333,8 +3335,8 @@ class RiceRV(RandomVariable):
33333335
_print_name = ("Rice", "\\operatorname{Rice}")
33343336

33353337
@classmethod
3336-
def rng_fn(cls, rng, b, sigma, size=None):
3337-
return stats.rice.rvs(b=b, scale=sigma, size=size, random_state=rng)
3338+
def rng_fn(cls, rng, b, sigma, size=None) -> np.ndarray:
3339+
return np.asarray(stats.rice.rvs(b=b, scale=sigma, size=size, random_state=rng))
33383340

33393341

33403342
rice = RiceRV()
@@ -3560,8 +3562,8 @@ class LogitNormalRV(RandomVariable):
35603562
_print_name = ("logitNormal", "\\operatorname{logitNormal}")
35613563

35623564
@classmethod
3563-
def rng_fn(cls, rng, mu, sigma, size=None):
3564-
return expit(stats.norm.rvs(loc=mu, scale=sigma, size=size, random_state=rng))
3565+
def rng_fn(cls, rng, mu, sigma, size=None) -> np.ndarray:
3566+
return np.asarray(expit(stats.norm.rvs(loc=mu, scale=sigma, size=size, random_state=rng)))
35653567

35663568

35673569
logit_normal = LogitNormalRV()
@@ -3684,9 +3686,9 @@ class InterpolatedRV(RandomVariable):
36843686
_print_name = ("Interpolated", "\\operatorname{Interpolated}")
36853687

36863688
@classmethod
3687-
def rng_fn(cls, rng, x, pdf, cdf, size=None):
3689+
def rng_fn(cls, rng, x, pdf, cdf, size=None) -> np.ndarray:
36883690
p = rng.uniform(size=size)
3689-
return _interpolated_argcdf(p, pdf, cdf, x)
3691+
return np.asarray(_interpolated_argcdf(p, pdf, cdf, x))
36903692

36913693

36923694
interpolated = InterpolatedRV()
@@ -3821,8 +3823,8 @@ class MoyalRV(RandomVariable):
38213823
_print_name = ("Moyal", "\\operatorname{Moyal}")
38223824

38233825
@classmethod
3824-
def rng_fn(cls, rng, mu, sigma, size=None):
3825-
return stats.moyal.rvs(mu, sigma, size=size, random_state=rng)
3826+
def rng_fn(cls, rng, mu, sigma, size=None) -> np.ndarray:
3827+
return np.asarray(stats.moyal.rvs(mu, sigma, size=size, random_state=rng))
38263828

38273829

38283830
moyal = MoyalRV()
@@ -3948,7 +3950,7 @@ def __call__(self, h=1.0, z=0.0, size=None, **kwargs):
39483950
return super().__call__(h, z, size=size, **kwargs)
39493951

39503952
@classmethod
3951-
def rng_fn(cls, rng, h, z, size=None):
3953+
def rng_fn(cls, rng, h, z, size=None) -> np.ndarray:
39523954
"""
39533955
Generate a random sample from the distribution with the given parameters
39543956
@@ -3976,7 +3978,9 @@ def rng_fn(cls, rng, h, z, size=None):
39763978
"""
39773979
# handle the kind of rng passed to the sampler
39783980
bg = rng._bit_generator if isinstance(rng, np.random.RandomState) else rng
3979-
return random_polyagamma(h, z, size=size, random_state=bg).astype(aesara.config.floatX)
3981+
return np.asarray(
3982+
random_polyagamma(h, z, size=size, random_state=bg).astype(aesara.config.floatX)
3983+
)
39803984

39813985

39823986
polyagamma = PolyaGammaRV()

0 commit comments

Comments
 (0)