Skip to content

Commit 31db489

Browse files
ColCarrolltwiecki
authored andcommitted
Change signature of get_data, use more often
1 parent c579611 commit 31db489

File tree

9 files changed

+601
-512
lines changed

9 files changed

+601
-512
lines changed

docs/source/notebooks/GLM.ipynb

Lines changed: 108 additions & 109 deletions
Large diffs are not rendered by default.

docs/source/notebooks/dawid-skene.ipynb

Lines changed: 36 additions & 39 deletions
Large diffs are not rendered by default.

docs/source/notebooks/hierarchical_partial_pooling.ipynb

Lines changed: 301 additions & 201 deletions
Large diffs are not rendered by default.

docs/source/notebooks/profiling.ipynb

Lines changed: 135 additions & 140 deletions
Large diffs are not rendered by default.

pymc3/data.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,33 @@
1-
import pkgutil
2-
import io
31
from copy import copy
2+
import io
3+
import os
4+
import pkgutil
5+
46
import numpy as np
7+
import pymc3 as pm
58
import theano.tensor as tt
69
import theano
7-
import pymc3 as pm
810

911
__all__ = [
10-
'get_data_file',
12+
'get_data',
1113
'GeneratorAdapter',
1214
'DataSampler'
1315
]
1416

1517

16-
def get_data_file(pkg, path):
17-
"""Returns a file object for a package data file.
18+
def get_data(filename):
19+
"""Returns a BytesIO object for a package data file.
1820
1921
Parameters
2022
----------
21-
pkg : str
22-
dotted package hierarchy. e.g. "pymc3.examples"
23-
path : str
24-
file path within package. e.g. "data/wells.dat"
25-
Returns
23+
filename : str
24+
file to load
25+
Returns
2626
-------
2727
BytesIO of the data
2828
"""
29-
30-
return io.BytesIO(pkgutil.get_data(pkg, path))
29+
data_pkg = 'pymc3.examples'
30+
return io.BytesIO(pkgutil.get_data(data_pkg, os.path.join('data', filename)))
3131

3232

3333
class GenTensorVariable(tt.TensorVariable):

pymc3/examples/GHME_2013.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
import pandas as pd
33
import matplotlib.pyplot as plt
44

5-
from pymc3 import HalfCauchy, Model, Normal, get_data_file, sample
5+
from pymc3 import HalfCauchy, Model, Normal, get_data, sample
66
from pymc3.distributions.timeseries import GaussianRandomWalk
77

8-
data = pd.read_csv(get_data_file('pymc3.examples', 'data/pancreatitis.csv'))
8+
data = pd.read_csv(get_data('pancreatitis.csv'))
99
countries = ['CYP', 'DNK', 'ESP', 'FIN', 'GBR', 'ISL']
1010
data = data[data.area.isin(countries)]
1111

pymc3/examples/baseball.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import theano
99

10-
data = np.loadtxt( 'data/efron-morris-75-data.tsv', delimiter="\t", skiprows=1, usecols=(2,3) )
10+
data = np.loadtxt(pm.get_data('efron-morris-75-data.tsv'), delimiter="\t", skiprows=1, usecols=(2,3))
1111

1212
atBats = data[:,0].astype(theano.config.floatX)
1313
hits = data[:,1].astype(theano.config.floatX)
@@ -35,4 +35,3 @@ def run( n=100000 ):
3535

3636
if __name__ == '__main__':
3737
run()
38-

pymc3/examples/lasso_missing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from numpy.ma import masked_values
44

55
# Import data, filling missing values with sentinels (-999)
6-
test_scores = pd.read_csv(pm.get_data_file(
7-
'pymc3.examples', 'data/test_scores.csv')).fillna(-999)
6+
test_scores = pd.read_csv(pm.get_data('test_scores.csv')).fillna(-999)
87

98
# Extract variables: test score, gender, number of siblings, previous disability, age,
109
# mother with HS education or better, hearing loss identified by 3 months

pymc3/tests/test_examples.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
def get_city_data():
1414
"""Helper to get city data"""
15-
data = pd.read_csv(pm.get_data_file('pymc3.examples', 'data/srrs2.dat'))
16-
cty_data = pd.read_csv(pm.get_data_file('pymc3.examples', 'data/cty.dat'))
15+
data = pd.read_csv(pm.get_data('srrs2.dat'))
16+
cty_data = pd.read_csv(pm.get_data('cty.dat'))
1717

1818
data = data[data.state == 'MN']
1919

@@ -30,8 +30,8 @@ def get_city_data():
3030

3131
class TestARM5_4(SeededTest):
3232
def build_model(self):
33-
wells = pm.get_data_file('pymc3.examples', 'data/wells.dat')
34-
data = pd.read_csv(wells, delimiter=u' ', index_col=u'id', dtype={u'switch': np.int8})
33+
data = pd.read_csv(pm.get_data('wells.dat'),
34+
delimiter=u' ', index_col=u'id', dtype={u'switch': np.int8})
3535
data.dist /= 100
3636
data.educ /= 4
3737
col = data.columns

0 commit comments

Comments
 (0)