Skip to content

Commit 24a2234

Browse files
ricardoV94jessegrabowski
authored andcommitted
Numba dispatch of StudentT
1 parent 7b0a392 commit 24a2234

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

pytensor/link/numba/dispatch/random.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@ def random(rng, p):
102102
return random
103103

104104

105+
@numba_core_rv_funcify.register(ptr.StudentTRV)
106+
def numba_core_StudentTRV(op, node):
107+
@numba_basic.numba_njit
108+
def random_fn(rng, df, loc, scale):
109+
return loc + scale * rng.standard_t(df)
110+
111+
return random_fn
112+
113+
105114
@numba_core_rv_funcify.register(ptr.HalfNormalRV)
106115
def numba_core_HalfNormalRV(op, node):
107116
@numba_basic.numba_njit

tests/link/numba/test_random.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,23 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
592592
"gumbel_r",
593593
lambda *args: args,
594594
),
595+
(
596+
ptr.t,
597+
[
598+
(pt.scalar(), np.array(np.e, dtype=np.float64)),
599+
(
600+
pt.dvector(),
601+
np.array([1.0, 2.0], dtype=np.float64),
602+
),
603+
(
604+
pt.dscalar(),
605+
np.array(np.pi, dtype=np.float64),
606+
),
607+
],
608+
(2,),
609+
"t",
610+
lambda *args: args,
611+
),
595612
],
596613
)
597614
def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_conv):

0 commit comments

Comments
 (0)