-
-
Notifications
You must be signed in to change notification settings - Fork 59
Model Builder refactoring #119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
8cc3383
2e4507d
064e3ae
d3a74ec
51cb782
63f6be5
9c481aa
b4e6349
68bc518
ac4ec1a
7b8c253
c023abe
c6c5cb5
b6cbda2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
|
||
import hashlib | ||
import json | ||
from abc import abstractmethod | ||
from pathlib import Path | ||
from typing import Dict, Union | ||
|
||
|
@@ -24,12 +25,10 @@ | |
import pymc as pm | ||
|
||
|
||
class ModelBuilder(pm.Model): | ||
class ModelBuilder: | ||
""" | ||
ModelBuilder can be used to provide an easy-to-use API (similar to scikit-learn) for models | ||
and help with deployment. | ||
|
||
Extends the pymc.Model class. | ||
""" | ||
|
||
_model_type = "BaseClass" | ||
|
@@ -38,8 +37,8 @@ class ModelBuilder(pm.Model): | |
def __init__( | ||
self, | ||
model_config: Dict, | ||
sampler_config: Dict, | ||
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None, | ||
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]], | ||
sampler_config: Dict = None, | ||
): | ||
""" | ||
Initializes model configuration and sampler configuration for the model | ||
|
@@ -48,10 +47,10 @@ def __init__( | |
---------- | ||
model_config : Dictionary | ||
dictionary of parameters that initialise model configuration. Generated by the user defined create_sample_input method. | ||
sampler_config : Dictionary | ||
dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method. | ||
data : Dictionary | ||
It is the data we need to train the model on. | ||
sampler_config : Dictionary | ||
dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method. | ||
Examples | ||
-------- | ||
>>> class LinearModel(ModelBuilder): | ||
|
@@ -61,19 +60,20 @@ def __init__( | |
|
||
super().__init__() | ||
self.model_config = model_config # parameters for priors etc. | ||
self.sample_config = sampler_config # parameters for sampling | ||
self.idata = None # inference data object | ||
self.sampler_config = sampler_config # parameters for sampling | ||
self.data = data | ||
self.build() | ||
self.idata = ( | ||
None # inference data object placeholder, idata is generated during build execution | ||
) | ||
|
||
def build(self): | ||
def build(self) -> None: | ||
""" | ||
Builds the defined model. | ||
""" | ||
|
||
with self: | ||
self.build_model(self.model_config, self.data) | ||
self.build_model(self, self.model_config, self.data) | ||
|
||
@abstractmethod | ||
def _data_setter( | ||
self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]], x_only: bool = True | ||
): | ||
|
@@ -100,8 +100,10 @@ def _data_setter( | |
|
||
raise NotImplementedError | ||
|
||
@classmethod | ||
def create_sample_input(cls): | ||
# need a discussion if it's really needed. | ||
@staticmethod | ||
@abstractmethod | ||
def create_sample_input(): | ||
""" | ||
Needs to be implemented by the user in the inherited class. | ||
Returns examples for data, model_config, sampler_config. | ||
|
@@ -135,7 +137,7 @@ def create_sample_input(cls): | |
|
||
raise NotImplementedError | ||
|
||
def save(self, fname): | ||
def save(self, fname: str) -> None: | ||
""" | ||
Saves inference data of the model. | ||
|
||
|
@@ -159,8 +161,9 @@ def save(self, fname): | |
self.idata.to_netcdf(file) | ||
|
||
@classmethod | ||
def load(cls, fname): | ||
def load(cls, fname: str): | ||
""" | ||
Creates a ModelBuilder instance from a file, | ||
Loads inference data for the model. | ||
|
||
Parameters | ||
|
@@ -170,7 +173,7 @@ def load(cls, fname): | |
|
||
Returns | ||
------- | ||
Returns the inference data that is loaded from local system. | ||
Returns an instance of ModelBuilder. | ||
|
||
Raises | ||
------ | ||
|
@@ -187,22 +190,29 @@ def load(cls, fname): | |
|
||
filepath = Path(str(fname)) | ||
idata = az.from_netcdf(filepath) | ||
self = cls( | ||
json.loads(idata.attrs["model_config"]), | ||
json.loads(idata.attrs["sampler_config"]), | ||
idata.fit_data.to_dataframe(), | ||
if "sampler_config" in idata.attrs: | ||
sampler_config = json.loads(idata.attrs["sampler_config"]) | ||
else: | ||
sampler_config = None | ||
model_builder = cls( | ||
model_config=json.loads(idata.attrs["model_config"]), | ||
sampler_config=sampler_config, | ||
data=idata.fit_data.to_dataframe(), | ||
) | ||
self.idata = idata | ||
if self.id != idata.attrs["id"]: | ||
model_builder.idata = idata | ||
model_builder.build() | ||
if model_builder.id != idata.attrs["id"]: | ||
raise ValueError( | ||
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{self._model_type}'" | ||
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'" | ||
) | ||
|
||
return self | ||
return model_builder | ||
|
||
def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None): | ||
def fit( | ||
self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None | ||
) -> az.InferenceData: | ||
""" | ||
As the name suggests fit can be used to fit a model using the data that is passed as a parameter. | ||
Fit a model using the data passed as a parameter. | ||
Sets attrs to inference data of the model. | ||
|
||
Parameter | ||
|
@@ -225,35 +235,40 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None | |
|
||
if data is not None: | ||
self.data = data | ||
self._data_setter(data) | ||
|
||
if self.basic_RVs == []: | ||
self.build() | ||
self._data_setter(data) | ||
|
||
with self: | ||
self.idata = pm.sample(**self.sample_config) | ||
with self.model: | ||
if self.sampler_config is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in init you can just set sampler_config to {} if it's None. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. true, good catch! |
||
self.idata = pm.sample(**self.sampler_config) | ||
else: | ||
self.idata = pm.sample() | ||
self.idata.extend(pm.sample_prior_predictive()) | ||
self.idata.extend(pm.sample_posterior_predictive(self.idata)) | ||
|
||
self.idata.attrs["id"] = self.id | ||
self.idata.attrs["model_type"] = self._model_type | ||
self.idata.attrs["version"] = self.version | ||
self.idata.attrs["sampler_config"] = json.dumps(self.sample_config) | ||
if self.sampler_config is not None: | ||
self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config) | ||
self.idata.attrs["model_config"] = json.dumps(self.model_config) | ||
self.idata.add_groups(fit_data=self.data.to_xarray()) | ||
return self.idata | ||
|
||
def predict( | ||
self, | ||
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None, | ||
): | ||
extend_idata: bool = True, | ||
) -> dict: | ||
""" | ||
Uses model to predict on unseen data and return point prediction of all the samples | ||
|
||
Parameters | ||
--------- | ||
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series | ||
It is the data we need to make prediction on using the model. | ||
extend_idata : Boolean determining whether the predictions should be added to inference data object. | ||
Defaults to True. | ||
|
||
Returns | ||
------- | ||
|
@@ -275,7 +290,8 @@ def predict( | |
|
||
with self.model: # sample with new input data | ||
post_pred = pm.sample_posterior_predictive(self.idata) | ||
|
||
if extend_idata: | ||
self.idata.extend(post_pred) | ||
# reshape output | ||
post_pred = self._extract_samples(post_pred) | ||
for key in post_pred: | ||
|
@@ -286,16 +302,17 @@ def predict( | |
def predict_posterior( | ||
self, | ||
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None, | ||
): | ||
extend_idata: bool = True, | ||
) -> Dict[str, np.array]: | ||
""" | ||
Uses model to predict samples on unseen data. | ||
|
||
Parameters | ||
--------- | ||
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series | ||
It is the data we need to make prediction on using the model. | ||
point_estimate : bool | ||
Adds point like estimate used as mean passed as | ||
extend_idata : Boolean determining whether the predictions should be added to inference data object. | ||
Defaults to True. | ||
|
||
Returns | ||
------- | ||
|
@@ -317,6 +334,8 @@ def predict_posterior( | |
|
||
with self.model: # sample with new input data | ||
post_pred = pm.sample_posterior_predictive(self.idata) | ||
if extend_idata: | ||
self.idata.extend(post_pred) | ||
|
||
# reshape output | ||
post_pred = self._extract_samples(post_pred) | ||
|
@@ -357,5 +376,4 @@ def id(self) -> str: | |
hasher.update(str(self.model_config.values()).encode()) | ||
hasher.update(self.version.encode()) | ||
hasher.update(self._model_type.encode()) | ||
# hasher.update(str(self.sample_config.values()).encode()) | ||
return hasher.hexdigest()[:16] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,6 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import hashlib | ||
import sys | ||
import tempfile | ||
|
@@ -29,26 +28,39 @@ class test_ModelBuilder(ModelBuilder): | |
_model_type = "LinearModel" | ||
version = "0.1" | ||
|
||
def build_model(self, model_config, data=None): | ||
if data is not None: | ||
x = pm.MutableData("x", data["input"].values) | ||
y_data = pm.MutableData("y_data", data["output"].values) | ||
|
||
# prior parameters | ||
a_loc = model_config["a_loc"] | ||
a_scale = model_config["a_scale"] | ||
b_loc = model_config["b_loc"] | ||
b_scale = model_config["b_scale"] | ||
obs_error = model_config["obs_error"] | ||
|
||
# priors | ||
a = pm.Normal("a", a_loc, sigma=a_scale) | ||
b = pm.Normal("b", b_loc, sigma=b_scale) | ||
obs_error = pm.HalfNormal("σ_model_fmc", obs_error) | ||
|
||
# observed data | ||
if data is not None: | ||
y_model = pm.Normal("y_model", a + b * x, obs_error, shape=x.shape, observed=y_data) | ||
def build_model( | ||
self, | ||
model_instance: ModelBuilder, | ||
model_config: dict, | ||
data: dict = None, | ||
sampler_config: dict = None, | ||
): | ||
model_instance.model_config = model_config | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can just keep the |
||
model_instance.data = data | ||
self.model_config = model_config | ||
self.sampler_config = sampler_config | ||
self.data = data | ||
|
||
with pm.Model() as model_instance.model: | ||
if data is not None: | ||
x = pm.MutableData("x", data["input"].values) | ||
y_data = pm.MutableData("y_data", data["output"].values) | ||
|
||
# prior parameters | ||
a_loc = model_config["a_loc"] | ||
a_scale = model_config["a_scale"] | ||
b_loc = model_config["b_loc"] | ||
b_scale = model_config["b_scale"] | ||
obs_error = model_config["obs_error"] | ||
|
||
# priors | ||
a = pm.Normal("a", a_loc, sigma=a_scale) | ||
b = pm.Normal("b", b_loc, sigma=b_scale) | ||
obs_error = pm.HalfNormal("σ_model_fmc", obs_error) | ||
|
||
# observed data | ||
if data is not None: | ||
y_model = pm.Normal("y_model", a + b * x, obs_error, shape=x.shape, observed=y_data) | ||
|
||
def _data_setter(self, data: pd.DataFrame): | ||
with self.model: | ||
|
@@ -57,7 +69,7 @@ def _data_setter(self, data: pd.DataFrame): | |
pm.set_data({"y_data": data["output"].values}) | ||
|
||
@classmethod | ||
def create_sample_input(cls): | ||
def create_sample_input(self): | ||
x = np.linspace(start=0, stop=1, num=100) | ||
y = 5 * x + 3 | ||
y = y + np.random.normal(0, 1, len(x)) | ||
|
@@ -81,14 +93,14 @@ def create_sample_input(cls): | |
return data, model_config, sampler_config | ||
|
||
@staticmethod | ||
def initial_build_and_fit(check_idata=True): | ||
def initial_build_and_fit(check_idata=True) -> ModelBuilder: | ||
data, model_config, sampler_config = test_ModelBuilder.create_sample_input() | ||
model = test_ModelBuilder(model_config, sampler_config, data) | ||
model.fit() | ||
model_builder = test_ModelBuilder(model_config, sampler_config, data) | ||
model_builder.idata = model_builder.fit(data=data) | ||
if check_idata: | ||
assert model.idata is not None | ||
assert "posterior" in model.idata.groups() | ||
return model | ||
assert model_builder.idata is not None | ||
assert "posterior" in model_builder.idata.groups() | ||
return model_builder | ||
|
||
|
||
def test_fit(): | ||
|
@@ -105,16 +117,16 @@ def test_fit(): | |
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI." | ||
) | ||
def test_save_load(): | ||
model = test_ModelBuilder.initial_build_and_fit(False) | ||
test_builder = test_ModelBuilder.initial_build_and_fit() | ||
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) | ||
model.save(temp.name) | ||
model2 = test_ModelBuilder.load(temp.name) | ||
assert model.idata.groups() == model2.idata.groups() | ||
test_builder.save(temp.name) | ||
test_builder2 = test_ModelBuilder.load(temp.name) | ||
assert test_builder.idata.groups() == test_builder2.idata.groups() | ||
|
||
x_pred = np.random.uniform(low=0, high=1, size=100) | ||
prediction_data = pd.DataFrame({"input": x_pred}) | ||
pred1 = model.predict(prediction_data) | ||
pred2 = model2.predict(prediction_data) | ||
pred1 = test_builder.predict(prediction_data) | ||
pred2 = test_builder2.predict(prediction_data) | ||
assert pred1["y_model"].shape == pred2["y_model"].shape | ||
temp.close() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make model_config also optional and move down.