-
-
Notifications
You must be signed in to change notification settings - Fork 59
Model Builder refactoring #119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
8cc3383
2e4507d
064e3ae
d3a74ec
51cb782
63f6be5
9c481aa
b4e6349
68bc518
ac4ec1a
7b8c253
c023abe
c6c5cb5
b6cbda2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,8 +37,8 @@ class ModelBuilder: | |
def __init__( | ||
self, | ||
model_config: Dict, | ||
sampler_config: Dict, | ||
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None, | ||
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]], | ||
sampler_config: Dict = None, | ||
): | ||
""" | ||
Initializes model configuration and sampler configuration for the model | ||
|
@@ -47,10 +47,10 @@ def __init__( | |
---------- | ||
model_config : Dictionary | ||
dictionary of parameters that initialise model configuration. Generated by the user defined create_sample_input method. | ||
sampler_config : Dictionary | ||
dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method. | ||
data : Dictionary | ||
It is the data we need to train the model on. | ||
sampler_config : Dictionary | ||
dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method. | ||
Examples | ||
-------- | ||
>>> class LinearModel(ModelBuilder): | ||
|
@@ -60,9 +60,11 @@ def __init__( | |
|
||
super().__init__() | ||
self.model_config = model_config # parameters for priors etc. | ||
self.sample_config = sampler_config # parameters for sampling | ||
self.idata = None # inference data object | ||
self.sampler_config = sampler_config # parameters for sampling | ||
self.data = data | ||
self.idata = ( | ||
None # inference data object placeholder, idata is generated during build execution | ||
) | ||
|
||
def build(self) -> None: | ||
""" | ||
|
@@ -188,16 +190,20 @@ def load(cls, fname: str): | |
|
||
filepath = Path(str(fname)) | ||
idata = az.from_netcdf(filepath) | ||
if "sampler_config" in idata.attrs: | ||
sampler_config = json.loads(idata.attrs["sampler_config"]) | ||
else: | ||
sampler_config = None | ||
model_builder = cls( | ||
json.loads(idata.attrs["model_config"]), | ||
json.loads(idata.attrs["sampler_config"]), | ||
idata.fit_data.to_dataframe(), | ||
model_config=json.loads(idata.attrs["model_config"]), | ||
sampler_config=sampler_config, | ||
data=idata.fit_data.to_dataframe(), | ||
) | ||
model_builder.idata = idata | ||
model_builder.build() | ||
if model_builder.id != idata.attrs["id"]: | ||
raise ValueError( | ||
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{self._model_type}'" | ||
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'" | ||
) | ||
|
||
return model_builder | ||
|
@@ -206,7 +212,7 @@ def fit( | |
self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None | ||
) -> az.InferenceData: | ||
""" | ||
As the name suggests fit can be used to fit a model using the data that is passed as a parameter. | ||
Fit a model using the data passed as a parameter. | ||
Sets attrs to inference data of the model. | ||
|
||
Parameter | ||
|
@@ -233,21 +239,26 @@ def fit( | |
self._data_setter(data) | ||
|
||
with self.model: | ||
self.idata = pm.sample(**self.sample_config) | ||
if self.sampler_config is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in init you can just set sampler_config to {} if it's None. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. true, good catch! |
||
self.idata = pm.sample(**self.sampler_config) | ||
else: | ||
self.idata = pm.sample() | ||
self.idata.extend(pm.sample_prior_predictive()) | ||
self.idata.extend(pm.sample_posterior_predictive(self.idata)) | ||
|
||
self.idata.attrs["id"] = self.id | ||
self.idata.attrs["model_type"] = self._model_type | ||
self.idata.attrs["version"] = self.version | ||
self.idata.attrs["sampler_config"] = json.dumps(self.sample_config) | ||
if self.sampler_config is not None: | ||
self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config) | ||
self.idata.attrs["model_config"] = json.dumps(self.model_config) | ||
self.idata.add_groups(fit_data=self.data.to_xarray()) | ||
return self.idata | ||
|
||
def predict( | ||
self, | ||
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None, | ||
extend_idata: bool = True, | ||
) -> dict: | ||
""" | ||
Uses model to predict on unseen data and return point prediction of all the samples | ||
|
@@ -256,6 +267,8 @@ def predict( | |
--------- | ||
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series | ||
It is the data we need to make prediction on using the model. | ||
extend_idata : Boolean determining whether the predictions should be added to inference data object. | ||
Defaults to True. | ||
|
||
Returns | ||
------- | ||
|
@@ -277,7 +290,8 @@ def predict( | |
|
||
with self.model: # sample with new input data | ||
post_pred = pm.sample_posterior_predictive(self.idata) | ||
|
||
if extend_idata: | ||
self.idata.extend(post_pred) | ||
# reshape output | ||
post_pred = self._extract_samples(post_pred) | ||
for key in post_pred: | ||
|
@@ -288,6 +302,7 @@ def predict( | |
def predict_posterior( | ||
self, | ||
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None, | ||
extend_idata: bool = True, | ||
) -> Dict[str, np.array]: | ||
""" | ||
Uses model to predict samples on unseen data. | ||
|
@@ -296,8 +311,8 @@ def predict_posterior( | |
--------- | ||
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series | ||
It is the data we need to make prediction on using the model. | ||
point_estimate : bool | ||
Adds point like estimate used as mean passed as | ||
extend_idata : Boolean determining whether the predictions should be added to inference data object. | ||
Defaults to True. | ||
|
||
Returns | ||
------- | ||
|
@@ -319,6 +334,8 @@ def predict_posterior( | |
|
||
with self.model: # sample with new input data | ||
post_pred = pm.sample_posterior_predictive(self.idata) | ||
if extend_idata: | ||
self.idata.extend(post_pred) | ||
|
||
# reshape output | ||
post_pred = self._extract_samples(post_pred) | ||
|
@@ -359,5 +376,4 @@ def id(self) -> str: | |
hasher.update(str(self.model_config.values()).encode()) | ||
hasher.update(self.version.encode()) | ||
hasher.update(self._model_type.encode()) | ||
# hasher.update(str(self.sample_config.values()).encode()) | ||
return hasher.hexdigest()[:16] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,10 +28,17 @@ class test_ModelBuilder(ModelBuilder): | |
_model_type = "LinearModel" | ||
version = "0.1" | ||
|
||
def build_model(self, model_instance: ModelBuilder, model_config: dict, data: dict = None): | ||
def build_model( | ||
self, | ||
model_instance: ModelBuilder, | ||
model_config: dict, | ||
data: dict = None, | ||
sampler_config: dict = None, | ||
): | ||
model_instance.model_config = model_config | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can just keep the |
||
model_instance.data = data | ||
self.model_config = model_config | ||
self.sampler_config = sampler_config | ||
self.data = data | ||
|
||
with pm.Model() as model_instance.model: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make model_config also optional and move down.