-
-
Notifications
You must be signed in to change notification settings - Fork 61
Remove mode
argument passed from Statespace.build_statespace_graph to scan
#482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove mode
argument passed from Statespace.build_statespace_graph to scan
#482
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR removes the now-unnecessary mode argument from the scan calls and related bookkeeping in the statespace models. Key changes include:
- Updates in tests and models to remove mode argument passed to build_graph and related functions.
- Elimination of mode-based branching in several files, including Kalman filter and smoother implementations.
Reviewed Changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
tests/statespace/utilities/test_helpers.py | Removed mode parameter from initialize_filter and kfilter.build_graph calls. |
tests/statespace/test_statespace_JAX.py | Removed mode parameter from build_statespace_graph call and updated corresponding tests. |
tests/statespace/test_kalman_filter.py | Removed mode parameter from initialize_filter call. |
tests/statespace/test_coord_assignment.py | Removed mode parameter from build_statespace_graph call. |
pymc_extras/statespace/models/structural.py | Removed mode parameter from build_statespace_graph calls. |
pymc_extras/statespace/models/SARIMAX.py | Removed mode parameter and simplified the conditional logic in _stationary_initialization. |
pymc_extras/statespace/filters/* | Removed mode argument and related get_mode calls from kalman filter, smoother, and distributions. |
pymc_extras/statespace/core/statespace.py | Removed mode handling from graph building and sampling functions. |
Comments suppressed due to low confidence (1)
pymc_extras/statespace/models/SARIMAX.py:369
- The conditional logic for selecting the method for solve_discrete_lyapunov has been removed and now always uses 'bilinear'. Please verify that always using 'bilinear' is appropriate, especially for models with fewer states where the 'direct' method was previously used.
def _stationary_initialization(self):
Do we want a transition period with future warning to aid users? I feel people may actually be using the module these days |
Yes that's more professional. I was waffling on this point. I guess I can raise a warning if you pass mode to Alternatively the warning could say "this should now be set when you create the model", e.g. I think that's convenient, but if a user wants that he can also change the pytensor config flag. So I'm somewhat torn. |
Both together sound fine |
e2b5dd0
to
c062b19
Compare
mode
argument from Statespace modelsmode
argument passed from Statespace.build_statespace_graph to scan
@ricardoV94 review plz :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR removes the now-unnecessary mode
argument from internal build_graph
and scan
calls, deprecates passing mode
to build_statespace_graph
, and centralizes compile-mode handling via model constructors and sampling compile_kwargs
.
- Removed
mode
parameters from low-level filter/smoother/distribution APIs and tests - Added
mode
to model constructors andbuild()
methods; emit deprecation warning if passed tobuild_statespace_graph
- Updated sampling methods to default to model-level
mode
via new_set_default_mode
helper
Reviewed Changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 3 comments.
Show a summary per file
File | Description |
---|---|
tests/statespace/utils/test_coord_assignment.py | Drop obsolete mode="JAX" in test |
tests/statespace/test_utilities.py | Remove mode from initialize_filter and its call |
tests/statespace/models/test_structural.py | Pass mode into mod.build() and assert built_model.mode |
tests/statespace/models/test_VARMAX.py | Add (misnamed) test for mode in BayesianVARMAX ctor |
tests/statespace/models/test_SARIMAX.py | Add (misnamed) test for mode in BayesianSARIMA ctor |
tests/statespace/models/test_ETS.py | Add (misnamed) test for mode in BayesianETS ctor |
tests/statespace/filters/test_kalman_filter.py | Remove mode="JAX" from initialize_filter usage |
tests/statespace/core/test_statespace_JAX.py | Remove mode="JAX" calls and related assertions |
pymc_extras/statespace/models/structural.py | Add mode param to __init__ , build() , propagate to core |
pymc_extras/statespace/models/VARMAX.py | Add mode param to BayesianVARMAX ctor and docstrings |
pymc_extras/statespace/models/SARIMAX.py | Add mode param to BayesianSARIMA ctor and docstrings |
pymc_extras/statespace/models/ETS.py | Add mode param to BayesianETS ctor and docstrings |
pymc_extras/statespace/filters/kalman_smoother.py | Remove mode attribute and import; simplify build_graph |
pymc_extras/statespace/filters/kalman_filter.py | Remove mode import/usage from BaseFilter and scans |
pymc_extras/statespace/filters/distributions.py | Strip mode from distribution wrappers |
pymc_extras/statespace/core/statespace.py | Remove get_mode , deprecate mode arg in build_statespace_graph , add _set_default_mode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot is probably right about the test name
Add mode argument to statespace constructors
32ffde0
to
81b4b7b
Compare
Closes #478
We were previously passing the
mode
argument to scan everywhere in statespace. This isn't strictly necessary, and required a fair amount of bookkeeping. This PR removes the mode argument from all the scans in the repo, and also removes all that bookkeeping.This will slightly break the user-facing API, because now
build_statespace_graph(data, mode='JAX')
will now raise an error (because you don't need to pass anything).It will also now require users to explicitly pass compile_kwargs to all of the sampling functions (
sample_conditional_posterior
,forecast
,impulse_response_function
, etc). That's consistent with all PyMC apis, though.I still need to go through and adjust all the notebooks, but it looks like tests are passing.