Skip to content

Commit 8dff969

Browse files
Use model variables named "sigma" as standard deviations rather than variances (#296)
* Model variables named "sigma" no longer expect variances * Rename variables "sigma" to "sigma2" in `create_structural_model_and_equivalent_statsmodel` test function
1 parent f26a6c9 commit 8dff969

File tree

4 files changed

+30
-28
lines changed

4 files changed

+30
-28
lines changed

pymc_experimental/statespace/models/SARIMAX.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -514,14 +514,14 @@ def make_symbolic_graph(self) -> None:
514514
state_cov = self.make_and_register_variable(
515515
"sigma_state", shape=(self.k_posdef,), dtype=floatX
516516
)
517-
self.ssm[state_cov_idx] = state_cov
517+
self.ssm[state_cov_idx] = state_cov**2
518518

519519
if self.measurement_error:
520520
obs_cov_idx = ("obs_cov",) + np.diag_indices(self.k_endog)
521521
obs_cov = self.make_and_register_variable(
522522
"sigma_obs", shape=(self.k_endog,), dtype=floatX
523523
)
524-
self.ssm[obs_cov_idx] = obs_cov
524+
self.ssm[obs_cov_idx] = obs_cov**2
525525

526526
# The initial conditions have to be done last in the case of stationary initialization, because it will depend
527527
# on c, T, R and Q

pymc_experimental/statespace/models/structural.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ def make_symbolic_graph(self) -> None:
818818
sigma_trend = self.make_and_register_variable("sigma_trend", shape=(self.k_posdef,))
819819
diag_idx = np.diag_indices(self.k_posdef)
820820
idx = np.s_["state_cov", diag_idx[0], diag_idx[1]]
821-
self.ssm[idx] = sigma_trend
821+
self.ssm[idx] = sigma_trend**2
822822

823823

824824
class MeasurementError(Component):
@@ -884,7 +884,7 @@ def make_symbolic_graph(self) -> None:
884884
error_sigma = self.make_and_register_variable(f"sigma_{self.name}", shape=(self.k_endog,))
885885
diag_idx = np.diag_indices(self.k_endog)
886886
idx = np.s_["obs_cov", diag_idx[0], diag_idx[1]]
887-
self.ssm[idx] = error_sigma
887+
self.ssm[idx] = error_sigma**2
888888

889889

890890
class AutoregressiveComponent(Component):
@@ -991,7 +991,7 @@ def make_symbolic_graph(self) -> None:
991991
self.ssm[ar_idx] = ar_params
992992

993993
cov_idx = ("state_cov", *np.diag_indices(1))
994-
self.ssm[cov_idx] = sigma_ar
994+
self.ssm[cov_idx] = sigma_ar**2
995995

996996

997997
class TimeSeasonality(Component):
@@ -1175,7 +1175,7 @@ def make_symbolic_graph(self) -> None:
11751175
self.ssm["selection", 0, 0] = 1
11761176
season_sigma = self.make_and_register_variable(f"sigma_{self.name}", shape=(1,))
11771177
cov_idx = ("state_cov", *np.diag_indices(1))
1178-
self.ssm[cov_idx] = season_sigma
1178+
self.ssm[cov_idx] = season_sigma**2
11791179

11801180

11811181
class FrequencySeasonality(Component):
@@ -1273,7 +1273,7 @@ def make_symbolic_graph(self) -> None:
12731273

12741274
if self.innovations:
12751275
sigma_season = self.make_and_register_variable(f"sigma_{self.name}", shape=(1,))
1276-
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_season
1276+
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_season**2
12771277
self.ssm["selection", :, :] = np.eye(self.k_states)
12781278

12791279
def populate_component_properties(self):
@@ -1453,7 +1453,7 @@ def make_symbolic_graph(self) -> None:
14531453

14541454
if self.innovations:
14551455
sigma_cycle = self.make_and_register_variable(f"sigma_{self.name}", shape=(1,))
1456-
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_cycle
1456+
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_cycle**2
14571457

14581458
def populate_component_properties(self):
14591459
self.state_names = [f"{self.name}_{f}" for f in ["Sin", "Cos"]]
@@ -1556,7 +1556,7 @@ def make_symbolic_graph(self) -> None:
15561556
f"sigma_beta_{self.name}", (self.k_states,)
15571557
)
15581558
row_idx, col_idx = np.diag_indices(self.k_states)
1559-
self.ssm["state_cov", row_idx, col_idx] = sigma_beta
1559+
self.ssm["state_cov", row_idx, col_idx] = sigma_beta**2
15601560

15611561
def populate_component_properties(self) -> None:
15621562
self.shock_names = self.state_names

pymc_experimental/tests/statespace/test_SARIMAX.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,9 @@ def test_SARIMAX_update_matches_statsmodels(p, d, q, P, D, Q, S, data, rng):
296296
),
297297
)
298298

299-
pm.Deterministic("sigma_state", pt.as_tensor_variable(np.array([param_d["sigma2"]])))
299+
pm.Deterministic(
300+
"sigma_state", pt.as_tensor_variable(np.sqrt(np.array([param_d["sigma2"]])))
301+
)
300302

301303
mod._insert_random_variables()
302304
matrices = pm.draw(mod.subbed_ssm)

