Skip to content

Commit 539659d

Browse files
adding and updating doctests
fixing indentation issue, adding exception for pytest adding forgotten decorator to generate_model_data making doctest more user-manual like, renaming example model for consistency chaning YourClass to MyClass for consistency
1 parent c272052 commit 539659d

File tree

3 files changed

+162
-42
lines changed

3 files changed

+162
-42
lines changed

pymc_experimental/bayesian_estimator_linearmodel.py

+43
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,25 @@ def default_sampler_config(self):
414414
}
415415

416416
def build_model(self):
417+
"""
418+
Build the PyMC model.
419+
420+
Returns
421+
-------
422+
None
423+
424+
Examples
425+
--------
426+
>>> self.build_model()
427+
>>> assert self.model is not None
428+
>>> assert isinstance(self.model, pm.Model)
429+
>>> assert "intercept" in self.model.named_vars
430+
>>> assert "slope" in self.model.named_vars
431+
>>> assert "σ_model_fmc" in self.model.named_vars
432+
>>> assert "y_model" in self.model.named_vars
433+
>>> assert "y_hat" in self.model.named_vars
434+
>>> assert self.output_var == "y_hat"
435+
"""
417436
cfg = self.model_config
418437

419438
# The model is built with placeholder data.
@@ -452,6 +471,30 @@ def _data_setter(self, X, y=None):
452471

453472
@classmethod
454473
def generate_model_data(cls, nsamples=100, data=None):
474+
"""
475+
Generate model data for linear regression.
476+
477+
Parameters
478+
----------
479+
nsamples : int, optional
480+
The number of samples to generate. Default is 100.
481+
data : np.ndarray, optional
482+
An optional data array to add noise to.
483+
484+
Returns
485+
-------
486+
tuple
487+
A tuple of two np.ndarrays representing the feature matrix and target vector, respectively.
488+
489+
Examples
490+
--------
491+
>>> import numpy as np
492+
>>> x, y = cls.generate_model_data()
493+
>>> assert isinstance(x, np.ndarray)
494+
>>> assert isinstance(y, np.ndarray)
495+
>>> assert x.shape == (100, 1)
496+
>>> assert y.shape == (100,)
497+
"""
455498
x = np.linspace(start=0, stop=1, num=nsamples)
456499
y = 5 * x + 3
457500
y = y + np.random.normal(0, 1, len(x))

pymc_experimental/model_builder.py

