From 1389a8be70027db60da3d69b56d9030d4b40e0bb Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 4 Dec 2024 19:59:15 +0800 Subject: [PATCH 1/7] initial commit --- pymc_experimental/gp/pytensor_gp.py | 92 +++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 pymc_experimental/gp/pytensor_gp.py diff --git a/pymc_experimental/gp/pytensor_gp.py b/pymc_experimental/gp/pytensor_gp.py new file mode 100644 index 00000000..c5c26052 --- /dev/null +++ b/pymc_experimental/gp/pytensor_gp.py @@ -0,0 +1,92 @@ +import numpy as np +import pymc as pm +import pytensor +import pytensor.tensor as pt + +from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs +from pytensor.graph.op import Apply, Op + + +class Cov(Op): + __props__ = ("fn",) + + def __init__(self, fn): + self.fn = fn + + def make_node(self, ls): + ls = pt.as_tensor(ls) + out = pt.matrix(shape=(None, None)) + + return Apply(self, [ls], [out]) + + def __call__(self, ls=1.0): + return super().__call__(ls) + + def perform(self, node, inputs, output_storage): + raise NotImplementedError("You should convert Cov into a TensorVariable expression!") + + def do_constant_folding(self, fgraph, node): + return False + + +class GP(Op): + __props__ = ("approx",) + + def __init__(self, approx): + self.approx = approx + + def make_node(self, mean, cov): + mean = pt.as_tensor(mean) + cov = pt.as_tensor(cov) + + if not (cov.owner and isinstance(cov.owner.op, Cov)): + raise ValueError("Second argument should be a Cov output.") + + out = pt.vector(shape=(None,)) + + return Apply(self, [mean, cov], [out]) + + def perform(self, node, inputs, output_storage): + raise NotImplementedError("You cannot evaluate a GP, not enough RAM in the Universe.") + + def do_constant_folding(self, fgraph, node): + return False + + +class PriorFromGP(Op): + """This Op will be replaced by the right MvNormal.""" + + def make_node(self, gp, x, rng): + gp = pt.as_tensor(gp) + if not (gp.owner and isinstance(gp.owner.op, GP)): + raise ValueError("First argument should be a GP output.") + + # TODO: Assert RNG has the right type + x = pt.as_tensor(x) + out = x.type() + + return Apply(self, [gp, x, rng], [out]) + + def __call__(self, gp, x, rng=None): + if rng is None: + rng = pytensor.shared(np.random.default_rng()) + return super().__call__(gp, x, rng) + + def perform(self, node, inputs, output_storage): + raise NotImplementedError("You should convert PriorFromGP into a MvNormal!") + + def do_constant_folding(self, fgraph, node): + return False + + +cov_op = Cov(fn=pm.gp.cov.ExpQuad) +gp_op = GP("vanilla") +# SymbolicRandomVariable.register(type(gp_op)) +prior_from_gp = PriorFromGP() + +MeasurableVariable.register(type(prior_from_gp)) + + +@_get_measurable_outputs.register(type(prior_from_gp)) +def gp_measurable_outputs(op, node): + return node.outputs From 3feaf2e48c572a23f97bdd31fb740b416759ecad Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 4 Dec 2024 23:16:40 +0800 Subject: [PATCH 2/7] Rough implementation of new API --- pymc_experimental/gp/pytensor_gp.py | 206 +++++++++++++++++++--------- tests/test_gp.py | 60 ++++++++ 2 files changed, 201 insertions(+), 65 deletions(-) create mode 100644 tests/test_gp.py diff --git a/pymc_experimental/gp/pytensor_gp.py b/pymc_experimental/gp/pytensor_gp.py index c5c26052..269e01f7 100644 --- a/pymc_experimental/gp/pytensor_gp.py +++ b/pymc_experimental/gp/pytensor_gp.py @@ -1,92 +1,168 @@ -import numpy as np import pymc as pm -import pytensor import pytensor.tensor as pt -from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs -from pytensor.graph.op import Apply, Op +from numpy.core.numeric import normalize_axis_tuple +from pymc.distributions.distribution import Continuous +from pytensor.compile.builders import OpFromGraph +from pytensor.tensor.einsum import _delta +# from pymc.logprob.abstract import MeasurableOp -class Cov(Op): - __props__ = ("fn",) - def __init__(self, fn): - self.fn = fn +class GPCovariance(OpFromGraph): + """OFG representing a GP covariance""" - def make_node(self, ls): - ls = pt.as_tensor(ls) - out = pt.matrix(shape=(None, None)) - - return Apply(self, [ls], [out]) - - def __call__(self, ls=1.0): - return super().__call__(ls) - - def perform(self, node, inputs, output_storage): - raise NotImplementedError("You should convert Cov into a TensorVariable expression!") - - def do_constant_folding(self, fgraph, node): - return False + @staticmethod + def square_dist(X, ls): + X = X / ls + X2 = pt.sum(pt.square(X), axis=-1) + sqd = -2.0 * X @ X.mT + (X2[..., :, None] + X2[..., None, :]) + return sqd -class GP(Op): - __props__ = ("approx",) - def __init__(self, approx): - self.approx = approx +class ExpQuadCov(GPCovariance): + """ + ExpQuad covariance function + """ - def make_node(self, mean, cov): - mean = pt.as_tensor(mean) - cov = pt.as_tensor(cov) - - if not (cov.owner and isinstance(cov.owner.op, Cov)): - raise ValueError("Second argument should be a Cov output.") - - out = pt.vector(shape=(None,)) + @classmethod + def exp_quad_full(cls, X, ls): + return pt.exp(-0.5 * cls.square_dist(X, ls)) - return Apply(self, [mean, cov], [out]) + @classmethod + def build_covariance(cls, X, ls): + X = pt.as_tensor(X) + ls = pt.as_tensor(ls) - def perform(self, node, inputs, output_storage): - raise NotImplementedError("You cannot evaluate a GP, not enough RAM in the Universe.") + ofg = cls(inputs=[X, ls], outputs=[cls.exp_quad_full(X, ls)]) + return ofg(X, ls) - def do_constant_folding(self, fgraph, node): - return False +def ExpQuad(X, ls): + return ExpQuadCov.build_covariance(X, ls) -class PriorFromGP(Op): - """This Op will be replaced by the right MvNormal.""" - def make_node(self, gp, x, rng): - gp = pt.as_tensor(gp) - if not (gp.owner and isinstance(gp.owner.op, GP)): - raise ValueError("First argument should be a GP output.") +class WhiteNoiseCov(GPCovariance): + @classmethod + def white_noise_full(cls, X, sigma): + X_shape = tuple(X.shape) + shape = X_shape[:-1] + (X_shape[-2],) - # TODO: Assert RNG has the right type - x = pt.as_tensor(x) - out = x.type() + return _delta(shape, normalize_axis_tuple((-1, -2), X.ndim)) * sigma**2 - return Apply(self, [gp, x, rng], [out]) + @classmethod + def build_covariance(cls, X, sigma): + X = pt.as_tensor(X) + sigma = pt.as_tensor(sigma) - def __call__(self, gp, x, rng=None): - if rng is None: - rng = pytensor.shared(np.random.default_rng()) - return super().__call__(gp, x, rng) + ofg = cls(inputs=[X, sigma], outputs=[cls.white_noise_full(X, sigma)]) + return ofg(X, sigma) - def perform(self, node, inputs, output_storage): - raise NotImplementedError("You should convert PriorFromGP into a MvNormal!") - def do_constant_folding(self, fgraph, node): - return False +def WhiteNoise(X, sigma): + return WhiteNoiseCov.build_covariance(X, sigma) -cov_op = Cov(fn=pm.gp.cov.ExpQuad) -gp_op = GP("vanilla") -# SymbolicRandomVariable.register(type(gp_op)) -prior_from_gp = PriorFromGP() +class GP_RV(pm.MvNormal.rv_type): + name = "gaussian_process" + signature = "(n),(n,n)->(n)" + dtype = "floatX" + _print_name = ("GP", "\\operatorname{GP}") -MeasurableVariable.register(type(prior_from_gp)) +class GP(Continuous): + rv_type = GP_RV + rv_op = GP_RV() -@_get_measurable_outputs.register(type(prior_from_gp)) -def gp_measurable_outputs(op, node): - return node.outputs + @classmethod + def dist(cls, cov, **kwargs): + cov = pt.as_tensor(cov) + mu = pt.zeros(cov.shape[-1]) + return super().dist([mu, cov], **kwargs) + + +# @register_canonicalize +# @node_rewriter(tracks=[pm.MvNormal.rv_type]) +# def GP_normal_mvnormal_conjugacy(fgraph: FunctionGraph, node): +# # TODO: Should this alert users that it can't be applied when the GP is in a deterministic? +# gp_rng, gp_size, mu, cov = node.inputs +# next_gp_rng, gp_rv = node.outputs +# +# if not isinstance(cov.owner.op, GPCovariance): +# return +# +# for client, input_index in fgraph.clients[gp_rv]: +# # input_index is 2 because it goes (rng, size, mu, sigma), and we want the mu +# # to be the GP we're looking +# if isinstance(client.op, pm.Normal.rv_type) and (input_index == 2): +# next_normal_rng, normal_rv = client.outputs +# normal_rng, normal_size, mu, sigma = client.inputs +# +# if normal_rv.ndim != gp_rv.ndim: +# return +# +# X = cov.owner.inputs[0] +# +# white_noise = WhiteNoiseCov.build_covariance(X, sigma) +# white_noise.name = 'WhiteNoiseCov' +# cov = cov + white_noise +# +# if not rv_size_is_none(normal_size): +# normal_size = tuple(normal_size) +# new_gp_size = normal_size[:-1] +# core_shape = normal_size[-1] +# +# cov_shape = (*(None,) * (cov.ndim - 2), core_shape, core_shape) +# cov = pt.specify_shape(cov, cov_shape) +# +# else: +# new_gp_size = None +# +# next_new_gp_rng, new_gp_mvn = pm.MvNormal.dist(cov=cov, rng=gp_rng, size=new_gp_size).owner.outputs +# new_gp_mvn.name = 'NewGPMvn' +# +# # Check that the new shape is at least as specific as the shape we are replacing +# for new_shape, old_shape in zip(new_gp_mvn.type.shape, normal_rv.type.shape, strict=True): +# if new_shape is None: +# assert old_shape is None +# +# return { +# next_normal_rng: next_new_gp_rng, +# normal_rv: new_gp_mvn, +# next_gp_rng: next_new_gp_rng +# } +# +# else: +# return None +# +# #TODO: Why do I need to register this twice? +# specialization_ir_rewrites_db.register( +# GP_normal_mvnormal_conjugacy.__name__, +# GP_normal_mvnormal_conjugacy, +# "basic", +# ) + +# @node_rewriter(tracks=[pm.MvNormal.rv_type]) +# def GP_normal_marginal_logp(fgraph: FunctionGraph, node): +# """ +# Replace Normal(GP(cov), sigma) -> MvNormal(0, cov + diag(sigma)). +# """ +# rng, size, mu, cov = node.inputs +# if cov.owner and cov.owner.op == matrix_inverse: +# tau = cov.owner.inputs[0] +# return PrecisionMvNormalRV.rv_op(mu, tau, size=size, rng=rng).owner.outputs +# return None +# + +# cov_op = GPCovariance() +# gp_op = GP("vanilla") +# # SymbolicRandomVariable.register(type(gp_op)) +# prior_from_gp = PriorFromGP() +# +# MeasurableVariable.register(type(prior_from_gp)) +# +# +# @_get_measurable_outputs.register(type(prior_from_gp)) +# def gp_measurable_outputs(op, node): +# return node.outputs diff --git a/tests/test_gp.py b/tests/test_gp.py new file mode 100644 index 00000000..e0b5c580 --- /dev/null +++ b/tests/test_gp.py @@ -0,0 +1,60 @@ +import numpy as np +import pymc as pm +import pytensor.tensor as pt +import pytest + +from pymc_experimental.gp.pytensor_gp import GP, ExpQuad + + +def test_exp_quad(): + x = pt.arange(3)[:, None] + ls = pt.ones(()) + cov = ExpQuad.build_covariance(x, ls).eval() + expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]]) + + np.testing.assert_allclose(cov, np.exp(-0.5 * expected_distance)) + + +@pytest.fixture(scope="session") +def marginal_model(): + with pm.Model() as m: + X = pm.Data("X", np.arange(3)[:, None]) + y = np.full(3, np.pi) + ls = 1.0 + cov = ExpQuad(X, ls) + gp = GP("gp", cov=cov) + + sigma = 1.0 + obs = pm.Normal("obs", mu=gp, sigma=sigma, observed=y) + + return m + + +def test_marginal_sigma_rewrites_to_white_noise_cov(marginal_model): + obs = marginal_model["obs"] + + # TODO: Bring these checks back after we implement marginalization of the GP RV + # + # assert sum(isinstance(var.owner.op, pm.Normal.rv_type) + # for var in ancestors([obs]) + # if var.owner is not None) == 1 + # + f = pm.compile_pymc([], obs) + # + # assert not any(isinstance(node.op, pm.Normal.rv_type) for node in f.maker.fgraph.apply_nodes) + + draws = np.stack([f() for _ in range(10_000)]) + empirical_cov = np.cov(draws.T) + + expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]]) + + np.testing.assert_allclose( + empirical_cov, np.exp(-0.5 * expected_distance) + np.eye(3), atol=0.1, rtol=0.1 + ) + + +def test_marginal_gp_logp(marginal_model): + expected_logps = {"obs": -8.8778} + point_logps = marginal_model.point_logps(round_vals=4) + for v1, v2 in zip(point_logps.values(), expected_logps.values()): + np.testing.assert_allclose(v1, v2, atol=1e-6) From f58327af0c5a55455a8b9edc32aa2c533af72dcc Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 2 Jan 2025 12:16:08 +0100 Subject: [PATCH 3/7] .refactor --- pymc_experimental/gp/pytensor_gp.py | 40 +++++-- tests/test_gp.py | 165 +++++++++++++++++++++++----- 2 files changed, 173 insertions(+), 32 deletions(-) diff --git a/pymc_experimental/gp/pytensor_gp.py b/pymc_experimental/gp/pytensor_gp.py index 269e01f7..57f5c3d3 100644 --- a/pymc_experimental/gp/pytensor_gp.py +++ b/pymc_experimental/gp/pytensor_gp.py @@ -12,6 +12,24 @@ class GPCovariance(OpFromGraph): """OFG representing a GP covariance""" + @staticmethod + def square_dist_Xs(X, Xs, ls): + assert X.ndim == 2, "Complain to Bill about it" + assert Xs.ndim == 2, "Complain to Bill about it" + + X = X / ls + Xs = Xs / ls + + X2 = pt.sum(pt.square(X), axis=-1) + Xs2 = pt.sum(pt.square(Xs), axis=-1) + + sqd = -2.0 * X @ X.mT + (X2[..., :, None] + Xs2[..., None, :]) + # sqd = -2.0 * pt.dot(X, pt.transpose(Xs)) + ( + # pt.reshape(X2, (-1, 1)) + pt.reshape(Xs2, (1, -1)) + # ) + + return pt.clip(sqd, 0, pt.inf) + @staticmethod def square_dist(X, ls): X = X / ls @@ -27,20 +45,27 @@ class ExpQuadCov(GPCovariance): """ @classmethod - def exp_quad_full(cls, X, ls): - return pt.exp(-0.5 * cls.square_dist(X, ls)) + def exp_quad_full(cls, X, Xs, ls): + return pt.exp(-0.5 * cls.square_dist_Xs(X, Xs, ls)) @classmethod - def build_covariance(cls, X, ls): + def build_covariance(cls, X, Xs=None, *, ls): X = pt.as_tensor(X) + if Xs is None: + Xs = X + else: + Xs = pt.as_tensor(Xs) ls = pt.as_tensor(ls) - ofg = cls(inputs=[X, ls], outputs=[cls.exp_quad_full(X, ls)]) - return ofg(X, ls) + out = cls.exp_quad_full(X, Xs, ls) + if Xs is X: + return cls(inputs=[X, ls], outputs=[out])(X, ls) + else: + return cls(inputs=[X, Xs, ls], outputs=[out])(X, Xs, ls) -def ExpQuad(X, ls): - return ExpQuadCov.build_covariance(X, ls) +def ExpQuad(X, X_new=None, *, ls): + return ExpQuadCov.build_covariance(X, X_new, ls=ls) class WhiteNoiseCov(GPCovariance): @@ -77,6 +102,7 @@ class GP(Continuous): @classmethod def dist(cls, cov, **kwargs): + # return Assert(msg="Don't know what a GP_RV is")(False) cov = pt.as_tensor(cov) mu = pt.zeros(cov.shape[-1]) return super().dist([mu, cov], **kwargs) diff --git a/tests/test_gp.py b/tests/test_gp.py index e0b5c580..9fdb699f 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -1,7 +1,6 @@ import numpy as np import pymc as pm import pytensor.tensor as pt -import pytest from pymc_experimental.gp.pytensor_gp import GP, ExpQuad @@ -9,19 +8,19 @@ def test_exp_quad(): x = pt.arange(3)[:, None] ls = pt.ones(()) - cov = ExpQuad.build_covariance(x, ls).eval() + cov = ExpQuad(x, ls=ls).eval() expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]]) np.testing.assert_allclose(cov, np.exp(-0.5 * expected_distance)) -@pytest.fixture(scope="session") -def marginal_model(): +# @pytest.fixture(scope="session") +def latent_model(): with pm.Model() as m: X = pm.Data("X", np.arange(3)[:, None]) y = np.full(3, np.pi) ls = 1.0 - cov = ExpQuad(X, ls) + cov = ExpQuad(X, ls=ls) gp = GP("gp", cov=cov) sigma = 1.0 @@ -30,31 +29,147 @@ def marginal_model(): return m -def test_marginal_sigma_rewrites_to_white_noise_cov(marginal_model): - obs = marginal_model["obs"] +def latent_model_old_API(): + with pm.Model() as m: + X = pm.Data("X", np.arange(3)[:, None]) + y = np.full(3, np.pi) + ls = 1.0 + cov = pm.gp.cov.ExpQuad(1, ls) + gp_class = pm.gp.Latent(cov_func=cov) + gp = gp_class.prior("gp", X, reparameterize=False) + + sigma = 1.0 + obs = pm.Normal("obs", mu=gp, sigma=sigma, observed=y) - # TODO: Bring these checks back after we implement marginalization of the GP RV - # - # assert sum(isinstance(var.owner.op, pm.Normal.rv_type) - # for var in ancestors([obs]) - # if var.owner is not None) == 1 - # - f = pm.compile_pymc([], obs) - # - # assert not any(isinstance(node.op, pm.Normal.rv_type) for node in f.maker.fgraph.apply_nodes) + return m, gp_class - draws = np.stack([f() for _ in range(10_000)]) - empirical_cov = np.cov(draws.T) - expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]]) +def test_latent_model_prior(): + m = latent_model() + ref_m, _ = latent_model_old_API() + + prior = pm.draw(m["gp"], draws=1000) + prior_ref = pm.draw(ref_m["gp"], draws=1000) + + np.testing.assert_allclose( + prior.mean(), + prior_ref.mean(), + atol=0.1, + ) + + np.testing.assert_allclose( + prior.std(), + prior_ref.std(), + rtol=0.1, + ) + + +def test_latent_model_logp(): + m = latent_model() + ip = m.initial_point() + + ref_m, _ = latent_model_old_API() + + np.testing.assert_allclose( + m.compile_logp()(ip), + ref_m.compile_logp()(ip), + rtol=1e-6, + ) + + +import arviz as az + + +def gp_conditional(model, gp, Xnew, jitter=1e-6): + def _build_conditional(self, Xnew, f, cov, jitter): + X, ls = cov.owner.inputs + + Kxx = cov + Kxs = cov.owner.op.build_covariance(X, Xnew, ls=ls) + Kss = cov.owner.op.build_covariance(Xnew, ls=ls) + + L = pt.linalg.cholesky(Kxx + pt.eye(X.shape[0]) * jitter) + # TODO: Use cho_solve + A = pt.linalg.solve_triangular(L, Kxs, lower=True) + v = pt.linalg.solve_triangular(L, f, lower=True) + + mu = (A.mT @ v).T # Vector? + cov = Kss - (A.mT @ A) + + return mu, cov + + with model.copy() as new_m: + gp = new_m[gp.name] + _, cov = gp.owner.op.dist_params(gp.owner) + mu_star, cov_star = _build_conditional(None, Xnew, gp, cov, jitter) + gp_star = pm.MvNormal("gp_star", mu_star, cov_star) + return new_m + + +def test_latent_model_predict_new_x(): + rng = np.random.default_rng(0) + new_x = np.array([3, 4])[:, None] + + m = latent_model() + ref_m, ref_gp_class = latent_model_old_API() + + posterior_idata = az.from_dict({"gp": rng.normal(np.pi, 1e-3, size=(4, 1000, 2))}) + + # with gp_extend_to_new_x(m): + with gp_conditional(m, m["gp"], new_x): + pred = ( + pm.sample_posterior_predictive(posterior_idata, var_names=["gp_star"]) + .posterior_predictiev["gp"] + .values + ) + + with ref_m: + gp_star = ref_gp_class.conditional("gp_star", Xnew=new_x) + pred_ref = ( + pm.sample_posterior_predictive(posterior_idata, var_names=["gp_star"]) + .posterior_predictive["gp"] + .values + ) + + np.testing.assert_allclose( + pred.mean(), + pred_ref.mean(), + atol=0.1, + ) np.testing.assert_allclose( - empirical_cov, np.exp(-0.5 * expected_distance) + np.eye(3), atol=0.1, rtol=0.1 + pred.std(), + pred_ref.std(), + rtol=0.1, ) -def test_marginal_gp_logp(marginal_model): - expected_logps = {"obs": -8.8778} - point_logps = marginal_model.point_logps(round_vals=4) - for v1, v2 in zip(point_logps.values(), expected_logps.values()): - np.testing.assert_allclose(v1, v2, atol=1e-6) +# +# def test_marginal_sigma_rewrites_to_white_noise_cov(marginal_model, ): +# obs = marginal_model["obs"] +# +# # TODO: Bring these checks back after we implement marginalization of the GP RV +# # +# # assert sum(isinstance(var.owner.op, pm.Normal.rv_type) +# # for var in ancestors([obs]) +# # if var.owner is not None) == 1 +# # +# f = pm.compile_pymc([], obs) +# # +# # assert not any(isinstance(node.op, pm.Normal.rv_type) for node in f.maker.fgraph.apply_nodes) +# +# draws = np.stack([f() for _ in range(10_000)]) +# empirical_cov = np.cov(draws.T) +# +# expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]]) +# +# np.testing.assert_allclose( +# empirical_cov, np.exp(-0.5 * expected_distance) + np.eye(3), atol=0.1, rtol=0.1 +# ) +# +# +# def test_marginal_gp_logp(marginal_model): +# expected_logps = {"obs": -8.8778} +# point_logps = marginal_model.point_logps(round_vals=4) +# for v1, v2 in zip(point_logps.values(), expected_logps.values()): +# np.testing.assert_allclose(v1, v2, atol=1e-6) From e9c1a9c0e3872e1f42aba3894198589ba5f80744 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 2 Jan 2025 13:25:45 +0100 Subject: [PATCH 4/7] Add inline to conditional transform --- pymc_experimental/gp/pytensor_gp.py | 130 +++++++++++++++++++++++----- tests/test_gp.py | 89 +++++++++---------- 2 files changed, 146 insertions(+), 73 deletions(-) diff --git a/pymc_experimental/gp/pytensor_gp.py b/pymc_experimental/gp/pytensor_gp.py index 57f5c3d3..7f92a3b8 100644 --- a/pymc_experimental/gp/pytensor_gp.py +++ b/pymc_experimental/gp/pytensor_gp.py @@ -1,12 +1,12 @@ +from collections.abc import Sequence + import pymc as pm import pytensor.tensor as pt -from numpy.core.numeric import normalize_axis_tuple from pymc.distributions.distribution import Continuous +from pymc.model.fgraph import fgraph_from_model, model_free_rv, model_from_fgraph +from pytensor import Variable from pytensor.compile.builders import OpFromGraph -from pytensor.tensor.einsum import _delta - -# from pymc.logprob.abstract import MeasurableOp class GPCovariance(OpFromGraph): @@ -23,7 +23,7 @@ def square_dist_Xs(X, Xs, ls): X2 = pt.sum(pt.square(X), axis=-1) Xs2 = pt.sum(pt.square(Xs), axis=-1) - sqd = -2.0 * X @ X.mT + (X2[..., :, None] + Xs2[..., None, :]) + sqd = -2.0 * X @ Xs.mT + (X2[..., :, None] + Xs2[..., None, :]) # sqd = -2.0 * pt.dot(X, pt.transpose(Xs)) + ( # pt.reshape(X2, (-1, 1)) + pt.reshape(Xs2, (1, -1)) # ) @@ -68,25 +68,26 @@ def ExpQuad(X, X_new=None, *, ls): return ExpQuadCov.build_covariance(X, X_new, ls=ls) -class WhiteNoiseCov(GPCovariance): - @classmethod - def white_noise_full(cls, X, sigma): - X_shape = tuple(X.shape) - shape = X_shape[:-1] + (X_shape[-2],) - - return _delta(shape, normalize_axis_tuple((-1, -2), X.ndim)) * sigma**2 - - @classmethod - def build_covariance(cls, X, sigma): - X = pt.as_tensor(X) - sigma = pt.as_tensor(sigma) - - ofg = cls(inputs=[X, sigma], outputs=[cls.white_noise_full(X, sigma)]) - return ofg(X, sigma) - +# class WhiteNoiseCov(GPCovariance): +# @classmethod +# def white_noise_full(cls, X, sigma): +# X_shape = tuple(X.shape) +# shape = X_shape[:-1] + (X_shape[-2],) +# +# return _delta(shape, normalize_axis_tuple((-1, -2), X.ndim)) * sigma**2 +# +# @classmethod +# def build_covariance(cls, X, sigma): +# X = pt.as_tensor(X) +# sigma = pt.as_tensor(sigma) +# +# ofg = cls(inputs=[X, sigma], outputs=[cls.white_noise_full(X, sigma)]) +# return ofg(X, sigma) -def WhiteNoise(X, sigma): - return WhiteNoiseCov.build_covariance(X, sigma) +# +# def WhiteNoise(X, sigma): +# return WhiteNoiseCov.build_covariance(X, sigma) +# class GP_RV(pm.MvNormal.rv_type): @@ -108,6 +109,89 @@ def dist(cls, cov, **kwargs): return super().dist([mu, cov], **kwargs) +def conditional_gp( + model, + gp: Variable | str, + Xnew, + *, + jitter=1e-6, + dims: Sequence[str] = (), + inline: bool = False, +): + """ + Condition a GP on new data. + + Parameters + ---------- + model: Model + gp: Variable | str + The GP to condition on. + Xnew: Tensor-like + New data to condition the GP on. + jitter: float, default=1e-6 + Jitter to add to the new GP covariance matrix. + dims: Sequence[str], default=() + Dimensions of the new GP. + inline: bool, default=False + Whether to inline the new GP in place of the old one. This is not always a safe operation. + If True, any variables that depend on the GP will be updated to depend on the new GP. + + Returns + ------- + Conditional model: Model + A new model with a GP free RV named f"{gp.name}_star" conditioned on the new data. + + """ + + def _build_conditional(Xnew, f, cov, jitter): + if not isinstance(cov.owner.op, GPCovariance): + raise NotImplementedError(f"Cannot build conditional of {cov.owner.op} operation") + X, ls = cov.owner.inputs + + Kxx = cov + Kxs = cov.owner.op.build_covariance(X, Xnew, ls=ls) + Kss = cov.owner.op.build_covariance(Xnew, ls=ls) + + L = pt.linalg.cholesky(Kxx + pt.eye(X.shape[0]) * jitter) + # TODO: Use cho_solve + A = pt.linalg.solve_triangular(L, Kxs, lower=True) + v = pt.linalg.solve_triangular(L, f, lower=True) + + mu = (A.mT @ v).T # Vector? + cov = Kss - (A.mT @ A) + + return mu, cov + + if isinstance(gp, Variable): + assert model[gp.name] is gp + else: + gp = model[gp.name] + + fgraph, memo = fgraph_from_model(model) + gp_model_var = memo[gp] + gp_rv = gp_model_var.owner.inputs[0] + + if isinstance(gp_rv.owner.op, pm.MvNormal.rv_type): + _, cov = gp_rv.owner.op.dist_params(gp.owner) + else: + raise NotImplementedError("Can only condition on pure GPs") + + # TODO: We should write the naive conditional covariance, and then have rewrites that lift it through kernels + mu_star, cov_star = _build_conditional(Xnew, gp_model_var, cov, jitter) + gp_rv_star = pm.MvNormal.dist(mu_star, cov_star, name=f"{gp.name}_star") + + value = gp_rv_star.clone() + transform = None + gp_model_var_star = model_free_rv(gp_rv_star, value, transform, *dims) + + if inline: + fgraph.replace(gp_model_var, gp_model_var_star, import_missing=True) + else: + fgraph.add_output(gp_model_var_star, import_missing=True) + + return model_from_fgraph(fgraph, mutate_fgraph=True) + + # @register_canonicalize # @node_rewriter(tracks=[pm.MvNormal.rv_type]) # def GP_normal_mvnormal_conjugacy(fgraph: FunctionGraph, node): diff --git a/tests/test_gp.py b/tests/test_gp.py index 9fdb699f..cc9900dd 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -1,8 +1,10 @@ +import arviz as az import numpy as np import pymc as pm import pytensor.tensor as pt +import pytest -from pymc_experimental.gp.pytensor_gp import GP, ExpQuad +from pymc_experimental.gp.pytensor_gp import GP, ExpQuad, conditional_gp def test_exp_quad(): @@ -77,72 +79,59 @@ def test_latent_model_logp(): ) -import arviz as az - - -def gp_conditional(model, gp, Xnew, jitter=1e-6): - def _build_conditional(self, Xnew, f, cov, jitter): - X, ls = cov.owner.inputs - - Kxx = cov - Kxs = cov.owner.op.build_covariance(X, Xnew, ls=ls) - Kss = cov.owner.op.build_covariance(Xnew, ls=ls) - - L = pt.linalg.cholesky(Kxx + pt.eye(X.shape[0]) * jitter) - # TODO: Use cho_solve - A = pt.linalg.solve_triangular(L, Kxs, lower=True) - v = pt.linalg.solve_triangular(L, f, lower=True) - - mu = (A.mT @ v).T # Vector? - cov = Kss - (A.mT @ A) - - return mu, cov - - with model.copy() as new_m: - gp = new_m[gp.name] - _, cov = gp.owner.op.dist_params(gp.owner) - mu_star, cov_star = _build_conditional(None, Xnew, gp, cov, jitter) - gp_star = pm.MvNormal("gp_star", mu_star, cov_star) - return new_m - - -def test_latent_model_predict_new_x(): +@pytest.mark.parametrize("inline", (False, True)) +def test_latent_model_conditional(inline): rng = np.random.default_rng(0) + posterior = az.from_dict( + posterior={"gp": rng.normal(np.pi, 1e-3, size=(4, 1000, 3))}, + constant_data={"X": np.arange(3)[:, None]}, + ) + new_x = np.array([3, 4])[:, None] m = latent_model() - ref_m, ref_gp_class = latent_model_old_API() + with m: + pm.Deterministic("gp_exp", m["gp"].exp()) - posterior_idata = az.from_dict({"gp": rng.normal(np.pi, 1e-3, size=(4, 1000, 2))}) - - # with gp_extend_to_new_x(m): - with gp_conditional(m, m["gp"], new_x): - pred = ( - pm.sample_posterior_predictive(posterior_idata, var_names=["gp_star"]) - .posterior_predictiev["gp"] - .values - ) + with conditional_gp(m, m["gp"], new_x, inline=inline) as cgp: + pred = pm.sample_posterior_predictive( + posterior, + var_names=["gp_star", "gp_exp"], + progressbar=False, + ).posterior_predictive + ref_m, ref_gp_class = latent_model_old_API() with ref_m: gp_star = ref_gp_class.conditional("gp_star", Xnew=new_x) - pred_ref = ( - pm.sample_posterior_predictive(posterior_idata, var_names=["gp_star"]) - .posterior_predictive["gp"] - .values - ) + pred_ref = pm.sample_posterior_predictive( + posterior, + var_names=["gp_star"], + progressbar=False, + ).posterior_predictive np.testing.assert_allclose( - pred.mean(), - pred_ref.mean(), + pred["gp_star"].mean(), + pred_ref["gp_star"].mean(), atol=0.1, ) np.testing.assert_allclose( - pred.std(), - pred_ref.std(), + pred["gp_star"].std(), + pred_ref["gp_star"].std(), rtol=0.1, ) + if inline: + assert np.testing.assert_allclose( + pred["gp_exp"], + np.exp(pred["gp_star"]), + ) + else: + np.testing.assert_allclose( + pred["gp_exp"], + np.exp(posterior.posterior["gp"]), + ) + # # def test_marginal_sigma_rewrites_to_white_noise_cov(marginal_model, ): From c7e84fbe3cf70077d85a0e420d7858a31a382281 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 2 Jan 2025 13:28:32 +0100 Subject: [PATCH 5/7] Remove dead code --- pymc_experimental/gp/pytensor_gp.py | 127 +--------------------------- tests/test_gp.py | 35 ++++---- 2 files changed, 20 insertions(+), 142 deletions(-) diff --git a/pymc_experimental/gp/pytensor_gp.py b/pymc_experimental/gp/pytensor_gp.py index 7f92a3b8..7278f970 100644 --- a/pymc_experimental/gp/pytensor_gp.py +++ b/pymc_experimental/gp/pytensor_gp.py @@ -13,7 +13,7 @@ class GPCovariance(OpFromGraph): """OFG representing a GP covariance""" @staticmethod - def square_dist_Xs(X, Xs, ls): + def square_dist(X, Xs, ls): assert X.ndim == 2, "Complain to Bill about it" assert Xs.ndim == 2, "Complain to Bill about it" @@ -24,20 +24,8 @@ def square_dist_Xs(X, Xs, ls): Xs2 = pt.sum(pt.square(Xs), axis=-1) sqd = -2.0 * X @ Xs.mT + (X2[..., :, None] + Xs2[..., None, :]) - # sqd = -2.0 * pt.dot(X, pt.transpose(Xs)) + ( - # pt.reshape(X2, (-1, 1)) + pt.reshape(Xs2, (1, -1)) - # ) - return pt.clip(sqd, 0, pt.inf) - @staticmethod - def square_dist(X, ls): - X = X / ls - X2 = pt.sum(pt.square(X), axis=-1) - sqd = -2.0 * X @ X.mT + (X2[..., :, None] + X2[..., None, :]) - - return sqd - class ExpQuadCov(GPCovariance): """ @@ -46,7 +34,7 @@ class ExpQuadCov(GPCovariance): @classmethod def exp_quad_full(cls, X, Xs, ls): - return pt.exp(-0.5 * cls.square_dist_Xs(X, Xs, ls)) + return pt.exp(-0.5 * cls.square_dist(X, Xs, ls)) @classmethod def build_covariance(cls, X, Xs=None, *, ls): @@ -64,32 +52,10 @@ def build_covariance(cls, X, Xs=None, *, ls): return cls(inputs=[X, Xs, ls], outputs=[out])(X, Xs, ls) -def ExpQuad(X, X_new=None, *, ls): +def ExpQuad(X, X_new=None, *, ls=1.0): return ExpQuadCov.build_covariance(X, X_new, ls=ls) -# class WhiteNoiseCov(GPCovariance): -# @classmethod -# def white_noise_full(cls, X, sigma): -# X_shape = tuple(X.shape) -# shape = X_shape[:-1] + (X_shape[-2],) -# -# return _delta(shape, normalize_axis_tuple((-1, -2), X.ndim)) * sigma**2 -# -# @classmethod -# def build_covariance(cls, X, sigma): -# X = pt.as_tensor(X) -# sigma = pt.as_tensor(sigma) -# -# ofg = cls(inputs=[X, sigma], outputs=[cls.white_noise_full(X, sigma)]) -# return ofg(X, sigma) - -# -# def WhiteNoise(X, sigma): -# return WhiteNoiseCov.build_covariance(X, sigma) -# - - class GP_RV(pm.MvNormal.rv_type): name = "gaussian_process" signature = "(n),(n,n)->(n)" @@ -103,7 +69,6 @@ class GP(Continuous): @classmethod def dist(cls, cov, **kwargs): - # return Assert(msg="Don't know what a GP_RV is")(False) cov = pt.as_tensor(cov) mu = pt.zeros(cov.shape[-1]) return super().dist([mu, cov], **kwargs) @@ -190,89 +155,3 @@ def _build_conditional(Xnew, f, cov, jitter): fgraph.add_output(gp_model_var_star, import_missing=True) return model_from_fgraph(fgraph, mutate_fgraph=True) - - -# @register_canonicalize -# @node_rewriter(tracks=[pm.MvNormal.rv_type]) -# def GP_normal_mvnormal_conjugacy(fgraph: FunctionGraph, node): -# # TODO: Should this alert users that it can't be applied when the GP is in a deterministic? -# gp_rng, gp_size, mu, cov = node.inputs -# next_gp_rng, gp_rv = node.outputs -# -# if not isinstance(cov.owner.op, GPCovariance): -# return -# -# for client, input_index in fgraph.clients[gp_rv]: -# # input_index is 2 because it goes (rng, size, mu, sigma), and we want the mu -# # to be the GP we're looking -# if isinstance(client.op, pm.Normal.rv_type) and (input_index == 2): -# next_normal_rng, normal_rv = client.outputs -# normal_rng, normal_size, mu, sigma = client.inputs -# -# if normal_rv.ndim != gp_rv.ndim: -# return -# -# X = cov.owner.inputs[0] -# -# white_noise = WhiteNoiseCov.build_covariance(X, sigma) -# white_noise.name = 'WhiteNoiseCov' -# cov = cov + white_noise -# -# if not rv_size_is_none(normal_size): -# normal_size = tuple(normal_size) -# new_gp_size = normal_size[:-1] -# core_shape = normal_size[-1] -# -# cov_shape = (*(None,) * (cov.ndim - 2), core_shape, core_shape) -# cov = pt.specify_shape(cov, cov_shape) -# -# else: -# new_gp_size = None -# -# next_new_gp_rng, new_gp_mvn = pm.MvNormal.dist(cov=cov, rng=gp_rng, size=new_gp_size).owner.outputs -# new_gp_mvn.name = 'NewGPMvn' -# -# # Check that the new shape is at least as specific as the shape we are replacing -# for new_shape, old_shape in zip(new_gp_mvn.type.shape, normal_rv.type.shape, strict=True): -# if new_shape is None: -# assert old_shape is None -# -# return { -# next_normal_rng: next_new_gp_rng, -# normal_rv: new_gp_mvn, -# next_gp_rng: next_new_gp_rng -# } -# -# else: -# return None -# -# #TODO: Why do I need to register this twice? -# specialization_ir_rewrites_db.register( -# GP_normal_mvnormal_conjugacy.__name__, -# GP_normal_mvnormal_conjugacy, -# "basic", -# ) - -# @node_rewriter(tracks=[pm.MvNormal.rv_type]) -# def GP_normal_marginal_logp(fgraph: FunctionGraph, node): -# """ -# Replace Normal(GP(cov), sigma) -> MvNormal(0, cov + diag(sigma)). -# """ -# rng, size, mu, cov = node.inputs -# if cov.owner and cov.owner.op == matrix_inverse: -# tau = cov.owner.inputs[0] -# return PrecisionMvNormalRV.rv_op(mu, tau, size=size, rng=rng).owner.outputs -# return None -# - -# cov_op = GPCovariance() -# gp_op = GP("vanilla") -# # SymbolicRandomVariable.register(type(gp_op)) -# prior_from_gp = PriorFromGP() -# -# MeasurableVariable.register(type(prior_from_gp)) -# -# -# @_get_measurable_outputs.register(type(prior_from_gp)) -# def gp_measurable_outputs(op, node): -# return node.outputs diff --git a/tests/test_gp.py b/tests/test_gp.py index cc9900dd..e461ea90 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -7,17 +7,7 @@ from pymc_experimental.gp.pytensor_gp import GP, ExpQuad, conditional_gp -def test_exp_quad(): - x = pt.arange(3)[:, None] - ls = pt.ones(()) - cov = ExpQuad(x, ls=ls).eval() - expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]]) - - np.testing.assert_allclose(cov, np.exp(-0.5 * expected_distance)) - - -# @pytest.fixture(scope="session") -def latent_model(): +def build_latent_model(): with pm.Model() as m: X = pm.Data("X", np.arange(3)[:, None]) y = np.full(3, np.pi) @@ -31,7 +21,7 @@ def latent_model(): return m -def latent_model_old_API(): +def build_latent_model_old_API(): with pm.Model() as m: X = pm.Data("X", np.arange(3)[:, None]) y = np.full(3, np.pi) @@ -46,9 +36,18 @@ def latent_model_old_API(): return m, gp_class +def test_exp_quad(): + x = pt.arange(3)[:, None] + ls = pt.ones(()) + cov = ExpQuad(x, ls=ls).eval() + expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]]) + + np.testing.assert_allclose(cov, np.exp(-0.5 * expected_distance)) + + def test_latent_model_prior(): - m = latent_model() - ref_m, _ = latent_model_old_API() + m = build_latent_model() + ref_m, _ = build_latent_model_old_API() prior = pm.draw(m["gp"], draws=1000) prior_ref = pm.draw(ref_m["gp"], draws=1000) @@ -67,10 +66,10 @@ def test_latent_model_prior(): def test_latent_model_logp(): - m = latent_model() + m = build_latent_model() ip = m.initial_point() - ref_m, _ = latent_model_old_API() + ref_m, _ = build_latent_model_old_API() np.testing.assert_allclose( m.compile_logp()(ip), @@ -89,7 +88,7 @@ def test_latent_model_conditional(inline): new_x = np.array([3, 4])[:, None] - m = latent_model() + m = build_latent_model() with m: pm.Deterministic("gp_exp", m["gp"].exp()) @@ -100,7 +99,7 @@ def test_latent_model_conditional(inline): progressbar=False, ).posterior_predictive - ref_m, ref_gp_class = latent_model_old_API() + ref_m, ref_gp_class = build_latent_model_old_API() with ref_m: gp_star = ref_gp_class.conditional("gp_star", Xnew=new_x) pred_ref = pm.sample_posterior_predictive( From 6639dffe52ddbc00ef817c0216981064354273ff Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 2 Jan 2025 13:46:47 +0100 Subject: [PATCH 6/7] Add some comments --- pymc_experimental/gp/pytensor_gp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pymc_experimental/gp/pytensor_gp.py b/pymc_experimental/gp/pytensor_gp.py index 7278f970..26db8d2f 100644 --- a/pymc_experimental/gp/pytensor_gp.py +++ b/pymc_experimental/gp/pytensor_gp.py @@ -110,11 +110,15 @@ def conditional_gp( def _build_conditional(Xnew, f, cov, jitter): if not isinstance(cov.owner.op, GPCovariance): + # TODO: Look for xx kernels in the ancestors of f raise NotImplementedError(f"Cannot build conditional of {cov.owner.op} operation") + X, ls = cov.owner.inputs Kxx = cov + # Kxs = toposort_replace(cov, tuple(zip(xx_kernels, xs_kernels))) Kxs = cov.owner.op.build_covariance(X, Xnew, ls=ls) + # Kss = toposort_replace(cov, tuple(zip(xx_kernels, ss_kernels))) Kss = cov.owner.op.build_covariance(Xnew, ls=ls) L = pt.linalg.cholesky(Kxx + pt.eye(X.shape[0]) * jitter) @@ -141,7 +145,6 @@ def _build_conditional(Xnew, f, cov, jitter): else: raise NotImplementedError("Can only condition on pure GPs") - # TODO: We should write the naive conditional covariance, and then have rewrites that lift it through kernels mu_star, cov_star = _build_conditional(Xnew, gp_model_var, cov, jitter) gp_rv_star = pm.MvNormal.dist(mu_star, cov_star, name=f"{gp.name}_star") From 5d0255ff4ce6aaadb8e4a4cf5f2e11a11eb7f6cf Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 2 Jan 2025 13:52:27 +0100 Subject: [PATCH 7/7] Add some comments --- pymc_experimental/gp/pytensor_gp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc_experimental/gp/pytensor_gp.py b/pymc_experimental/gp/pytensor_gp.py index 26db8d2f..edcb1b14 100644 --- a/pymc_experimental/gp/pytensor_gp.py +++ b/pymc_experimental/gp/pytensor_gp.py @@ -116,9 +116,9 @@ def _build_conditional(Xnew, f, cov, jitter): X, ls = cov.owner.inputs Kxx = cov - # Kxs = toposort_replace(cov, tuple(zip(xx_kernels, xs_kernels))) + # Kxs = toposort_replace(cov, tuple(zip(xx_kernels, xs_kernels)), rebuild=True) Kxs = cov.owner.op.build_covariance(X, Xnew, ls=ls) - # Kss = toposort_replace(cov, tuple(zip(xx_kernels, ss_kernels))) + # Kss = toposort_replace(cov, tuple(zip(xx_kernels, ss_kernels)), rebuild=True) Kss = cov.owner.op.build_covariance(Xnew, ls=ls) L = pt.linalg.cholesky(Kxx + pt.eye(X.shape[0]) * jitter)