Skip to content

Commit bca5360

Browse files
committed
added and modified statespace tests
1 parent dc02857 commit bca5360

File tree

1 file changed

+54
-97
lines changed

1 file changed

+54
-97
lines changed

tests/statespace/test_statespace.py

Lines changed: 54 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -128,32 +128,51 @@ def ss_mod_no_exog_dt(rng):
128128

129129

130130
@pytest.fixture(scope="session")
131-
def exog_ss_mod(rng):
132-
ll = st.LevelTrendComponent()
133-
reg = st.RegressionComponent(name="exog", state_names=["a", "b", "c"])
134-
mod = (ll + reg).build(verbose=False)
131+
def exog_data(rng):
132+
# simulate data
133+
df = pd.DataFrame(
134+
{
135+
"date": pd.date_range(start="2023-05-01", end="2023-05-10", freq="D"),
136+
"x1": rng.choice(2, size=10, replace=True).astype(float),
137+
"y": rng.normal(size=(10,)),
138+
}
139+
)
135140

136-
return mod
141+
df.loc[[1, 3, 9], ["y"]] = np.nan
142+
return df.set_index("date")
137143

138144

139145
@pytest.fixture(scope="session")
140-
def exog_pymc_mod(exog_ss_mod, rng):
141-
y = rng.normal(size=(100, 1)).astype(floatX)
142-
X = rng.normal(size=(100, 3)).astype(floatX)
146+
def exog_ss_mod(exog_data):
147+
level_trend = st.LevelTrendComponent(order=1, innovations_order=[0])
148+
exog = st.RegressionComponent(
149+
name="exog", # Name of this exogenous variable component
150+
k_exog=1, # Only one exogenous variable now
151+
innovations=False, # Typically fixed effect (no stochastic evolution)
152+
state_names=exog_data[["x1"]].columns.tolist(),
153+
)
143154

144-
with pm.Model(coords=exog_ss_mod.coords) as m:
145-
exog_data = pm.Data("data_exog", X)
146-
initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
147-
P0_sigma = pm.Exponential("P0_sigma", 1)
148-
P0 = pm.Deterministic(
149-
"P0", pt.eye(exog_ss_mod.k_states) * P0_sigma, dims=["state", "state_aux"]
155+
combined_model = level_trend + exog
156+
return combined_model.build()
157+
158+
159+
@pytest.fixture(scope="session")
160+
def exog_pymc_mod(exog_ss_mod, exog_data):
161+
# define pymc model
162+
with pm.Model(coords=exog_ss_mod.coords) as struct_model:
163+
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=4, dims=["state"])
164+
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=["state", "state_aux"])
165+
166+
initial_trend = pm.Normal("initial_trend", mu=[0], sigma=[0.005], dims=["trend_state"])
167+
168+
data_exog = pm.Data(
169+
"data_exog", exog_data["x1"].values[:, None], dims=["time", "exog_state"]
150170
)
151-
beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
171+
beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"])
152172

153-
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
154-
exog_ss_mod.build_statespace_graph(y, save_kalman_filter_outputs_in_idata=True)
173+
exog_ss_mod.build_statespace_graph(exog_data["y"])
155174

156-
return m
175+
return struct_model
157176

158177

