Skip to content

Commit 99f30aa

Browse files
Fix initial state size in CycleComponent (#288)
* Cycle initial state has shape (2,) and two named dimensions Test initial state against statsmodels * Add tests of coords and dims for structural models * Add tests of params_info for structural models
1 parent 656b800 commit 99f30aa

File tree

2 files changed

+261
-53
lines changed

2 files changed

+261
-53
lines changed

pymc_experimental/statespace/models/structural.py

+40-27
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
from pymc_experimental.statespace.utils.constants import (
1919
ALL_STATE_AUX_DIM,
2020
ALL_STATE_DIM,
21+
AR_PARAM_DIM,
2122
LONG_MATRIX_NAMES,
2223
OBS_STATE_DIM,
2324
POSITION_DERIVATIVE_NAMES,
25+
TIME_DIM,
2426
)
2527

2628
_log = logging.getLogger("pymc.experimental.statespace")
@@ -786,7 +788,7 @@ def populate_component_properties(self):
786788
self.state_names = [name for name, mask in zip(name_slice, self._order_mask) if mask]
787789
self.param_dims = {"initial_trend": ("trend_state",)}
788790
self.coords = {"trend_state": self.state_names}
789-
self.param_info = {"initial_trend": {"shape": (self.k_states,), "constraints": "None"}}
791+
self.param_info = {"initial_trend": {"shape": (self.k_states,), "constraints": None}}
790792

791793
if self.k_posdef > 0:
792794
self.param_names += ["sigma_trend"]
@@ -871,7 +873,11 @@ def populate_component_properties(self):
871873
self.param_names = [f"sigma_{self.name}"]
872874
self.param_dims = {f"sigma_{self.name}": (OBS_STATE_DIM,)}
873875
self.param_info = {
874-
f"sigma_{self.name}": {"shape": (1,), "constraints": "Positive", "dims": "None"}
876+
f"sigma_{self.name}": {
877+
"shape": (1,),
878+
"constraints": "Positive",
879+
"dims": (OBS_STATE_DIM,),
880+
}
875881
}
876882

877883
def make_symbolic_graph(self) -> None:
@@ -959,11 +965,15 @@ def populate_component_properties(self):
959965
self.state_names = [f"L{i + 1}.data" for i in range(self.k_states)]
960966
self.shock_names = [f"{self.name}_innovation"]
961967
self.param_names = ["ar_params", "sigma_ar"]
962-
self.param_dims = {"ar_params": ("ar_lags",)}
963-
self.coords = {"ar_lags": self.ar_lags}
968+
self.param_dims = {"ar_params": (AR_PARAM_DIM,)}
969+
self.coords = {AR_PARAM_DIM: self.ar_lags.tolist()}
964970

965971
self.param_info = {
966-
"ar_params": {"shape": (self.k_states,), "constraints": "None", "dims": "(ar_lags, )"},
972+
"ar_params": {
973+
"shape": (self.k_states,),
974+
"constraints": None,
975+
"dims": (AR_PARAM_DIM,),
976+
},
967977
"sigma_ar": {"shape": (1,), "constraints": "Positive", "dims": None},
968978
}
969979

@@ -1133,19 +1143,19 @@ def populate_component_properties(self):
11331143
self.param_info = {
11341144
f"{self.name}_coefs": {
11351145
"shape": (self.k_states,),
1136-
"constraints": "None",
1137-
"dims": f"({self.name}_state, )",
1146+
"constraints": None,
1147+
"dims": (f"{self.name}_state",),
11381148
}
11391149
}
1140-
self.param_dims = {f"{self.name}_coefs": (f"{self.name}_periods",)}
1150+
self.param_dims = {f"{self.name}_coefs": (f"{self.name}_state",)}
11411151
self.coords = {f"{self.name}_state": self.state_names}
11421152

11431153
if self.innovations:
11441154
self.param_names += [f"sigma_{self.name}"]
11451155
self.param_info[f"sigma_{self.name}"] = {
11461156
"shape": (1,),
11471157
"constraints": "Positive",
1148-
"dims": "None",
1158+
"dims": None,
11491159
}
11501160
self.shock_names = [f"{self.name}"]
11511161

@@ -1270,27 +1280,27 @@ def populate_component_properties(self):
12701280
self.state_names = [f"{self.name}_{f}_{i}" for i in range(self.n) for f in ["Cos", "Sin"]]
12711281
self.param_names = [f"{self.name}"]
12721282

1273-
self.param_dims = {self.name: (f"{self.name}_initial_state",)}
1283+
self.param_dims = {self.name: (f"{self.name}_state",)}
12741284
self.param_info = {
12751285
f"{self.name}": {
12761286
"shape": (self.k_states - int(self.last_state_not_identified),),
1277-
"constraints": "None",
1278-
"dims": f"({self.name}_initial_state, )",
1287+
"constraints": None,
1288+
"dims": (f"{self.name}_state",),
12791289
}
12801290
}
12811291

12821292
init_state_idx = np.arange(self.k_states, dtype=int)
12831293
if self.last_state_not_identified:
12841294
init_state_idx = init_state_idx[:-1]
1285-
self.coords = {f"{self.name}_initial_state": [self.state_names[i] for i in init_state_idx]}
1295+
self.coords = {f"{self.name}_state": [self.state_names[i] for i in init_state_idx]}
12861296

12871297
if self.innovations:
12881298
self.shock_names = self.state_names.copy()
12891299
self.param_names += [f"sigma_{self.name}"]
12901300
self.param_info[f"sigma_{self.name}"] = {
12911301
"shape": (1,),
12921302
"constraints": "Positive",
1293-
"dims": "None",
1303+
"dims": None,
12941304
}
12951305

12961306

@@ -1421,10 +1431,12 @@ def __init__(
14211431
def make_symbolic_graph(self) -> None:
14221432
self.ssm["design", 0, slice(0, self.k_states, 2)] = 1
14231433
self.ssm["selection", :, :] = np.eye(self.k_states)
1434+
self.param_dims = {self.name: (f"{self.name}_state",)}
1435+
self.coords = {f"{self.name}_state": self.state_names}
14241436

1425-
init_state = self.make_and_register_variable(f"{self.name}", shape=(1,))
1437+
init_state = self.make_and_register_variable(f"{self.name}", shape=(self.k_states,))
14261438

1427-
self.ssm["initial_state", 0] = init_state
1439+
self.ssm["initial_state", :] = init_state
14281440

14291441
if self.estimate_cycle_length:
14301442
lamb = self.make_and_register_variable(f"{self.name}_length", shape=(1,))
@@ -1440,18 +1452,18 @@ def make_symbolic_graph(self) -> None:
14401452
self.ssm["transition", :, :] = T
14411453

14421454
if self.innovations:
1443-
sigma_season = self.make_and_register_variable(f"sigma_{self.name}", shape=(1,))
1444-
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_season
1455+
sigma_cycle = self.make_and_register_variable(f"sigma_{self.name}", shape=(1,))
1456+
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_cycle
14451457

14461458
def populate_component_properties(self):
14471459
self.state_names = [f"{self.name}_{f}" for f in ["Sin", "Cos"]]
14481460
self.param_names = [f"{self.name}"]
14491461

14501462
self.param_info = {
14511463
f"{self.name}": {
1452-
"shape": (1,),
1453-
"constraints": "None",
1454-
"dims": None,
1464+
"shape": (2,),
1465+
"constraints": None,
1466+
"dims": (f"{self.name}_state",),
14551467
}
14561468
}
14571469

@@ -1476,7 +1488,7 @@ def populate_component_properties(self):
14761488
self.param_info[f"sigma_{self.name}"] = {
14771489
"shape": (1,),
14781490
"constraints": "Positive",
1479-
"dims": "None",
1491+
"dims": None,
14801492
}
14811493
self.shock_names = self.state_names.copy()
14821494

@@ -1551,15 +1563,16 @@ def populate_component_properties(self) -> None:
15511563

15521564
self.param_names = [f"beta_{self.name}", f"data_{self.name}"]
15531565
self.param_dims = {
1554-
f"beta_{self.name}": "exog_state",
1555-
f"data_{self.name}": ("time", "exog_state"),
1566+
f"beta_{self.name}": ("exog_state",),
1567+
f"data_{self.name}": (TIME_DIM, "exog_state"),
15561568
}
1569+
15571570
self.param_info = {
1558-
f"beta_{self.name}": {"shape": (1,), "constraints": "None", "dims": ("exog_state",)},
1571+
f"beta_{self.name}": {"shape": (1,), "constraints": None, "dims": ("exog_state",)},
15591572
f"data_{self.name}": {
15601573
"shape": (None, self.k_states),
1561-
"constraints": "None",
1562-
"dims": ("time", "exog_state"),
1574+
"constraints": None,
1575+
"dims": (TIME_DIM, "exog_state"),
15631576
},
15641577
}
15651578
self.coords = {f"exog_state": self.state_names}

0 commit comments

Comments
 (0)