From 2c6ef1ffe00799e6438953c4b3488cc4496f8bd3 Mon Sep 17 00:00:00 2001 From: Alex Andorra Date: Fri, 4 Apr 2025 15:08:10 -0400 Subject: [PATCH 1/4] Set up working infrastructure for batched KF --- conda-envs/environment-test.yml | 3 + notebooks/batch-examples.ipynb | 388 ++++++++++++++++++ .../statespace/filters/distributions.py | 51 ++- 3 files changed, 422 insertions(+), 20 deletions(-) create mode 100644 notebooks/batch-examples.ipynb diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 450b46e3..8234b954 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -3,6 +3,7 @@ channels: - conda-forge - nodefaults dependencies: +- ipywidgets - pymc>=5.21 - pytest-cov>=2.5 - pytest>=3.0 @@ -10,8 +11,10 @@ dependencies: - xhistogram - statsmodels - numba<=0.60.0 +- nutpie - pip - pip: - blackjax - scikit-learn - better_optimize + - -e . diff --git a/notebooks/batch-examples.ipynb b/notebooks/batch-examples.ipynb new file mode 100644 index 00000000..6139992b --- /dev/null +++ b/notebooks/batch-examples.ipynb @@ -0,0 +1,388 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "0a5841d3", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pytensor\n", + "import pytensor.tensor as pt\n", + "from pymc_extras.statespace.filters import StandardFilter\n", + "from tests.statespace.utilities.test_helpers import make_test_inputs\n", + "from pytensor.graph.replace import vectorize_graph\n", + "from importlib import reload\n", + "import pymc_extras.statespace.filters.distributions as pmss_dist\n", + "from pymc_extras.statespace.filters.distributions import SequenceMvNormal\n", + "import pymc as pm" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "14299e50", + "metadata": {}, + "outputs": [], + "source": [ + "seed = sum(map(ord, \"batched-kf\"))\n", + "rng = np.random.default_rng(seed)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "71bc513e", + "metadata": {}, + "outputs": [], + "source": [ + "def create_batch_inputs(batch_size, p=1, m=5, r=1, n=10, rng=rng):\n", + " \"\"\"\n", + " Create batched inputs for testing.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch_size : int\n", + " Number of batches to create\n", + " p : int\n", + " First dimension parameter\n", + " m : int\n", + " Second dimension parameter\n", + " r : int\n", + " Third dimension parameter\n", + " n : int\n", + " Fourth dimension parameter\n", + " rng : numpy.random.Generator\n", + " Random number generator\n", + "\n", + " Returns\n", + " -------\n", + " list\n", + " List of stacked inputs for each batch\n", + " \"\"\"\n", + " # Create individual inputs for each batch\n", + " np_batch_inputs = []\n", + " for i in range(batch_size):\n", + " inputs = make_test_inputs(p, m, r, n, rng)\n", + " np_batch_inputs.append(inputs)\n", + "\n", + " return [np.stack(x, axis=0) for x in zip(*np_batch_inputs)]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0c1824cf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 10, 1)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create batch inputs with batch size 3\n", + "np_batch_inputs = create_batch_inputs(3)\n", + "np_batch_inputs[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "773d4cb4", + "metadata": {}, + "outputs": [], + "source": [ + "p, m, r, n = 1, 5, 1, 10\n", + "inputs = [pt.as_tensor(x).type() for x in make_test_inputs(p, m, r, n, rng)]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "511de29f", + "metadata": {}, + "outputs": [], + "source": [ + "kf = StandardFilter()\n", + "kf_outputs = kf.build_graph(*inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "33006d8e", + "metadata": {}, + "outputs": [], + "source": [ + "batched_inputs = [pt.tensor(shape=(None, *x.type.shape)) for x in inputs]\n", + "vec_subs = dict(zip(inputs, batched_inputs))\n", + "bacthed_kf_outputs = vectorize_graph(kf_outputs, vec_subs)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "987a4647", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[filtered_states,\n", + " predicted_states,\n", + " observed_states,\n", + " filtered_covariances,\n", + " predicted_covariances,\n", + " observed_covariances,\n", + " loglike_obs]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "kf_outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "4b8be0f9", + "metadata": {}, + "outputs": [], + "source": [ + "mu = bacthed_kf_outputs[1]\n", + "cov = bacthed_kf_outputs[4]\n", + "logp = bacthed_kf_outputs[-1]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1dc80f94", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(None, 10, 5)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mu.type.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "1262c7d4", + "metadata": {}, + "outputs": [], + "source": [ + "pmss_dist = reload(pmss_dist)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "2dcd3958", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mus_.type.shape: (None, 10, 5), covs_.type.shape: (None, 10, 5, 5)\n", + "mus.type.shape: (10, None, 5), covs.type.shape: (10, None, 5, 5)\n", + "mvn_seq.type.shape: (None, None, 5)\n", + "mvn_seq.type.shape: (None, 10, 5)\n", + "mvn_seq.type.shape: (None, 10, 5)\n", + "mvn_seq.type.shape: (None, 10, 5)\n", + "mus_.type.shape: (None, 10, 5), covs_.type.shape: (None, 10, 5, 5)\n", + "mus.type.shape: (10, None, 5), covs.type.shape: (10, None, 5, 5)\n", + "mvn_seq.type.shape: (None, None, 5)\n", + "mvn_seq.type.shape: (None, 10, 5)\n", + "mvn_seq.type.shape: (None, 10, 5)\n", + "mvn_seq.type.shape: (None, 10, 5)\n" + ] + } + ], + "source": [ + "mv_outputs = pmss_dist.SequenceMvNormal.dist(mus=mu, covs=cov, logp=logp)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "6f41344f", + "metadata": {}, + "outputs": [], + "source": [ + "np_batch_inputs = create_batch_inputs(3)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "44905b8a", + "metadata": {}, + "outputs": [], + "source": [ + "np_batch_inputs[0] = rng.normal(size=(3, 10, 1))" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "34fe01b8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 10, 5)" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "f_test = pytensor.function(batched_inputs, mv_outputs)\n", + "f_test(*np_batch_inputs).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "f37efe79", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(None, 10, 1) (None, 10, 5) (None, 10, 5, 5)\n" + ] + } + ], + "source": [ + "f_mv = pytensor.function(batched_inputs, pm.logp(mv_outputs, batched_inputs[0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "7b45de74", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 10)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "f_mv(*np_batch_inputs).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f14596aa", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "69519822", + "metadata": {}, + "outputs": [], + "source": [ + "f = pytensor.function(batched_inputs, bacthed_kf_outputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "3f745449", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "633 μs ± 18.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n", + "1.52 ms ± 35.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n", + "4.76 ms ± 259 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "for s in [1, 3, 10]:\n", + " np_batch_inputs = create_batch_inputs(s)\n", + " %timeit outputs = f(*np_batch_inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5fcadef", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c479ff22", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc-extras-test", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pymc_extras/statespace/filters/distributions.py b/pymc_extras/statespace/filters/distributions.py index 1e4f2b15..60b74c99 100644 --- a/pymc_extras/statespace/filters/distributions.py +++ b/pymc_extras/statespace/filters/distributions.py @@ -374,44 +374,55 @@ def dist(cls, mus, covs, logp, **kwargs): @classmethod def rv_op(cls, mus, covs, logp, size=None): # Batch dimensions (if any) will be on the far left, but scan requires time to be there instead - if mus.ndim > 2: - mus = pt.moveaxis(mus, -2, 0) - if covs.ndim > 3: - covs = pt.moveaxis(covs, -3, 0) - mus_, covs_ = mus.type(), covs.type() + print(f"mus_.type.shape: {mus_.type.shape}, covs_.type.shape: {covs_.type.shape}") logp_ = logp.type() rng = pytensor.shared(np.random.default_rng()) - def step(mu, cov, rng): - new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method="svd").owner.outputs - return mvn, {rng: new_rng} + def recursion(mus, covs, rng): + if mus.ndim > 2: + mus = pt.moveaxis(mus, -2, 0) + if covs.ndim > 3: + covs = pt.moveaxis(covs, -3, 0) + print(f"mus.type.shape: {mus.type.shape}, covs.type.shape: {covs.type.shape}") - mvn_seq, updates = pytensor.scan( - step, sequences=[mus_, covs_], non_sequences=[rng], strict=True, n_steps=mus_.shape[0] - ) - mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape) + def step(mu, cov, rng): + new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method="svd").owner.outputs + return mvn, {rng: new_rng} + + mvn_seq, updates = pytensor.scan( + step, sequences=[mus, covs], non_sequences=[rng], strict=True, n_steps=mus.shape[0] + ) + print(f"mvn_seq.type.shape: {mvn_seq.type.shape}") + mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape) + + # Move time axis back to position -2 so batches are on the left + if mvn_seq.ndim > 2: + mvn_seq = pt.moveaxis(mvn_seq, 0, -2) + print(f"mvn_seq.type.shape: {mvn_seq.type.shape}") + + (seq_mvn_rng,) = tuple(updates.values()) - # Move time axis back to position -2 so batches are on the left - if mvn_seq.ndim > 2: - mvn_seq = pt.moveaxis(mvn_seq, 0, -2) + print(f"mvn_seq.type.shape: {mvn_seq.type.shape}") - (seq_mvn_rng,) = tuple(updates.values()) + return [seq_mvn_rng, mvn_seq] mvn_seq_op = KalmanFilterRV( - inputs=[mus_, covs_, logp_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2 + inputs=[mus_, covs_, logp_, rng], outputs=recursion(mus_, covs_, rng), ndim_supp=2 ) mvn_seq = mvn_seq_op(mus, covs, logp, rng) + print(f"mvn_seq.type.shape: {mvn_seq.type.shape}") return mvn_seq @_logprob.register(KalmanFilterRV) def sequence_mvnormal_logp(op, values, mus, covs, logp, rng, **kwargs): + print(values[0].type.shape, mus.type.shape, covs.type.shape) return check_parameters( logp, - pt.eq(values[0].shape[0], mus.shape[0]), - pt.eq(covs.shape[0], mus.shape[0]), - msg="Observed data and parameters must have the same number of timesteps (dimension 0)", + pt.eq(values[0].shape[-2], mus.shape[-2]), + pt.eq(covs.shape[-3], mus.shape[-2]), + msg="Observed data and parameters must have the same number of timesteps", ) From 7f6845eace50412e0d257030cd5a96436d38c8ec Mon Sep 17 00:00:00 2001 From: Alex Andorra Date: Tue, 8 Apr 2025 16:11:35 -0400 Subject: [PATCH 2/4] Update conda env file --- conda-envs/environment-test.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 8234b954..16cfcc3b 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -1,20 +1,20 @@ name: pymc-extras-test channels: - conda-forge -- nodefaults dependencies: +- blackjax - ipywidgets -- pymc>=5.21 -- pytest-cov>=2.5 -- pytest>=3.0 +- ipython +- pymc +- pytest-cov +- pytest - dask - xhistogram - statsmodels -- numba<=0.60.0 +- numba - nutpie - pip +- scikit-learn - pip: - - blackjax - - scikit-learn - better_optimize - -e . From cc9f7509a46496f190ba170664f7635d4e837b71 Mon Sep 17 00:00:00 2001 From: Alex Andorra Date: Tue, 8 Apr 2025 16:12:21 -0400 Subject: [PATCH 3/4] Working with Filter, not with Smoother --- notebooks/batch-examples.ipynb | 28 ++++---- .../statespace/filters/kalman_filter.py | 29 +++++++++ .../statespace/filters/kalman_smoother.py | 39 ++++++++++- tests/statespace/test_kalman_filter.py | 65 ++++++++++--------- tests/statespace/utilities/test_helpers.py | 54 ++++++++++----- 5 files changed, 153 insertions(+), 62 deletions(-) diff --git a/notebooks/batch-examples.ipynb b/notebooks/batch-examples.ipynb index 6139992b..72b1ca92 100644 --- a/notebooks/batch-examples.ipynb +++ b/notebooks/batch-examples.ipynb @@ -189,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 11, "id": "1262c7d4", "metadata": {}, "outputs": [], @@ -199,7 +199,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 12, "id": "2dcd3958", "metadata": {}, "outputs": [ @@ -228,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 13, "id": "6f41344f", "metadata": {}, "outputs": [], @@ -238,7 +238,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 14, "id": "44905b8a", "metadata": {}, "outputs": [], @@ -248,7 +248,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 15, "id": "34fe01b8", "metadata": {}, "outputs": [ @@ -258,7 +258,7 @@ "(3, 10, 5)" ] }, - "execution_count": 24, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -270,7 +270,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 16, "id": "f37efe79", "metadata": {}, "outputs": [ @@ -288,7 +288,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 17, "id": "7b45de74", "metadata": {}, "outputs": [ @@ -298,7 +298,7 @@ "(3, 10)" ] }, - "execution_count": 26, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -317,7 +317,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 18, "id": "69519822", "metadata": {}, "outputs": [], @@ -327,7 +327,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 19, "id": "3f745449", "metadata": {}, "outputs": [ @@ -335,9 +335,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "633 μs ± 18.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n", - "1.52 ms ± 35.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n", - "4.76 ms ± 259 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "675 μs ± 22.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n", + "1.64 ms ± 37.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n", + "5.28 ms ± 424 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], diff --git a/pymc_extras/statespace/filters/kalman_filter.py b/pymc_extras/statespace/filters/kalman_filter.py index 0ca47b50..887130f4 100644 --- a/pymc_extras/statespace/filters/kalman_filter.py +++ b/pymc_extras/statespace/filters/kalman_filter.py @@ -10,6 +10,7 @@ from pytensor.raise_op import Assert from pytensor.tensor import TensorVariable from pytensor.tensor.slinalg import solve_triangular +from pytensor.graph.replace import vectorize_graph from pymc_extras.statespace.filters.utilities import ( quad_form_sym, @@ -20,6 +21,7 @@ MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64")) PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"] +CORE_NDIM = (2, 1, 2, 1, 1, 2, 2, 2, 2, 2) assert_time_varying_dim_correct = Assert( "The first dimension of a time varying matrix (the time dimension) must be " @@ -73,6 +75,23 @@ def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q): """ return data, a0, P0, c, d, T, Z, R, H, Q + def has_batched_input(self, data, a0, P0, c, d, T, Z, R, H, Q): + """ + Check if any of the inputs are batched. + """ + return any(x.ndim > CORE_NDIM[i] for i, x in enumerate([data, a0, P0, c, d, T, Z, R, H, Q])) + + def get_dummy_core_inputs(self, data, a0, P0, c, d, T, Z, R, H, Q): + """ + Get dummy inputs for the core parameters. + """ + out = [] + for x, core_ndim in zip([data, a0, P0, c, d, T, Z, R, H, Q], CORE_NDIM): + out.append( + pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:]) + ) + return out + @staticmethod def add_check_on_time_varying_shapes( data: TensorVariable, sequence_params: list[TensorVariable] @@ -202,6 +221,7 @@ def build_graph( self.mode = mode self.missing_fill_value = missing_fill_value self.cov_jitter = cov_jitter + is_batched = self.has_batched_input(data, a0, P0, c, d, T, Z, R, H, Q) [R_shape] = constant_fold([R.shape], raise_not_constant=False) [Z_shape] = constant_fold([Z.shape], raise_not_constant=False) @@ -209,6 +229,10 @@ def build_graph( self.n_states, self.n_shocks = R_shape[-2:] self.n_endog = Z_shape[-2] + if is_batched: + batched_inputs = [data, a0, P0, c, d, T, Z, R, H, Q] + data, a0, P0, c, d, T, Z, R, H, Q = self.get_dummy_core_inputs(*batched_inputs) + data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q) sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq( @@ -233,8 +257,13 @@ def build_graph( filter_results = self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0]) + if is_batched: + vec_subs = dict(zip([data, a0, P0, c, d, T, Z, R, H, Q], batched_inputs)) + filter_results = vectorize_graph(filter_results, vec_subs) + if return_updates: return filter_results, updates + return filter_results def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]: diff --git a/pymc_extras/statespace/filters/kalman_smoother.py b/pymc_extras/statespace/filters/kalman_smoother.py index f15913b8..671d9366 100644 --- a/pymc_extras/statespace/filters/kalman_smoother.py +++ b/pymc_extras/statespace/filters/kalman_smoother.py @@ -3,7 +3,7 @@ from pytensor.compile import get_mode from pytensor.tensor.nlinalg import matrix_dot - +from pytensor.graph.replace import vectorize_graph from pymc_extras.statespace.filters.utilities import ( quad_form_sym, split_vars_into_seq_and_nonseq, @@ -11,6 +11,8 @@ ) from pymc_extras.statespace.utils.constants import JITTER_DEFAULT +SMOOTHER_CORE_NDIM = (2, 2, 2, 2, 3) + class KalmanSmoother: """ @@ -63,12 +65,41 @@ def unpack_args(self, args): return a, P, a_smooth, P_smooth, T, R, Q + def has_batched_input(self, T, R, Q, filtered_states, filtered_covariances): + """ + Check if any of the inputs are batched. + """ + return any( + x.ndim > SMOOTHER_CORE_NDIM[i] + for i, x in enumerate([T, R, Q, filtered_states, filtered_covariances]) + ) + + def get_dummy_core_inputs(self, T, R, Q, filtered_states, filtered_covariances): + """ + Get dummy inputs for the core parameters. + """ + out = [] + for x, core_ndim in zip( + [T, R, Q, filtered_states, filtered_covariances], SMOOTHER_CORE_NDIM + ): + out.append( + pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:]) + ) + return out + def build_graph( self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT ): self.mode = mode self.cov_jitter = cov_jitter + is_batched = self.has_batched_input(T, R, Q, filtered_states, filtered_covariances) + if is_batched: + batched_inputs = [T, R, Q, filtered_states, filtered_covariances] + T, R, Q, filtered_states, filtered_covariances = self.get_dummy_core_inputs( + *batched_inputs + ) + n, k = filtered_states.type.shape a_last = pt.specify_shape(filtered_states[-1], (k,)) @@ -98,6 +129,12 @@ def build_graph( smoothed_covariances = pt.concatenate( [smoothed_covariances[::-1], pt.expand_dims(P_last, axis=(0,))], axis=0 ) + smoothed_states.dprint() + if is_batched: + vec_subs = dict(zip([T, R, Q, filtered_states, filtered_covariances], batched_inputs)) + smoothed_states, smoothed_covariances = vectorize_graph( + [smoothed_states, smoothed_covariances], vec_subs + ) smoothed_states.name = "smoothed_states" smoothed_covariances.name = "smoothed_covariances" diff --git a/tests/statespace/test_kalman_filter.py b/tests/statespace/test_kalman_filter.py index 6c0bc18c..3cdfa569 100644 --- a/tests/statespace/test_kalman_filter.py +++ b/tests/statespace/test_kalman_filter.py @@ -31,19 +31,22 @@ RTOL = 1e-6 if floatX.endswith("64") else 1e-3 standard_inout = initialize_filter(StandardFilter()) +standard_inout_batched = initialize_filter(StandardFilter(), batched=True) cholesky_inout = initialize_filter(SquareRootFilter()) univariate_inout = initialize_filter(UnivariateFilter()) f_standard = pytensor.function(*standard_inout, on_unused_input="ignore") +f_standard_batched = pytensor.function(*standard_inout_batched, on_unused_input="ignore") f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore") f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore") -filter_funcs = [f_standard, f_cholesky, f_univariate] +filter_funcs = [f_standard, f_standard_batched] # , f_cholesky, f_univariate] filter_names = [ "StandardFilter", - "CholeskyFilter", - "UnivariateFilter", + "StandardFilterBatched", + # "CholeskyFilter", + # "UnivariateFilter", ] output_names = [ @@ -65,17 +68,21 @@ def test_base_class_update_raises(): filter.update(*inputs) -@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) -def test_output_shapes_one_state_one_observed(filter_func, rng): +@pytest.mark.parametrize( + "filter_func, filter_name", zip(filter_funcs, filter_names), ids=filter_names +) +def test_output_shapes_one_state_one_observed(filter_func, filter_name, rng): + batch_size = 3 if "batched" in filter_name.lower() else 0 p, m, r, n = 1, 1, 1, 10 - inputs = make_test_inputs(p, m, r, n, rng) - outputs = filter_func(*inputs) + inputs = make_test_inputs(p, m, r, n, rng, batch_size=batch_size) + assert 0 + # outputs = filter_func(*inputs) for output_idx, name in enumerate(output_names): - expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + expected_shape = get_expected_shape(name, p, m, r, n, batch_size) + # assert outputs[output_idx].shape == expected_shape, ( + # f"Shape of {name} does not match expected" + # ) @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) @@ -86,9 +93,9 @@ def test_output_shapes_when_all_states_are_stochastic(filter_func, rng): outputs = filter_func(*inputs) for output_idx, name in enumerate(output_names): expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + assert outputs[output_idx].shape == expected_output, ( + f"Shape of {name} does not match expected" + ) @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) @@ -99,9 +106,9 @@ def test_output_shapes_when_some_states_are_deterministic(filter_func, rng): outputs = filter_func(*inputs) for output_idx, name in enumerate(output_names): expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + assert outputs[output_idx].shape == expected_output, ( + f"Shape of {name} does not match expected" + ) @pytest.fixture @@ -161,9 +168,9 @@ def test_output_shapes_with_time_varying_matrices(f_standard_nd, rng): for output_idx, name in enumerate(output_names): expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + assert outputs[output_idx].shape == expected_output, ( + f"Shape of {name} does not match expected" + ) @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) @@ -175,9 +182,9 @@ def test_output_with_deterministic_observation_equation(filter_func, rng): for output_idx, name in enumerate(output_names): expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + assert outputs[output_idx].shape == expected_output, ( + f"Shape of {name} does not match expected" + ) @pytest.mark.parametrize( @@ -190,9 +197,9 @@ def test_output_with_multiple_observed(filter_func, filter_name, rng): outputs = filter_func(*inputs) for output_idx, name in enumerate(output_names): expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + assert outputs[output_idx].shape == expected_output, ( + f"Shape of {name} does not match expected" + ) @pytest.mark.parametrize( @@ -206,9 +213,9 @@ def test_missing_data(filter_func, filter_name, p, rng): outputs = filter_func(*inputs) for output_idx, name in enumerate(output_names): expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + assert outputs[output_idx].shape == expected_output, ( + f"Shape of {name} does not match expected" + ) @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) diff --git a/tests/statespace/utilities/test_helpers.py b/tests/statespace/utilities/test_helpers.py index c6170f88..fa970f14 100644 --- a/tests/statespace/utilities/test_helpers.py +++ b/tests/statespace/utilities/test_helpers.py @@ -34,18 +34,18 @@ def load_nile_test_data(): return nile -def initialize_filter(kfilter, mode=None, p=None, m=None, r=None, n=None): +def initialize_filter(kfilter, mode=None, p=None, m=None, r=None, n=None, batched=False): ksmoother = KalmanSmoother() - data = pt.tensor(name="data", dtype=floatX, shape=(n, p)) - a0 = pt.tensor(name="x0", dtype=floatX, shape=(m,)) - P0 = pt.tensor(name="P0", dtype=floatX, shape=(m, m)) - c = pt.tensor(name="c", dtype=floatX, shape=(m,)) - d = pt.tensor(name="d", dtype=floatX, shape=(p,)) - Q = pt.tensor(name="Q", dtype=floatX, shape=(r, r)) - H = pt.tensor(name="H", dtype=floatX, shape=(p, p)) - T = pt.tensor(name="T", dtype=floatX, shape=(m, m)) - R = pt.tensor(name="R", dtype=floatX, shape=(m, r)) - Z = pt.tensor(name="Z", dtype=floatX, shape=(p, m)) + data = pt.tensor(name="data", dtype=floatX, shape=(None, n, p) if batched else (n, p)) + a0 = pt.tensor(name="x0", dtype=floatX, shape=(None, m) if batched else (m,)) + P0 = pt.tensor(name="P0", dtype=floatX, shape=(None, m, m) if batched else (m, m)) + c = pt.tensor(name="c", dtype=floatX, shape=(None, m) if batched else (m,)) + d = pt.tensor(name="d", dtype=floatX, shape=(None, p) if batched else (p,)) + Q = pt.tensor(name="Q", dtype=floatX, shape=(None, r, r) if batched else (r, r)) + H = pt.tensor(name="H", dtype=floatX, shape=(None, p, p) if batched else (p, p)) + T = pt.tensor(name="T", dtype=floatX, shape=(None, m, m) if batched else (m, m)) + R = pt.tensor(name="R", dtype=floatX, shape=(None, m, r) if batched else (m, r)) + Z = pt.tensor(name="Z", dtype=floatX, shape=(None, p, m) if batched else (p, m)) inputs = [data, a0, P0, c, d, T, Z, R, H, Q] @@ -68,7 +68,7 @@ def initialize_filter(kfilter, mode=None, p=None, m=None, r=None, n=None): filtered_covs, predicted_covs, smoothed_covs, - ll_obs.sum(), + ll_obs.sum(axis=-1), ll_obs, ] @@ -83,7 +83,7 @@ def add_missing_data(data, n_missing, rng): return data -def make_test_inputs(p, m, r, n, rng, missing_data=None, H_is_zero=False): +def make_1d_test_inputs(p, m, r, n, rng, missing_data=None, H_is_zero=False): data = np.arange(n * p, dtype=floatX).reshape(-1, p) if missing_data is not None: data = add_missing_data(data, missing_data, rng) @@ -106,16 +106,34 @@ def make_test_inputs(p, m, r, n, rng, missing_data=None, H_is_zero=False): return data, a0, P0, c, d, T, Z, R, H, Q -def get_expected_shape(name, p, m, r, n): +def make_test_inputs(p, m, r, n, rng, missing_data=None, H_is_zero=False, batch_size=0): + if batch_size == 0: + return make_1d_test_inputs(p, m, r, n, rng, missing_data, H_is_zero) + + # Create individual inputs for each batch + np_batch_inputs = [] + for i in range(batch_size): + inputs = make_1d_test_inputs(p, m, r, n, rng, missing_data, H_is_zero) + np_batch_inputs.append(inputs) + + return [np.stack(x, axis=0) for x in zip(*np_batch_inputs)] + + +def get_expected_shape(name, p, m, r, n, batch_size=0): if name == "log_likelihood": - return () + shape = () elif name == "ll_obs": - return (n,) + shape = (n,) filter_type, variable = name.split("_") if variable == "states": - return n, m + shape = n, m if variable == "covs": - return n, m, m + shape = n, m, m + + if batch_size != 0: + shape = (batch_size, *shape) + + return shape def get_sm_state_from_output_name(res, name): From c333b69f763bcffd45397d04059aaaf0031190ee Mon Sep 17 00:00:00 2001 From: Alex Andorra Date: Thu, 24 Apr 2025 18:29:00 -0400 Subject: [PATCH 4/4] Working batched Kalman filter and smoother --- notebooks/batch-examples.ipynb | 856 +++++++++++++++++- .../statespace/filters/kalman_filter.py | 117 ++- .../statespace/filters/kalman_smoother.py | 102 ++- pymc_extras/statespace/filters/utilities.py | 46 +- pymc_extras/statespace/utils/constants.py | 10 + tests/statespace/test_kalman_filter.py | 48 +- 6 files changed, 1100 insertions(+), 79 deletions(-) diff --git a/notebooks/batch-examples.ipynb b/notebooks/batch-examples.ipynb index 72b1ca92..512d09da 100644 --- a/notebooks/batch-examples.ipynb +++ b/notebooks/batch-examples.ipynb @@ -349,18 +349,868 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, + "id": "42366399", + "metadata": {}, + "outputs": [], + "source": [ + "from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "id": "d5fcadef", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "def build_fk(data, a0, P0, c, d, T, Z, R, H, Q):\n", + " kf = StandardFilter()\n", + " kf_outputs = kf.build_graph(data, a0, P0, c, d, T, Z, R, H, Q)\n", + "\n", + " ks = KalmanSmoother()\n", + " ks_outputs = ks.build_graph(T, R, Q, kf_outputs[0], kf_outputs[3])\n", + "\n", + " return (*kf_outputs, *ks_outputs)" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "id": "c479ff22", "metadata": {}, "outputs": [], + "source": [ + "signature = \"(t, o), (s), (s, s), (s), (o), (s, s), (o, s), (s, p), (o, o), (p, p) -> (t, s), (t, s), (t, o), (t, s, s), (t, s, s), (t, o, o), (t), (t, s), (t, s, s)\"" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "d3a403e9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Join [id A]\n", + " ├─ 0 [id B]\n", + " ├─ Subtensor{::step} [id C]\n", + " │ ├─ Subtensor{start:} [id D]\n", + " │ │ ├─ Scan{kalman_smoother, while_loop=False, inplace=none}.0 [id E]\n", + " │ │ │ ├─ Minimum [id F]\n", + " │ │ │ │ ├─ Subtensor{i} [id G]\n", + " │ │ │ │ │ ├─ Shape [id H]\n", + " │ │ │ │ │ │ └─ Subtensor{::step} [id I]\n", + " │ │ │ │ │ │ ├─ Subtensor{start:} [id J]\n", + " │ │ │ │ │ │ │ ├─ Subtensor{:stop} [id K]\n", + " │ │ │ │ │ │ │ │ ├─ SpecifyShape [id L] 'filtered_states'\n", + " │ │ │ │ │ │ │ │ │ ├─ Scan{forward_kalman_pass, while_loop=False, inplace=none}.2 [id M]\n", + " │ │ │ │ │ │ │ │ │ │ ├─ Subtensor{i} [id N]\n", + " │ │ │ │ │ │ │ │ │ │ │ ├─ Shape [id O]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ └─ Subtensor{start:} [id P]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ ├─ [id Q]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ └─ 0 [id R]\n", + " │ │ │ │ │ │ │ │ │ │ │ └─ 0 [id S]\n", + " │ │ │ │ │ │ │ │ │ │ ├─ Subtensor{:stop} [id T]\n", + " │ │ │ │ │ │ │ │ │ │ │ ├─ Subtensor{start:} [id P]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ │ └─ ScalarFromTensor [id U]\n", + " │ │ │ │ │ │ │ │ │ │ │ └─ Subtensor{i} [id N]\n", + " │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ ├─ SetSubtensor{:stop} [id V]\n", + " │ │ │ │ │ │ │ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id W]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ ├─ Add [id X]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ ├─ Subtensor{i} [id N]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ └─ Subtensor{i} [id Y]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ ├─ Shape [id Z]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BA]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ │ └─ [id BB]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ └─ 0 [id BC]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ └─ Subtensor{i} [id BD]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ ├─ Shape [id Z]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ │ │ └─ 1 [id BE]\n", + " │ │ │ │ │ │ │ │ │ │ │ ├─ ExpandDims{axis=0} [id BA]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ │ └─ ScalarFromTensor [id BF]\n", + " │ │ │ │ │ │ │ │ │ │ │ └─ Subtensor{i} [id Y]\n", + " │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ ├─ SetSubtensor{:stop} [id BG]\n", + " │ │ │ │ │ │ │ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ ├─ Add [id BI]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ ├─ Subtensor{i} [id N]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ └─ Subtensor{i} [id BJ]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ ├─ Shape [id BK]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BL]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ │ └─ [id BM]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ └─ 0 [id BN]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ ├─ Subtensor{i} [id BO]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ ├─ Shape [id BK]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ └─ 1 [id BP]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ └─ Subtensor{i} [id BQ]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ ├─ Shape [id BK]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ │ │ └─ 2 [id BR]\n", + " │ │ │ │ │ │ │ │ │ │ │ ├─ ExpandDims{axis=0} [id BL]\n", + " │ │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ │ └─ ScalarFromTensor [id BS]\n", + " │ │ │ │ │ │ │ │ │ │ │ └─ Subtensor{i} [id BJ]\n", + " │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ ├─ Subtensor{i} [id N]\n", + " │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ ├─ Subtensor{i} [id N]\n", + " │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ ├─ Subtensor{i} [id N]\n", + " │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ ├─ Subtensor{i} [id N]\n", + " │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ ├─ Subtensor{i} [id N]\n", + " │ │ │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ │ │ ├─ [id BT]\n", + " │ │ │ │ │ │ │ │ │ │ ├─ [id BU]\n", + " │ │ │ │ │ │ │ │ │ │ ├─ [id BV]\n", + " │ │ │ │ │ │ │ │ │ │ ├─ [id BW]\n", + " │ │ │ │ │ │ │ │ │ │ ├─ [id BX]\n", + " │ │ │ │ │ │ │ │ │ │ ├─ [id BY]\n", + " │ │ │ │ │ │ │ │ │ │ └─ [id BZ]\n", + " │ │ │ │ │ │ │ │ │ ├─ 10 [id CA]\n", + " │ │ │ │ │ │ │ │ │ └─ 5 [id CB]\n", + " │ │ │ │ │ │ │ │ └─ -1 [id CC]\n", + " │ │ │ │ │ │ │ └─ 0 [id CD]\n", + " │ │ │ │ │ │ └─ -1 [id CE]\n", + " │ │ │ │ │ └─ 0 [id CF]\n", + " │ │ │ │ └─ Subtensor{i} [id CG]\n", + " │ │ │ │ ├─ Shape [id CH]\n", + " │ │ │ │ │ └─ Subtensor{::step} [id CI]\n", + " │ │ │ │ │ ├─ Subtensor{start:} [id CJ]\n", + " │ │ │ │ │ │ ├─ Subtensor{:stop} [id CK]\n", + " │ │ │ │ │ │ │ ├─ SpecifyShape [id CL] 'filtered_covariances'\n", + " │ │ │ │ │ │ │ │ ├─ Scan{forward_kalman_pass, while_loop=False, inplace=none}.4 [id M]\n", + " │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ ├─ 10 [id CM]\n", + " │ │ │ │ │ │ │ │ ├─ 5 [id CN]\n", + " │ │ │ │ │ │ │ │ └─ 5 [id CO]\n", + " │ │ │ │ │ │ │ └─ -1 [id CP]\n", + " │ │ │ │ │ │ └─ 0 [id CQ]\n", + " │ │ │ │ │ └─ -1 [id CR]\n", + " │ │ │ │ └─ 0 [id CS]\n", + " │ │ │ ├─ Subtensor{:stop} [id CT]\n", + " │ │ │ │ ├─ Subtensor{::step} [id I]\n", + " │ │ │ │ │ └─ ···\n", + " │ │ │ │ └─ ScalarFromTensor [id CU]\n", + " │ │ │ │ └─ Minimum [id F]\n", + " │ │ │ │ └─ ···\n", + " │ │ │ ├─ Subtensor{:stop} [id CV]\n", + " │ │ │ │ ├─ Subtensor{::step} [id CI]\n", + " │ │ │ │ │ └─ ···\n", + " │ │ │ │ └─ ScalarFromTensor [id CW]\n", + " │ │ │ │ └─ Minimum [id F]\n", + " │ │ │ │ └─ ···\n", + " │ │ │ ├─ SetSubtensor{:stop} [id CX]\n", + " │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id CY]\n", + " │ │ │ │ │ ├─ Add [id CZ]\n", + " │ │ │ │ │ │ ├─ Minimum [id F]\n", + " │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ └─ Subtensor{i} [id DA]\n", + " │ │ │ │ │ │ ├─ Shape [id DB]\n", + " │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id DC]\n", + " │ │ │ │ │ │ │ └─ Subtensor{i} [id DD]\n", + " │ │ │ │ │ │ │ ├─ SpecifyShape [id L] 'filtered_states'\n", + " │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ └─ -1 [id DE]\n", + " │ │ │ │ │ │ └─ 0 [id DF]\n", + " │ │ │ │ │ └─ Subtensor{i} [id DG]\n", + " │ │ │ │ │ ├─ Shape [id DB]\n", + " │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ └─ 1 [id DH]\n", + " │ │ │ │ ├─ ExpandDims{axis=0} [id DC]\n", + " │ │ │ │ │ └─ ···\n", + " │ │ │ │ └─ ScalarFromTensor [id DI]\n", + " │ │ │ │ └─ Subtensor{i} [id DA]\n", + " │ │ │ │ └─ ···\n", + " │ │ │ ├─ SetSubtensor{:stop} [id DJ]\n", + " │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id DK]\n", + " │ │ │ │ │ ├─ Add [id DL]\n", + " │ │ │ │ │ │ ├─ Minimum [id F]\n", + " │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ └─ Subtensor{i} [id DM]\n", + " │ │ │ │ │ │ ├─ Shape [id DN]\n", + " │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id DO]\n", + " │ │ │ │ │ │ │ └─ Subtensor{i} [id DP]\n", + " │ │ │ │ │ │ │ ├─ SpecifyShape [id CL] 'filtered_covariances'\n", + " │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ └─ -1 [id DQ]\n", + " │ │ │ │ │ │ └─ 0 [id DR]\n", + " │ │ │ │ │ ├─ Subtensor{i} [id DS]\n", + " │ │ │ │ │ │ ├─ Shape [id DN]\n", + " │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ └─ 1 [id DT]\n", + " │ │ │ │ │ └─ Subtensor{i} [id DU]\n", + " │ │ │ │ │ ├─ Shape [id DN]\n", + " │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ └─ 2 [id DV]\n", + " │ │ │ │ ├─ ExpandDims{axis=0} [id DO]\n", + " │ │ │ │ │ └─ ···\n", + " │ │ │ │ └─ ScalarFromTensor [id DW]\n", + " │ │ │ │ └─ Subtensor{i} [id DM]\n", + " │ │ │ │ └─ ···\n", + " │ │ │ ├─ [id BV]\n", + " │ │ │ ├─ [id BX]\n", + " │ │ │ └─ [id BZ]\n", + " │ │ └─ 1 [id DX]\n", + " │ └─ -1 [id DY]\n", + " └─ ExpandDims{axis=0} [id DZ]\n", + " └─ Subtensor{i} [id DD]\n", + " └─ ···\n", + "\n", + "Inner graphs:\n", + "\n", + "Scan{kalman_smoother, while_loop=False, inplace=none} [id E]\n", + " ← Add [id EA]\n", + " ├─ *0- [id EB] -> [id CT]\n", + " └─ Squeeze{axis=1} [id EC]\n", + " └─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id ED]\n", + " ├─ Transpose{axes=[1, 0]} [id EE]\n", + " │ └─ dot [id EF]\n", + " │ ├─ dot [id EG]\n", + " │ │ ├─ Blockwise{MatrixPinv{hermitian=False}, (m,n)->(n,m)} [id EH]\n", + " │ │ │ └─ Add [id EI]\n", + " │ │ │ ├─ Add [id EJ]\n", + " │ │ │ │ ├─ Mul [id EK]\n", + " │ │ │ │ │ ├─ ExpandDims{axes=[0, 1]} [id EL]\n", + " │ │ │ │ │ │ └─ 0.5 [id EM]\n", + " │ │ │ │ │ └─ Add [id EN]\n", + " │ │ │ │ │ ├─ dot [id EO]\n", + " │ │ │ │ │ │ ├─ dot [id EP]\n", + " │ │ │ │ │ │ │ ├─ *4- [id EQ] -> [id BV]\n", + " │ │ │ │ │ │ │ └─ *1- [id ER] -> [id CV]\n", + " │ │ │ │ │ │ └─ Transpose{axes=[1, 0]} [id ES]\n", + " │ │ │ │ │ │ └─ *4- [id EQ] -> [id BV]\n", + " │ │ │ │ │ └─ Transpose{axes=[1, 0]} [id ET]\n", + " │ │ │ │ │ └─ dot [id EO]\n", + " │ │ │ │ │ └─ ···\n", + " │ │ │ │ └─ Mul [id EU]\n", + " │ │ │ │ ├─ ExpandDims{axes=[0, 1]} [id EV]\n", + " │ │ │ │ │ └─ 0.5 [id EW]\n", + " │ │ │ │ └─ Add [id EX]\n", + " │ │ │ │ ├─ dot [id EY]\n", + " │ │ │ │ │ ├─ dot [id EZ]\n", + " │ │ │ │ │ │ ├─ *5- [id FA] -> [id BX]\n", + " │ │ │ │ │ │ └─ *6- [id FB] -> [id BZ]\n", + " │ │ │ │ │ └─ Transpose{axes=[1, 0]} [id FC]\n", + " │ │ │ │ │ └─ *5- [id FA] -> [id BX]\n", + " │ │ │ │ └─ Transpose{axes=[1, 0]} [id FD]\n", + " │ │ │ │ └─ dot [id EY]\n", + " │ │ │ │ └─ ···\n", + " │ │ │ └─ Mul [id FE]\n", + " │ │ │ ├─ Eye{dtype='float64'} [id FF]\n", + " │ │ │ │ ├─ Subtensor{i} [id FG]\n", + " │ │ │ │ │ ├─ Shape [id FH]\n", + " │ │ │ │ │ │ └─ Add [id EJ]\n", + " │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ └─ 0 [id FI]\n", + " │ │ │ │ ├─ Subtensor{i} [id FJ]\n", + " │ │ │ │ │ ├─ Shape [id FK]\n", + " │ │ │ │ │ │ └─ Add [id EJ]\n", + " │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ └─ 1 [id FL]\n", + " │ │ │ │ └─ 0 [id FM]\n", + " │ │ │ └─ ExpandDims{axes=[0, 1]} [id FN]\n", + " │ │ │ └─ 1e-08 [id FO]\n", + " │ │ └─ *4- [id EQ] -> [id BV]\n", + " │ └─ *1- [id ER] -> [id CV]\n", + " └─ ExpandDims{axis=1} [id FP]\n", + " └─ Sub [id FQ]\n", + " ├─ *2- [id FR] -> [id CX]\n", + " └─ dot [id FS]\n", + " ├─ *4- [id EQ] -> [id BV]\n", + " └─ *0- [id EB] -> [id CT]\n", + " ← Add [id FT]\n", + " ├─ Add [id FU]\n", + " │ ├─ Add [id FV]\n", + " │ │ ├─ *1- [id ER] -> [id CV]\n", + " │ │ └─ Mul [id FW]\n", + " │ │ ├─ ExpandDims{axes=[0, 1]} [id FX]\n", + " │ │ │ └─ 0.5 [id FY]\n", + " │ │ └─ Add [id FZ]\n", + " │ │ ├─ dot [id GA]\n", + " │ │ │ ├─ dot [id GB]\n", + " │ │ │ │ ├─ Transpose{axes=[1, 0]} [id EE]\n", + " │ │ │ │ │ └─ ···\n", + " │ │ │ │ └─ Sub [id GC]\n", + " │ │ │ │ ├─ *3- [id GD] -> [id DJ]\n", + " │ │ │ │ └─ Add [id EI]\n", + " │ │ │ │ └─ ···\n", + " │ │ │ └─ Transpose{axes=[1, 0]} [id GE]\n", + " │ │ │ └─ Transpose{axes=[1, 0]} [id EE]\n", + " │ │ │ └─ ···\n", + " │ │ └─ Transpose{axes=[1, 0]} [id GF]\n", + " │ │ └─ dot [id GA]\n", + " │ │ └─ ···\n", + " │ └─ Mul [id GG]\n", + " │ ├─ Eye{dtype='float64'} [id GH]\n", + " │ │ ├─ Subtensor{i} [id GI]\n", + " │ │ │ ├─ Shape [id GJ]\n", + " │ │ │ │ └─ Add [id FV]\n", + " │ │ │ │ └─ ···\n", + " │ │ │ └─ 0 [id GK]\n", + " │ │ ├─ Subtensor{i} [id GL]\n", + " │ │ │ ├─ Shape [id GM]\n", + " │ │ │ │ └─ Add [id FV]\n", + " │ │ │ │ └─ ···\n", + " │ │ │ └─ 1 [id GN]\n", + " │ │ └─ 0 [id GO]\n", + " │ └─ ExpandDims{axes=[0, 1]} [id GP]\n", + " │ └─ 1e-08 [id GQ]\n", + " └─ Mul [id GR]\n", + " ├─ Eye{dtype='float64'} [id GS]\n", + " │ ├─ Subtensor{i} [id GT]\n", + " │ │ ├─ Shape [id GU]\n", + " │ │ │ └─ Add [id FU]\n", + " │ │ │ └─ ···\n", + " │ │ └─ 0 [id GV]\n", + " │ ├─ Subtensor{i} [id GW]\n", + " │ │ ├─ Shape [id GX]\n", + " │ │ │ └─ Add [id FU]\n", + " │ │ │ └─ ···\n", + " │ │ └─ 1 [id GY]\n", + " │ └─ 0 [id GZ]\n", + " └─ ExpandDims{axes=[0, 1]} [id HA]\n", + " └─ 1e-08 [id HB]\n", + "\n", + "Scan{forward_kalman_pass, while_loop=False, inplace=none} [id M]\n", + " ← Add [id HC]\n", + " ├─ dot [id HD]\n", + " │ ├─ *5- [id HE] -> [id BV]\n", + " │ └─ Add [id HF]\n", + " │ ├─ *1- [id HG] -> [id V]\n", + " │ └─ dot [id HH]\n", + " │ ├─ Transpose{axes=[1, 0]} [id HI]\n", + " │ │ └─ Blockwise{Solve{assume_a='pos', lower=False, check_finite=False, b_ndim=2, overwrite_a=False, overwrite_b=False}, (m,m),(m,n)->(m,n)} [id HJ]\n", + " │ │ ├─ Transpose{axes=[1, 0]} [id HK]\n", + " │ │ │ └─ Add [id HL]\n", + " │ │ │ ├─ dot [id HM]\n", + " │ │ │ │ ├─ dot [id HN]\n", + " │ │ │ │ │ ├─ AllocDiag{self.axis1=0, self.axis2=1, self.offset=0} [id HO]\n", + " │ │ │ │ │ │ └─ Cast{float64} [id HP]\n", + " │ │ │ │ │ │ └─ Invert [id HQ]\n", + " │ │ │ │ │ │ └─ Or [id HR]\n", + " │ │ │ │ │ │ ├─ Isnan [id HS]\n", + " │ │ │ │ │ │ │ └─ *0- [id HT] -> [id T]\n", + " │ │ │ │ │ │ └─ Eq [id HU]\n", + " │ │ │ │ │ │ ├─ *0- [id HT] -> [id T]\n", + " │ │ │ │ │ │ └─ ExpandDims{axis=0} [id HV]\n", + " │ │ │ │ │ │ └─ -9999.0 [id HW]\n", + " │ │ │ │ │ └─ *6- [id HX] -> [id BW]\n", + " │ │ │ │ └─ dot [id HY]\n", + " │ │ │ │ ├─ *2- [id HZ] -> [id BG]\n", + " │ │ │ │ └─ Transpose{axes=[1, 0]} [id IA]\n", + " │ │ │ │ └─ dot [id HN]\n", + " │ │ │ │ └─ ···\n", + " │ │ │ └─ Add [id IB]\n", + " │ │ │ ├─ dot [id IC]\n", + " │ │ │ │ ├─ AllocDiag{self.axis1=0, self.axis2=1, self.offset=0} [id HO]\n", + " │ │ │ │ │ └─ ···\n", + " │ │ │ │ └─ *8- [id ID] -> [id BY]\n", + " │ │ │ └─ Mul [id IE]\n", + " │ │ │ ├─ Eye{dtype='float64'} [id IF]\n", + " │ │ │ │ ├─ Subtensor{i} [id IG]\n", + " │ │ │ │ │ ├─ Shape [id IH]\n", + " │ │ │ │ │ │ └─ dot [id IC]\n", + " │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ └─ 0 [id II]\n", + " │ │ │ │ ├─ Subtensor{i} [id IJ]\n", + " │ │ │ │ │ ├─ Shape [id IK]\n", + " │ │ │ │ │ │ └─ dot [id IC]\n", + " │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ └─ 1 [id IL]\n", + " │ │ │ │ └─ 0 [id IM]\n", + " │ │ │ └─ ExpandDims{axes=[0, 1]} [id IN]\n", + " │ │ │ └─ 1e-08 [id IO]\n", + " │ │ └─ Transpose{axes=[1, 0]} [id IP]\n", + " │ │ └─ dot [id HY]\n", + " │ │ └─ ···\n", + " │ └─ Sub [id IQ]\n", + " │ ├─ AdvancedSetSubtensor [id IR]\n", + " │ │ ├─ *0- [id HT] -> [id T]\n", + " │ │ ├─ 0.0 [id IS]\n", + " │ │ └─ Or [id HR]\n", + " │ │ └─ ···\n", + " │ └─ Add [id IT]\n", + " │ ├─ *4- [id IU] -> [id BU]\n", + " │ └─ dot [id IV]\n", + " │ ├─ dot [id HN]\n", + " │ │ └─ ···\n", + " │ └─ *1- [id HG] -> [id V]\n", + " └─ *3- [id IW] -> [id BT]\n", + " ← Add [id IX]\n", + " ├─ Mul [id IY]\n", + " │ ├─ ExpandDims{axes=[0, 1]} [id IZ]\n", + " │ │ └─ 0.5 [id JA]\n", + " │ └─ Add [id JB]\n", + " │ ├─ dot [id JC]\n", + " │ │ ├─ dot [id JD]\n", + " │ │ │ ├─ *5- [id HE] -> [id BV]\n", + " │ │ │ └─ Add [id JE]\n", + " │ │ │ ├─ Add [id JF]\n", + " │ │ │ │ ├─ Mul [id JG]\n", + " │ │ │ │ │ ├─ ExpandDims{axes=[0, 1]} [id JH]\n", + " │ │ │ │ │ │ └─ 0.5 [id JI]\n", + " │ │ │ │ │ └─ Add [id JJ]\n", + " │ │ │ │ │ ├─ dot [id JK]\n", + " │ │ │ │ │ │ ├─ dot [id JL]\n", + " │ │ │ │ │ │ │ ├─ Sub [id JM]\n", + " │ │ │ │ │ │ │ │ ├─ Eye{dtype='float64'} [id JN]\n", + " │ │ │ │ │ │ │ │ │ ├─ 5 [id JO]\n", + " │ │ │ │ │ │ │ │ │ ├─ 5 [id JP]\n", + " │ │ │ │ │ │ │ │ │ └─ 0 [id JQ]\n", + " │ │ │ │ │ │ │ │ └─ dot [id JR]\n", + " │ │ │ │ │ │ │ │ ├─ Transpose{axes=[1, 0]} [id HI]\n", + " │ │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ │ └─ dot [id HN]\n", + " │ │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ │ └─ *2- [id HZ] -> [id BG]\n", + " │ │ │ │ │ │ └─ Transpose{axes=[1, 0]} [id JS]\n", + " │ │ │ │ │ │ └─ Sub [id JM]\n", + " │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ └─ Transpose{axes=[1, 0]} [id JT]\n", + " │ │ │ │ │ └─ dot [id JK]\n", + " │ │ │ │ │ └─ ···\n", + " │ │ │ │ └─ Mul [id JU]\n", + " │ │ │ │ ├─ ExpandDims{axes=[0, 1]} [id JV]\n", + " │ │ │ │ │ └─ 0.5 [id JW]\n", + " │ │ │ │ └─ Add [id JX]\n", + " │ │ │ │ ├─ dot [id JY]\n", + " │ │ │ │ │ ├─ dot [id JZ]\n", + " │ │ │ │ │ │ ├─ Transpose{axes=[1, 0]} [id HI]\n", + " │ │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ │ └─ dot [id IC]\n", + " │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ └─ Transpose{axes=[1, 0]} [id KA]\n", + " │ │ │ │ │ └─ Transpose{axes=[1, 0]} [id HI]\n", + " │ │ │ │ │ └─ ···\n", + " │ │ │ │ └─ Transpose{axes=[1, 0]} [id KB]\n", + " │ │ │ │ └─ dot [id JY]\n", + " │ │ │ │ └─ ···\n", + " │ │ │ └─ Mul [id KC]\n", + " │ │ │ ├─ Eye{dtype='float64'} [id KD]\n", + " │ │ │ │ ├─ Subtensor{i} [id KE]\n", + " │ │ │ │ │ ├─ Shape [id KF]\n", + " │ │ │ │ │ │ └─ Add [id JF]\n", + " │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ └─ 0 [id KG]\n", + " │ │ │ │ ├─ Subtensor{i} [id KH]\n", + " │ │ │ │ │ ├─ Shape [id KI]\n", + " │ │ │ │ │ │ └─ Add [id JF]\n", + " │ │ │ │ │ │ └─ ···\n", + " │ │ │ │ │ └─ 1 [id KJ]\n", + " │ │ │ │ └─ 0 [id KK]\n", + " │ │ │ └─ ExpandDims{axes=[0, 1]} [id KL]\n", + " │ │ │ └─ 1e-08 [id KM]\n", + " │ │ └─ Transpose{axes=[1, 0]} [id KN]\n", + " │ │ └─ *5- [id HE] -> [id BV]\n", + " │ └─ Transpose{axes=[1, 0]} [id KO]\n", + " │ └─ dot [id JC]\n", + " │ └─ ···\n", + " └─ Mul [id KP]\n", + " ├─ ExpandDims{axes=[0, 1]} [id KQ]\n", + " │ └─ 0.5 [id KR]\n", + " └─ Add [id KS]\n", + " ├─ dot [id KT]\n", + " │ ├─ dot [id KU]\n", + " │ │ ├─ *7- [id KV] -> [id BX]\n", + " │ │ └─ *9- [id KW] -> [id BZ]\n", + " │ └─ Transpose{axes=[1, 0]} [id KX]\n", + " │ └─ *7- [id KV] -> [id BX]\n", + " └─ Transpose{axes=[1, 0]} [id KY]\n", + " └─ dot [id KT]\n", + " └─ ···\n", + " ← Add [id HF]\n", + " └─ ···\n", + " ← Add [id IT]\n", + " └─ ···\n", + " ← Add [id JE]\n", + " └─ ···\n", + " ← Add [id HL]\n", + " └─ ···\n", + " ← Switch [id KZ]\n", + " ├─ Cast{float64} [id LA]\n", + " │ └─ All{axes=None} [id LB]\n", + " │ └─ Or [id HR]\n", + " │ └─ ···\n", + " ├─ 0.0 [id LC]\n", + " └─ Mul [id LD]\n", + " ├─ -0.5 [id LE]\n", + " └─ Subtensor{i} [id LF]\n", + " ├─ Reshape{1} [id LG]\n", + " │ ├─ Add [id LH]\n", + " │ │ ├─ Add [id LI]\n", + " │ │ │ ├─ Log [id LJ]\n", + " │ │ │ │ └─ Mul [id LK]\n", + " │ │ │ │ ├─ 2 [id LL]\n", + " │ │ │ │ └─ 3.141592653589793 [id LM]\n", + " │ │ │ └─ Log [id LN]\n", + " │ │ │ └─ Blockwise{Det, (m,m)->()} [id LO]\n", + " │ │ │ └─ Add [id HL]\n", + " │ │ │ └─ ···\n", + " │ │ └─ dot [id LP]\n", + " │ │ ├─ Sub [id IQ]\n", + " │ │ │ └─ ···\n", + " │ │ └─ Blockwise{Solve{assume_a='pos', lower=False, check_finite=False, b_ndim=1, overwrite_a=False, overwrite_b=False}, (m,m),(m)->(m)} [id LQ]\n", + " │ │ ├─ Add [id HL]\n", + " │ │ │ └─ ···\n", + " │ │ └─ Sub [id IQ]\n", + " │ │ └─ ···\n", + " │ └─ [-1] [id LR]\n", + " └─ 0 [id LS]\n", + "\n", + "AllocDiag{self.axis1=0, self.axis2=1, self.offset=0} [id HO]\n", + " ← AdvancedSetSubtensor [id LT]\n", + " ├─ Alloc [id LU]\n", + " │ ├─ 0.0 [id LV]\n", + " │ ├─ Add [id LW]\n", + " │ │ ├─ Subtensor{i} [id LX]\n", + " │ │ │ ├─ Shape [id LY]\n", + " │ │ │ │ └─ *0- [id HT]\n", + " │ │ │ └─ -1 [id LZ]\n", + " │ │ └─ 0 [id MA]\n", + " │ └─ Add [id LW]\n", + " │ └─ ···\n", + " ├─ *0- [id HT]\n", + " ├─ Add [id MB]\n", + " │ ├─ ARange{dtype='int64'} [id MC]\n", + " │ │ ├─ 0 [id MD]\n", + " │ │ ├─ Subtensor{i} [id ME]\n", + " │ │ │ ├─ Shape [id MF]\n", + " │ │ │ │ └─ *0- [id HT]\n", + " │ │ │ └─ -1 [id MG]\n", + " │ │ └─ 1 [id MH]\n", + " │ └─ ExpandDims{axis=0} [id MI]\n", + " │ └─ 0 [id MJ]\n", + " └─ Add [id MK]\n", + " ├─ ARange{dtype='int64'} [id MC]\n", + " │ └─ ···\n", + " └─ ExpandDims{axis=0} [id ML]\n", + " └─ 0 [id MM]\n" + ] + }, + { + "data": { + "text/plain": [ + "(3, 10, 5)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pt.vectorize(build_fk, signature=signature)(*[pt.as_tensor(x) for x in np_batch_inputs])[\n", + " 0\n", + "].eval().shape" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "a8c1a39f", + "metadata": {}, + "outputs": [], + "source": [ + "def make_signature(inputs, outputs):\n", + " states = \"s\"\n", + " obs = \"p\"\n", + " exog = \"r\"\n", + " time = \"t\"\n", + "\n", + " matrix_to_shape = {\n", + " \"data\": (time, obs),\n", + " \"a0\": (states,),\n", + " \"P0\": (states, states),\n", + " \"c\": (states,),\n", + " \"d\": (obs,),\n", + " \"T\": (states, states),\n", + " \"Z\": (obs, states),\n", + " \"R\": (states, exog),\n", + " \"H\": (obs, obs),\n", + " \"Q\": (exog, exog),\n", + " \"filtered_states\": (time, states),\n", + " \"filtered_covariances\": (time, states, states),\n", + " \"predicted_states\": (time, states),\n", + " \"predicted_covariances\": (time, states, states),\n", + " \"observed_states\": (time, obs),\n", + " \"observed_covariances\": (time, obs, obs),\n", + " \"smoothed_states\": (time, states),\n", + " \"smoothed_covariances\": (time, states, states),\n", + " \"loglike_obs\": (time,),\n", + " }\n", + " input_shapes = []\n", + " output_shapes = []\n", + "\n", + " for matrix in inputs:\n", + " name = matrix.name\n", + " input_shapes.append(matrix_to_shape[name])\n", + "\n", + " for matrix in outputs:\n", + " print(matrix, matrix.name)\n", + " name = matrix.name\n", + " output_shapes.append(matrix_to_shape[name])\n", + "\n", + " input_signature = \",\".join([\"(\" + \",\".join(shapes) + \")\" for shapes in input_shapes])\n", + " output_signature = \",\".join([\"(\" + \",\".join(shapes) + \")\" for shapes in output_shapes])\n", + "\n", + " return f\"{input_signature} -> {output_signature}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "828ce742", + "metadata": {}, + "outputs": [], + "source": [ + "floatX = \"float64\"\n", + "data = pt.tensor(name=\"data\", dtype=floatX, shape=(None, None))\n", + "a0 = pt.vector(name=\"a0\", dtype=floatX)\n", + "P0 = pt.matrix(name=\"P0\", dtype=floatX)\n", + "c = pt.vector(name=\"c\", dtype=floatX)\n", + "d = pt.vector(name=\"d\", dtype=floatX)\n", + "Q = pt.tensor(name=\"Q\", dtype=floatX, shape=(None, None, None))\n", + "H = pt.tensor(name=\"H\", dtype=floatX, shape=(None, None, None))\n", + "T = pt.tensor(name=\"T\", dtype=floatX, shape=(None, None, None))\n", + "R = pt.tensor(name=\"R\", dtype=floatX, shape=(None, None, None))\n", + "Z = pt.tensor(name=\"Z\", dtype=floatX, shape=(None, None, None))\n", + "\n", + "inputs = [data, a0, P0, c, d, T, Z, R, H, Q]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "52e105e3", + "metadata": {}, + "outputs": [], + "source": [ + "outputs = build_fk(*inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "9b5d94ab", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "filtered_states filtered_states\n", + "predicted_states predicted_states\n", + "observed_states observed_states\n", + "filtered_covariances filtered_covariances\n", + "predicted_covariances predicted_covariances\n", + "observed_covariances observed_covariances\n", + "loglike_obs loglike_obs\n", + "smoothed_states smoothed_states\n", + "smoothed_covariances smoothed_covariances\n" + ] + }, + { + "data": { + "text/plain": [ + "'(t,p),(s),(s,s),(s),(p),(s,s),(p,s),(s,r),(p,p),(r,r) -> (t,s),(t,s),(t,p),(t,s,s),(t,s,s),(t,p,p),(t),(t,s),(t,s,s)'" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "make_signature(inputs, outputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce8e287b", + "metadata": {}, + "outputs": [], + "source": [ + "signature = \"(t, o), (s), (s, s), (s), (o), (s, s), (o, s), (s, p), (o, o), (p, p) -> (t, s), (t, s), (t, o), (t, s, s), (t, s, s), (t, o, o), (t), (t, s), (t, s, s)\"" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "8a2632bc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "filtered_states filtered_states\n", + "predicted_states predicted_states\n", + "observed_states observed_states\n", + "filtered_covariances filtered_covariances\n", + "predicted_covariances predicted_covariances\n", + "observed_covariances observed_covariances\n", + "loglike_obs loglike_obs\n", + "smoothed_states smoothed_states\n", + "smoothed_covariances smoothed_covariances\n" + ] + }, + { + "data": { + "text/plain": [ + "(3, 10, 5)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pt.vectorize(build_fk, signature=make_signature(inputs, outputs))(\n", + " *[pt.as_tensor(x) for x in np_batch_inputs]\n", + ")[0].eval().shape" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "bd771148", + "metadata": {}, + "outputs": [], + "source": [ + "kf = StandardFilter()\n", + "ks = KalmanSmoother()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "e523bf60", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "filtered_states filtered_states\n", + "predicted_states predicted_states\n", + "observed_states observed_states\n", + "filtered_covariances filtered_covariances\n", + "predicted_covariances predicted_covariances\n", + "observed_covariances observed_covariances\n", + "loglike_obs loglike_obs\n" + ] + } + ], + "source": [ + "kf_outputs = kf.build_graph(*inputs)\n", + "kf_signature = make_signature(inputs, kf_outputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "e651e683", + "metadata": {}, + "outputs": [], + "source": [ + "make_batched_kf = pt.vectorize(kf.build_graph, signature=kf_signature)\n", + "ks_inputs = [T, R, Q, kf_outputs[0], kf_outputs[3]]\n", + "ks_outputs = ks.build_graph(*ks_inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "272c49f5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "smoothed_states smoothed_states\n", + "smoothed_covariances smoothed_covariances\n" + ] + } + ], + "source": [ + "ks_signature = make_signature(ks_inputs, ks_outputs)\n", + "make_batched_ks = pt.vectorize(ks.build_graph, signature=ks_signature)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "0f9c8aa8", + "metadata": {}, + "outputs": [], + "source": [ + "batched_kf_outputs = make_batched_kf(*[pt.as_tensor(x) for x in np_batch_inputs])" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "32c57d14", + "metadata": {}, + "outputs": [], + "source": [ + "data, a0, P0, c, d, T, Z, R, H, Q = np_batch_inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "9c7f5519", + "metadata": {}, + "outputs": [], + "source": [ + "batched_ks_outputs = make_batched_ks(\n", + " *[pt.as_tensor_variable(x) for x in [T, R, Q, batched_kf_outputs[0], batched_kf_outputs[3]]]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "875ea24a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 10, 5)" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batched_ks_outputs[0].eval().shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08329289", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/pymc_extras/statespace/filters/kalman_filter.py b/pymc_extras/statespace/filters/kalman_filter.py index 887130f4..a856d65d 100644 --- a/pymc_extras/statespace/filters/kalman_filter.py +++ b/pymc_extras/statespace/filters/kalman_filter.py @@ -1,4 +1,5 @@ from abc import ABC +from functools import partial import numpy as np import pytensor @@ -10,14 +11,13 @@ from pytensor.raise_op import Assert from pytensor.tensor import TensorVariable from pytensor.tensor.slinalg import solve_triangular -from pytensor.graph.replace import vectorize_graph from pymc_extras.statespace.filters.utilities import ( quad_form_sym, split_vars_into_seq_and_nonseq, stabilize, ) -from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL +from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL, ALL_KF_OUTPUT_NAMES MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64")) PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"] @@ -75,22 +75,56 @@ def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q): """ return data, a0, P0, c, d, T, Z, R, H, Q - def has_batched_input(self, data, a0, P0, c, d, T, Z, R, H, Q): - """ - Check if any of the inputs are batched. - """ - return any(x.ndim > CORE_NDIM[i] for i, x in enumerate([data, a0, P0, c, d, T, Z, R, H, Q])) - - def get_dummy_core_inputs(self, data, a0, P0, c, d, T, Z, R, H, Q): - """ - Get dummy inputs for the core parameters. - """ - out = [] - for x, core_ndim in zip([data, a0, P0, c, d, T, Z, R, H, Q], CORE_NDIM): - out.append( - pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:]) - ) - return out + def _make_gufunc_signature(self, inputs): + states = "s" + obs = "p" + exog = "r" + time = "t" + + matrix_to_shape = { + "data": (time, obs), + "a0": (states,), + "x0": (states,), + "P0": (states, states), + "c": (states,), + "d": (obs,), + "T": (states, states), + "Z": (obs, states), + "R": (states, exog), + "H": (obs, obs), + "Q": (exog, exog), + "filtered_states": (time, states), + "filtered_covariances": (time, states, states), + "predicted_states": (time, states), + "predicted_covariances": (time, states, states), + "observed_states": (time, obs), + "observed_covariances": (time, obs, obs), + "smoothed_states": (time, states), + "smoothed_covariances": (time, states, states), + "loglike_obs": (time,), + } + input_shapes = [] + output_shapes = [] + + for matrix in inputs: + name = matrix.name + input_shapes.append(matrix_to_shape[name]) + + for name in [ + "filtered_states", + "predicted_states", + "smoothed_states", + "filtered_covariances", + "predicted_covariances", + "smoothed_covariances", + "loglike_obs", + ]: + output_shapes.append(matrix_to_shape[name]) + + input_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in input_shapes]) + output_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in output_shapes]) + + return f"{input_signature} -> {output_signature}" @staticmethod def add_check_on_time_varying_shapes( @@ -160,7 +194,7 @@ def unpack_args(self, args) -> tuple: return y, a0, P0, c, d, T, Z, R, H, Q - def build_graph( + def _build_graph( self, data, a0, @@ -221,7 +255,6 @@ def build_graph( self.mode = mode self.missing_fill_value = missing_fill_value self.cov_jitter = cov_jitter - is_batched = self.has_batched_input(data, a0, P0, c, d, T, Z, R, H, Q) [R_shape] = constant_fold([R.shape], raise_not_constant=False) [Z_shape] = constant_fold([Z.shape], raise_not_constant=False) @@ -229,10 +262,6 @@ def build_graph( self.n_states, self.n_shocks = R_shape[-2:] self.n_endog = Z_shape[-2] - if is_batched: - batched_inputs = [data, a0, P0, c, d, T, Z, R, H, Q] - data, a0, P0, c, d, T, Z, R, H, Q = self.get_dummy_core_inputs(*batched_inputs) - data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q) sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq( @@ -257,15 +286,47 @@ def build_graph( filter_results = self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0]) - if is_batched: - vec_subs = dict(zip([data, a0, P0, c, d, T, Z, R, H, Q], batched_inputs)) - filter_results = vectorize_graph(filter_results, vec_subs) - if return_updates: return filter_results, updates return filter_results + def build_graph( + self, + data, + a0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + mode=None, + return_updates=False, + missing_fill_value=None, + cov_jitter=None, + ) -> list[TensorVariable] | tuple[list[TensorVariable], dict]: + """ + Build the vectorized computation graph for the Kalman filter. + """ + signature = self._make_gufunc_signature( + [data, a0, P0, c, d, T, Z, R, H, Q], + ) + fn = partial( + self._build_graph, + mode=mode, + return_updates=return_updates, + missing_fill_value=missing_fill_value, + cov_jitter=cov_jitter, + ) + filter_outputs = pt.vectorize(fn, signature=signature)(data, a0, P0, c, d, T, Z, R, H, Q) + for output, name in zip(filter_outputs, ALL_KF_OUTPUT_NAMES): + output.name = name + + return filter_outputs + def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]: """ Transform the values returned by the Kalman Filter scan into a form expected by users. In particular: diff --git a/pymc_extras/statespace/filters/kalman_smoother.py b/pymc_extras/statespace/filters/kalman_smoother.py index 671d9366..3d656b39 100644 --- a/pymc_extras/statespace/filters/kalman_smoother.py +++ b/pymc_extras/statespace/filters/kalman_smoother.py @@ -1,9 +1,8 @@ import pytensor import pytensor.tensor as pt - +from functools import partial from pytensor.compile import get_mode from pytensor.tensor.nlinalg import matrix_dot -from pytensor.graph.replace import vectorize_graph from pymc_extras.statespace.filters.utilities import ( quad_form_sym, split_vars_into_seq_and_nonseq, @@ -65,41 +64,58 @@ def unpack_args(self, args): return a, P, a_smooth, P_smooth, T, R, Q - def has_batched_input(self, T, R, Q, filtered_states, filtered_covariances): - """ - Check if any of the inputs are batched. - """ - return any( - x.ndim > SMOOTHER_CORE_NDIM[i] - for i, x in enumerate([T, R, Q, filtered_states, filtered_covariances]) - ) - - def get_dummy_core_inputs(self, T, R, Q, filtered_states, filtered_covariances): - """ - Get dummy inputs for the core parameters. - """ - out = [] - for x, core_ndim in zip( - [T, R, Q, filtered_states, filtered_covariances], SMOOTHER_CORE_NDIM - ): - out.append( - pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:]) - ) - return out - - def build_graph( + def _make_gufunc_signature(self, inputs): + states = "s" + obs = "p" + exog = "r" + time = "t" + + matrix_to_shape = { + "data": (time, obs), + "a0": (states,), + "x0": (states,), + "P0": (states, states), + "c": (states,), + "d": (obs,), + "T": (states, states), + "Z": (obs, states), + "R": (states, exog), + "H": (obs, obs), + "Q": (exog, exog), + "filtered_states": (time, states), + "filtered_covariances": (time, states, states), + "predicted_states": (time, states), + "predicted_covariances": (time, states, states), + "observed_states": (time, obs), + "observed_covariances": (time, obs, obs), + "smoothed_states": (time, states), + "smoothed_covariances": (time, states, states), + "loglike_obs": (time,), + } + input_shapes = [] + output_shapes = [] + + for matrix in inputs: + name = matrix.name + input_shapes.append(matrix_to_shape[name]) + + for name in [ + "smoothed_states", + "smoothed_covariances", + ]: + output_shapes.append(matrix_to_shape[name]) + + input_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in input_shapes]) + output_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in output_shapes]) + + return f"{input_signature} -> {output_signature}" + + def _build_graph( self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT ): self.mode = mode self.cov_jitter = cov_jitter - is_batched = self.has_batched_input(T, R, Q, filtered_states, filtered_covariances) - if is_batched: - batched_inputs = [T, R, Q, filtered_states, filtered_covariances] - T, R, Q, filtered_states, filtered_covariances = self.get_dummy_core_inputs( - *batched_inputs - ) - n, k = filtered_states.type.shape a_last = pt.specify_shape(filtered_states[-1], (k,)) @@ -129,18 +145,28 @@ def build_graph( smoothed_covariances = pt.concatenate( [smoothed_covariances[::-1], pt.expand_dims(P_last, axis=(0,))], axis=0 ) - smoothed_states.dprint() - if is_batched: - vec_subs = dict(zip([T, R, Q, filtered_states, filtered_covariances], batched_inputs)) - smoothed_states, smoothed_covariances = vectorize_graph( - [smoothed_states, smoothed_covariances], vec_subs - ) smoothed_states.name = "smoothed_states" smoothed_covariances.name = "smoothed_covariances" return smoothed_states, smoothed_covariances + def build_graph( + self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT + ): + """ + Build the vectorized computation graph for the Kalman smoother. + """ + signature = self._make_gufunc_signature( + [T, R, Q, filtered_states, filtered_covariances], + ) + fn = partial( + self._build_graph, + mode=mode, + cov_jitter=cov_jitter, + ) + return pt.vectorize(fn, signature=signature)(T, R, Q, filtered_states, filtered_covariances) + def smoother_step(self, *args): a, P, a_smooth, P_smooth, T, R, Q = self.unpack_args(args) a_hat, P_hat = self.predict(a, P, T, R, Q) diff --git a/pymc_extras/statespace/filters/utilities.py b/pymc_extras/statespace/filters/utilities.py index d61537b6..ef254df5 100644 --- a/pymc_extras/statespace/filters/utilities.py +++ b/pymc_extras/statespace/filters/utilities.py @@ -2,7 +2,14 @@ from pytensor.tensor.nlinalg import matrix_dot -from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, NEVER_TIME_VARYING, VECTOR_VALUED +from pymc_extras.statespace.utils.constants import ( + JITTER_DEFAULT, + NEVER_TIME_VARYING, + VECTOR_VALUED, +) + +CORE_NDIM = (2, 1, 2, 1, 1, 2, 2, 2, 2, 2) +SMOOTHER_CORE_NDIM = (2, 2, 2, 2, 3) def decide_if_x_time_varies(x, name): @@ -57,3 +64,40 @@ def stabilize(cov, jitter=JITTER_DEFAULT): def quad_form_sym(A, B): out = matrix_dot(A, B, A.T) return 0.5 * (out + out.T) + + +def has_batched_input_smoother(T, R, Q, filtered_states, filtered_covariances): + """ + Check if any of the inputs are batched. + """ + return any( + x.ndim > SMOOTHER_CORE_NDIM[i] + for i, x in enumerate([T, R, Q, filtered_states, filtered_covariances]) + ) + + +def get_dummy_core_inputs_smoother(T, R, Q, filtered_states, filtered_covariances): + """ + Get dummy inputs for the core parameters. + """ + out = [] + for x, core_ndim in zip([T, R, Q, filtered_states, filtered_covariances], SMOOTHER_CORE_NDIM): + out.append(pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:])) + return out + + +def has_batched_input_filter(data, a0, P0, c, d, T, Z, R, H, Q): + """ + Check if any of the inputs are batched. + """ + return any(x.ndim > CORE_NDIM[i] for i, x in enumerate([data, a0, P0, c, d, T, Z, R, H, Q])) + + +def get_dummy_core_inputs_filter(data, a0, P0, c, d, T, Z, R, H, Q): + """ + Get dummy inputs for the core parameters. + """ + out = [] + for x, core_ndim in zip([data, a0, P0, c, d, T, Z, R, H, Q], CORE_NDIM): + out.append(pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:])) + return out diff --git a/pymc_extras/statespace/utils/constants.py b/pymc_extras/statespace/utils/constants.py index c4064858..20f92d15 100644 --- a/pymc_extras/statespace/utils/constants.py +++ b/pymc_extras/statespace/utils/constants.py @@ -47,6 +47,16 @@ SMOOTHER_OUTPUT_NAMES = ["smoothed_state", "smoothed_covariance"] OBSERVED_OUTPUT_NAMES = ["predicted_observed_state", "predicted_observed_covariance"] +ALL_KF_OUTPUT_NAMES = [ + "filtered_states", + "predicted_states", + "observed_states", + "filtered_covariances", + "predicted_covariances", + "observed_covariances", + "loglike_obs", +] + MATRIX_DIMS = { "x0": (ALL_STATE_DIM,), "P0": (ALL_STATE_DIM, ALL_STATE_AUX_DIM), diff --git a/tests/statespace/test_kalman_filter.py b/tests/statespace/test_kalman_filter.py index 3cdfa569..c0427814 100644 --- a/tests/statespace/test_kalman_filter.py +++ b/tests/statespace/test_kalman_filter.py @@ -31,20 +31,20 @@ RTOL = 1e-6 if floatX.endswith("64") else 1e-3 standard_inout = initialize_filter(StandardFilter()) -standard_inout_batched = initialize_filter(StandardFilter(), batched=True) +# standard_inout_batched = initialize_filter(StandardFilter(), batched=True) cholesky_inout = initialize_filter(SquareRootFilter()) univariate_inout = initialize_filter(UnivariateFilter()) f_standard = pytensor.function(*standard_inout, on_unused_input="ignore") -f_standard_batched = pytensor.function(*standard_inout_batched, on_unused_input="ignore") +# f_standard_batched = pytensor.function(*standard_inout_batched, on_unused_input="ignore") f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore") f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore") -filter_funcs = [f_standard, f_standard_batched] # , f_cholesky, f_univariate] +filter_funcs = [f_standard] # , f_cholesky, f_univariate] filter_names = [ "StandardFilter", - "StandardFilterBatched", + # "StandardFilterBatched", # "CholeskyFilter", # "UnivariateFilter", ] @@ -75,14 +75,13 @@ def test_output_shapes_one_state_one_observed(filter_func, filter_name, rng): batch_size = 3 if "batched" in filter_name.lower() else 0 p, m, r, n = 1, 1, 1, 10 inputs = make_test_inputs(p, m, r, n, rng, batch_size=batch_size) - assert 0 - # outputs = filter_func(*inputs) + outputs = filter_func(*inputs) for output_idx, name in enumerate(output_names): expected_shape = get_expected_shape(name, p, m, r, n, batch_size) - # assert outputs[output_idx].shape == expected_shape, ( - # f"Shape of {name} does not match expected" - # ) + assert outputs[output_idx].shape == expected_shape, ( + f"Shape of {name} does not match expected" + ) @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) @@ -331,3 +330,34 @@ def test_kalman_filter_jax(filter): for name, jax_res, pt_res in zip(output_names, jax_outputs, pt_outputs): assert_allclose(jax_res, pt_res, atol=ATOL, rtol=RTOL, err_msg=f"{name} failed!") + + +def test_batched_standard_filter(): + p, m, r, n = 1, 5, 1, 10 + input_names = ["data", "x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"] + inputs = [ + pt.as_tensor(x, name=name) + for x, name in zip(make_test_inputs(p, m, r, n, rng, batch_size=8), input_names) + ] + kf = StandardFilter() + outputs = kf.build_graph(*inputs) + np.testing.assert_equal(outputs[0].shape.eval(), (8, n, m)) + + +def test_batched_kalman_smoother(): + p, m, r, n = 1, 5, 1, 10 + filter_input_names = ["data", "x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"] + smoother_input_names = ["T", "R", "Q", "filtered_states", "filtered_covs"] + + kf_inputs = data, x0, P0, c, d, T, Z, R, H, Q = [ + pt.as_tensor(x, name=name) + for x, name in zip(make_test_inputs(p, m, r, n, rng, batch_size=8), filter_input_names) + ] + kf = StandardFilter() + kf_outputs = kf.build_graph(*kf_inputs) + + ks = KalmanSmoother() + ks_inputs = T, R, Q, kf_outputs[0], kf_outputs[3] + ks_outputs = ks.build_graph(*ks_inputs) + + np.testing.assert_equal(ks_outputs[0].shape.eval(), (8, n, m))