forked from pymc-devs/pymc-examples
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlasso_missing.py
62 lines (46 loc) · 1.83 KB
/
lasso_missing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import pandas as pd
from numpy.ma import masked_values
import pymc3 as pm
# Import data, filling missing values with sentinels (-999)
test_scores = pd.read_csv(pm.get_data("test_scores.csv")).fillna(-999)
# Extract variables: test score, gender, number of siblings, previous disability, age,
# mother with HS education or better, hearing loss identified by 3 months
# of age
(score, male, siblings, disability, age, mother_hs, early_ident) = (
test_scores[["score", "male", "siblings", "prev_disab", "age_test", "mother_hs", "early_ident"]]
.astype(float)
.values.T
)
with pm.Model() as model:
# Impute missing values
sib_mean = pm.Exponential("sib_mean", 1.0)
siblings_imp = pm.Poisson("siblings_imp", sib_mean, observed=siblings)
p_disab = pm.Beta("p_disab", 1.0, 1.0)
disability_imp = pm.Bernoulli(
"disability_imp", p_disab, observed=masked_values(disability, value=-999)
)
p_mother = pm.Beta("p_mother", 1.0, 1.0)
mother_imp = pm.Bernoulli("mother_imp", p_mother, observed=masked_values(mother_hs, value=-999))
s = pm.HalfCauchy("s", 5.0, testval=5)
beta = pm.Laplace("beta", 0.0, 100.0, shape=7, testval=0.1)
expected_score = (
beta[0]
+ beta[1] * male
+ beta[2] * siblings_imp
+ beta[3] * disability_imp
+ beta[4] * age
+ beta[5] * mother_imp
+ beta[6] * early_ident
)
observed_score = pm.Normal("observed_score", expected_score, s, observed=score)
with model:
start = pm.find_MAP()
step1 = pm.NUTS([beta, s, p_disab, p_mother, sib_mean], scaling=start)
step2 = pm.BinaryGibbsMetropolis([mother_imp.missing_values, disability_imp.missing_values])
def run(n=5000):
if n == "short":
n = 100
with model:
pm.sample(n, step=[step1, step2], start=start)
if __name__ == "__main__":
run()