You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I still do not get pymc_experimental running when adding a cycle and/or a seasonal to the model. Below an example code. In this case, the model consists of an integrated random walk + cycle + seasonal with 1 harmonic + measurement noise. I can generate the data associated with this model, but when trying to integrate into pymc I get an error message from build_statespace_graph. The error message points to inconsistencies in the shape of the cyclic and/or seasonal component (shape (2,) versus shape (1,) ) in base code. The error appears when I add any combination of cyclic and/or seasonal components. Before the last error fix related to cyclic components documented on Github it only appeared when a cyclic component was added.
from pymc_experimental.statespace import structural as st
from pymc_experimental.statespace.utils.constants import SHORT_NAME_TO_LONG, MATRIX_NAMES
import matplotlib.pyplot as plt
import pymc as pm
import arviz as az
import pytensor
import pytensor.tensor as pt
import numpy as np
import pandas as pd
from patsy import dmatrix
from pymc_experimental.statespace.core.representation import PytensorRepresentation
import xarray as xr
%reload_ext autoreload
%autoreload complete
from importlib.metadata import version
print('pymc version = ', version('pymc'))
print('pytensor version = ', version('pytensor'))
print('pandas version = ', version('pandas'))
print('pandas version = ', pd.version)
print('arviz version = ', version('arviz'))
print('numpy version = ', version('numpy'))
print('pytensor version = ', version('pytensor'))
#print('blackjax version = ', version('blackjax'))
print('nutpie version = ', version('nutpie'))
print('xarray version = ', version('xarray'))
xs = np.zeros((n_simulations, steps, k_states))
ys = np.zeros((n_simulations, steps))
for i in range(n_simulations):
x, y = simulate_from_numpy_model(mod, rng, param_dict, steps)
xs[i] = x
ys[i] = y
return xs, ys
measurement_error = st.MeasurementError(name="obs")
IRW = st.LevelTrendComponent(order=2, innovations_order=[0, 1])
cycle = st.CycleComponent(name="annual_cycle", cycle_length=12, innovations=True) # cycle length is the period in number of samples; non-integer periods are allowed
SA_cycle = st.FrequencySeasonality(
name="SA_cycle", season_length=5.347, n=1, innovations=True # season_length is the period in units of number of sampels; non-integer values allowed
)
time = np.arange(144)/12
data = pd.DataFrame({'time': time, 'meas': y})
nobs = len(data['meas'])
dt = np.mean(np.diff(data['time'])) # sampling period in units of years
mod = st.LevelTrendComponent(order=2, innovations_order=[0, 1])
mod += st.CycleComponent(name='annual_cycle', cycle_length=12, innovations=True)
mod += st.FrequencySeasonality(name='SA_cycle', season_length=5.347, n=1, innovations=True)
mod += st.MeasurementError(name="obs")
model = mod.build(name="IRW+cycle+measurement_error")
with model_1:
model.build_statespace_graph(data['meas'], mode="JAX")
TypeError Traceback (most recent call last)
Cell In[12], line 2
1 with model_1:
----> 2 model.build_statespace_graph(data['meas'], mode="JAX")
File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pymc_experimental/statespace/core/statespace.py:779, in PyMCStateSpace.build_statespace_graph(self, data, register_data, mode, missing_fill_value, cov_jitter, return_updates, include_smoother)
721 """
722 Given a parameter vector theta, constructs the full computational graph describing the state space model and
723 the associated log probability of the data. Hidden states and log probabilities are computed via the Kalman
(...)
775 If return_updates is False, the method will return None.
776 """
777 pm_mod = modelcontext(None)
--> 779 self._insert_random_variables()
780 obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None)
782 self.data_len = data.shape[0]
File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pymc_experimental/statespace/core/statespace.py:637, in PyMCStateSpace._insert_random_variables(self)
633 matrices = list(self._unpack_statespace_with_placeholders())
634 replacement_dict = {
635 var: pt.atleast_1d(pymc_model[name]) for name, var in self._name_to_variable.items()
636 }
--> 637 self.subbed_ssm = graph_replace(matrices, replace=replacement_dict, strict=True)
File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/replace.py:205, in graph_replace(outputs, replace, strict)
197 raise ValueError(f"{key} is not a part of graph")
199 sorted_replacements = sorted(
200 fg_replace.items(),
201 # sort based on the fg toposort, if a variable has no owner, it goes first
202 key=partial(toposort_key, fg, toposort),
203 reverse=True,
204 )
--> 205 fg.replace_all(sorted_replacements, import_missing=True)
206 if as_list:
207 return list(fg.outputs)
File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/fg.py:515, in FunctionGraph.replace_all(self, pairs, **kwargs)
513 """Replace variables in the FunctionGraph according to (var, new_var) pairs in a list."""
514 for var, new_var in pairs:
--> 515 self.replace(var, new_var, **kwargs)
File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/fg.py:508, in FunctionGraph.replace(self, var, new_var, reason, verbose, import_missing)
501 raise AssertionError(
502 "The replacement variable has a test value with "
503 "a shape different from the original variable's "
504 f"test value. Original: {tval_shape}, new: {new_tval_shape}"
505 )
507 for node, i in list(self.clients[var]):
--> 508 self.change_node_input(
509 node, i, new_var, reason=reason, import_missing=import_missing
510 )
File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/fg.py:428, in FunctionGraph.change_node_input(self, node, i, new_var, reason, import_missing, check)
426 r = node.inputs[i]
427 if check and not r.type.is_super(new_var.type):
--> 428 raise TypeError(
429 f"The type of the replacement ({new_var.type}) must be "
430 f"compatible with the type of the original Variable ({r.type})."
431 )
432 node.inputs[i] = new_var
434 if r is new_var:
TypeError: The type of the replacement (Vector(float64, shape=(1,))) must be compatible with the type of the original Variable (Vector(float64, shape=(2,))).
The text was updated successfully, but these errors were encountered:
rklees
changed the title
including cycles generates error messages from build_statespace_graph
including cyclic or seasonal components causes error messages from build_statespace_graph since last bug fix
Dec 21, 2023
I still do not get pymc_experimental running when adding a cycle and/or a seasonal to the model. Below an example code. In this case, the model consists of an integrated random walk + cycle + seasonal with 1 harmonic + measurement noise. I can generate the data associated with this model, but when trying to integrate into pymc I get an error message from build_statespace_graph. The error message points to inconsistencies in the shape of the cyclic and/or seasonal component (shape (2,) versus shape (1,) ) in base code. The error appears when I add any combination of cyclic and/or seasonal components. Before the last error fix related to cyclic components documented on Github it only appeared when a cyclic component was added.
Here is the code and the error message:
import jax
jax.config.update("jax_platform_name", "cpu")
import numpyro
#import blackjax
import nutpie
numpyro.set_host_device_count(4)
import sys
sys.path.append("..")
print(sys.path)
import pymc_experimental.statespace
import importlib
importlib.reload(pymc_experimental.statespace.structural)
from pymc_experimental.statespace import structural as st
from pymc_experimental.statespace.utils.constants import SHORT_NAME_TO_LONG, MATRIX_NAMES
import matplotlib.pyplot as plt
import pymc as pm
import arviz as az
import pytensor
import pytensor.tensor as pt
import numpy as np
import pandas as pd
from patsy import dmatrix
from pymc_experimental.statespace.core.representation import PytensorRepresentation
import xarray as xr
%reload_ext autoreload
%autoreload complete
from importlib.metadata import version
print('pymc version = ', version('pymc'))
print('pytensor version = ', version('pytensor'))
print('pandas version = ', version('pandas'))
print('pandas version = ', pd.version)
print('arviz version = ', version('arviz'))
print('numpy version = ', version('numpy'))
print('pytensor version = ', version('pytensor'))
#print('blackjax version = ', version('blackjax'))
print('nutpie version = ', version('nutpie'))
print('xarray version = ', version('xarray'))
plt.rcParams.update(
{
"figure.figsize": (14, 4),
"figure.dpi": 144,
"figure.constrained_layout.use": True,
"axes.grid": True,
"grid.linewidth": 0.5,
"grid.linestyle": "--",
"axes.spines.top": False,
"axes.spines.bottom": False,
"axes.spines.left": False,
"axes.spines.right": False,
}
)
def unpack_statespace(ssm):
return [ssm[SHORT_NAME_TO_LONG[x]] for x in MATRIX_NAMES]
def unpack_symbolic_matrices_with_params(mod, param_dict):
f_matrices = pytensor.function(
list(mod._name_to_variable.values()), unpack_statespace(mod.ssm), on_unused_input="ignore"
)
x0, P0, c, d, T, Z, R, H, Q = f_matrices(**param_dict)
return x0, P0, c, d, T, Z, R, H, Q
def simulate_from_numpy_model(mod, rng, param_dict, steps=100):
"""
Helper function to visualize the components outside of a PyMC model context
"""
x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, param_dict)
Z_time_varies = Z.ndim == 3
def simulate_many_trajectories(mod, rng, param_dict, n_simulations, steps=100):
k_states = mod.k_states
k_posdef = mod.k_posdef
seed = sum(map(ord, "Structural Timeseries"))
rng = np.random.default_rng(seed)
measurement_error = st.MeasurementError(name="obs")
IRW = st.LevelTrendComponent(order=2, innovations_order=[0, 1])
cycle = st.CycleComponent(name="annual_cycle", cycle_length=12, innovations=True) # cycle length is the period in number of samples; non-integer periods are allowed
SA_cycle = st.FrequencySeasonality(
name="SA_cycle", season_length=5.347, n=1, innovations=True # season_length is the period in units of number of sampels; non-integer values allowed
)
param_dict = {
"initial_trend": np.zeros((2,)), "sigma_trend": np.array([0.2]),
"annual_cycle": np.array([10., 0.]), "sigma_annual_cycle": np.array([1.0]),
"SA_cycle": np.array([20., 0.]), "sigma_SA_cycle": np.array([0.5]),
"sigma_obs": np.array([0.1]),
}
mod = IRW + cycle + SA_cycle + measurement_error
x, y = simulate_from_numpy_model(mod, rng, param_dict, steps=144)
plt.figure(figsize=(10,5))
plt.plot(y), plt.title('IRW plus annual cycle plus SA_cycle')
plt.figure(figsize=(10,5))
plt.plot(x[:, 0]), plt.title('level component')
plt.figure(figsize=(10,5))
plt.plot(x[:, 1]), plt.title('trend component')
plt.figure(figsize=(10,5))
plt.plot(x[:, 2]), plt.title('annual cycle component')
plt.figure(figsize=(10,5))
plt.plot(x[:, 4]), plt.title('SA cycle component')
time = np.arange(144)/12
data = pd.DataFrame({'time': time, 'meas': y})
nobs = len(data['meas'])
dt = np.mean(np.diff(data['time'])) # sampling period in units of years
mod = st.LevelTrendComponent(order=2, innovations_order=[0, 1])
mod += st.CycleComponent(name='annual_cycle', cycle_length=12, innovations=True)
mod += st.FrequencySeasonality(name='SA_cycle', season_length=5.347, n=1, innovations=True)
mod += st.MeasurementError(name="obs")
model = mod.build(name="IRW+cycle+measurement_error")
with pm.Model(coords=coords) as model_1:
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)
initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims)
sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=10, dims=sigma_trend_dims)
annual_cycle = pm.Normal("annual_cycle", sigma=5)
sigma_annual_cycle = pm.Gamma("sigma_annual_cycle", alpha=2, beta=5)
SA_cycle = pm.Normal("SA_cycle", sigma=5)
sigma_SA_cycle = pm.Gamma("sigma_SA_cycle", alpha=2, beta=5)
sigma_obs = pm.Gamma("sigma_obs", alpha=2, beta=5, dims=('observed_state',))
with model_1:
model.build_statespace_graph(data['meas'], mode="JAX")
TypeError Traceback (most recent call last)
Cell In[12], line 2
1 with model_1:
----> 2 model.build_statespace_graph(data['meas'], mode="JAX")
File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pymc_experimental/statespace/core/statespace.py:779, in PyMCStateSpace.build_statespace_graph(self, data, register_data, mode, missing_fill_value, cov_jitter, return_updates, include_smoother)
721 """
722 Given a parameter vector
theta
, constructs the full computational graph describing the state space model and723 the associated log probability of the data. Hidden states and log probabilities are computed via the Kalman
(...)
775 If
return_updates
is False, the method will return None.776 """
777 pm_mod = modelcontext(None)
--> 779 self._insert_random_variables()
780 obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None)
782 self.data_len = data.shape[0]
File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pymc_experimental/statespace/core/statespace.py:637, in PyMCStateSpace._insert_random_variables(self)
633 matrices = list(self._unpack_statespace_with_placeholders())
634 replacement_dict = {
635 var: pt.atleast_1d(pymc_model[name]) for name, var in self._name_to_variable.items()
636 }
--> 637 self.subbed_ssm = graph_replace(matrices, replace=replacement_dict, strict=True)
File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/replace.py:205, in graph_replace(outputs, replace, strict)
197 raise ValueError(f"{key} is not a part of graph")
199 sorted_replacements = sorted(
200 fg_replace.items(),
201 # sort based on the fg toposort, if a variable has no owner, it goes first
202 key=partial(toposort_key, fg, toposort),
203 reverse=True,
204 )
--> 205 fg.replace_all(sorted_replacements, import_missing=True)
206 if as_list:
207 return list(fg.outputs)
File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/fg.py:515, in FunctionGraph.replace_all(self, pairs, **kwargs)
513 """Replace variables in the
FunctionGraph
according to(var, new_var)
pairs in a list."""514 for var, new_var in pairs:
--> 515 self.replace(var, new_var, **kwargs)
File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/fg.py:508, in FunctionGraph.replace(self, var, new_var, reason, verbose, import_missing)
501 raise AssertionError(
502 "The replacement variable has a test value with "
503 "a shape different from the original variable's "
504 f"test value. Original: {tval_shape}, new: {new_tval_shape}"
505 )
507 for node, i in list(self.clients[var]):
--> 508 self.change_node_input(
509 node, i, new_var, reason=reason, import_missing=import_missing
510 )
File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/fg.py:428, in FunctionGraph.change_node_input(self, node, i, new_var, reason, import_missing, check)
426 r = node.inputs[i]
427 if check and not r.type.is_super(new_var.type):
--> 428 raise TypeError(
429 f"The type of the replacement ({new_var.type}) must be "
430 f"compatible with the type of the original Variable ({r.type})."
431 )
432 node.inputs[i] = new_var
434 if r is new_var:
TypeError: The type of the replacement (Vector(float64, shape=(1,))) must be compatible with the type of the original Variable (Vector(float64, shape=(2,))).
The text was updated successfully, but these errors were encountered: