From 78cc4c363ecc64453106f06bae6be23da2faaa2a Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 15 Mar 2025 21:58:44 +0800 Subject: [PATCH 1/3] Make index check less strict --- pymc_extras/statespace/utils/data_tools.py | 12 ++++----- tests/statespace/test_coord_assignment.py | 29 ++++++++++++++++++++++ 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/pymc_extras/statespace/utils/data_tools.py b/pymc_extras/statespace/utils/data_tools.py index 5119726e..87589c52 100644 --- a/pymc_extras/statespace/utils/data_tools.py +++ b/pymc_extras/statespace/utils/data_tools.py @@ -87,12 +87,7 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals col_names = data.columns _validate_data_shape(data.shape, n_obs, obs_coords, check_column_names, col_names) - if isinstance(data.index, pd.RangeIndex): - if obs_coords is not None: - warnings.warn(NO_TIME_INDEX_WARNING) - return preprocess_numpy_data(data.values, n_obs, obs_coords) - - elif isinstance(data.index, pd.DatetimeIndex): + if isinstance(data.index, pd.DatetimeIndex): if data.index.freq is None: warnings.warn(NO_FREQ_INFO_WARNING) data.index.freq = data.index.inferred_freq @@ -100,6 +95,11 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals index = data.index return data.values, index + elif isinstance(data.index, pd.Index): + if obs_coords is not None: + warnings.warn(NO_TIME_INDEX_WARNING) + return preprocess_numpy_data(data.values, n_obs, obs_coords) + else: raise IndexError( f"Expected pd.DatetimeIndex or pd.RangeIndex on data, found {type(data.index)}" diff --git a/tests/statespace/test_coord_assignment.py b/tests/statespace/test_coord_assignment.py index 8e2fea58..938527c6 100644 --- a/tests/statespace/test_coord_assignment.py +++ b/tests/statespace/test_coord_assignment.py @@ -8,6 +8,7 @@ import pytest from pymc_extras.statespace.models import structural +from pymc_extras.statespace.models.structural import LevelTrendComponent from pymc_extras.statespace.utils.constants import ( FILTER_OUTPUT_DIMS, FILTER_OUTPUT_NAMES, @@ -114,3 +115,31 @@ def test_data_index_is_coord(f, warning, create_model): with warning: pymc_model = create_model(f) assert TIME_DIM in pymc_model.coords + + +def test_integer_index(): + a = pd.DataFrame( + index=np.arange(8), columns=["A", "B", "C", "D"], data=np.arange(32).reshape(8, 4) + ) + + mod = LevelTrendComponent(order=2, innovations_order=[0, 1]) + ss_mod = mod.build(name="a", verbose=False) + + initial_trend_dims, sigma_trend_dims, P0_dims = ss_mod.param_dims.values() + coords = ss_mod.coords + + with pm.Model(coords=coords) as model_1: + P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5) + P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states) * P0_diag, dims=P0_dims) + + initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims) + sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=50, dims=sigma_trend_dims) + + with pytest.warns(UserWarning, match="No time index found on the supplied data"): + ss_mod.build_statespace_graph( + a["A"], + mode="JAX", + ) + + assert TIME_DIM in model_1.coords + np.testing.assert_allclose(model_1.coords[TIME_DIM], a.index) From 14b635c43ed451426ce3c4bf5e23aefdc19dbc8d Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Wed, 19 Mar 2025 15:02:12 +0800 Subject: [PATCH 2/3] Validate generic index values --- pymc_extras/statespace/utils/data_tools.py | 26 +++++++++-- tests/statespace/test_coord_assignment.py | 50 +++++++++++++++++++--- 2 files changed, 65 insertions(+), 11 deletions(-) diff --git a/pymc_extras/statespace/utils/data_tools.py b/pymc_extras/statespace/utils/data_tools.py index 87589c52..b85d7e67 100644 --- a/pymc_extras/statespace/utils/data_tools.py +++ b/pymc_extras/statespace/utils/data_tools.py @@ -95,15 +95,33 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals index = data.index return data.values, index - elif isinstance(data.index, pd.Index): + elif isinstance(data.index, pd.RangeIndex): if obs_coords is not None: warnings.warn(NO_TIME_INDEX_WARNING) return preprocess_numpy_data(data.values, n_obs, obs_coords) + elif isinstance(data.index, pd.MultiIndex): + if obs_coords is not None: + warnings.warn(NO_TIME_INDEX_WARNING) + + raise NotImplementedError("MultiIndex panel data is not currently supported.") + else: - raise IndexError( - f"Expected pd.DatetimeIndex or pd.RangeIndex on data, found {type(data.index)}" - ) + if obs_coords is not None: + warnings.warn(NO_TIME_INDEX_WARNING) + + index = data.index + if not np.issubdtype(index.dtype, np.integer): + raise IndexError("Provided index is not an integer index.") + + if not index.is_monotonic_increasing: + raise IndexError("Provided index is not monotonic increasing.") + + index_diff = index.to_series().diff().dropna().values + if not (index_diff == 1).all(): + raise IndexError("Provided index is not monotonic increasing.") + + return preprocess_numpy_data(data.values, n_obs, obs_coords) def add_data_to_active_model(values, index, data_dims=None): diff --git a/tests/statespace/test_coord_assignment.py b/tests/statespace/test_coord_assignment.py index 938527c6..40aaec12 100644 --- a/tests/statespace/test_coord_assignment.py +++ b/tests/statespace/test_coord_assignment.py @@ -117,10 +117,9 @@ def test_data_index_is_coord(f, warning, create_model): assert TIME_DIM in pymc_model.coords -def test_integer_index(): - a = pd.DataFrame( - index=np.arange(8), columns=["A", "B", "C", "D"], data=np.arange(32).reshape(8, 4) - ) +def make_model(index): + n = len(index) + a = pd.DataFrame(index=index, columns=["A", "B", "C", "D"], data=np.arange(n * 4).reshape(n, 4)) mod = LevelTrendComponent(order=2, innovations_order=[0, 1]) ss_mod = mod.build(name="a", verbose=False) @@ -128,7 +127,7 @@ def test_integer_index(): initial_trend_dims, sigma_trend_dims, P0_dims = ss_mod.param_dims.values() coords = ss_mod.coords - with pm.Model(coords=coords) as model_1: + with pm.Model(coords=coords) as model: P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5) P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states) * P0_diag, dims=P0_dims) @@ -140,6 +139,43 @@ def test_integer_index(): a["A"], mode="JAX", ) + return model + + +def test_integer_index(): + index = np.arange(8).astype(int) + model = make_model(index) + assert TIME_DIM in model.coords + np.testing.assert_allclose(model.coords[TIME_DIM], index) + + +def test_float_index_raises(): + index = np.linspace(0, 1, 8) + + with pytest.raises(IndexError, match="Provided index is not an integer index"): + make_model(index) + + +def test_non_strictly_monotone_index_raises(): + # Decreases + index = [0, 1, 2, 1, 2, 3] + with pytest.raises(IndexError, match="Provided index is not monotonic increasing"): + make_model(index) + + # Has gaps + index = [0, 1, 2, 3, 5, 6] + with pytest.raises(IndexError, match="Provided index is not monotonic increasing"): + make_model(index) + + # Has duplicates + index = [0, 1, 1, 2, 3, 4] + with pytest.raises(IndexError, match="Provided index is not monotonic increasing"): + make_model(index) + - assert TIME_DIM in model_1.coords - np.testing.assert_allclose(model_1.coords[TIME_DIM], a.index) +def test_multiindex_raises(): + index = pd.MultiIndex.from_tuples([(0, 0), (1, 1), (2, 2), (3, 3)]) + with pytest.raises( + NotImplementedError, match="MultiIndex panel data is not currently supported" + ): + make_model(index) From 17c2b9350a006fba6ed0bf5b134e1b3a5f835c56 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Wed, 19 Mar 2025 22:21:21 +0800 Subject: [PATCH 3/3] Remove redundant check --- pymc_extras/statespace/utils/data_tools.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pymc_extras/statespace/utils/data_tools.py b/pymc_extras/statespace/utils/data_tools.py index b85d7e67..8c8284c5 100644 --- a/pymc_extras/statespace/utils/data_tools.py +++ b/pymc_extras/statespace/utils/data_tools.py @@ -114,9 +114,6 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals if not np.issubdtype(index.dtype, np.integer): raise IndexError("Provided index is not an integer index.") - if not index.is_monotonic_increasing: - raise IndexError("Provided index is not monotonic increasing.") - index_diff = index.to_series().diff().dropna().values if not (index_diff == 1).all(): raise IndexError("Provided index is not monotonic increasing.")