Skip to content

Commit 02095c2

Browse files
Fix pymc3 to work with latest theano-pymc master (#4382)
* Update Theano-PyMC version * Update JAX imports * Update RandomStream imports and arguments * Update use of theano.config.change_flags in pymc3.tests.test_variational_inference Co-authored-by: Brandon T. Willard <[email protected]>
1 parent 91993d8 commit 02095c2

File tree

6 files changed

+37
-31
lines changed

6 files changed

+37
-31
lines changed

pymc3/sampling_jax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import numpy as np
1313
import pandas as pd
1414
import theano
15-
import theano.sandbox.jax_linker
16-
import theano.sandbox.jaxify
15+
16+
from theano.link.jax.jax_dispatch import jax_funcify
1717

1818
import pymc3 as pm
1919

@@ -46,7 +46,7 @@ def sample_tfp_nuts(
4646
seed = jax.random.PRNGKey(random_seed)
4747

4848
fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
49-
fns = theano.sandbox.jaxify.jax_funcify(fgraph)
49+
fns = jax_funcify(fgraph)
5050
logp_fn_jax = fns[0]
5151

5252
rv_names = [rv.name for rv in model.free_RVs]
@@ -131,7 +131,7 @@ def sample_numpyro_nuts(
131131
seed = jax.random.PRNGKey(random_seed)
132132

133133
fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
134-
fns = theano.sandbox.jaxify.jax_funcify(fgraph)
134+
fns = jax_funcify(fgraph)
135135
logp_fn_jax = fns[0]
136136

137137
rv_names = [rv.name for rv in model.free_RVs]

pymc3/tests/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import theano
2121

2222
from theano.gradient import verify_grad as tt_verify_grad
23-
from theano.sandbox.rng_mrg import MRG_RandomStreams
23+
from theano.sandbox.rng_mrg import MRG_RandomStream as RandomStream
2424

2525
from pymc3.theanof import set_tt_rng, tt_rng
2626

@@ -35,7 +35,7 @@ def setup_class(cls):
3535
def setup_method(self):
3636
nr.seed(self.random_seed)
3737
self.old_tt_rng = tt_rng()
38-
set_tt_rng(MRG_RandomStreams(self.random_seed))
38+
set_tt_rng(RandomStream(self.random_seed))
3939

4040
def teardown_method(self):
4141
set_tt_rng(self.old_tt_rng)

pymc3/tests/test_variational_inference.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from pymc3.tests import models
2929
from pymc3.tests.helpers import not_raises
30-
from pymc3.theanof import change_flags, intX
30+
from pymc3.theanof import intX
3131
from pymc3.variational import flows, opvi
3232
from pymc3.variational.approximations import (
3333
Empirical,
@@ -533,17 +533,20 @@ def test_scale_cost_to_minibatch_works(aux_total_size):
533533
sigma = 1.0
534534
y_obs = np.array([1.6, 1.4])
535535
beta = len(y_obs) / float(aux_total_size)
536-
post_mu = np.array([1.88], dtype=theano.config.floatX)
537-
post_sigma = np.array([1], dtype=theano.config.floatX)
538536

539537
# TODO: theano_config
540538
# with pm.Model(theano_config=dict(floatX='float64')):
541539
# did not not work as expected
542540
# there were some numeric problems, so float64 is forced
543-
with pm.theanof.change_flags(floatX="float64", warn_float64="ignore"):
541+
with theano.config.change_flags(floatX="float64", warn_float64="ignore"):
542+
543+
assert theano.config.floatX == "float64"
544+
assert theano.config.warn_float64 == "ignore"
545+
546+
post_mu = np.array([1.88], dtype=theano.config.floatX)
547+
post_sigma = np.array([1], dtype=theano.config.floatX)
548+
544549
with pm.Model():
545-
assert theano.config.floatX == "float64"
546-
assert theano.config.warn_float64 == "ignore"
547550
mu = pm.Normal("mu", mu=mu0, sigma=sigma)
548551
pm.Normal("y", mu=mu, sigma=1, observed=y_obs, total_size=aux_total_size)
549552
# Create variational gradient tensor
@@ -552,7 +555,7 @@ def test_scale_cost_to_minibatch_works(aux_total_size):
552555
mean_field_1.shared_params["mu"].set_value(post_mu)
553556
mean_field_1.shared_params["rho"].set_value(np.log(np.exp(post_sigma) - 1))
554557

555-
with pm.theanof.change_flags(compute_test_value="off"):
558+
with theano.config.change_flags(compute_test_value="off"):
556559
elbo_via_total_size_scaled = -pm.operators.KL(mean_field_1)()(10000)
557560

558561
with pm.Model():
@@ -566,7 +569,7 @@ def test_scale_cost_to_minibatch_works(aux_total_size):
566569
mean_field_2.shared_params["mu"].set_value(post_mu)
567570
mean_field_2.shared_params["rho"].set_value(np.log(np.exp(post_sigma) - 1))
568571

569-
with pm.theanof.change_flags(compute_test_value="off"):
572+
with theano.config.change_flags(compute_test_value="off"):
570573
elbo_via_total_size_unscaled = -pm.operators.KL(mean_field_2)()(10000)
571574

572575
np.testing.assert_allclose(
@@ -583,9 +586,12 @@ def test_elbo_beta_kl(aux_total_size):
583586
sigma = 1.0
584587
y_obs = np.array([1.6, 1.4])
585588
beta = len(y_obs) / float(aux_total_size)
586-
post_mu = np.array([1.88], dtype=theano.config.floatX)
587-
post_sigma = np.array([1], dtype=theano.config.floatX)
588-
with pm.theanof.change_flags(floatX="float64", warn_float64="ignore"):
589+
590+
with theano.config.change_flags(floatX="float64", warn_float64="ignore"):
591+
592+
post_mu = np.array([1.88], dtype=theano.config.floatX)
593+
post_sigma = np.array([1], dtype=theano.config.floatX)
594+
589595
with pm.Model():
590596
mu = pm.Normal("mu", mu=mu0, sigma=sigma)
591597
pm.Normal("y", mu=mu, sigma=1, observed=y_obs, total_size=aux_total_size)
@@ -595,7 +601,7 @@ def test_elbo_beta_kl(aux_total_size):
595601
mean_field_1.shared_params["mu"].set_value(post_mu)
596602
mean_field_1.shared_params["rho"].set_value(np.log(np.exp(post_sigma) - 1))
597603

598-
with pm.theanof.change_flags(compute_test_value="off"):
604+
with theano.config.change_flags(compute_test_value="off"):
599605
elbo_via_total_size_scaled = -pm.operators.KL(mean_field_1)()(10000)
600606

601607
with pm.Model():
@@ -606,7 +612,7 @@ def test_elbo_beta_kl(aux_total_size):
606612
mean_field_3.shared_params["mu"].set_value(post_mu)
607613
mean_field_3.shared_params["rho"].set_value(np.log(np.exp(post_sigma) - 1))
608614

609-
with pm.theanof.change_flags(compute_test_value="off"):
615+
with theano.config.change_flags(compute_test_value="off"):
610616
elbo_via_beta_kl = -pm.operators.KL(mean_field_3, beta=beta)()(10000)
611617

612618
np.testing.assert_allclose(
@@ -1014,7 +1020,7 @@ def init_(**kw):
10141020
def test_flow_det(flow_spec):
10151021
z0 = tt.arange(0, 20).astype("float32")
10161022
flow = flow_spec(dim=20, z0=z0.dimshuffle("x", 0))
1017-
with change_flags(compute_test_value="off"):
1023+
with theano.config.change_flags(compute_test_value="off"):
10181024
z1 = flow.forward.flatten()
10191025
J = tt.jacobian(z1, z0)
10201026
logJdet = tt.log(tt.abs_(tt.nlinalg.det(J)))
@@ -1030,7 +1036,7 @@ def test_flow_det_local(flow_spec):
10301036
params[k] = np.random.randn(1, *shp).astype("float32")
10311037
flow = flow_spec(dim=12, z0=z0.reshape((1, 1, 12)), **params)
10321038
assert flow.batched
1033-
with change_flags(compute_test_value="off"):
1039+
with theano.config.change_flags(compute_test_value="off"):
10341040
z1 = flow.forward.flatten()
10351041
J = tt.jacobian(z1, z0)
10361042
logJdet = tt.log(tt.abs_(tt.nlinalg.det(J)))

pymc3/theanof.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from theano import tensor as tt
2020
from theano.gof import Op
2121
from theano.gof.graph import inputs
22-
from theano.sandbox.rng_mrg import MRG_RandomStreams
22+
from theano.sandbox.rng_mrg import MRG_RandomStream as RandomStream
2323

2424
from pymc3.blocking import ArrayOrdering
2525
from pymc3.data import GeneratorAdapter
@@ -394,7 +394,7 @@ def generator(gen, default=None):
394394
return GeneratorOp(gen, default)()
395395

396396

397-
_tt_rng = MRG_RandomStreams()
397+
_tt_rng = RandomStream()
398398

399399

400400
def tt_rng(random_seed=None):
@@ -409,14 +409,14 @@ def tt_rng(random_seed=None):
409409
410410
Returns
411411
-------
412-
`theano.sandbox.rng_mrg.MRG_RandomStreams` instance
413-
`theano.sandbox.rng_mrg.MRG_RandomStreams`
412+
`theano.tensor.random.utils.RandomStream` instance
413+
`theano.tensor.random.utils.RandomStream`
414414
instance passed to the most recent call of `set_tt_rng`
415415
"""
416416
if random_seed is None:
417417
return _tt_rng
418418
else:
419-
ret = MRG_RandomStreams(random_seed)
419+
ret = RandomStream(random_seed)
420420
return ret
421421

422422

@@ -426,14 +426,14 @@ def set_tt_rng(new_rng):
426426
427427
Parameters
428428
----------
429-
new_rng: `theano.sandbox.rng_mrg.MRG_RandomStreams` instance
429+
new_rng: `theano.tensor.random.utils.RandomStream` instance
430430
The random number generator to use.
431431
"""
432432
# pylint: disable=global-statement
433433
global _tt_rng
434434
# pylint: enable=global-statement
435435
if isinstance(new_rng, int):
436-
new_rng = MRG_RandomStreams(new_rng)
436+
new_rng = RandomStream(new_rng)
437437
_tt_rng = new_rng
438438

439439

pymc3/variational/opvi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,9 +1078,9 @@ def _new_initial(self, size, deterministic, more_replacements=None):
10781078
if deterministic:
10791079
return tt.ones(shape, dtype) * dist_map
10801080
else:
1081-
return getattr(self._rng, dist_name)(shape)
1081+
return getattr(self._rng, dist_name)(size=shape)
10821082
else:
1083-
sample = getattr(self._rng, dist_name)(shape)
1083+
sample = getattr(self._rng, dist_name)(size=shape)
10841084
initial = tt.switch(deterministic, tt.ones(shape, dtype) * dist_map, sample)
10851085
return initial
10861086

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ numpy>=1.13.0
55
pandas>=0.18.0
66
patsy>=0.5.1
77
scipy>=0.18.1
8-
theano-pymc==1.0.12
8+
theano-pymc==1.0.14
99
typing-extensions>=3.7.4

0 commit comments

Comments
 (0)