Skip to content

Commit 0a0f544

Browse files
michaelraczyckimbjosephMax Josephtwiecki
authored
docstrings update in model_builder.py (pymc-devs#148)
* docstrings update in model_builder.py * bringing back accidentally removed line from example * Return posterior predictive samples from all chains in `ModelBuilder` (pymc-devs#140) * Return posterior predictive samples from all chains This fixes a bug where only values from one chain were returned. It also refactors the prediction logic to reduce duplication, and makes the output of predict_posterior() consistent in type and shape with the output of pymc.sample_posterior_predictive(). * keep attributes even when computing posterior means * Add/test combined arg, revert method order * Fix import order. --------- Co-authored-by: Max Joseph <[email protected]> Co-authored-by: Thomas Wiecki <[email protected]> * fixing merge conflicts --------- Co-authored-by: Max Joseph <[email protected]> Co-authored-by: Max Joseph <[email protected]> Co-authored-by: Thomas Wiecki <[email protected]>
1 parent 03ad8dd commit 0a0f544

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

pymc_experimental/model_builder.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,16 @@ def create_sample_input():
133133
>>> 'target_accept': 0.95,
134134
>>> }
135135
>>> return data, model_config, sampler_config
136-
"""
137136
137+
Returns
138+
-------
139+
data : dict
140+
The data we want to train the model on
141+
model_config : dict
142+
A set of parameters for predictor distributions that allow to save and recreate the model
143+
sampler_config : dict
144+
A set of default settings for sampler config, customization of contents of sampler_config allows introducing new settings to the sampler
145+
"""
138146
raise NotImplementedError
139147

140148
@abstractmethod
@@ -149,11 +157,17 @@ def build_model(
149157
150158
Required Parameters
151159
----------
152-
model_data - preformated data that is going to be used in the model.
153-
For efficiency reasons it should contain only the necesary data columns, not entire available
154-
dataset since it's going to be encoded into data used to recreate the model.
155-
model_config - dictionary where keys are strings representing names of parameters of the model, values are
156-
dictionaries of parameters needed for creating model parameters (see example in create_model_input)
160+
model_data : dict
161+
Preformated data that is going to be used in the model. For efficiency reasons it should contain only the necesary data columns,
162+
not entire available dataset since it's going to be encoded into data used to recreate the model.
163+
model_config : dict
164+
Dictionary where keys are strings representing names of parameters of the model, values are dictionaries of parameters
165+
needed for creating model parameters
166+
167+
See Also
168+
--------
169+
create_model_input : Creates all required input for the model builder based on the data given. Shows the examples of data structures on which the specific
170+
inherited version of model builder operates on.
157171
158172
Returns:
159173
----------
@@ -233,9 +247,9 @@ def load(cls, fname: str):
233247

234248
def fit(
235249
self,
250+
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
236251
progressbar: bool = True,
237252
random_seed: RandomState = None,
238-
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
239253
**kwargs: Any,
240254
) -> az.InferenceData:
241255
"""
@@ -244,8 +258,15 @@ def fit(
244258
245259
Parameter
246260
---------
247-
data : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
248-
It is the data we need to train the model on.
261+
data : dict
262+
Dictionary of string and either of numpy array, pandas dataframe or pandas Series. It is the data we need to train the model on.
263+
progressbar : bool
264+
Specifies whether the fit progressbar should be displayed
265+
random_seed : RandomState
266+
Provides sampler with initial random seed for obtaining reproducible samples
267+
**kwargs : Any
268+
Custom sampler settings can be provided in form of keyword arguments. The recommended way is to add custom settings to sampler_config provided by
269+
create_sample_input, because arguments provided in the form of kwargs will not be saved into the model, therefore will not be available after loading the model
249270
250271
Returns
251272
-------

0 commit comments

Comments
 (0)