Skip to content

Commit c272052

Browse files
step 2: removing duplications and adapting BayesianEstimator
1 parent db13bde commit c272052

File tree

3 files changed

+10
-22
lines changed

3 files changed

+10
-22
lines changed

pymc_experimental/bayesian_estimator_linearmodel.py

+7-18
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class BayesianEstimator(ModelBuilder):
6060

6161
def __init__(
6262
self,
63+
data: Union[np.ndarray, pd.DataFrame, pd.Series] = None,
6364
model_config: Dict = None,
6465
sampler_config: Dict = None,
6566
):
@@ -75,16 +76,9 @@ def __init__(
7576
"""
7677
if model_config is None:
7778
model_config = self.default_model_config
78-
self.model_config = model_config
79-
8079
if sampler_config is None:
8180
sampler_config = self.default_sampler_config
82-
self.sampler_config = sampler_config
83-
84-
self.model = None # Set by build_model
85-
self.output_var = None # Set by build_model
86-
self.idata = None # idata is generated during fitting
87-
self.is_fitted_ = False
81+
super().__init__(data=data, model_config=model_config, sampler_config=sampler_config)
8882

8983
@property
9084
@abstractmethod
@@ -103,16 +97,11 @@ def _validate_data(self, X, y=None):
10397
return check_array(X, accept_sparse=False)
10498

10599
@abstractmethod
106-
def build_model(self) -> None:
107-
"""
108-
Build the PYMC model. The model is built with placeholder data.
109-
Actual data will be set by _data_setter when fitting or evaluating the model.
110-
Data array size can change but number of dimensions must stay the same.
100+
def build_model(
101+
model_data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
102+
model_config: Dict[str, Union[int, float, Dict]] = None,
103+
) -> None:
111104

112-
Returns:
113-
----------
114-
None
115-
"""
116105
raise NotImplementedError
117106

118107
@abstractmethod
@@ -462,7 +451,7 @@ def _data_setter(self, X, y=None):
462451
pm.set_data({"y_data": y.squeeze()})
463452

464453
@classmethod
465-
def create_sample_input(cls, nsamples=100):
454+
def generate_model_data(cls, nsamples=100, data=None):
466455
x = np.linspace(start=0, stop=1, num=nsamples)
467456
y = 5 * x + 3
468457
y = y + np.random.normal(0, 1, len(x))

pymc_experimental/model_builder.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,10 @@ def build_model(
182182
model_config: Dict[str, Union[int, float, Dict]] = None,
183183
) -> None:
184184
"""
185-
Needs to be implemented by the user in the child class.
186185
Creates an instance of pm.Model based on provided model_data and model_config, and
187186
attaches it to self.
188187
189-
Required Parameters
188+
Parameters
190189
----------
191190
model_data : dict
192191
Preformated data that is going to be used in the model. For efficiency reasons it should contain only the necesary data columns,

pymc_experimental/tests/test_bayesian_estimator_linearmodel.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
@pytest.fixture(scope="module")
3636
def sample_input():
37-
x, y = LinearModel.create_sample_input()
37+
x, y = LinearModel.generate_model_data()
3838
return x, y
3939

4040

@@ -53,7 +53,7 @@ def fitted_linear_model_instance(sample_input):
5353

5454

5555
def test_save_without_fit_raises_runtime_error():
56-
x, y = LinearModel.create_sample_input()
56+
x, y = LinearModel.generate_model_data()
5757
test_model = LinearModel()
5858
with pytest.raises(RuntimeError):
5959
test_model.save("saved_model")

0 commit comments

Comments
 (0)