Skip to content

Commit dacb6f4

Browse files
Pr 451 - modified and added tests to statespace (#466)
* Use `set_data` in forecast * Ignore new numpy matmul warnings in tests * Tracking down data bug * resolved merge conflicts environment-test.yml * added and modified statespace tests * added logic to update static shape of target when forecasting with exogenous variables * make sure updated dummy target has correct dimensions * Revert test env name change * Add some checks to `test_foreacast_valid_index` --------- Co-authored-by: jessegrabowski <[email protected]>
1 parent 04b838a commit dacb6f4

File tree

4 files changed

+105
-35
lines changed

4 files changed

+105
-35
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,9 @@ def _kalman_filter_outputs_from_dummy_graph(
10271027
provided when the model was built.
10281028
data_dims: str or tuple of str, optional
10291029
Dimension names associated with the model data. If None, defaults to ("time", "obs_state")
1030+
scenario: dict[str, pd.DataFrame], optional
1031+
Dictionary of out-of-sample scenario dataframes. If provided, it must have values for all data variables
1032+
in the model. pm.set_data is used to replace training data with new values.
10301033
10311034
Returns
10321035
-------
@@ -1567,8 +1570,10 @@ def _validate_forecast_args(
15671570
raise ValueError(
15681571
"Integer start must be within the range of the data index used to fit the model."
15691572
)
1570-
if periods is None and end is None:
1571-
raise ValueError("Must specify one of either periods or end")
1573+
if periods is None and end is None and not use_scenario_index:
1574+
raise ValueError(
1575+
"Must specify one of either periods or end unless use_scenario_index=True"
1576+
)
15721577
if periods is not None and end is not None:
15731578
raise ValueError("Must specify exactly one of either periods or end")
15741579
if scenario is None and use_scenario_index:
@@ -2060,9 +2065,18 @@ def forecast(
20602065

20612066
with pm.Model(coords=temp_coords) as forecast_model:
20622067
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
2068+
scenario=scenario,
20632069
data_dims=["data_time", OBS_STATE_DIM],
20642070
)
20652071

2072+
for name in self.data_names:
2073+
if name in scenario.keys():
2074+
pm.set_data(
2075+
{"data": np.zeros((len(forecast_index), self.k_endog))},
2076+
coords={"data_time": np.arange(len(forecast_index))},
2077+
)
2078+
break
2079+
20662080
group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
20672081
mu, cov = grouped_outputs[group_idx]
20682082

@@ -2073,17 +2087,6 @@ def forecast(
20732087
"P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
20742088
)
20752089

2076-
if scenario is not None:
2077-
sub_dict = {
2078-
forecast_model[data_name]: pt.as_tensor_variable(
2079-
scenario.get(data_name), name=data_name
2080-
)
2081-
for data_name in self.data_names
2082-
}
2083-
2084-
matrices = graph_replace(matrices, replace=sub_dict, strict=True)
2085-
[setattr(matrix, "name", name) for name, matrix in zip(MATRIX_NAMES[2:], matrices)]
2086-
20872090
_ = LinearGaussianStateSpace(
20882091
"forecast",
20892092
x0,

tests/statespace/test_SARIMAX.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ def test_make_SARIMA_transition_matrix(p, d, q, P, D, Q, S):
252252
"ignore:Non-stationary starting autoregressive parameters found",
253253
"ignore:Non-invertible starting seasonal moving average",
254254
"ignore:Non-stationary starting seasonal autoregressive",
255+
"ignore:divide by zero encountered in matmul:RuntimeWarning",
256+
"ignore:overflow encountered in matmul:RuntimeWarning",
257+
"ignore:invalid value encountered in matmul:RuntimeWarning",
255258
)
256259
def test_SARIMAX_update_matches_statsmodels(p, d, q, P, D, Q, S, data, rng):
257260
sm_sarimax = sm.tsa.SARIMAX(data, order=(p, d, q), seasonal_order=(P, D, Q, S))
@@ -361,6 +364,9 @@ def test_interpretable_states_are_interpretable(arima_mod_interp, pymc_mod_inter
361364
"ignore:Non-invertible starting MA parameters found.",
362365
"ignore:Non-stationary starting autoregressive parameters found",
363366
"ignore:Maximum Likelihood optimization failed to converge.",
367+
"ignore:divide by zero encountered in matmul:RuntimeWarning",
368+
"ignore:overflow encountered in matmul:RuntimeWarning",
369+
"ignore:invalid value encountered in matmul:RuntimeWarning",
364370
)
365371
def test_representations_are_equivalent(p, d, q, P, D, Q, S, data, rng):
366372
if (d + D) > 0:

tests/statespace/test_statespace.py

Lines changed: 78 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -128,32 +128,51 @@ def ss_mod_no_exog_dt(rng):
128128

129129

130130
@pytest.fixture(scope="session")
131-
def exog_ss_mod(rng):
132-
ll = st.LevelTrendComponent()
133-
reg = st.RegressionComponent(name="exog", state_names=["a", "b", "c"])
134-
mod = (ll + reg).build(verbose=False)
131+
def exog_data(rng):
132+
# simulate data
133+
df = pd.DataFrame(
134+
{
135+
"date": pd.date_range(start="2023-05-01", end="2023-05-10", freq="D"),
136+
"x1": rng.choice(2, size=10, replace=True).astype(float),
137+
"y": rng.normal(size=(10,)),
138+
}
139+
)
135140

136-
return mod
141+
df.loc[[1, 3, 9], ["y"]] = np.nan
142+
return df.set_index("date")
137143

138144

139145
@pytest.fixture(scope="session")
140-
def exog_pymc_mod(exog_ss_mod, rng):
141-
y = rng.normal(size=(100, 1)).astype(floatX)
142-
X = rng.normal(size=(100, 3)).astype(floatX)
146+
def exog_ss_mod(exog_data):
147+
level_trend = st.LevelTrendComponent(order=1, innovations_order=[0])
148+
exog = st.RegressionComponent(
149+
name="exog", # Name of this exogenous variable component
150+
k_exog=1, # Only one exogenous variable now
151+
innovations=False, # Typically fixed effect (no stochastic evolution)
152+
state_names=exog_data[["x1"]].columns.tolist(),
153+
)
143154

144-
with pm.Model(coords=exog_ss_mod.coords) as m:
145-
exog_data = pm.Data("data_exog", X)
146-
initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
147-
P0_sigma = pm.Exponential("P0_sigma", 1)
148-
P0 = pm.Deterministic(
149-
"P0", pt.eye(exog_ss_mod.k_states) * P0_sigma, dims=["state", "state_aux"]
155+
combined_model = level_trend + exog
156+
return combined_model.build()
157+
158+
159+
@pytest.fixture(scope="session")
160+
def exog_pymc_mod(exog_ss_mod, exog_data):
161+
# define pymc model
162+
with pm.Model(coords=exog_ss_mod.coords) as struct_model:
163+
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=4, dims=["state"])
164+
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=["state", "state_aux"])
165+
166+
initial_trend = pm.Normal("initial_trend", mu=[0], sigma=[0.005], dims=["trend_state"])
167+
168+
data_exog = pm.Data(
169+
"data_exog", exog_data["x1"].values[:, None], dims=["time", "exog_state"]
150170
)
151-
beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
171+
beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"])
152172

153-
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
154-
exog_ss_mod.build_statespace_graph(y, save_kalman_filter_outputs_in_idata=True)
173+
exog_ss_mod.build_statespace_graph(exog_data["y"])
155174

156-
return m
175+
return struct_model
157176

158177

159178
@pytest.fixture(scope="session")
@@ -844,10 +863,14 @@ def test_forecast(filter_output, mod_name, idata_name, start, end, periods, rng,
844863
assert forecast_idx[0] == (t0 + delta)
845864

846865

866+
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
867+
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
847868
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
848-
@pytest.mark.parametrize("start", [None, -1, 10])
869+
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
870+
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
871+
@pytest.mark.parametrize("start", [None, -1, 5])
849872
def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
850-
scenario = pd.DataFrame(np.zeros((10, 3)), columns=["a", "b", "c"])
873+
scenario = pd.DataFrame(np.zeros((10, 1)), columns=["x1"])
851874
scenario.iloc[5, 0] = 1e9
852875

853876
forecast_idata = exog_ss_mod.forecast(
@@ -856,17 +879,50 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
856879

857880
components = exog_ss_mod.extract_components_from_idata(forecast_idata)
858881
level = components.forecast_latent.sel(state="LevelTrend[level]")
859-
betas = components.forecast_latent.sel(state=["exog[a]", "exog[b]", "exog[c]"])
882+
betas = components.forecast_latent.sel(state=["exog[x1]"])
860883

861884
scenario.index.name = "time"
862885
scenario_xr = (
863886
scenario.unstack()
864887
.to_xarray()
865888
.rename({"level_0": "state"})
866-
.assign_coords(state=["exog[a]", "exog[b]", "exog[c]"])
889+
.assign_coords(state=["exog[x1]"])
867890
)
868891

869892
regression_effect = forecast_idata.forecast_observed.isel(observed_state=0) - level
870893
regression_effect_expected = (betas * scenario_xr).sum(dim=["state"])
871894

872895
assert_allclose(regression_effect, regression_effect_expected)
896+
897+
898+
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
899+
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
900+
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
901+
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
902+
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
903+
def test_foreacast_valid_index(exog_pymc_mod, exog_ss_mod, exog_data):
904+
# Regression test for issue reported at https://github.com/pymc-devs/pymc-extras/issues/424
905+
with exog_pymc_mod:
906+
idata = pm.sample_prior_predictive()
907+
908+
# Define start date and forecast period
909+
start_date, n_periods = pd.to_datetime("2023-05-05"), 5
910+
911+
# Extract exogenous data for the forecast period
912+
scenario = {
913+
"data_exog": pd.DataFrame(
914+
exog_data[["x1"]].loc[start_date:].iloc[:n_periods], columns=exog_data[["x1"]].columns
915+
)
916+
}
917+
918+
# Generate the forecast
919+
forecasts = exog_ss_mod.forecast(idata.prior, scenario=scenario, use_scenario_index=True)
920+
assert "forecast_latent" in forecasts
921+
assert "forecast_observed" in forecasts
922+
923+
assert (forecasts.coords["time"].values == scenario["data_exog"].index.values).all()
924+
assert not np.any(np.isnan(forecasts.forecast_latent.values))
925+
assert not np.any(np.isnan(forecasts.forecast_observed.values))
926+
927+
assert forecasts.forecast_latent.shape[2] == n_periods
928+
assert forecasts.forecast_observed.shape[2] == n_periods

tests/statespace/test_structural.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,11 @@ def test_autoregressive_model(order, rng):
594594
@pytest.mark.parametrize("s", [10, 25, 50])
595595
@pytest.mark.parametrize("innovations", [True, False])
596596
@pytest.mark.parametrize("remove_first_state", [True, False])
597+
@pytest.mark.filterwarnings(
598+
"ignore:divide by zero encountered in matmul:RuntimeWarning",
599+
"ignore:overflow encountered in matmul:RuntimeWarning",
600+
"ignore:invalid value encountered in matmul:RuntimeWarning",
601+
)
597602
def test_time_seasonality(s, innovations, remove_first_state, rng):
598603
def random_word(rng):
599604
return "".join(rng.choice(list("abcdefghijklmnopqrstuvwxyz")) for _ in range(5))

0 commit comments

Comments
 (0)