Skip to content

Commit 9eb69fc

Browse files
SpaakMarcoGorelli
andauthored
adding meaningful str representations to PyMC3 objects (#4076)
* adding unit tests for new __str__ functionality * use get_var_name instead of str to get variable name * adding semantically meaningful __str__ * correcting import syntax error * patch type of Deterministic to ensure proper handling of str() * adding test for tuning.starting.allinmodel * more precise exception checks in unit test Co-authored-by: Marco Gorelli <[email protected]> Co-authored-by: Marco Gorelli <[email protected]>
1 parent 07b584a commit 9eb69fc

14 files changed

+104
-26
lines changed

pymc3/backends/base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from ..model import modelcontext, Model
3030
from .report import SamplerReport, merge_reports
31+
from ..util import get_var_name
3132

3233
logger = logging.getLogger('pymc3')
3334

@@ -109,7 +110,7 @@ def _set_sampler_vars(self, sampler_vars):
109110
self.sampler_vars = sampler_vars
110111

111112
# pylint: disable=unused-argument
112-
def setup(self, draws, chain, sampler_vars=None) -> None:
113+
def setup(self, draws, chain, sampler_vars=None) -> None:
113114
"""Perform chain-specific setup.
114115
115116
Parameters
@@ -335,7 +336,7 @@ def __getitem__(self, idx):
335336
var = idx
336337
burn, thin = 0, 1
337338

338-
var = str(var)
339+
var = get_var_name(var)
339340
if var in self.varnames:
340341
if var in self.stat_names:
341342
warnings.warn("Attribute access on a trace object is ambigous. "
@@ -355,7 +356,7 @@ def __getattr__(self, name):
355356
if name in self._attrs:
356357
raise AttributeError
357358

358-
name = str(name)
359+
name = get_var_name(name)
359360
if name in self.varnames:
360361
if name in self.stat_names:
361362
warnings.warn("Attribute access on a trace object is ambigous. "
@@ -482,7 +483,7 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
482483
"""
483484
if chains is None:
484485
chains = self.chains
485-
varname = str(varname)
486+
varname = get_var_name(varname)
486487
try:
487488
results = [self._straces[chain].get_values(varname, burn, thin)
488489
for chain in chains]

pymc3/backends/sqlite.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
from ..backends import base, ndarray
3838
from . import tracetab as ttab
39+
from ..util import get_var_name
3940

4041
TEMPLATES = {
4142
'table': ('CREATE TABLE IF NOT EXISTS [{table}] '
@@ -244,7 +245,7 @@ def get_values(self, varname, burn=0, thin=1):
244245
if thin < 1:
245246
raise ValueError('Only positive thin values are supported '
246247
'in SQLite backend.')
247-
varname = str(varname)
248+
varname = get_var_name(varname)
248249

249250
statement_args = {'chain': self.chain}
250251
if burn == 0 and thin == 1:

pymc3/blocking.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import numpy as np
2222
import collections
2323

24+
from .util import get_var_name
25+
2426
__all__ = ['ArrayOrdering', 'DictToArrayBijection', 'DictToVarBijection']
2527

2628
VarMap = collections.namedtuple('VarMap', 'var, slc, shp, dtyp')
@@ -237,7 +239,7 @@ class DictToVarBijection:
237239
"""
238240

239241
def __init__(self, var, idx, dpoint):
240-
self.var = str(var)
242+
self.var = get_var_name(var)
241243
self.idx = idx
242244
self.dpt = dpoint
243245

pymc3/distributions/distribution.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import numpy as np
2424
import theano.tensor as tt
2525
from theano import function
26-
from ..util import get_repr_for_variable
26+
from ..util import get_repr_for_variable, get_var_name
2727
import theano
2828
from ..memoize import memoize
2929
from ..model import (
@@ -174,6 +174,9 @@ def _str_repr(self, name=None, dist=None, formatting='plain'):
174174
return "{var_name} ~ {distr_name}({params})".format(var_name=name,
175175
distr_name=dist._distr_name_for_repr(), params=param_string)
176176

177+
def __str__(self, **kwargs):
178+
return self._str_repr(formatting="plain", **kwargs)
179+
177180
def _repr_latex_(self, **kwargs):
178181
"""Magic method name for IPython to use for LaTeX formatting."""
179182
return self._str_repr(formatting="latex", **kwargs)
@@ -728,7 +731,7 @@ def draw_values(params, point=None, size=None):
728731
missing_inputs = set([j for j, p in symbolic_params])
729732
while to_eval or missing_inputs:
730733
if to_eval == missing_inputs:
731-
raise ValueError('Cannot resolve inputs for {}'.format([str(params[j]) for j in to_eval]))
734+
raise ValueError('Cannot resolve inputs for {}'.format([get_var_name(params[j]) for j in to_eval]))
732735
to_eval = set(missing_inputs)
733736
missing_inputs = set()
734737
for param_idx in to_eval:

pymc3/distributions/posterior_predictive.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
)
4343
from ..exceptions import IncorrectArgumentsError
4444
from ..vartypes import theano_constant
45-
from ..util import dataset_to_point_dict, chains_and_samples
45+
from ..util import dataset_to_point_dict, chains_and_samples, get_var_name
4646

4747
# Failing tests:
4848
# test_mixture_random_shape::test_mixture_random_shape
@@ -460,7 +460,7 @@ def draw_values(self) -> List[np.ndarray]:
460460
if to_eval == missing_inputs:
461461
raise ValueError(
462462
"Cannot resolve inputs for {}".format(
463-
[str(trace.varnames[j]) for j in to_eval]
463+
[get_var_name(trace.varnames[j]) for j in to_eval]
464464
)
465465
)
466466
to_eval = set(missing_inputs)
@@ -493,7 +493,7 @@ def draw_values(self) -> List[np.ndarray]:
493493
return [self.evaluated[j] for j in params]
494494

495495
def init(self) -> None:
496-
"""This method carries out the initialization phase of sampling
496+
"""This method carries out the initialization phase of sampling
497497
from the posterior predictive distribution. Notably it initializes the
498498
``_DrawValuesContext`` bookkeeping object and evaluates the "fast drawable"
499499
parts of the model."""
@@ -567,7 +567,7 @@ def draw_value(self, param, trace: Optional[_TraceDict] = None, givens=None):
567567
The value or distribution. Constants or shared variables
568568
will be converted to an array and returned. Theano variables
569569
are evaluated. If `param` is a pymc3 random variable, draw
570-
values from it and return that (as ``np.ndarray``), unless a
570+
values from it and return that (as ``np.ndarray``), unless a
571571
value is specified in the ``trace``.
572572
trace: pm.MultiTrace, optional
573573
A dictionary from pymc3 variable names to samples of their values

pymc3/model.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from .theanof import gradient, hessian, inputvars, generator
3737
from .vartypes import typefilter, discrete_types, continuous_types, isgenerator
3838
from .blocking import DictToArrayBijection, ArrayOrdering
39-
from .util import get_transformed_name
39+
from .util import get_transformed_name, get_var_name
4040
from .exceptions import ImputationWarning
4141

4242
__all__ = [
@@ -80,6 +80,9 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
8080
def _repr_latex_(self, **kwargs):
8181
return self._str_repr(formatting="latex", **kwargs)
8282

83+
def __str__(self, **kwargs):
84+
return self._str_repr(formatting="plain", **kwargs)
85+
8386
__latex__ = _repr_latex_
8487

8588

@@ -1365,6 +1368,9 @@ def _str_repr(self, formatting="plain", **kwargs):
13651368
for n, d in zip(names, distrs)]
13661369
return "\n".join(rv_reprs)
13671370

1371+
def __str__(self, **kwargs):
1372+
return self._str_repr(formatting="plain", **kwargs)
1373+
13681374
def _repr_latex_(self, **kwargs):
13691375
return self._str_repr(formatting="latex", **kwargs)
13701376

@@ -1477,7 +1483,8 @@ def Point(*args, **kwargs):
14771483
except Exception as e:
14781484
raise TypeError("can't turn {} and {} into a dict. {}".format(args, kwargs, e))
14791485
return dict(
1480-
(str(k), np.array(v)) for k, v in d.items() if str(k) in map(str, model.vars)
1486+
(get_var_name(k), np.array(v)) for k, v in d.items()
1487+
if get_var_name(k) in map(get_var_name, model.vars)
14811488
)
14821489

14831490

@@ -1869,6 +1876,14 @@ def Deterministic(name, var, model=None, dims=None):
18691876
model.add_random_variable(var, dims)
18701877
var._repr_latex_ = functools.partial(_repr_deterministic_rv, var, formatting='latex')
18711878
var.__latex__ = var._repr_latex_
1879+
1880+
# simply assigning var.__str__ is not enough, since str() will default to the class-
1881+
# defined __str__ anyway; see https://stackoverflow.com/a/5918210/1692028
1882+
old_type = type(var)
1883+
new_type = type(old_type.__name__ + '_pymc3_Deterministic', (old_type,),
1884+
{'__str__': functools.partial(_repr_deterministic_rv, var, formatting='plain')})
1885+
var.__class__ = new_type
1886+
18721887
return var
18731888

18741889

pymc3/model_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from theano.compile import SharedVariable
2222
from theano.tensor import Tensor
2323

24-
from .util import get_default_varnames
24+
from .util import get_default_varnames, get_var_name
2525
from .model import ObservedRV
2626
import pymc3 as pm
2727

@@ -83,7 +83,7 @@ def _filter_parents(self, var, parents) -> Set[VarName]:
8383
if self.transform_map[p] != var.name:
8484
keep.add(self.transform_map[p])
8585
else:
86-
raise AssertionError('Do not know what to do with {}'.format(str(p)))
86+
raise AssertionError('Do not know what to do with {}'.format(get_var_name(p)))
8787
return keep
8888

8989
def get_parents(self, var: Tensor) -> Set[VarName]:

pymc3/step_methods/arraystep.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ..model import modelcontext
1717
from ..theanof import inputvars
1818
from ..blocking import ArrayOrdering, DictToArrayBijection
19+
from ..util import get_var_name
1920
import numpy as np
2021
from numpy.random import uniform
2122
from enum import IntEnum, unique
@@ -175,7 +176,7 @@ def __init__(self, vars, shared, blocked=True):
175176
"""
176177
self.vars = vars
177178
self.ordering = ArrayOrdering(vars)
178-
self.shared = {str(var): shared for var, shared in shared.items()}
179+
self.shared = {get_var_name(var): shared for var, shared in shared.items()}
179180
self.blocked = blocked
180181
self.bij = None
181182

pymc3/tests/sampler_fixtures.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import pymc3 as pm
16+
from pymc3.util import get_var_name
1617
import numpy as np
1718
import numpy.testing as npt
1819
from scipy import stats
@@ -145,7 +146,7 @@ def setup_class(cls):
145146
cls.trace = pm.sample(cls.n_samples, tune=cls.tune, step=cls.step, cores=cls.chains)
146147
cls.samples = {}
147148
for var in cls.model.unobserved_RVs:
148-
cls.samples[str(var)] = cls.trace.get_values(var, burn=cls.burn)
149+
cls.samples[get_var_name(var)] = cls.trace.get_values(var, burn=cls.burn)
149150

150151
def test_neff(self):
151152
if hasattr(self, 'min_n_eff'):

pymc3/tests/test_distributions.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1771,7 +1771,7 @@ def test_bound():
17711771
BoundPoissonPositionalArgs = Bound(Poisson, upper=6)("x", 2.0)
17721772

17731773

1774-
class TestLatex:
1774+
class TestStrAndLatexRepr:
17751775
def setup_class(self):
17761776
# True parameter values
17771777
alpha, sigma = 1, 1
@@ -1800,30 +1800,46 @@ def setup_class(self):
18001800
# Likelihood (sampling distribution) of observations
18011801
Y_obs = Normal("Y_obs", mu=mu, sigma=sigma, observed=Y)
18021802
self.distributions = [alpha, sigma, mu, b, Z, Y_obs]
1803-
self.expected = (
1803+
self.expected_latex = (
18041804
r"$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
18051805
r"$\text{sigma} \sim \text{HalfNormal}(\mathit{sigma}=1.0)$",
18061806
r"$\text{mu} \sim \text{Deterministic}(\text{alpha},~\text{Constant},~\text{beta})$",
18071807
r"$\text{beta} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
18081808
r"$\text{Z} \sim \text{MvNormal}(\mathit{mu}=array,~\mathit{chol_cov}=array)$",
18091809
r"$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=\text{mu},~\mathit{sigma}=f(\text{sigma}))$",
18101810
)
1811+
self.expected_str = (
1812+
r"alpha ~ Normal(mu=0.0, sigma=10.0)",
1813+
r"sigma ~ HalfNormal(sigma=1.0)",
1814+
r"mu ~ Deterministic(alpha, Constant, beta)",
1815+
r"beta ~ Normal(mu=0.0, sigma=10.0)",
1816+
r"Z ~ MvNormal(mu=array, chol_cov=array)",
1817+
r"Y_obs ~ Normal(mu=mu, sigma=f(sigma))",
1818+
)
18111819

18121820
def test__repr_latex_(self):
1813-
for distribution, tex in zip(self.distributions, self.expected):
1821+
for distribution, tex in zip(self.distributions, self.expected_latex):
18141822
assert distribution._repr_latex_() == tex
18151823

18161824
model_tex = self.model._repr_latex_()
18171825

1818-
for tex in self.expected: # make sure each variable is in the model
1826+
for tex in self.expected_latex: # make sure each variable is in the model
18191827
for segment in tex.strip("$").split(r"\sim"):
18201828
assert segment in model_tex
18211829

18221830
def test___latex__(self):
1823-
for distribution, tex in zip(self.distributions, self.expected):
1831+
for distribution, tex in zip(self.distributions, self.expected_latex):
18241832
assert distribution._repr_latex_() == distribution.__latex__()
18251833
assert self.model._repr_latex_() == self.model.__latex__()
18261834

1835+
def test___str__(self):
1836+
for distribution, str_repr in zip(self.distributions, self.expected_str):
1837+
assert distribution.__str__() == str_repr
1838+
1839+
model_str = self.model.__str__()
1840+
for str_repr in self.expected_str:
1841+
assert str_repr in model_str
1842+
18271843

18281844
def test_discrete_trafo():
18291845
with pytest.raises(ValueError) as err:

pymc3/tests/test_starting.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from .models import simple_model, non_normal, simple_arbitrary_det
2020
from .helpers import select_by_precision
2121

22+
from pytest import raises
23+
2224

2325
def test_accuracy_normal():
2426
_, model, (mu, _) = simple_model()
@@ -83,3 +85,23 @@ def test_find_MAP():
8385

8486
close_to(map_est2['mu'], 0, tol)
8587
close_to(map_est2['sigma'], 1, tol)
88+
89+
90+
def test_allinmodel():
91+
model1 = Model()
92+
model2 = Model()
93+
with model1:
94+
x1 = Normal('x1', mu=0, sigma=1)
95+
y1 = Normal('y1', mu=0, sigma=1)
96+
with model2:
97+
x2 = Normal('x2', mu=0, sigma=1)
98+
y2 = Normal('y2', mu=0, sigma=1)
99+
100+
starting.allinmodel([x1, y1], model1)
101+
starting.allinmodel([x1], model1)
102+
with raises(ValueError, match=r"Some variables not in the model: \['x2', 'y2'\]"):
103+
starting.allinmodel([x2, y2], model1)
104+
with raises(ValueError, match=r"Some variables not in the model: \['x2'\]"):
105+
starting.allinmodel([x2, y1], model1)
106+
with raises(ValueError, match=r"Some variables not in the model: \['x2'\]"):
107+
starting.allinmodel([x2], model1)

pymc3/tuning/scaling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ..model import modelcontext, Point
1818
from ..theanof import hessian_diag, inputvars
1919
from ..blocking import DictToArrayBijection, ArrayOrdering
20+
from ..util import get_var_name
2021

2122
__all__ = ['find_hessian', 'trace_cov', 'guess_scaling']
2223

@@ -135,7 +136,7 @@ def trace_cov(trace, vars=None, model=None):
135136
vars = trace.varnames
136137

137138
def flat_t(var):
138-
x = trace[str(var)]
139+
x = trace[get_var_name(var)]
139140
return x.reshape((x.shape[0], np.prod(x.shape[1:], dtype=int)))
140141

141142
return np.cov(np.concatenate(list(map(flat_t, vars)), 1).T)

pymc3/tuning/starting.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ..theanof import inputvars
2929
import theano.gradient as tg
3030
from ..blocking import DictToArrayBijection, ArrayOrdering
31-
from ..util import update_start_vals, get_default_varnames
31+
from ..util import update_start_vals, get_default_varnames, get_var_name
3232

3333
import warnings
3434
from inspect import getargspec
@@ -196,6 +196,7 @@ def nan_to_high(x):
196196
def allinmodel(vars, model):
197197
notin = [v for v in vars if v not in model.vars]
198198
if notin:
199+
notin = list(map(get_var_name, notin))
199200
raise ValueError("Some variables not in the model: " + str(notin))
200201

201202

0 commit comments

Comments
 (0)