Skip to content

Commit 9743786

Browse files
authored
Add Spline Interpolation using basis functions (#52)
* add spline basis * refactor the spline utils * fix spline initialization * add docs * add docs * reference Adrian in Notes * black * refactor bspline basis calls * add type errors explicit * add custom eval points API * test against eval_points api * fix docstring * fix docs * fix exception condition
1 parent 44505f6 commit 9743786

File tree

8 files changed

+187
-3
lines changed

8 files changed

+187
-3
lines changed

.github/workflows/test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ jobs:
131131
# The ">-" in the next line replaces newlines with spaces (see https://stackoverflow.com/a/66809682).
132132
run: >-
133133
conda activate pymc-test-py37 &&
134-
python -m pytest -vv --cov=pymc_experimental --cov-append --cov-report=xml --cov-report term --durations=50 %TEST_SUBSET%
134+
python -m pytest -vv --cov=pymc_experimental --doctest-modules pymc_experimental --cov-append --cov-report=xml --cov-report term --durations=50 %TEST_SUBSET%
135135
- name: Upload coverage to Codecov
136136
uses: codecov/codecov-action@v2
137137
with:

docs/api_reference.rst

+7
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,10 @@ methods in the current release of PyMC experimental.
2222
.. automodule:: pymc_experimental.distributions.histogram_utils
2323
:members: histogram_approximation
2424

25+
26+
:mod:`pymc_experimental.utils`
27+
=============================
28+
29+
.. automodule:: pymc_experimental.utils.spline
30+
:members: bspline_interpolation
31+

pymc_experimental/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313

1414
from pymc_experimental.bart import *
1515
from pymc_experimental import distributions
16+
from pymc_experimental import gp
17+
from pymc_experimental import utils

pymc_experimental/gp/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from pymc_experimental.gp.latent_approx import HSGP, ProjectedProcess, KarhunenLoeveExpansion
1+
from pymc_experimental.gp.latent_approx import HSGP, ProjectedProcess, KarhunenLoeveExpansion
+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import aesara
2+
import pymc_experimental as pmx
3+
import aesara.tensor as at
4+
import numpy as np
5+
import pytest
6+
7+
8+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
9+
@pytest.mark.parametrize("sparse", [True, False])
10+
def test_spline_construction(dtype, sparse):
11+
x = np.linspace(0, 1, 20, dtype=dtype)
12+
np_out = pmx.utils.spline.numpy_bspline_basis(x, 10, 3)
13+
assert np_out.shape == (20, 10)
14+
assert np_out.dtype == dtype
15+
spline_op = pmx.utils.spline.BSplineBasis(sparse=sparse)
16+
out = spline_op(x, at.constant(10), at.constant(3))
17+
if not sparse:
18+
assert isinstance(out.type, at.TensorType)
19+
else:
20+
assert isinstance(out.type, aesara.sparse.SparseTensorType)
21+
B = out.eval()
22+
if not sparse:
23+
np.testing.assert_allclose(B, np_out)
24+
else:
25+
np.testing.assert_allclose(B.todense(), np_out)
26+
assert B.shape == (20, 10)
27+
28+
29+
@pytest.mark.parametrize("shape", [(100,), (100, 5)])
30+
@pytest.mark.parametrize("sparse", [True, False])
31+
@pytest.mark.parametrize("points", [dict(n=1001), dict(eval_points=np.linspace(0, 1, 1001))])
32+
def test_interpolation_api(shape, sparse, points):
33+
x = np.random.randn(*shape)
34+
yt = pmx.utils.spline.bspline_interpolation(x, **points, sparse=sparse)
35+
y = yt.eval()
36+
assert y.shape == (1001, *shape[1:])
37+
38+
39+
@pytest.mark.parametrize(
40+
"params",
41+
[
42+
(dict(sparse="foo", n=100, degree=1), TypeError, "sparse should be True or False"),
43+
(dict(n=100, degree=0.5), TypeError, "degree should be integer"),
44+
(
45+
dict(n=100, eval_points=np.linspace(0, 1), degree=1),
46+
ValueError,
47+
"Please provide one of n or eval_points",
48+
),
49+
(
50+
dict(degree=1),
51+
ValueError,
52+
"Please provide one of n or eval_points",
53+
),
54+
],
55+
)
56+
def test_bad_calls(params):
57+
kw, E, err = params
58+
x = np.random.randn(10)
59+
with pytest.raises(E, match=err):
60+
pmx.utils.spline.bspline_interpolation(x, **kw)

pymc_experimental/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from pymc_experimental.utils import spline

pymc_experimental/utils/spline.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import aesara
2+
import numpy as np
3+
import scipy.interpolate
4+
from aesara.graph.op import Op, Apply
5+
import aesara.tensor as at
6+
import aesara.sparse
7+
8+
9+
def numpy_bspline_basis(eval_points: np.ndarray, k: int, degree=3):
10+
k_knots = k + degree + 1
11+
knots = np.linspace(0, 1, k_knots - 2 * degree)
12+
knots = np.r_[[0] * degree, knots, [1] * degree]
13+
basis_funcs = scipy.interpolate.BSpline(knots, np.eye(k), k=degree)
14+
Bx = basis_funcs(eval_points).astype(eval_points.dtype)
15+
return Bx
16+
17+
18+
class BSplineBasis(Op):
19+
__props__ = ("sparse",)
20+
21+
def __init__(self, sparse=True) -> None:
22+
super().__init__()
23+
if not isinstance(sparse, bool):
24+
raise TypeError("sparse should be True or False")
25+
self.sparse = sparse
26+
27+
def make_node(self, *inputs) -> Apply:
28+
eval_points, k, d = map(at.as_tensor, inputs)
29+
if not (eval_points.ndim == 1 and np.issubdtype(eval_points.dtype, np.floating)):
30+
raise TypeError("eval_points should be a vector of floats")
31+
if not k.type in at.int_types:
32+
raise TypeError("k should be integer")
33+
if not d.type in at.int_types:
34+
raise TypeError("degree should be integer")
35+
if self.sparse:
36+
out_type = aesara.sparse.SparseTensorType("csr", eval_points.dtype)()
37+
else:
38+
out_type = aesara.tensor.matrix(dtype=eval_points.dtype)
39+
return Apply(self, [eval_points, k, d], [out_type])
40+
41+
def perform(self, node, inputs, output_storage, params=None) -> None:
42+
eval_points, k, d = inputs
43+
Bx = numpy_bspline_basis(eval_points, int(k), int(d))
44+
if self.sparse:
45+
Bx = scipy.sparse.csr_matrix(Bx, dtype=eval_points.dtype)
46+
output_storage[0][0] = Bx
47+
48+
def infer_shape(self, fgraph, node, ins_shapes):
49+
return [(node.inputs[0].shape[0], node.inputs[1])]
50+
51+
52+
def bspline_basis(n, k, degree=3, dtype=None, sparse=True):
53+
dtype = dtype or aesara.config.floatX
54+
eval_points = np.linspace(0, 1, n, dtype=dtype)
55+
return BSplineBasis(sparse=sparse)(eval_points, k, degree)
56+
57+
58+
def bspline_interpolation(x, *, n=None, eval_points=None, degree=3, sparse=True):
59+
"""Interpolate sparse grid to dense grid using bsplines.
60+
61+
Parameters
62+
----------
63+
x : Variable
64+
Input Variable to interpolate.
65+
0th coordinate assumed to be mapped regularly on [0, 1] interval
66+
n : int (optional)
67+
Resolution of interpolation
68+
eval_points : vector (optional)
69+
Custom eval points in [0, 1] interval (or scaled properly using min/max scaling)
70+
degree : int, optional
71+
BSpline degree, by default 3
72+
sparse : bool, optional
73+
Use sparse operation, by default True
74+
75+
Returns
76+
-------
77+
Variable
78+
The interpolated variable, interpolation is across 0th axis
79+
80+
Examples
81+
--------
82+
>>> import pymc as pm
83+
>>> import numpy as np
84+
>>> half_months = np.linspace(0, 365, 12*2)
85+
>>> with pm.Model(coords=dict(knots_time=half_months, time=np.arange(365))) as model:
86+
... kernel = pm.gp.cov.ExpQuad(1, ls=365/12)
87+
... # ready to define gp (a latent process over parameters)
88+
... gp = pm.gp.gp.Latent(
89+
... cov_func=kernel
90+
... )
91+
... y_knots = gp.prior("y_knots", half_months[:, None], dims="knots_time")
92+
... y = pm.Deterministic(
93+
... "y",
94+
... bspline_interpolation(y_knots, n=365, degree=3),
95+
... dims="time"
96+
... )
97+
... trace = pm.sample_prior_predictive(1)
98+
99+
Notes
100+
-----
101+
Adopted from `BayesAlpha <https://github.com/quantopian/bayesalpha/blob/676f4f194ad20211fd040d3b0c6e82969aafb87e/bayesalpha/dists.py#L97>`_
102+
where it was written by @aseyboldt
103+
"""
104+
x = at.as_tensor(x)
105+
if n is not None and eval_points is not None:
106+
raise ValueError("Please provide one of n or eval_points")
107+
elif n is not None:
108+
eval_points = np.linspace(0, 1, n, dtype=x.dtype)
109+
elif eval_points is None:
110+
raise ValueError("Please provide one of n or eval_points")
111+
basis = BSplineBasis(sparse=sparse)(eval_points, x.shape[0], degree)
112+
if sparse:
113+
return aesara.sparse.dot(basis, x)
114+
else:
115+
return aesara.tensor.dot(basis, x)

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
[tool.pytest.ini_options]
22
minversion = "6.0"
33
xfail_strict=true
4-
addopts = "--doctest-modules pymc_experimental"
54

65
[tool.black]
76
line-length = 100

0 commit comments

Comments
 (0)