|
1 |
| -from .autocorrplot import autocorrplot |
2 |
| -from .compareplot import compareplot |
3 |
| -from .forestplot import forestplot |
4 |
| -from .kdeplot import kdeplot |
5 |
| -from .posteriorplot import plot_posterior, plot_posterior_predictive_glm |
6 |
| -from .traceplot import traceplot |
7 |
| -from .energyplot import energyplot |
8 |
| -from .densityplot import densityplot |
9 |
| -from .pairplot import pairplot |
| 1 | +"""PyMC3 Plotting. |
| 2 | +
|
| 3 | +Plots are delegated to the ArviZ library, a general purpose library for |
| 4 | +"exploratory analysis of Bayesian models." See https://arviz-devs.github.io/arviz/ |
| 5 | +for details on plots. |
| 6 | +""" |
| 7 | +import functools |
| 8 | +import sys |
| 9 | +import warnings |
| 10 | +try: |
| 11 | + import arviz as az |
| 12 | +except ImportError: # arviz is optional, throw exception when used |
| 13 | + |
| 14 | + class _ImportWarner: |
| 15 | + __all__ = [] |
| 16 | + |
| 17 | + def __init__(self, attr): |
| 18 | + self.attr = attr |
| 19 | + |
| 20 | + def __call__(self, *args, **kwargs): |
| 21 | + raise ImportError( |
| 22 | + "ArviZ is not installed. In order to use `{0.attr}`:\npip install arviz".format(self) |
| 23 | + ) |
| 24 | + |
| 25 | + class _ArviZ: |
| 26 | + def __getattr__(self, attr): |
| 27 | + return _ImportWarner(attr) |
| 28 | + |
| 29 | + |
| 30 | + az = _ArviZ() |
| 31 | + |
| 32 | +def map_args(func): |
| 33 | + swaps = [ |
| 34 | + ('varnames', 'var_names') |
| 35 | + ] |
| 36 | + @functools.wraps(func) |
| 37 | + def wrapped(*args, **kwargs): |
| 38 | + for (old, new) in swaps: |
| 39 | + if old in kwargs and new not in kwargs: |
| 40 | + warnings.warn('Keyword argument `{old}` renamed to `{new}`, and will be removed in pymc3 3.8'.format(old=old, new=new)) |
| 41 | + kwargs[new] = kwargs.pop(old) |
| 42 | + return func(*args, **kwargs) |
| 43 | + return wrapped |
| 44 | + |
| 45 | +# pymc3 custom plots: override these names for custom behavior |
| 46 | +autocorrplot = map_args(az.plot_autocorr) |
| 47 | +compareplot = map_args(az.plot_compare) |
| 48 | +forestplot = map_args(az.plot_forest) |
| 49 | +kdeplot = map_args(az.plot_kde) |
| 50 | +plot_posterior = map_args(az.plot_posterior) |
| 51 | +traceplot = map_args(az.plot_trace) |
| 52 | +energyplot = map_args(az.plot_energy) |
| 53 | +densityplot = map_args(az.plot_density) |
| 54 | +pairplot = map_args(az.plot_pair) |
| 55 | + |
| 56 | +from .posteriorplot import plot_posterior_predictive_glm |
| 57 | + |
| 58 | + |
| 59 | +# Access to arviz plots: base plots provided by arviz |
| 60 | +for plot in az.plots.__all__: |
| 61 | + setattr(sys.modules[__name__], plot, map_args(getattr(az.plots, plot))) |
| 62 | + |
| 63 | +__all__ = tuple(az.plots.__all__) + ( |
| 64 | + 'autocorrplot', |
| 65 | + 'compareplot', |
| 66 | + 'forestplot', |
| 67 | + 'kdeplot', |
| 68 | + 'plot_posterior', |
| 69 | + 'traceplot', |
| 70 | + 'energyplot', |
| 71 | + 'densityplot', |
| 72 | + 'pairplot', |
| 73 | + 'plot_posterior_predictive_glm', |
| 74 | +) |
0 commit comments