Skip to content

Statespace: Don't automatically save statespace matrices as Deterministic variables #302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 13, 2024

Conversation

jessegrabowski
Copy link
Member

I had originally done this to facilitate out-of-sample sampling tasks, so I could do e.g. pm.Flat('T', T). The result was that all matrices were saved to the idata, creating pretty horrible pm.model_to_graphviz outputs like this:

image

This also was extremely memory wasteful. Many of the matrices are not random at all, and they were being saved (chain, draw) times.

After this refactor, the matrices are dynamically rebuilt as needed from the parameter samples. The new graphs look like this:

image

There are also some en-passant changes to how exogenous data are handled that breaks the Structural example notebook. I will open a new PR after this one to address that, because I think I finally have it set up to handle forecasting with exogenous data. Basically, I was previously treating exogenous data like a type of "parameter". It's been upgraded to a first-class object, and custom models subclassing PyMCStateSpace that use exogenous data will now need to implement data_names and data_info properties.

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 5, 2024

For the memory question, matrices that don't change could be defined as constant or mutabledata

@jessegrabowski
Copy link
Member Author

They're stored as TensorVariables in the statespace representation. This new way of doing this will directly use those once instead of copying them over and over into the idata.

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 5, 2024

They're stored as TensorVariables in the statespace representation. This new way of doing this will directly use those once instead of copying them over and over into the idata.

ConstantData and MutableData are also stored only once in the idata as opposed to Deterministics which are stored per draw. Seems like that's what you wanted?

@jessegrabowski
Copy link
Member Author

Something like that, but I don't want to have a bunch of logic to decide if a matrix is static or contains parameters. If the user really wants to inspect matrices, he can ask for them manually and save them however he wants. In general, I think the important outputs for most users is going to be the parameters and the states. The rest can be more hidden away.

@jessegrabowski jessegrabowski force-pushed the clean-graph branch 2 times, most recently from 0c806d5 to 6ac1d0d Compare February 5, 2024 23:46
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a tiny type-hint thing

@@ -291,7 +314,7 @@ def _unpack_statespace_with_placeholders(self) -> Tuple:

return a0, P0, c, d, T, Z, R, H, Q

def unpack_statespace(self) -> Tuple:
def unpack_statespace(self) -> list[pt.TensorVariable, ...]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List doesn't need ellipsis

Suggested change
def unpack_statespace(self) -> list[pt.TensorVariable, ...]:
def unpack_statespace(self) -> list[pt.TensorVariable]:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pycharm actually complains when I make this change, it wants me to literally list out tuple[ TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable], which seems silly. list[TensorVariable, ...] was my hack-around.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore pycharm

@jessegrabowski
Copy link
Member Author

What's up with these jax test failures on the Ubuntu CI?

@ricardoV94
Copy link
Member

What's up with these jax test failures on the Ubuntu CI?

#305 ?

@ricardoV94 ricardoV94 merged commit 015ba1f into pymc-devs:main Feb 13, 2024
@jessegrabowski jessegrabowski deleted the clean-graph branch February 20, 2024 16:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants