Skip to content

Commit d3a74ec

Browse files
typehinting and refactoring load function
1 parent 064e3ae commit d3a74ec

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

pymc_experimental/model_builder.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ class ModelBuilder:
2828
"""
2929
ModelBuilder can be used to provide an easy-to-use API (similar to scikit-learn) for models
3030
and help with deployment.
31-
32-
Extends the pymc.Model class.
3331
"""
3432

3533
_model_type = "BaseClass"
@@ -65,7 +63,7 @@ def __init__(
6563
self.idata = None # inference data object
6664
self.data = data
6765

68-
def build(self):
66+
def build(self) -> None:
6967
"""
7068
Builds the defined model.
7169
"""
@@ -136,7 +134,7 @@ def create_sample_input():
136134

137135
raise NotImplementedError
138136

139-
def save(self, fname):
137+
def save(self, fname: str) -> None:
140138
"""
141139
Saves inference data of the model.
142140
@@ -160,8 +158,9 @@ def save(self, fname):
160158
self.idata.to_netcdf(file)
161159

162160
@classmethod
163-
def load(cls, fname):
161+
def load(cls, fname: str):
164162
"""
163+
Creates a ModelBuilder instance from a file,
165164
Loads inference data for the model.
166165
167166
Parameters
@@ -171,7 +170,7 @@ def load(cls, fname):
171170
172171
Returns
173172
-------
174-
Returns an instance of pm.Model, that is loaded from local data.
173+
Returns an instance of ModelBuilder.
175174
176175
Raises
177176
------
@@ -188,21 +187,23 @@ def load(cls, fname):
188187

189188
filepath = Path(str(fname))
190189
idata = az.from_netcdf(filepath)
191-
self = cls(
190+
model_builder = cls(
192191
dict(zip(idata.attrs["model_config_keys"], idata.attrs["model_config_values"])),
193192
dict(zip(idata.attrs["sample_config_keys"], idata.attrs["sample_config_values"])),
194193
idata.fit_data.to_dataframe(),
195194
)
196-
self.idata = idata
197-
self.build()
198-
if self.id != idata.attrs["id"]:
195+
model_builder.idata = idata
196+
model_builder.build()
197+
if model_builder.id != idata.attrs["id"]:
199198
raise ValueError(
200199
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{self._model_type}'"
201200
)
202201

203-
return self.model
202+
return model_builder
204203

205-
def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None):
204+
def fit(
205+
self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
206+
) -> az.InferenceData:
206207
"""
207208
As the name suggests fit can be used to fit a model using the data that is passed as a parameter.
208209
Sets attrs to inference data of the model.
@@ -248,7 +249,7 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
248249
def predict(
249250
self,
250251
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
251-
):
252+
) -> dict:
252253
"""
253254
Uses model to predict on unseen data and return point prediction of all the samples
254255
@@ -288,7 +289,7 @@ def predict(
288289
def predict_posterior(
289290
self,
290291
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
291-
):
292+
) -> Dict[str, np.array]:
292293
"""
293294
Uses model to predict samples on unseen data.
294295

pymc_experimental/tests/test_model_builder.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class test_ModelBuilder(ModelBuilder):
2828
_model_type = "LinearModel"
2929
version = "0.1"
3030

31-
def build_model(self, model_instance, model_config, data=None):
31+
def build_model(self, model_instance: ModelBuilder, model_config: dict, data: dict = None):
3232
model_instance.model_config = model_config
3333
model_instance.data = data
3434
self.model_config = model_config
@@ -86,7 +86,7 @@ def create_sample_input(self):
8686
return data, model_config, sampler_config
8787

8888
@staticmethod
89-
def initial_build_and_fit(check_idata=True):
89+
def initial_build_and_fit(check_idata=True) -> ModelBuilder:
9090
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
9191
model_builder = test_ModelBuilder(model_config, sampler_config, data)
9292
model_builder.idata = model_builder.fit(data=data)
@@ -113,8 +113,7 @@ def test_save_load():
113113
test_builder = test_ModelBuilder.initial_build_and_fit()
114114
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
115115
test_builder.save(temp.name)
116-
test_builder2 = test_ModelBuilder.initial_build_and_fit()
117-
test_builder2.model = test_ModelBuilder.load(temp.name)
116+
test_builder2 = test_ModelBuilder.load(temp.name)
118117
assert test_builder.idata.groups() == test_builder2.idata.groups()
119118

120119
x_pred = np.random.uniform(low=0, high=1, size=100)

0 commit comments

Comments
 (0)