jupytext | kernelspec | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
(weibull_aft)=
:::{post} January 17, 2023 :tags: censored, survival analysis, weibull :category: intermediate, how-to :author: Junpeng Lao, George Ho, Chris Fonnesbeck :::
import arviz as az
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import statsmodels.api as sm
print(f"Running on PyMC v{pm.__version__}")
%config InlineBackend.figure_format = 'retina'
RANDOM_SEED = 8927
np.random.seed(RANDOM_SEED)
az.style.use("arviz-darkgrid")
The {ref}previous example notebook on Bayesian parametric survival analysis <bayes_param_survival_pymc3>
introduced two different accelerated failure time (AFT) models: Weibull and log-linear. In this notebook, we present three different parameterizations of the Weibull AFT model.
The data set we'll use is the flchain
R data set, which comes from a medical study investigating the effect of serum free light chain (FLC) on lifespan. Read the full documentation of the data by running:
print(sm.datasets.get_rdataset(package='survival', dataname='flchain').__doc__)
.
# Fetch and clean data
data = (
sm.datasets.get_rdataset(package="survival", dataname="flchain")
.data.sample(500) # Limit ourselves to 500 observations
.reset_index(drop=True)
)
y = data.futime.values
censored = ~data["death"].values.astype(bool)
y[:5]
censored[:5]
We have an unique problem when modelling censored data. Strictly speaking, we don't have any data for censored values: we only know the number of values that were censored. How can we include this information in our model?
One way do this is by making use of pm.Potential
. The PyMC2 docs explain its usage very well. Essentially, declaring pm.Potential('x', logp)
will add logp
to the log-likelihood of the model.
+++
This parameterization is an intuitive, straightforward parameterization of the Weibull survival function. This is probably the first parameterization to come to one's mind.
def weibull_lccdf(x, alpha, beta):
"""Log complementary cdf of Weibull distribution."""
return -((x / beta) ** alpha)
with pm.Model() as model_1:
alpha_sd = 10.0
mu = pm.Normal("mu", mu=0, sigma=100)
alpha_raw = pm.Normal("a0", mu=0, sigma=0.1)
alpha = pm.Deterministic("alpha", pt.exp(alpha_sd * alpha_raw))
beta = pm.Deterministic("beta", pt.exp(mu / alpha))
y_obs = pm.Weibull("y_obs", alpha=alpha, beta=beta, observed=y[~censored])
y_cens = pm.Potential("y_cens", weibull_lccdf(y[censored], alpha, beta))
with model_1:
# Change init to avoid divergences
data_1 = pm.sample(target_accept=0.9, init="adapt_diag")
az.plot_trace(data_1, var_names=["alpha", "beta"])
az.summary(data_1, var_names=["alpha", "beta"], round_to=2)
Note that, confusingly, alpha
is now called r
, and alpha
denotes a prior; we maintain this notation to stay faithful to the original implementation in Stan. In this parameterization, we still model the same parameters alpha
(now r
) and beta
.
For more information, see this Stan example model and the corresponding documentation.
with pm.Model() as model_2:
alpha = pm.Normal("alpha", mu=0, sigma=10)
r = pm.Gamma("r", alpha=1, beta=0.001, initval=0.25)
beta = pm.Deterministic("beta", pt.exp(-alpha / r))
y_obs = pm.Weibull("y_obs", alpha=r, beta=beta, observed=y[~censored])
y_cens = pm.Potential("y_cens", weibull_lccdf(y[censored], r, beta))
with model_2:
# Increase target_accept to avoid divergences
data_2 = pm.sample(target_accept=0.9)
az.plot_trace(data_2, var_names=["r", "beta"])
az.summary(data_2, var_names=["r", "beta"], round_to=2)
In this parameterization, we model the log-linear error distribution with a Gumbel distribution instead of modelling the survival function directly. For more information, see this blog post.
logtime = np.log(y)
def gumbel_sf(y, mu, sigma):
"""Gumbel survival function."""
return 1.0 - pt.exp(-pt.exp(-(y - mu) / sigma))
with pm.Model() as model_3:
s = pm.HalfNormal("s", tau=5.0)
gamma = pm.Normal("gamma", mu=0, sigma=5)
y_obs = pm.Gumbel("y_obs", mu=gamma, beta=s, observed=logtime[~censored])
y_cens = pm.Potential("y_cens", gumbel_sf(y=logtime[censored], mu=gamma, sigma=s))
with model_3:
# Change init to avoid divergences
data_3 = pm.sample(init="adapt_diag")
az.plot_trace(data_3)
az.summary(data_3, round_to=2)
- Originally collated by Junpeng Lao on Apr 21, 2018. See original code here.
- Authored and ported to Jupyter notebook by George Ho on Jul 15, 2018.
- Updated for compatibility with PyMC v5 by Chris Fonnesbeck on Jan 16, 2023.
%load_ext watermark
%watermark -n -u -v -iv -w
:::{include} ../page_footer.md :::