Skip to content

Commit d8868cc

Browse files
authored
Add inf special cases to gamma.c function (#634)
1 parent 453fb4d commit d8868cc

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

pytensor/scalar/c_code/gamma.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,11 @@ DEVICE double GammaP (double n, double x)
218218
{ /* --- regularized Gamma function P */
219219
if ((n <= 0) || (x < 0)) return NPY_NAN; /* check the function arguments */
220220
if (x <= 0) return 0; /* treat x = 0 as a special case */
221+
if (isinf(n)) {
222+
if (isinf(x)) return NPY_NAN;
223+
return 0;
224+
}
225+
if (isinf(x)) return 1;
221226
if (x < n+1) return _series(n, x) *exp(n *log(x) -x -logGamma(n));
222227
return 1 -_cfrac(n, x) *exp(n *log(x) -x -logGamma(n));
223228
} /* GammaP() */
@@ -228,6 +233,11 @@ DEVICE double GammaQ (double n, double x)
228233
{ /* --- regularized Gamma function Q */
229234
if ((n <= 0) || (x < 0)) return NPY_NAN; /* check the function arguments */
230235
if (x <= 0) return 1; /* treat x = 0 as a special case */
236+
if (isinf(n)) {
237+
if (isinf(x)) return NPY_NAN;
238+
return 1;
239+
}
240+
if (isinf(x)) return 0;
231241
if (x < n+1) return 1 -_series(n, x) *exp(n *log(x) -x -logGamma(n));
232242
return _cfrac(n, x) *exp(n *log(x) -x -logGamma(n));
233243
} /* GammaQ() */

pytensor/scalar/math.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,13 @@ def __eq__(self, other):
631631
def __hash__(self):
632632
return hash(type(self))
633633

634+
def c_code_cache_version(self):
635+
v = super().c_code_cache_version()
636+
if v:
637+
return (2, *v)
638+
else:
639+
return v
640+
634641

635642
chi2sf = Chi2SF(upgrade_to_float64, name="chi2sf")
636643

@@ -677,6 +684,13 @@ def __eq__(self, other):
677684
def __hash__(self):
678685
return hash(type(self))
679686

687+
def c_code_cache_version(self):
688+
v = super().c_code_cache_version()
689+
if v:
690+
return (2, *v)
691+
else:
692+
return v
693+
680694

681695
gammainc = GammaInc(upgrade_to_float, name="gammainc")
682696

@@ -723,6 +737,13 @@ def __eq__(self, other):
723737
def __hash__(self):
724738
return hash(type(self))
725739

740+
def c_code_cache_version(self):
741+
v = super().c_code_cache_version()
742+
if v:
743+
return (2, *v)
744+
else:
745+
return v
746+
726747

727748
gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")
728749

tests/scalar/test_math.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ def test_gammainc_nan_c():
4141
assert np.isnan(test_func(-1, -1))
4242

4343

44+
def test_gammainc_inf_c():
45+
x1 = pt.dscalar()
46+
x2 = pt.dscalar()
47+
y = gammainc(x1, x2)
48+
test_func = make_function(CLinker().accept(FunctionGraph([x1, x2], [y])))
49+
assert np.isclose(test_func(np.inf, 1), sp.gammainc(np.inf, 1))
50+
assert np.isclose(test_func(1, np.inf), sp.gammainc(1, np.inf))
51+
assert np.isnan(test_func(np.inf, np.inf))
52+
53+
4454
def test_gammaincc_python():
4555
x1 = pt.dscalar()
4656
x2 = pt.dscalar()
@@ -59,6 +69,16 @@ def test_gammaincc_nan_c():
5969
assert np.isnan(test_func(-1, -1))
6070

6171

72+
def test_gammaincc_inf_c():
73+
x1 = pt.dscalar()
74+
x2 = pt.dscalar()
75+
y = gammaincc(x1, x2)
76+
test_func = make_function(CLinker().accept(FunctionGraph([x1, x2], [y])))
77+
assert np.isclose(test_func(np.inf, 1), sp.gammaincc(np.inf, 1))
78+
assert np.isclose(test_func(1, np.inf), sp.gammaincc(1, np.inf))
79+
assert np.isnan(test_func(np.inf, np.inf))
80+
81+
6282
def test_gammal_nan_c():
6383
x1 = pt.dscalar()
6484
x2 = pt.dscalar()

0 commit comments

Comments
 (0)