You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I build a PyMC statespace model with pymc-experimental, I always get an error message when adding a cycle to the model:
"TypeError: The type of the replacement (Vector(float64, shape=(2,))) must be compatible with the type of the original Variable (Vector(float64, shape=(1,)))."
Below the code. The error message refers to the line "annual cycle = ...".
I guess the problem is that annual_cycle should have shape = (2,), but has shape = (1,). May that points to a bug in the base code? For instance, in line 1329, init_state gets shape=(1,), though I wonder if it should read shape=(2,). However, this alone does not fix the problem as I found; more adaptations may be needed.
The text was updated successfully, but these errors were encountered:
When I build a PyMC statespace model with pymc-experimental, I always get an error message when adding a cycle to the model:
"TypeError: The type of the replacement (Vector(float64, shape=(2,))) must be compatible with the type of the original Variable (Vector(float64, shape=(1,)))."
Below the code. The error message refers to the line "annual cycle = ...".
mod = st.LevelTrendComponent(order=2, innovations_order=[0, 1])
mod += st.CycleComponent(name='annual_cycle', cycle_length=12, innovations=True)
mod += st.MeasurementError(name="obs")
model = mod.build(name="IRW+cycle+measurement_error")
model.param_dims
initial_trend_dims, sigma_trend_dims, annual_cycle_dims, sigma_obs_dims, P0_dims = model.param_dims.values()
coords = model.coords
with pm.Model(coords=coords) as model_1:
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)
initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims)
sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=10, dims=sigma_trend_dims)
annual_cycle = pm.Normal("annual_cycle", sigma=5, dims=annual_cycle_dims)
sigma_annual_cycle = pm.Gamma("sigma_annual_cycle", alpha=2, beta=5)
sigma_obs = pm.Gamma("sigma_obs", alpha=2, beta=5, dims=['observed_state'])
model.build_statespace_graph(data, mode="JAX")
idata = pm.sample(nuts_sampler="numpyro", target_accept=0.9)
I guess the problem is that annual_cycle should have shape = (2,), but has shape = (1,). May that points to a bug in the base code? For instance, in line 1329, init_state gets shape=(1,), though I wonder if it should read shape=(2,). However, this alone does not fix the problem as I found; more adaptations may be needed.
The text was updated successfully, but these errors were encountered: