Skip to content

Commit 98b2577

Browse files
Add a Theano Polya-Gamma random variable type
1 parent e5dc51b commit 98b2577

File tree

4 files changed

+76
-2
lines changed

4 files changed

+76
-2
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ python:
1010
# - "pypy3"
1111

1212
install:
13+
- pip install Cython
1314
- pip install -r requirements.txt
1415

1516
script:

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ setuptools>=45.2.0
22
six>=1.14.0
33
-e ./
44
sympy>=1.3
5+
pypolyagamma @ git+https://github.com/slinderman/pypolyagamma.git@b5883e661123862ca07d29ab14369fae85bdbc27#egg=pypolyagamma-1.2.2
56
coveralls
67
pydocstyle>=3.0.0
78
pytest>=5.0.0

symbolic_pymc/theano/random_variables.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from functools import partial
77

8+
from pypolyagamma import PyPolyaGamma
9+
810
from .ops import RandomVariable, param_supp_shape_fn
911

1012

@@ -204,7 +206,7 @@ def __init__(self):
204206
"invgamma",
205207
theano.config.floatX,
206208
0,
207-
[0, 0, 0],
209+
[0, 0],
208210
lambda rng, shape, rate, size: scipy.stats.invgamma.rvs(
209211
shape, scale=rate, size=size, random_state=rng
210212
),
@@ -343,6 +345,46 @@ def make_node(self, pvals, size=None, rng=None, name=None):
343345
CategoricalRV = CategoricalRVType()
344346

345347

348+
class PolyaGammaRVType(RandomVariable):
349+
"""Polya-Gamma random variable.
350+
351+
XXX: This doesn't really use the given RNG, due to the narrowness of the
352+
sampler package's implementation.
353+
"""
354+
355+
print_name = ("PG", "\\operatorname{PG}")
356+
357+
def __init__(self):
358+
super().__init__(
359+
"polya-gamma", theano.config.floatX, 0, [0, 0], self._smpl_fn, inplace=True,
360+
)
361+
362+
def make_node(self, b, c, size=None, rng=None, name=None):
363+
return super().make_node(b, c, size=size, rng=rng, name=name)
364+
365+
@classmethod
366+
def _smpl_fn(cls, rng, b, c, size):
367+
pg = PyPolyaGamma(rng.randint(2 ** 16))
368+
369+
if not size and b.shape == c.shape == ():
370+
return pg.pgdraw(b, c)
371+
else:
372+
b, c = np.broadcast_arrays(b, c)
373+
out_shape = b.shape + tuple(size or ())
374+
smpl_val = np.empty(out_shape, dtype="double")
375+
b = np.tile(b, tuple(size or ()) + (1,))
376+
c = np.tile(c, tuple(size or ()) + (1,))
377+
pg.pgdrawv(
378+
np.asarray(b.flat).astype("double", copy=True),
379+
np.asarray(c.flat).astype("double", copy=True),
380+
np.asarray(smpl_val.flat),
381+
)
382+
return smpl_val
383+
384+
385+
PolyaGammaRV = PolyaGammaRVType()
386+
387+
346388
class Observed(tt.Op):
347389
"""An `Op` that represents an observed random variable.
348390

tests/theano/test_rv.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import theano.tensor as tt
44

5-
from symbolic_pymc.theano.random_variables import NormalRV, MvNormalRV
5+
from pytest import importorskip
6+
7+
from symbolic_pymc.theano.random_variables import NormalRV, MvNormalRV, PolyaGammaRV
68

79

810
def rv_numpy_tester(rv, *params, size=None):
@@ -68,3 +70,31 @@ def test_mvnormalrv():
6870
# Looks like NumPy doesn't support that (and it's probably better off for
6971
# it).
7072
# rv_numpy_tester(MvNormalRV, [[0, 1, 2], [4, 5, 6]], np.diag([1, 1, 1]))
73+
74+
75+
def test_polyagammarv():
76+
77+
_ = importorskip("pypolyagamma")
78+
79+
# Sampled values should be scalars
80+
pg_rv = PolyaGammaRV(1.1, -10.5)
81+
assert pg_rv.eval().shape == ()
82+
83+
pg_rv = PolyaGammaRV(1.1, -10.5, size=[1])
84+
assert pg_rv.eval().shape == (1,)
85+
86+
pg_rv = PolyaGammaRV(1.1, -10.5, size=[2, 3])
87+
bcast_smpl = pg_rv.eval()
88+
assert bcast_smpl.shape == (2, 3)
89+
# Make sure they're not all equal
90+
assert np.all(np.abs(np.diff(bcast_smpl.flat)) > 0.0)
91+
92+
pg_rv = PolyaGammaRV(np.r_[1.1, 3], -10.5)
93+
bcast_smpl = pg_rv.eval()
94+
assert bcast_smpl.shape == (2,)
95+
assert np.all(np.abs(np.diff(bcast_smpl.flat)) > 0.0)
96+
97+
pg_rv = PolyaGammaRV(np.r_[1.1, 3], -10.5, size=(2, 3))
98+
bcast_smpl = pg_rv.eval()
99+
assert bcast_smpl.shape == (2, 2, 3)
100+
assert np.all(np.abs(np.diff(bcast_smpl.flat)) > 0.0)

0 commit comments

Comments
 (0)