Skip to content

Commit c40d692

Browse files
Dhruvanshu-Joshitwiecki
authored andcommitted
Implement Rayleigh distribution in Pytensor
1 parent 82a5757 commit c40d692

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

pytensor/tensor/random/basic.py

+38
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytensor
88
from pytensor.tensor.basic import arange, as_tensor_variable
9+
from pytensor.tensor.math import sqrt
910
from pytensor.tensor.random.op import RandomVariable
1011
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
1112
from pytensor.tensor.random.utils import (
@@ -526,6 +527,43 @@ def chisquare(df, size=None, **kwargs):
526527
return gamma(shape=df / 2.0, scale=2.0, size=size, **kwargs)
527528

528529

530+
def rayleigh(scale=1.0, *, size=None, **kwargs):
531+
r"""Draw samples from a Rayleigh distribution.
532+
533+
The probability density function for `rayleigh` with parameter `scale` is given by:
534+
535+
.. math::
536+
f(x; s) = \frac{x}{s^2} e^{-x^2/(2 s^2)}
537+
538+
where :math:`s` is the scale parameter.
539+
540+
This variable is obtained by taking the square root of the sum of the squares of
541+
two independent, standard normally distributed random variables.
542+
543+
Signature
544+
---------
545+
`() -> ()`
546+
547+
Parameters
548+
----------
549+
scale : float or array_like of floats, optional
550+
Scale parameter of the distribution (positive). Default is 1.0.
551+
size : int or tuple of ints, optional
552+
Output shape. If the given shape is, e.g., `(m, n, k)`, then `m * n * k` samples
553+
are drawn. Default is None, in which case the output shape is determined by the
554+
shape of `scale`.
555+
556+
Notes
557+
-----
558+
`Rayleigh` is a special case of `chisquare` with ``df=2``.
559+
"""
560+
561+
scale = as_tensor_variable(scale)
562+
if size is None:
563+
size = scale.shape
564+
return sqrt(chisquare(df=2, size=size, **kwargs)) * scale
565+
566+
529567
class ParetoRV(ScipyRandomVariable):
530568
r"""A pareto continuous random variable.
531569

tests/tensor/random/test_basic.py

+15
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
permutation,
5151
poisson,
5252
randint,
53+
rayleigh,
5354
standard_normal,
5455
t,
5556
triangular,
@@ -390,6 +391,20 @@ def test_chisquare_samples(df, size):
390391
compare_sample_values(chisquare, df, size=size, test_fn=fixed_scipy_rvs("chi2"))
391392

392393

394+
@pytest.mark.parametrize(
395+
"scale, size",
396+
[
397+
(1, None),
398+
(2, []),
399+
(4, 100),
400+
],
401+
)
402+
def test_rayleigh_samples(scale, size):
403+
compare_sample_values(
404+
rayleigh, scale=scale, size=size, test_fn=fixed_scipy_rvs("rayleigh")
405+
)
406+
407+
393408
@pytest.mark.parametrize(
394409
"mu, beta, size",
395410
[

0 commit comments

Comments
 (0)