-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Allow specification of dims instead of shape #3551
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
e8199eb
169feef
57909b0
6cbcdb7
e603855
a0753ac
e7cb979
b9acb96
fcb709c
26c3fc8
37da459
8864bbe
1bdcd38
b178207
9a943ca
10c617b
39b6d92
903ee61
68a863f
2e73535
0236ccd
a4c832b
80aaa35
aec6d9c
add54c4
072c6a4
abfeba9
7a5c327
cc25d47
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,13 +22,15 @@ | |
import pymc3 as pm | ||
import theano.tensor as tt | ||
import theano | ||
import pandas as pd | ||
michaelosthege marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
__all__ = [ | ||
'get_data', | ||
'GeneratorAdapter', | ||
'Minibatch', | ||
'align_minibatches', | ||
'Data', | ||
'TidyData', | ||
] | ||
|
||
|
||
|
@@ -479,7 +481,7 @@ class Data: | |
https://docs.pymc.io/notebooks/data_container.html | ||
""" | ||
|
||
def __new__(self, name, value): | ||
def __new__(self, name, value, *, dims=None, export_dims=False): | ||
if isinstance(value, list): | ||
value = np.array(value) | ||
|
||
|
@@ -497,10 +499,130 @@ def __new__(self, name, value): | |
# transforms it to something digestible for pymc3 | ||
shared_object = theano.shared(pm.model.pandas_to_array(value), name) | ||
|
||
if isinstance(dims, str): | ||
dims = (dims,) | ||
if dims is not None and len(dims) != shared_object.ndim: | ||
raise ValueError('Length of `dims` must match the dimensionality ' | ||
'of the dataset.') | ||
michaelosthege marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
coords = {} | ||
if isinstance(value, (pd.Series, pd.DataFrame)): | ||
michaelosthege marked this conversation as resolved.
Show resolved
Hide resolved
|
||
name = None | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if dims is not None: | ||
name = dims[0] | ||
if (name is None | ||
and value.index.name is not None | ||
and value.index.name.isidentifier()): | ||
michaelosthege marked this conversation as resolved.
Show resolved
Hide resolved
|
||
name = value.index.name | ||
if name is not None: | ||
coords[name] = value.index | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(value, pd.DataFrame): | ||
name = None | ||
if dims is not None: | ||
name = dims[1] | ||
if (name is None | ||
and value.columns.name is not None | ||
and value.columns.name.isidentifier()): | ||
name = value.columns.name | ||
if name is not None: | ||
coords[name] = value.columns | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(value, np.ndarray) and dims is not None: | ||
if len(dims) != value.ndim: | ||
raise ValueError('Invalid data shape %s. The rank of the dataset ' | ||
'must match the length of `dims`.' | ||
% value.shape) | ||
michaelosthege marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for size, dim in zip(value.shape, dims): | ||
coord = model.coords.get(dim, None) | ||
if coord is None: | ||
coords[dim] = pd.RangeIndex(size, name=dim) | ||
|
||
if export_dims: | ||
model.add_coords(coords) | ||
|
||
# To draw the node for this variable in the graphviz Digraph we need | ||
# its shape. | ||
shared_object.dshape = tuple(shared_object.shape.eval()) | ||
if dims is not None: | ||
shape_dims = model.shape_from_dims(dims) | ||
if shared_object.dshape != shape_dims: | ||
raise ValueError('Invalid shape. It is %s but the dimensions ' | ||
'suggest %s.' | ||
% (shared_object.dshape, shape_dims)) | ||
michaelosthege marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
model.add_random_variable(shared_object, dims=dims) | ||
return shared_object | ||
|
||
model.add_random_variable(shared_object) | ||
|
||
return shared_object | ||
class _IndexAccessor: | ||
def __init__(self, data): | ||
self._data = data | ||
|
||
def __getitem__(self, key): | ||
category = self._data._col_as_category(key) | ||
vals = self._data.data.reset_index().loc[:, key] | ||
return pd.Categorical(vals, dtype=category).codes | ||
|
||
|
||
class TidyData: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This new class is user-facing, right? In which case it probably needs some docstrings to explain what it does and how to use it. It could also be useful to update the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The class is prefixed with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is supposed to be user-facing, but I'm not sure the implementation is good as it is. We could merge the PR without it at first and go from there. Say, we have a dataset like this:
I would like to extract from that 4 different dimensions and corresponding coordinates:
And also mappings, that allow us to index a array with dim=treatment etc to get one with dim=observation. So that we can run linear regressions like this:
where
Where mu has now dim=observation. |
||
def __init__(self, data, copy_data=True, import_dims=None, model=None): | ||
self.data = data | ||
self._shared_vars = {} | ||
self._category_cols = {} | ||
self._index_dict = _IndexAccessor(self) | ||
|
||
if import_dims is not None: | ||
model = pm.model.modelcontext(model) | ||
coords = self._extract_coords(import_dims) | ||
model.add_coords(coords) | ||
|
||
@property | ||
def idxs(self): | ||
return self._index_dict | ||
|
||
def _col_as_category(self, key): | ||
if key in self._category_cols: | ||
return self._category_cols[key] | ||
data = self.data.reset_index() | ||
values = data.loc[:, key] | ||
if values.dtype.name != 'category': | ||
values = values.astype('category') | ||
self._category_cols[key] = values.dtype | ||
return values.dtype | ||
|
||
def __getitem__(self, key): | ||
if key not in self.data.columns: | ||
raise KeyError('Unknown column %s' % key) | ||
if key in self._shared_vars: | ||
return self._shared_vars[key] | ||
|
||
shared_var = theano.shared(self.data.loc[:, key].values) | ||
self._shared_vars[key] = shared_var | ||
return shared_var | ||
|
||
def _extract_coords(self, dims): | ||
data = self.data | ||
dims = set(dims) | ||
coords = {} | ||
|
||
if data.index.name is not None and data.index.name in dims: | ||
dims.remove(data.index.name) | ||
coords[data.index.name] = data.index | ||
|
||
# We want to iterate over index columns of a multi index as well | ||
data = data.reset_index() | ||
for col in data.columns: | ||
if col not in dims: | ||
continue | ||
dims.remove(col) | ||
category = self._col_as_category(col) | ||
cat = pd.Categorical(category.categories, dtype=category) | ||
coords[col] = pd.CategoricalIndex(cat, name=col) | ||
|
||
if dims: | ||
raise KeyError('Unknown columns: %s' % dims) | ||
|
||
return coords | ||
|
||
@property | ||
def columns(self): | ||
return self.data.columns |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,17 +56,32 @@ def __new__(cls, name, *args, **kwargs): | |
"a 'with model:' block, or use the '.dist' syntax " | ||
"for a standalone distribution.") | ||
|
||
if isinstance(name, string_types): | ||
data = kwargs.pop('observed', None) | ||
cls.data = data | ||
if isinstance(data, ObservedRV) or isinstance(data, FreeRV): | ||
raise TypeError("observed needs to be data but got: {}".format(type(data))) | ||
total_size = kwargs.pop('total_size', None) | ||
dist = cls.dist(*args, **kwargs) | ||
return model.Var(name, dist, data, total_size) | ||
else: | ||
if not isinstance(name, string_types): | ||
raise TypeError("Name needs to be a string but got: {}".format(name)) | ||
|
||
data = kwargs.pop('observed', None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to add some comments explaining what the new code below does? Would be helpful for future-us and new comers I think. |
||
cls.data = data | ||
if isinstance(data, ObservedRV) or isinstance(data, FreeRV): | ||
raise TypeError("observed needs to be data but got: {}".format(type(data))) | ||
total_size = kwargs.pop('total_size', None) | ||
|
||
dims = kwargs.pop('dims', None) | ||
has_shape = 'shape' in kwargs | ||
shape = kwargs.pop('shape', None) | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if dims is not None: | ||
if shape is not None: | ||
raise ValueError("Specify only one of 'dims' and 'shape'") | ||
michaelosthege marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(dims, string_types): | ||
dims = (dims,) | ||
shape = model.shape_from_dims(dims) | ||
|
||
# Some distribution do not accept shape=None | ||
michaelosthege marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if has_shape or shape is not None: | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dist = cls.dist(*args, **kwargs, shape=shape) | ||
else: | ||
dist = cls.dist(*args, **kwargs) | ||
return model.Var(name, dist, data, total_size, dims=dims) | ||
|
||
def __getnewargs__(self): | ||
return _Unpickling, | ||
|
||
|
@@ -77,7 +92,7 @@ def dist(cls, *args, **kwargs): | |
return dist | ||
|
||
def __init__(self, shape, dtype, testval=None, defaults=(), | ||
transform=None, broadcastable=None): | ||
transform=None, broadcastable=None, dims=None): | ||
self.shape = np.atleast_1d(shape) | ||
if False in (np.floor(self.shape) == self.shape): | ||
raise TypeError("Expected int elements in shape") | ||
|
Uh oh!
There was an error while loading. Please reload this page.