pymc_experimental/tests/statespace/test_structural.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,9 @@ def create_structural_model_and_equivalent_statsmodel(
196196
components = []
197197

198198
if irregular:
199-
sigma = np.abs(rng.normal(size=(1,))).astype(floatX)
200-
params["sigma_irregular"] = sigma
201-
sm_params["sigma2.irregular"] = sigma.item()
199+
sigma2 = np.abs(rng.normal(size=(1,))).astype(floatX)
200+
params["sigma_irregular"] = np.sqrt(sigma2)
201+
sm_params["sigma2.irregular"] = sigma2.item()
202202
expected_param_dims["sigma_irregular"] += ("observed_state",)
203203

204204
comp = st.MeasurementError("irregular")
@@ -255,7 +255,7 @@ def create_structural_model_and_equivalent_statsmodel(
255255
).astype(floatX),
256256
np.zeros(2, dtype=floatX),
257257
)
258-
sigma_level_value = np.abs(rng.normal(size=(2,)))[
258+
sigma_level_value2 = np.abs(rng.normal(size=(2,)))[
259259
np.array(level_trend_innov_order, dtype="bool")
260260
]
261261
max_order = np.flatnonzero(level_value)[-1].item() + 1
@@ -267,9 +267,9 @@ def create_structural_model_and_equivalent_statsmodel(
267267

268268
if sum(level_trend_innov_order) > 0:
269269
expected_param_dims["sigma_trend"] += ("trend_shock",)
270-
params["sigma_trend"] = sigma_level_value
270+
params["sigma_trend"] = np.sqrt(sigma_level_value2)
271271

272-
sigma_level_value = sigma_level_value.tolist()
272+
sigma_level_value = sigma_level_value2.tolist()
273273
if stochastic_level:
274274
sigma = sigma_level_value.pop(0)
275275
sm_params["sigma2.level"] = sigma
@@ -298,9 +298,9 @@ def create_structural_model_and_equivalent_statsmodel(
298298
sm_init.update(seasonal_dict)
299299

300300
if stochastic_seasonal:
301-
sigma = np.abs(rng.normal(size=(1,))).astype(floatX)
302-
params["sigma_seasonal"] = sigma
303-
sm_params["sigma2.seasonal"] = sigma
301+
sigma2 = np.abs(rng.normal(size=(1,))).astype(floatX)
302+
params["sigma_seasonal"] = np.sqrt(sigma2)
303+
sm_params["sigma2.seasonal"] = sigma2
304304
expected_coords[SHOCK_DIM] += [
305305
"seasonal",
306306
]
@@ -343,9 +343,9 @@ def create_structural_model_and_equivalent_statsmodel(
343343
state_count += 1
344344

345345
if has_innov:
346-
sigma = np.abs(rng.normal(size=(1,))).astype(floatX)
347-
params[f"sigma_seasonal_{s}"] = sigma
348-
sm_params[f"sigma2.freq_seasonal_{s}({n})"] = sigma
346+
sigma2 = np.abs(rng.normal(size=(1,))).astype(floatX)
347+
params[f"sigma_seasonal_{s}"] = np.sqrt(sigma2)
348+
sm_params[f"sigma2.freq_seasonal_{s}({n})"] = sigma2
349349
expected_coords[SHOCK_DIM] += state_names
350350
expected_coords[SHOCK_AUX_DIM] += state_names
351351

@@ -374,12 +374,12 @@ def create_structural_model_and_equivalent_statsmodel(
374374
sm_init["cycle.auxilliary"] = init_cycle[1]
375375

376376
if stochastic_cycle:
377-
sigma = np.abs(rng.normal(size=(1,))).astype(floatX)
378-
params["sigma_cycle"] = sigma
377+
sigma2 = np.abs(rng.normal(size=(1,))).astype(floatX)
378+
params["sigma_cycle"] = np.sqrt(sigma2)
379379
expected_coords[SHOCK_DIM] += state_names
380380
expected_coords[SHOCK_AUX_DIM] += state_names
381381

382-
sm_params["sigma2.cycle"] = sigma
382+
sm_params["sigma2.cycle"] = sigma2
383383

384384
if damped_cycle:
385385
rho = rng.beta(1, 1, size=(1,)).astype(floatX)
@@ -398,18 +398,18 @@ def create_structural_model_and_equivalent_statsmodel(
398398
if autoregressive is not None:
399399
ar_names = [f"L{i+1}.data" for i in range(autoregressive)]
400400
ar_params = rng.normal(size=(autoregressive,)).astype(floatX)
401-
sigma = np.abs(rng.normal(size=(1,))).astype(floatX)
401+
sigma2 = np.abs(rng.normal(size=(1,))).astype(floatX)
402402

403403
params["ar_params"] = ar_params
404-
params["sigma_ar"] = sigma
404+
params["sigma_ar"] = np.sqrt(sigma2)
405405
expected_param_dims["ar_params"] += (AR_PARAM_DIM,)
406406
expected_coords[AR_PARAM_DIM] += tuple(list(range(1, autoregressive + 1)))
407407
expected_coords[ALL_STATE_DIM] += ar_names
408408
expected_coords[ALL_STATE_AUX_DIM] += ar_names
409409
expected_coords[SHOCK_DIM] += ["ar_innovation"]
410410
expected_coords[SHOCK_AUX_DIM] += ["ar_innovation"]
411411

412-
sm_params["sigma2.ar"] = sigma
412+
sm_params["sigma2.ar"] = sigma2
413413
for i, rho in enumerate(ar_params):
414414
sm_init[f"ar.L{i+1}"] = 0
415415
sm_params[f"ar.L{i+1}"] = rho

0 commit comments

Comments
 (0)