diff --git a/pymc_extras/statespace/utils/data_tools.py b/pymc_extras/statespace/utils/data_tools.py index 5119726e..8c8284c5 100644 --- a/pymc_extras/statespace/utils/data_tools.py +++ b/pymc_extras/statespace/utils/data_tools.py @@ -87,12 +87,7 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals col_names = data.columns _validate_data_shape(data.shape, n_obs, obs_coords, check_column_names, col_names) - if isinstance(data.index, pd.RangeIndex): - if obs_coords is not None: - warnings.warn(NO_TIME_INDEX_WARNING) - return preprocess_numpy_data(data.values, n_obs, obs_coords) - - elif isinstance(data.index, pd.DatetimeIndex): + if isinstance(data.index, pd.DatetimeIndex): if data.index.freq is None: warnings.warn(NO_FREQ_INFO_WARNING) data.index.freq = data.index.inferred_freq @@ -100,10 +95,30 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals index = data.index return data.values, index + elif isinstance(data.index, pd.RangeIndex): + if obs_coords is not None: + warnings.warn(NO_TIME_INDEX_WARNING) + return preprocess_numpy_data(data.values, n_obs, obs_coords) + + elif isinstance(data.index, pd.MultiIndex): + if obs_coords is not None: + warnings.warn(NO_TIME_INDEX_WARNING) + + raise NotImplementedError("MultiIndex panel data is not currently supported.") + else: - raise IndexError( - f"Expected pd.DatetimeIndex or pd.RangeIndex on data, found {type(data.index)}" - ) + if obs_coords is not None: + warnings.warn(NO_TIME_INDEX_WARNING) + + index = data.index + if not np.issubdtype(index.dtype, np.integer): + raise IndexError("Provided index is not an integer index.") + + index_diff = index.to_series().diff().dropna().values + if not (index_diff == 1).all(): + raise IndexError("Provided index is not monotonic increasing.") + + return preprocess_numpy_data(data.values, n_obs, obs_coords) def add_data_to_active_model(values, index, data_dims=None): diff --git a/tests/statespace/test_coord_assignment.py b/tests/statespace/test_coord_assignment.py index 8e2fea58..40aaec12 100644 --- a/tests/statespace/test_coord_assignment.py +++ b/tests/statespace/test_coord_assignment.py @@ -8,6 +8,7 @@ import pytest from pymc_extras.statespace.models import structural +from pymc_extras.statespace.models.structural import LevelTrendComponent from pymc_extras.statespace.utils.constants import ( FILTER_OUTPUT_DIMS, FILTER_OUTPUT_NAMES, @@ -114,3 +115,67 @@ def test_data_index_is_coord(f, warning, create_model): with warning: pymc_model = create_model(f) assert TIME_DIM in pymc_model.coords + + +def make_model(index): + n = len(index) + a = pd.DataFrame(index=index, columns=["A", "B", "C", "D"], data=np.arange(n * 4).reshape(n, 4)) + + mod = LevelTrendComponent(order=2, innovations_order=[0, 1]) + ss_mod = mod.build(name="a", verbose=False) + + initial_trend_dims, sigma_trend_dims, P0_dims = ss_mod.param_dims.values() + coords = ss_mod.coords + + with pm.Model(coords=coords) as model: + P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5) + P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states) * P0_diag, dims=P0_dims) + + initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims) + sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=50, dims=sigma_trend_dims) + + with pytest.warns(UserWarning, match="No time index found on the supplied data"): + ss_mod.build_statespace_graph( + a["A"], + mode="JAX", + ) + return model + + +def test_integer_index(): + index = np.arange(8).astype(int) + model = make_model(index) + assert TIME_DIM in model.coords + np.testing.assert_allclose(model.coords[TIME_DIM], index) + + +def test_float_index_raises(): + index = np.linspace(0, 1, 8) + + with pytest.raises(IndexError, match="Provided index is not an integer index"): + make_model(index) + + +def test_non_strictly_monotone_index_raises(): + # Decreases + index = [0, 1, 2, 1, 2, 3] + with pytest.raises(IndexError, match="Provided index is not monotonic increasing"): + make_model(index) + + # Has gaps + index = [0, 1, 2, 3, 5, 6] + with pytest.raises(IndexError, match="Provided index is not monotonic increasing"): + make_model(index) + + # Has duplicates + index = [0, 1, 1, 2, 3, 4] + with pytest.raises(IndexError, match="Provided index is not monotonic increasing"): + make_model(index) + + +def test_multiindex_raises(): + index = pd.MultiIndex.from_tuples([(0, 0), (1, 1), (2, 2), (3, 3)]) + with pytest.raises( + NotImplementedError, match="MultiIndex panel data is not currently supported" + ): + make_model(index)