159178
@pytest.fixture(scope="session")
@@ -844,10 +863,14 @@ def test_forecast(filter_output, mod_name, idata_name, start, end, periods, rng,
844863
assert forecast_idx[0] == (t0 + delta)
845864

846865

866+
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
867+
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
847868
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
848-
@pytest.mark.parametrize("start", [None, -1, 10])
869+
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
870+
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
871+
@pytest.mark.parametrize("start", [None, -1, 5])
849872
def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
850-
scenario = pd.DataFrame(np.zeros((10, 3)), columns=["a", "b", "c"])
873+
scenario = pd.DataFrame(np.zeros((10, 1)), columns=["x1"])
851874
scenario.iloc[5, 0] = 1e9
852875

853876
forecast_idata = exog_ss_mod.forecast(
@@ -856,14 +879,14 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
856879

857880
components = exog_ss_mod.extract_components_from_idata(forecast_idata)
858881
level = components.forecast_latent.sel(state="LevelTrend[level]")
859-
betas = components.forecast_latent.sel(state=["exog[a]", "exog[b]", "exog[c]"])
882+
betas = components.forecast_latent.sel(state=["exog[x1]"])
860883

861884
scenario.index.name = "time"
862885
scenario_xr = (
863886
scenario.unstack()
864887
.to_xarray()
865888
.rename({"level_0": "state"})
866-
.assign_coords(state=["exog[a]", "exog[b]", "exog[c]"])
889+
.assign_coords(state=["exog[x1]"])
867890
)
868891

869892
regression_effect = forecast_idata.forecast_observed.isel(observed_state=0) - level
@@ -872,91 +895,25 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
872895
assert_allclose(regression_effect, regression_effect_expected)
873896

874897

875-
@pytest.mark.filterwarnings("ignore:Provided data contains missing values.")
898+
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
876899
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
877-
def test_foreacast_valid_index(rng):
900+
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
901+
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
902+
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
903+
def test_foreacast_valid_index(exog_pymc_mod, exog_ss_mod, exog_data):
878904
# Regression test for issue reported at https://github.com/pymc-devs/pymc-extras/issues/424
879-
880-
index = pd.date_range(start="2023-05-01", end="2025-01-29", freq="D")
881-
T, k = len(index), 2
882-
data = np.zeros((T, k))
883-
idx = rng.choice(T, size=10, replace=False)
884-
cols = rng.choice(k, size=10, replace=True)
885-
886-
data[idx, cols] = 1
887-
888-
df_holidays = pd.DataFrame(data, index=index, columns=["Holiday 1", "Holiday 2"])
889-
890-
data = rng.normal(size=(T, 1))
891-
nan_locs = rng.choice(T, size=10, replace=False)
892-
data[nan_locs] = np.nan
893-
y = pd.DataFrame(data, index=index, columns=["sales"])
894-
895-
level_trend = st.LevelTrendComponent(order=1, innovations_order=[0])
896-
weekly_seasonality = st.TimeSeasonality(
897-
season_length=7,
898-
state_names=["Sun", "Mon", "Tues", "Wed", "Thu", "Fri", "Sat"],
899-
innovations=True,
900-
remove_first_state=False,
901-
)
902-
quarterly_seasonality = st.FrequencySeasonality(season_length=365, n=2, innovations=True)
903-
ar1 = st.AutoregressiveComponent(order=1)
904-
me = st.MeasurementError()
905-
906-
exog = st.RegressionComponent(
907-
name="exog", # Name of this exogenous variable component
908-
k_exog=2, # Only one exogenous variable now
909-
innovations=False, # Typically fixed effect (no stochastic evolution)
910-
state_names=df_holidays.columns.tolist(),
911-
)
912-
913-
combined_model = level_trend + weekly_seasonality + quarterly_seasonality + me + ar1 + exog
914-
ss_mod = combined_model.build()
915-
916-
with pm.Model(coords=ss_mod.coords) as struct_model:
917-
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=10, dims=["state"])
918-
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=["state", "state_aux"])
919-
920-
initial_trend = pm.Normal("initial_trend", mu=[0], sigma=[0.005], dims=["trend_state"])
921-
# sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=1, dims=["trend_shock"]) # Applied to the level only
922-
923-
Seasonal_coefs = pm.ZeroSumNormal(
924-
"Seasonal[s=7]_coefs", sigma=0.5, dims=["Seasonal[s=7]_state"]
925-
) # DOW dev. from weekly mean
926-
sigma_Seasonal = pm.Gamma(
927-
"sigma_Seasonal[s=7]", alpha=2, beta=1
928-
) # How much this dev. can dev.
929-
930-
Frequency_coefs = pm.Normal(
931-
"Frequency[s=365, n=2]", mu=0, sigma=0.5, dims=["Frequency[s=365, n=2]_state"]
932-
) # amplitudes in short-term (weekly noise culprit)
933-
sigma_Frequency = pm.Gamma(
934-
"sigma_Frequency[s=365, n=2]", alpha=2, beta=1
935-
) # smoothness & adaptability over time
936-
937-
ar_params = pm.Laplace("ar_params", mu=0, b=0.2, dims=["ar_lag"])
938-
sigma_ar = pm.Gamma("sigma_ar", alpha=2, beta=1)
939-
940-
sigma_measurement_error = pm.HalfStudentT("sigma_MeasurementError", nu=3, sigma=1)
941-
942-
data_exog = pm.Data("data_exog", df_holidays.values, dims=["time", "exog_state"])
943-
beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"])
944-
945-
ss_mod.build_statespace_graph(y, mode="JAX")
946-
905+
with exog_pymc_mod:
947906
idata = pm.sample_prior_predictive()
948907

949-
post = ss_mod.sample_conditional_prior(idata)
950-
951908
# Define start date and forecast period
952-
start_date, n_periods = pd.to_datetime("2024-4-15"), 8
909+
start_date, n_periods = pd.to_datetime("2023-05-05"), 5
953910

954911
# Extract exogenous data for the forecast period
955912
scenario = {
956913
"data_exog": pd.DataFrame(
957-
df_holidays.loc[start_date:].iloc[:n_periods], columns=df_holidays.columns
914+
exog_data[["x1"]].loc[start_date:].iloc[:n_periods], columns=exog_data[["x1"]].columns
958915
)
959916
}
960917

961918
# Generate the forecast
962-
forecasts = ss_mod.forecast(idata.prior, scenario=scenario, use_scenario_index=True)
919+
forecasts = exog_ss_mod.forecast(idata.prior, scenario=scenario, use_scenario_index=True)

0 commit comments

Comments
 (0)