Skip to content

Remove mode argument passed from Statespace.build_statespace_graph to scan #482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 28, 2025

Conversation

jessegrabowski
Copy link
Member

Closes #478

We were previously passing the mode argument to scan everywhere in statespace. This isn't strictly necessary, and required a fair amount of bookkeeping. This PR removes the mode argument from all the scans in the repo, and also removes all that bookkeeping.

This will slightly break the user-facing API, because now build_statespace_graph(data, mode='JAX') will now raise an error (because you don't need to pass anything).

It will also now require users to explicitly pass compile_kwargs to all of the sampling functions (sample_conditional_posterior, forecast, impulse_response_function, etc). That's consistent with all PyMC apis, though.

I still need to go through and adjust all the notebooks, but it looks like tests are passing.

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR removes the now-unnecessary mode argument from the scan calls and related bookkeeping in the statespace models. Key changes include:

  • Updates in tests and models to remove mode argument passed to build_graph and related functions.
  • Elimination of mode-based branching in several files, including Kalman filter and smoother implementations.

Reviewed Changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tests/statespace/utilities/test_helpers.py Removed mode parameter from initialize_filter and kfilter.build_graph calls.
tests/statespace/test_statespace_JAX.py Removed mode parameter from build_statespace_graph call and updated corresponding tests.
tests/statespace/test_kalman_filter.py Removed mode parameter from initialize_filter call.
tests/statespace/test_coord_assignment.py Removed mode parameter from build_statespace_graph call.
pymc_extras/statespace/models/structural.py Removed mode parameter from build_statespace_graph calls.
pymc_extras/statespace/models/SARIMAX.py Removed mode parameter and simplified the conditional logic in _stationary_initialization.
pymc_extras/statespace/filters/* Removed mode argument and related get_mode calls from kalman filter, smoother, and distributions.
pymc_extras/statespace/core/statespace.py Removed mode handling from graph building and sampling functions.
Comments suppressed due to low confidence (1)

pymc_extras/statespace/models/SARIMAX.py:369

  • The conditional logic for selecting the method for solve_discrete_lyapunov has been removed and now always uses 'bilinear'. Please verify that always using 'bilinear' is appropriate, especially for models with fewer states where the 'direct' method was previously used.
def _stationary_initialization(self):

@ricardoV94
Copy link
Member

ricardoV94 commented May 22, 2025

Do we want a transition period with future warning to aid users?

I feel people may actually be using the module these days

@jessegrabowski
Copy link
Member Author

Do we want a transition period with future warning to aid users?

I feel people may actually be using the module these days

Yes that's more professional. I was waffling on this point.

I guess I can raise a warning if you pass mode to build_statespace_graph and store it in the model, but not pass it to the scans. I could keep automatically setting it in the helper functions, but again with a warning that in the future we're not going to do this for you?

Alternatively the warning could say "this should now be set when you create the model", e.g. ss_mod = BayesianSARIMA(p=p, q=q, mode='JAX') would request that all sampling methods use compile_kwargs={'mode':'JAX'}, but it wouldn't do anything else.

I think that's convenient, but if a user wants that he can also change the pytensor config flag. So I'm somewhat torn.

@ricardoV94
Copy link
Member

I guess I can raise a warning if you pass mode to build_statespace_graph and store it in the model, but not pass it to the scans. I could keep automatically setting it in the helper functions, but again with a warning that in the future we're not going to do this for you?

Alternatively the warning could say ...

Both together sound fine

@jessegrabowski jessegrabowski changed the title Remove mode argument from Statespace models Remove mode argument passed from Statespace.build_statespace_graph to scan May 26, 2025
@jessegrabowski
Copy link
Member Author

@ricardoV94 review plz :)

@ricardoV94 ricardoV94 requested a review from Copilot May 27, 2025 13:03
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR removes the now-unnecessary mode argument from internal build_graph and scan calls, deprecates passing mode to build_statespace_graph, and centralizes compile-mode handling via model constructors and sampling compile_kwargs.

  • Removed mode parameters from low-level filter/smoother/distribution APIs and tests
  • Added mode to model constructors and build() methods; emit deprecation warning if passed to build_statespace_graph
  • Updated sampling methods to default to model-level mode via new _set_default_mode helper

Reviewed Changes

Copilot reviewed 16 out of 16 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tests/statespace/utils/test_coord_assignment.py Drop obsolete mode="JAX" in test
tests/statespace/test_utilities.py Remove mode from initialize_filter and its call
tests/statespace/models/test_structural.py Pass mode into mod.build() and assert built_model.mode
tests/statespace/models/test_VARMAX.py Add (misnamed) test for mode in BayesianVARMAX ctor
tests/statespace/models/test_SARIMAX.py Add (misnamed) test for mode in BayesianSARIMA ctor
tests/statespace/models/test_ETS.py Add (misnamed) test for mode in BayesianETS ctor
tests/statespace/filters/test_kalman_filter.py Remove mode="JAX" from initialize_filter usage
tests/statespace/core/test_statespace_JAX.py Remove mode="JAX" calls and related assertions
pymc_extras/statespace/models/structural.py Add mode param to __init__, build(), propagate to core
pymc_extras/statespace/models/VARMAX.py Add mode param to BayesianVARMAX ctor and docstrings
pymc_extras/statespace/models/SARIMAX.py Add mode param to BayesianSARIMA ctor and docstrings
pymc_extras/statespace/models/ETS.py Add mode param to BayesianETS ctor and docstrings
pymc_extras/statespace/filters/kalman_smoother.py Remove mode attribute and import; simplify build_graph
pymc_extras/statespace/filters/kalman_filter.py Remove mode import/usage from BaseFilter and scans
pymc_extras/statespace/filters/distributions.py Strip mode from distribution wrappers
pymc_extras/statespace/core/statespace.py Remove get_mode, deprecate mode arg in build_statespace_graph, add _set_default_mode

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot is probably right about the test name

@jessegrabowski jessegrabowski requested a review from ricardoV94 May 27, 2025 15:45
@jessegrabowski jessegrabowski merged commit eab41fa into pymc-devs:main May 28, 2025
17 checks passed
@jessegrabowski jessegrabowski deleted the remove-scan-mode branch May 28, 2025 08:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove mode from statespace module functions
2 participants