diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index 6d251334..3f48f3f9 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -15,6 +15,7 @@ import hashlib import json +from abc import abstractmethod from pathlib import Path from typing import Dict, Union @@ -24,12 +25,10 @@ import pymc as pm -class ModelBuilder(pm.Model): +class ModelBuilder: """ ModelBuilder can be used to provide an easy-to-use API (similar to scikit-learn) for models and help with deployment. - - Extends the pymc.Model class. """ _model_type = "BaseClass" @@ -37,21 +36,21 @@ class ModelBuilder(pm.Model): def __init__( self, - model_config: Dict, - sampler_config: Dict, - data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None, + data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]], + model_config: Dict = None, + sampler_config: Dict = None, ): """ Initializes model configuration and sampler configuration for the model Parameters ---------- - model_config : Dictionary + model_config : Dictionary, optional dictionary of parameters that initialise model configuration. Generated by the user defined create_sample_input method. - sampler_config : Dictionary - dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method. - data : Dictionary + data : Dictionary, required It is the data we need to train the model on. + sampler_config : Dictionary, optional + dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method. Examples -------- >>> class LinearModel(ModelBuilder): @@ -59,21 +58,18 @@ def __init__( >>> model = LinearModel(model_config, sampler_config) """ - super().__init__() + if sampler_config is None: + sampler_config = {} + if model_config is None: + model_config = {} self.model_config = model_config # parameters for priors etc. - self.sample_config = sampler_config # parameters for sampling - self.idata = None # inference data object + self.sampler_config = sampler_config # parameters for sampling self.data = data - self.build() - - def build(self): - """ - Builds the defined model. - """ - - with self: - self.build_model(self.model_config, self.data) + self.idata = ( + None # inference data object placeholder, idata is generated during build execution + ) + @abstractmethod def _data_setter( self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]], x_only: bool = True ): @@ -100,8 +96,9 @@ def _data_setter( raise NotImplementedError - @classmethod - def create_sample_input(cls): + @staticmethod + @abstractmethod + def create_sample_input(): """ Needs to be implemented by the user in the inherited class. Returns examples for data, model_config, sampler_config. @@ -135,7 +132,7 @@ def create_sample_input(cls): raise NotImplementedError - def save(self, fname): + def save(self, fname: str) -> None: """ Saves inference data of the model. @@ -159,8 +156,9 @@ def save(self, fname): self.idata.to_netcdf(file) @classmethod - def load(cls, fname): + def load(cls, fname: str): """ + Creates a ModelBuilder instance from a file, Loads inference data for the model. Parameters @@ -170,7 +168,7 @@ def load(cls, fname): Returns ------- - Returns the inference data that is loaded from local system. + Returns an instance of ModelBuilder. Raises ------ @@ -187,22 +185,25 @@ def load(cls, fname): filepath = Path(str(fname)) idata = az.from_netcdf(filepath) - self = cls( - json.loads(idata.attrs["model_config"]), - json.loads(idata.attrs["sampler_config"]), - idata.fit_data.to_dataframe(), + model_builder = cls( + model_config=json.loads(idata.attrs["model_config"]), + sampler_config=json.loads(idata.attrs["sampler_config"]), + data=idata.fit_data.to_dataframe(), ) - self.idata = idata - if self.id != idata.attrs["id"]: + model_builder.idata = idata + model_builder.build() + if model_builder.id != idata.attrs["id"]: raise ValueError( - f"The file '{fname}' does not contain an inference data of the same model or configuration as '{self._model_type}'" + f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'" ) - return self + return model_builder - def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None): + def fit( + self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None + ) -> az.InferenceData: """ - As the name suggests fit can be used to fit a model using the data that is passed as a parameter. + Fit a model using the data passed as a parameter. Sets attrs to inference data of the model. Parameter @@ -223,22 +224,22 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None Initializing NUTS using jitter+adapt_diag... """ + # If a new data was provided, assign it to the model if data is not None: self.data = data - self._data_setter(data) - if self.basic_RVs == []: - self.build() + self.build() + self._data_setter(data) - with self: - self.idata = pm.sample(**self.sample_config) + with self.model: + self.idata = pm.sample(**self.sampler_config) self.idata.extend(pm.sample_prior_predictive()) self.idata.extend(pm.sample_posterior_predictive(self.idata)) self.idata.attrs["id"] = self.id self.idata.attrs["model_type"] = self._model_type self.idata.attrs["version"] = self.version - self.idata.attrs["sampler_config"] = json.dumps(self.sample_config) + self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config) self.idata.attrs["model_config"] = json.dumps(self.model_config) self.idata.add_groups(fit_data=self.data.to_xarray()) return self.idata @@ -246,7 +247,8 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None def predict( self, data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None, - ): + extend_idata: bool = True, + ) -> dict: """ Uses model to predict on unseen data and return point prediction of all the samples @@ -254,6 +256,8 @@ def predict( --------- data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series It is the data we need to make prediction on using the model. + extend_idata : Boolean determining whether the predictions should be added to inference data object. + Defaults to True. Returns ------- @@ -275,7 +279,8 @@ def predict( with self.model: # sample with new input data post_pred = pm.sample_posterior_predictive(self.idata) - + if extend_idata: + self.idata.extend(post_pred) # reshape output post_pred = self._extract_samples(post_pred) for key in post_pred: @@ -286,7 +291,8 @@ def predict( def predict_posterior( self, data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None, - ): + extend_idata: bool = True, + ) -> Dict[str, np.array]: """ Uses model to predict samples on unseen data. @@ -294,8 +300,8 @@ def predict_posterior( --------- data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series It is the data we need to make prediction on using the model. - point_estimate : bool - Adds point like estimate used as mean passed as + extend_idata : Boolean determining whether the predictions should be added to inference data object. + Defaults to True. Returns ------- @@ -317,6 +323,8 @@ def predict_posterior( with self.model: # sample with new input data post_pred = pm.sample_posterior_predictive(self.idata) + if extend_idata: + self.idata.extend(post_pred) # reshape output post_pred = self._extract_samples(post_pred) @@ -357,5 +365,4 @@ def id(self) -> str: hasher.update(str(self.model_config.values()).encode()) hasher.update(self.version.encode()) hasher.update(self._model_type.encode()) - # hasher.update(str(self.sample_config.values()).encode()) return hasher.hexdigest()[:16] diff --git a/pymc_experimental/tests/test_model_builder.py b/pymc_experimental/tests/test_model_builder.py index 1dd67e62..20845fee 100644 --- a/pymc_experimental/tests/test_model_builder.py +++ b/pymc_experimental/tests/test_model_builder.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import hashlib import sys import tempfile @@ -29,26 +28,28 @@ class test_ModelBuilder(ModelBuilder): _model_type = "LinearModel" version = "0.1" - def build_model(self, model_config, data=None): - if data is not None: - x = pm.MutableData("x", data["input"].values) - y_data = pm.MutableData("y_data", data["output"].values) + def build(self): + + with pm.Model() as self.model: + if self.data is not None: + x = pm.MutableData("x", self.data["input"].values) + y_data = pm.MutableData("y_data", self.data["output"].values) - # prior parameters - a_loc = model_config["a_loc"] - a_scale = model_config["a_scale"] - b_loc = model_config["b_loc"] - b_scale = model_config["b_scale"] - obs_error = model_config["obs_error"] + # prior parameters + a_loc = self.model_config["a_loc"] + a_scale = self.model_config["a_scale"] + b_loc = self.model_config["b_loc"] + b_scale = self.model_config["b_scale"] + obs_error = self.model_config["obs_error"] - # priors - a = pm.Normal("a", a_loc, sigma=a_scale) - b = pm.Normal("b", b_loc, sigma=b_scale) - obs_error = pm.HalfNormal("σ_model_fmc", obs_error) + # priors + a = pm.Normal("a", a_loc, sigma=a_scale) + b = pm.Normal("b", b_loc, sigma=b_scale) + obs_error = pm.HalfNormal("σ_model_fmc", obs_error) - # observed data - if data is not None: - y_model = pm.Normal("y_model", a + b * x, obs_error, shape=x.shape, observed=y_data) + # observed data + if self.data is not None: + y_model = pm.Normal("y_model", a + b * x, obs_error, shape=x.shape, observed=y_data) def _data_setter(self, data: pd.DataFrame): with self.model: @@ -57,7 +58,7 @@ def _data_setter(self, data: pd.DataFrame): pm.set_data({"y_data": data["output"].values}) @classmethod - def create_sample_input(cls): + def create_sample_input(self): x = np.linspace(start=0, stop=1, num=100) y = 5 * x + 3 y = y + np.random.normal(0, 1, len(x)) @@ -81,14 +82,36 @@ def create_sample_input(cls): return data, model_config, sampler_config @staticmethod - def initial_build_and_fit(check_idata=True): + def initial_build_and_fit(check_idata=True) -> ModelBuilder: data, model_config, sampler_config = test_ModelBuilder.create_sample_input() - model = test_ModelBuilder(model_config, sampler_config, data) - model.fit() + model_builder = test_ModelBuilder( + model_config=model_config, sampler_config=sampler_config, data=data + ) + model_builder.idata = model_builder.fit(data=data) if check_idata: - assert model.idata is not None - assert "posterior" in model.idata.groups() - return model + assert model_builder.idata is not None + assert "posterior" in model_builder.idata.groups() + return model_builder + + +def test_empty_model_config(): + data, model_config, sampler_config = test_ModelBuilder.create_sample_input() + sampler_config = {} + model_builder = test_ModelBuilder( + model_config=model_config, sampler_config=sampler_config, data=data + ) + model_builder.idata = model_builder.fit(data=data) + assert model_builder.idata is not None + assert "posterior" in model_builder.idata.groups() + + +def test_empty_model_config(): + data, model_config, sampler_config = test_ModelBuilder.create_sample_input() + model_config = {} + model_builder = test_ModelBuilder( + model_config=model_config, sampler_config=sampler_config, data=data + ) + assert model_builder.model_config == {} def test_fit(): @@ -105,16 +128,16 @@ def test_fit(): sys.platform == "win32", reason="Permissions for temp files not granted on windows CI." ) def test_save_load(): - model = test_ModelBuilder.initial_build_and_fit(False) + test_builder = test_ModelBuilder.initial_build_and_fit() temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) - model.save(temp.name) - model2 = test_ModelBuilder.load(temp.name) - assert model.idata.groups() == model2.idata.groups() + test_builder.save(temp.name) + test_builder2 = test_ModelBuilder.load(temp.name) + assert test_builder.idata.groups() == test_builder2.idata.groups() x_pred = np.random.uniform(low=0, high=1, size=100) prediction_data = pd.DataFrame({"input": x_pred}) - pred1 = model.predict(prediction_data) - pred2 = model2.predict(prediction_data) + pred1 = test_builder.predict(prediction_data) + pred2 = test_builder2.predict(prediction_data) assert pred1["y_model"].shape == pred2["y_model"].shape temp.close() @@ -163,7 +186,7 @@ def test_extract_samples(): def test_id(): data, model_config, sampler_config = test_ModelBuilder.create_sample_input() - model = test_ModelBuilder(model_config, sampler_config, data) + model = test_ModelBuilder(model_config=model_config, sampler_config=sampler_config, data=data) expected_id = hashlib.sha256( str(model_config.values()).encode() + model.version.encode() + model._model_type.encode()