Skip to content

Commit ffbe432

Browse files
authored
Merge pull request #3389 from jmloyola/add_data_container
Add data container and pm.set_data
2 parents 5b39caf + c349569 commit ffbe432

File tree

9 files changed

+1795
-133
lines changed

9 files changed

+1795
-133
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
### New features
66

7+
- Add data container class (`Data`) that wraps the theano SharedVariable class and let the model be aware of its inputs and outputs.
8+
- Add function `set_data` to update variables defined as `Data`.
79
- `Mixture` now supports mixtures of multidimensional probability distributions, not just lists of 1D distributions.
810
- `GLM.from_formula` and `LinearComponent.from_formula` can extract variables from the calling scope. Customizable via the new `eval_env` argument. Fixing #3382.
911

docs/source/notebooks/bayesian_neural_network_advi.ipynb

Lines changed: 86 additions & 130 deletions
Large diffs are not rendered by default.

docs/source/notebooks/data_container.ipynb

Lines changed: 675 additions & 0 deletions
Large diffs are not rendered by default.

docs/source/notebooks/table_of_contents_tutorials.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Gallery.contents = {
44
"PyMC3_and_Theano.rst": "Basics",
55
"Probability_Distributions.rst": "Basics",
66
"Gaussian_Processes.rst": "Basics",
7+
"data_container": "Basics",
78
"sampling_compound_step": "Deep dives",
89
"sampler-stats": "Deep dives",
910
"Diagnosing_biased_Inference_with_Divergences": "Deep dives",

pymc3/data.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
'get_data',
1313
'GeneratorAdapter',
1414
'Minibatch',
15-
'align_minibatches'
15+
'align_minibatches',
16+
'Data',
1617
]
1718

1819

@@ -385,3 +386,67 @@ def align_minibatches(batches=None):
385386
raise TypeError('{b} is not a Minibatch')
386387
for rng in Minibatch.RNG[id(b)]:
387388
rng.seed()
389+
390+
391+
class Data:
392+
"""Data container class that wraps the theano SharedVariable class
393+
and let the model be aware of its inputs and outputs.
394+
395+
Parameters
396+
----------
397+
name : str
398+
The name for this variable
399+
value
400+
A value to associate with this variable
401+
402+
Examples
403+
--------
404+
405+
.. code:: ipython
406+
407+
>>> import pymc3 as pm
408+
>>> import numpy as np
409+
>>> # We generate 10 datasets
410+
>>> true_mu = [np.random.randn() for _ in range(10)]
411+
>>> observed_data = [mu + np.random.randn(20) for mu in true_mu]
412+
413+
>>> with pm.Model() as model:
414+
... data = pm.Data('data', observed_data[0])
415+
... mu = pm.Normal('mu', 0, 10)
416+
... pm.Normal('y', mu=mu, sigma=1, observed=data)
417+
418+
.. code:: ipython
419+
420+
>>> # Generate one trace for each dataset
421+
>>> traces = []
422+
>>> for data_vals in observed_data:
423+
... with model:
424+
... # Switch out the observed dataset
425+
... pm.set_data({'data': data_vals})
426+
... traces.append(pm.sample())
427+
428+
To set the value of the data container variable, check out
429+
:func:`pm.set_data()`.
430+
431+
For more information, take a look at this example notebook
432+
https://docs.pymc.io/notebooks/data_container.html
433+
"""
434+
def __new__(self, name, value):
435+
# `pm.model.pandas_to_array` takes care of parameter `value` and
436+
# transforms it to something digestible for pymc3
437+
shared_object = theano.shared(pm.model.pandas_to_array(value), name)
438+
439+
# To draw the node for this variable in the graphviz Digraph we need
440+
# its shape.
441+
shared_object.dshape = tuple(shared_object.shape.eval())
442+
443+
# Add data container to the named variables of the model.
444+
try:
445+
model = pm.Model.get_context()
446+
except TypeError:
447+
raise TypeError("No model on context stack, which is needed to "
448+
"instantiate a data container. Add variable "
449+
"inside a 'with model:' block.")
450+
model.add_random_variable(shared_object)
451+
452+
return shared_object

0 commit comments

Comments
 (0)