+118-41
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ def __init__(
5555
dictionary of parameters that initialise sampler configuration. Class-default defined by the user default_sampler_config method.
5656
Examples
5757
--------
58-
>>> class LinearModel(ModelBuilder):
58+
>>> class MyModel(ModelBuilder):
5959
>>> ...
60-
>>> model = LinearModel(model_config, sampler_config)
60+
>>> model = MyModel(model_config, sampler_config)
6161
"""
6262

6363
if sampler_config is None:
@@ -159,6 +159,12 @@ def generate_model_data(
159159
) -> pd.DataFrame:
160160
"""
161161
Returns a default dataset for a class, can be used as a hint to data formatting required for the class
162+
If data is not None, dataset will be created from it's content.
163+
164+
Parameters:
165+
data : Union[np.ndarray, pd.DataFrame, pd.Series], optional
166+
dataset that will replace the default sample data
167+
162168
163169
Examples
164170
--------
@@ -178,16 +184,16 @@ def generate_model_data(
178184

179185
@abstractmethod
180186
def build_model(
181-
model_data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
187+
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
182188
model_config: Dict[str, Union[int, float, Dict]] = None,
183189
) -> None:
184190
"""
185-
Creates an instance of pm.Model based on provided model_data and model_config, and
191+
Creates an instance of pm.Model based on provided data and model_config, and
186192
attaches it to self.
187193
188194
Parameters
189195
----------
190-
model_data : dict
196+
data : dict
191197
Preformated data that is going to be used in the model. For efficiency reasons it should contain only the necesary data columns,
192198
not entire available dataset since it's going to be encoded into data used to recreate the model.
193199
If not provided uses data from self.data
@@ -207,7 +213,34 @@ def build_model(
207213
raise NotImplementedError
208214

209215
def sample_model(self, **kwargs):
216+
"""
217+
Sample from the PyMC model.
218+
219+
Parameters
220+
----------
221+
**kwargs : dict
222+
Additional keyword arguments to pass to the PyMC sampler.
223+
224+
Returns
225+
-------
226+
xarray.Dataset
227+
The PyMC3 samples dataset.
228+
229+
Raises
230+
------
231+
RuntimeError
232+
If the PyMC model hasn't been built yet.
210233
234+
Examples
235+
--------
236+
>>> self.build_model()
237+
>>> idata = self.sample_model(draws=100, tune=10)
238+
>>> assert isinstance(idata, xr.Dataset)
239+
>>> assert "posterior" in idata
240+
>>> assert "prior" in idata
241+
>>> assert "observed_data" in idata
242+
>>> assert "log_likelihood" in idata
243+
"""
211244
if self.model is None:
212245
raise RuntimeError(
213246
"The model hasn't been built yet, call .build_model() first or call .fit() instead."
@@ -223,6 +256,34 @@ def sample_model(self, **kwargs):
223256
return idata
224257

225258
def set_idata_attrs(self, idata=None):
259+
"""
260+
Set attributes on an InferenceData object.
261+
262+
Parameters
263+
----------
264+
idata : arviz.InferenceData, optional
265+
The InferenceData object to set attributes on.
266+
267+
Raises
268+
------
269+
RuntimeError
270+
If no InferenceData object is provided.
271+
272+
Returns
273+
-------
274+
None
275+
276+
Examples
277+
--------
278+
>>> model = MyModel(ModelBuilder)
279+
>>> idata = az.InferenceData(your_dataset)
280+
>>> model.set_idata_attrs(idata=idata)
281+
>>> assert "id" in idata.attrs #this and the following lines are part of doctest, not user manual
282+
>>> assert "model_type" in idata.attrs
283+
>>> assert "version" in idata.attrs
284+
>>> assert "sampler_config" in idata.attrs
285+
>>> assert "model_config" in idata.attrs
286+
"""
226287
if idata is None:
227288
idata = self.idata
228289
if idata is None:
@@ -235,22 +296,33 @@ def set_idata_attrs(self, idata=None):
235296

236297
def save(self, fname: str) -> None:
237298
"""
238-
Saves inference data of the model.
299+
Save the model's inference data to a file.
239300
240301
Parameters
241302
----------
242-
fname : string
243-
This denotes the name with path from where idata should be saved.
303+
fname : str
304+
The name and path of the file to save the inference data with model parameters.
305+
306+
Returns
307+
-------
308+
None
309+
310+
Raises
311+
------
312+
RuntimeError
313+
If the model hasn't been fit yet (no inference data available).
244314
245315
Examples
246316
--------
247-
>>> class LinearModel(ModelBuilder):
248-
>>> ...
249-
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
250-
>>> model = LinearModel(model_config, sampler_config)
251-
>>> idata = model.fit(data)
252-
>>> name = './mymodel.nc'
253-
>>> model.save(name)
317+
This method is meant to be overridden and implemented by subclasses.
318+
It should not be called directly on the base abstract class or its instances.
319+
320+
>>> class MyModel(ModelBuilder):
321+
>>> def __init__(self):
322+
>>> super().__init__()
323+
>>> model = MyModel()
324+
>>> model.fit(data)
325+
>>> model.save('model_results.nc') # This will call the overridden method in MyModel
254326
"""
255327
if self.idata is not None and "posterior" in self.idata:
256328
file = Path(str(fname))
@@ -259,33 +331,32 @@ def save(self, fname: str) -> None:
259331
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
260332

261333
@classmethod
262-
def load(cls, fname: str):
334+
def load(cls, fname: str) -> "ModelBuilder":
263335
"""
264-
Creates a ModelBuilder instance from a file,
265-
Loads inference data for the model.
336+
Create a ModelBuilder instance from a file and load inference data for the model.
266337
267338
Parameters
268339
----------
269-
fname : string
270-
This denotes the name with path from where idata should be loaded from.
340+
fname : str
341+
The name and path from which the inference data should be loaded.
271342
272343
Returns
273344
-------
274-
Returns an instance of ModelBuilder.
345+
ModelBuilder
346+
An instance of the ModelBuilder class.
275347
276348
Raises
277349
------
278350
ValueError
279-
If the inference data that is loaded doesn't match with the model.
351+
If the loaded inference data does not match the model.
280352
281353
Examples
282354
--------
283-
>>> class LinearModel(ModelBuilder):
355+
>>> class MyModel(ModelBuilder):
284356
>>> ...
285357
>>> name = './mymodel.nc'
286-
>>> imported_model = LinearModel.load(name)
358+
>>> imported_model = MyModel.load(name)
287359
"""
288-
289360
filepath = Path(str(fname))
290361
idata = az.from_netcdf(filepath)
291362
model_builder = cls(
@@ -297,7 +368,7 @@ def load(cls, fname: str):
297368
model_builder.build_model(model_builder.data, model_builder.model_config)
298369
if model_builder.id != idata.attrs["id"]:
299370
raise ValueError(
300-
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
371+
f"The file '{fname}' does not contain inference data of the same model or configuration as '{cls._model_type}'"
301372
)
302373

303374
return model_builder
@@ -331,28 +402,27 @@ def fit(
331402
332403
Examples
333404
--------
334-
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
335-
>>> model = LinearModel(model_config, sampler_config)
405+
>>> model = MyModel()
336406
>>> idata = model.fit(data)
337407
Auto-assigning NUTS sampler...
338408
Initializing NUTS using jitter+adapt_diag...
339409
"""
340410

341411
# If a new data was provided, assign it to the model
342412
if data is not None:
343-
self.data = data
344-
self.model_data = self.generate_model_data(data=self.data)
413+
self.data = self.generate_model_data(data=self.data)
414+
345415
if self.model_config is None:
346416
self.model_config = self.default_model_config
347417
if self.sampler_config is None:
348418
self.sampler_config = self.default_sampler_config
349419
if self.model is None:
350-
self.build_model(self.model_data, self.model_config)
420+
self.build_model(self.data, self.model_config)
351421

352422
self.sampler_config["progressbar"] = progressbar
353423
self.sampler_config["random_seed"] = random_seed
354-
355-
self.idata = self.sample_model(**self.sampler_config)
424+
temp_sampler_config = {**self.sampler_config, **kwargs}
425+
self.idata = self.sample_model(**temp_sampler_config)
356426
self.idata.add_groups(fit_data=self.data.to_xarray())
357427
return self.idata
358428

@@ -377,8 +447,7 @@ def predict(
377447
378448
Examples
379449
--------
380-
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
381-
>>> model = LinearModel(model_config, sampler_config)
450+
>>> model = MyModel()
382451
>>> idata = model.fit(data)
383452
>>> x_pred = []
384453
>>> prediction_data = pd.DataFrame({'input':x_pred})
@@ -398,7 +467,7 @@ def predict_posterior(
398467
Generate posterior predictive samples on unseen data.
399468
400469
Parameters
401-
---------
470+
----------
402471
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
403472
It is the data we need to make prediction on using the model.
404473
extend_idata : Boolean determining whether the predictions should be added to inference data object.
@@ -412,8 +481,7 @@ def predict_posterior(
412481
413482
Examples
414483
--------
415-
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
416-
>>> model = LinearModel(model_config, sampler_config)
484+
>>> model = MyModel()
417485
>>> idata = model.fit(data)
418486
>>> x_pred = []
419487
>>> prediction_data = pd.DataFrame({'input': x_pred})
@@ -450,13 +518,22 @@ def _serializable_model_config(self) -> Dict[str, Union[int, float, Dict]]:
450518
@property
451519
def id(self) -> str:
452520
"""
453-
It creates a hash value to match the model version using last 16 characters of hash encoding.
521+
Generate a unique hash value for the model.
522+
523+
The hash value is created using the last 16 characters of the SHA256 hash encoding, based on the model configuration,
524+
version, and model type.
454525
455526
Returns
456527
-------
457-
Returns string of length 16 characters contains unique hash of the model
458-
"""
528+
str
529+
A string of length 16 characters containing a unique hash of the model.
459530
531+
Examples
532+
--------
533+
>>> model = MyModel()
534+
>>> model.id
535+
'0123456789abcdef'
536+
"""
460537
hasher = hashlib.sha256()
461538
hasher.update(str(self.model_config.values()).encode())
462539
hasher.update(self.version.encode())

pytest.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[pytest]
22
filterwarnings =
33
error
4-
ignore:.*?(\b(pkg_resources\.declare_namespace)\b).*:DeprecationWarning
4+
ignore:.*(\b(pkg_resources\.declare_namespace|np\.bool8)\b).*:DeprecationWarning
55
ignore::UserWarning:arviz.data.inference_data
66
ignore::DeprecationWarning:pkg_resources

0 commit comments

Comments
 (0)