-
-
Notifications
You must be signed in to change notification settings - Fork 59
Statespace: Don't automatically save statespace matrices as Deterministic
variables
#302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
For the memory question, matrices that don't change could be defined as constant or mutabledata |
They're stored as TensorVariables in the statespace representation. This new way of doing this will directly use those once instead of copying them over and over into the idata. |
ConstantData and MutableData are also stored only once in the idata as opposed to Deterministics which are stored per draw. Seems like that's what you wanted? |
Something like that, but I don't want to have a bunch of logic to decide if a matrix is static or contains parameters. If the user really wants to inspect matrices, he can ask for them manually and save them however he wants. In general, I think the important outputs for most users is going to be the parameters and the states. The rest can be more hidden away. |
0c806d5
to
6ac1d0d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a tiny type-hint thing
@@ -291,7 +314,7 @@ def _unpack_statespace_with_placeholders(self) -> Tuple: | |||
|
|||
return a0, P0, c, d, T, Z, R, H, Q | |||
|
|||
def unpack_statespace(self) -> Tuple: | |||
def unpack_statespace(self) -> list[pt.TensorVariable, ...]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
List doesn't need ellipsis
def unpack_statespace(self) -> list[pt.TensorVariable, ...]: | |
def unpack_statespace(self) -> list[pt.TensorVariable]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pycharm actually complains when I make this change, it wants me to literally list out tuple[ TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable]
, which seems silly. list[TensorVariable, ...]
was my hack-around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignore pycharm
What's up with these jax test failures on the Ubuntu CI? |
#305 ? |
93265aa
to
97c3ab8
Compare
97c3ab8
to
eb44c83
Compare
I had originally done this to facilitate out-of-sample sampling tasks, so I could do e.g.
pm.Flat('T', T)
. The result was that all matrices were saved to theidata
, creating pretty horriblepm.model_to_graphviz
outputs like this:This also was extremely memory wasteful. Many of the matrices are not random at all, and they were being saved
(chain, draw)
times.After this refactor, the matrices are dynamically rebuilt as needed from the parameter samples. The new graphs look like this:
There are also some en-passant changes to how exogenous data are handled that breaks the Structural example notebook. I will open a new PR after this one to address that, because I think I finally have it set up to handle forecasting with exogenous data. Basically, I was previously treating exogenous data like a type of "parameter". It's been upgraded to a first-class object, and custom models subclassing
PyMCStateSpace
that use exogenous data will now need to implementdata_names
anddata_info
properties.