Skip to content

Commit 28ab6d9

Browse files
authored
Merge pull request #3057 from bwengals/maunaloa2
A second CO2 example for GPs
2 parents 559c17e + e438edd commit 28ab6d9

File tree

6 files changed

+2906
-1155
lines changed

6 files changed

+2906
-1155
lines changed

docs/source/notebooks/GP-MaunaLoa.ipynb

Lines changed: 530 additions & 1096 deletions
Large diffs are not rendered by default.

docs/source/notebooks/GP-MaunaLoa2.ipynb

Lines changed: 1871 additions & 0 deletions
Large diffs are not rendered by default.

docs/source/notebooks/GP-MeansAndCovs.ipynb

Lines changed: 229 additions & 51 deletions
Large diffs are not rendered by default.
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"-------------------------------------------------------------------------------------------"
2+
" Atmospheric CO2 record based on ice core data before 1958, "
3+
" (Ethridge et. al., 1996; MacFarling Meure et al., 2006) and "
4+
" yearly averages of direct observations from Mauna Loa and the South Pole after and "
5+
" including 1958 (from Scripps CO2 Program). "
6+
" "
7+
" in situ data is based on simple average of Mauna Loa and South Pole from values on Jan 1 "
8+
" "
9+
" For Mauna Loa Observatory, Hawaii: Latitude 19.5�N Longitude 155.6�W Elevation 3397m "
10+
" For South Pole: Latitude 90.0�S Elevation 2810m "
11+
" "
12+
" Scripps CO2 Source: R. F. Keeling, S. C. Piper, A. F. Bollenbacher and S. J. Walker "
13+
" Scripps CO2 Program ( http://scrippsco2.ucsd.edu ) "
14+
" Scripps Institution of Oceanography (SIO) "
15+
" University of California "
16+
" La Jolla, California USA 92093-0244 "
17+
" "
18+
" Status of Scripps CO2 data and correspondence: "
19+
" "
20+
" These data are subject to revision based on recalibration of standard gases. Questions "
21+
" about the data should be directed to Dr. Ralph Keeling ([email protected]) and "
22+
" and Dr. Stephen Piper ([email protected]), Scripps CO2 Program. "
23+
" "
24+
" "
25+
"-------------------------------------------------------------------------------------------"
26+
" "
27+
Sample date (year), CO2 (ppm)
28+
13.3, 276.750
29+
29.5, 277.880
30+
56.0, 277.380
31+
104.5, 277.510
32+
136.0, 278.130
33+
168.2, 280.050
34+
202.5, 280.720
35+
227.9, 281.490
36+
274.2, 280.130
37+
302.3, 279.830
38+
329.2, 278.920
39+
364.6, 277.050
40+
428.4, 276.910
41+
461.2, 276.720
42+
499.8, 276.360
43+
536.7, 276.000
44+
572.0, 277.560
45+
595.6, 276.900
46+
632.0, 278.250
47+
667.9, 279.390
48+
698.4, 279.700
49+
729.7, 278.530
50+
764.5, 278.530
51+
799.2, 278.550
52+
857.3, 279.340
53+
897.4, 278.910
54+
944.2, 279.120
55+
968.2, 278.460
56+
1005.0, 280.500
57+
1025.2, 280.830
58+
1058.0, 282.760
59+
1105.4, 282.750
60+
1159.6, 283.880
61+
1207.4, 283.600
62+
1257.6, 282.110
63+
1275.8, 281.130
64+
1306.5, 281.500
65+
1349.7, 280.060
66+
1411.3, 279.620
67+
1429.3, 279.540
68+
1431.0, 282.520
69+
1560.4, 281.750
70+
1588.3, 281.030
71+
1610.4, 271.830
72+
1628.9, 274.500
73+
1640.1, 276.620
74+
1689.6, 276.250
75+
1722.8, 276.940
76+
1734.1, 278.230
77+
1742.7, 276.740
78+
1752.0, 276.390
79+
1763.5, 276.320
80+
1773.7, 277.780
81+
1780.6, 276.780
82+
1794.4, 281.540
83+
1799.6, 281.150
84+
1814.2, 284.340
85+
1826.2, 281.280
86+
1834.5, 283.730
87+
1838.0, 284.070
88+
1844.0, 286.500
89+
1846.0, 284.130
90+
1849.0, 287.730
91+
1852.3, 288.570
92+
1854.0, 287.030
93+
1859.0, 286.480
94+
1864.0, 285.410
95+
1867.0, 285.220
96+
1869.1, 287.690
97+
1873.0, 287.170
98+
1874.0, 290.520
99+
1884.0, 289.810
100+
1884.4, 289.010
101+
1886.0, 290.620
102+
1889.0, 291.870
103+
1894.0, 293.840
104+
1896.0, 298.160
105+
1899.0, 296.100
106+
1902.0, 295.320
107+
1904.0, 295.120
108+
1909.0, 300.450
109+
1911.5, 298.390
110+
1914.0, 300.350
111+
1918.6, 303.270
112+
1919.0, 303.550
113+
1923.0, 303.200
114+
1923.6, 305.200
115+
1928.8, 307.770
116+
1929.0, 305.720
117+
1933.0, 307.210
118+
1934.5, 307.820
119+
1936.5, 308.990
120+
1938.0, 309.560
121+
1939.0, 310.940
122+
1940.0, 311.900
123+
1941.0, 310.700
124+
1941.5, 310.300
125+
1942.0, 311.290
126+
1943.0, 310.770
127+
1944.0, 311.590
128+
1945.0, 309.750
129+
1946.0, 311.460
130+
1947.0, 310.750
131+
1948.0, 310.470
132+
1949.0, 311.150
133+
1950.0, 312.550
134+
1953.0, 312.060
135+
1954.0, 311.680
136+
1955.0, 313.640
137+
1957.0, 314.040
138+
1958.0, 314.600
139+
1959.0, 315.370
140+
1960.0, 316.390
141+
1961.0, 317.490
142+
1962.0, 318.070
143+
1963.0, 318.470
144+
1964.0, 318.880
145+
1965.0, 319.290
146+
1966.0, 320.450
147+
1967.0, 321.430
148+
1968.0, 322.060
149+
1969.0, 323.150
150+
1970.0, 324.480
151+
1971.0, 325.400
152+
1972.0, 326.090
153+
1973.0, 327.680
154+
1974.0, 329.190
155+
1975.0, 329.760
156+
1976.0, 330.890
157+
1977.0, 332.040
158+
1978.0, 333.880
159+
1979.0, 335.200
160+
1980.0, 336.900
161+
1981.0, 338.580
162+
1982.0, 339.660
163+
1983.0, 340.890
164+
1984.0, 342.890
165+
1985.0, 344.140
166+
1986.0, 345.540
167+
1987.0, 346.950
168+
1988.0, 349.190
169+
1989.0, 351.060
170+
1990.0, 352.280
171+
1991.0, 353.780
172+
1992.0, 354.860
173+
1993.0, 355.690
174+
1994.0, 356.770
175+
1995.0, 358.680
176+
1996.0, 360.580
177+
1997.0, 361.830
178+
1998.0, 363.630
179+
1999.0, 366.360
180+
2000.0, 367.610
181+
2001.0, 368.910
182+
2002.0, 370.690
183+
2003.0, 373.100
184+
2004.0, 375.190
185+
2005.0, 376.940
186+
2006.0, 379.290
187+
2007.0, 381.010
188+
2008.0, 383.080
189+
2009.0, 384.820
190+
2010.0, 386.720
191+
2011.0, 388.930
192+
2012.0, 390.710
193+
2013.0, 393.310
194+
2014.0, 395.880
195+
2015.0, 397.840
196+
2016.0, 400.660

pymc3/gp/cov.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
'Periodic',
1717
'WarpedInput',
1818
'Gibbs',
19+
'ScaledCov',
1920
'Coregion']
2021

2122

@@ -542,15 +543,49 @@ def full(self, X, Xs=None):
542543
def diag(self, X):
543544
return tt.alloc(1.0, X.shape[0])
544545

545-
def handle_args(func, args):
546-
def f(x, args):
547-
if args is None:
548-
return func(x)
546+
547+
class ScaledCov(Covariance):
548+
R"""
549+
Construct a kernel by multiplying a base kernel with a scaling
550+
function defined using Theano. The scaling function is
551+
non-negative, and can be parameterized.
552+
553+
.. math::
554+
k(x, x') = \phi(x) k_{\text{base}}(x, x') \phi(x')
555+
556+
Parameters
557+
----------
558+
cov_func: Covariance
559+
Base kernel or covariance function
560+
scaling_func : callable
561+
Theano function of X and additional optional arguments.
562+
args : optional, tuple or list of scalars or PyMC3 variables
563+
Additional inputs (besides X or Xs) to lengthscale_func.
564+
"""
565+
def __init__(self, input_dim, cov_func, scaling_func, args=None, active_dims=None):
566+
super(ScaledCov, self).__init__(input_dim, active_dims)
567+
if not callable(scaling_func):
568+
raise TypeError("scaling_func must be callable")
569+
if not isinstance(cov_func, Covariance):
570+
raise TypeError("Must be or inherit from the Covariance class")
571+
self.cov_func = cov_func
572+
self.scaling_func = handle_args(scaling_func, args)
573+
self.args = args
574+
575+
def diag(self, X):
576+
X, _ = self._slice(X, None)
577+
cov_diag = self.cov_func(X, diag=True)
578+
scf_diag = tt.square(tt.flatten(self.scaling_func(X, self.args)))
579+
return cov_diag * scf_diag
580+
581+
def full(self, X, Xs=None):
582+
X, Xs = self._slice(X, Xs)
583+
scf_x = self.scaling_func(X, self.args)
584+
if Xs is None:
585+
return tt.outer(scf_x, scf_x) * self.cov_func(X)
549586
else:
550-
if not isinstance(args, tuple):
551-
args = (args,)
552-
return func(x, *args)
553-
return f
587+
scf_xs = self.scaling_func(Xs, self.args)
588+
return tt.outer(scf_x, scf_xs) * self.cov_func(X, Xs)
554589

555590

556591
class Coregion(Covariance):
@@ -615,3 +650,16 @@ def diag(self, X):
615650
X, _ = self._slice(X, None)
616651
index = tt.cast(X, 'int32')
617652
return tt.diag(self.B)[index.ravel()]
653+
654+
655+
def handle_args(func, args):
656+
def f(x, args):
657+
if args is None:
658+
return func(x)
659+
else:
660+
if not isinstance(args, tuple):
661+
args = (args,)
662+
return func(x, *args)
663+
return f
664+
665+

pymc3/tests/test_gp.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,30 @@ def test_raises(self):
552552
pm.gp.cov.Gibbs(3, lambda x: x, active_dims=[0,1])
553553

554554

555+
class TestScaledCov(object):
556+
def test_1d(self):
557+
X = np.linspace(0, 1, 10)[:, None]
558+
def scaling_func(x, a, b):
559+
return a + b*x
560+
with pm.Model() as model:
561+
cov_m52 = pm.gp.cov.Matern52(1, 0.2)
562+
cov = pm.gp.cov.ScaledCov(1, scaling_func=scaling_func, args=(2, -1), cov_func=cov_m52)
563+
K = theano.function([], cov(X))()
564+
npt.assert_allclose(K[0, 1], 3.00686, atol=1e-3)
565+
K = theano.function([], cov(X, X))()
566+
npt.assert_allclose(K[0, 1], 3.00686, atol=1e-3)
567+
# check diagonal
568+
Kd = theano.function([], cov(X, diag=True))()
569+
npt.assert_allclose(np.diag(K), Kd, atol=1e-5)
570+
571+
def test_raises(self):
572+
cov_m52 = pm.gp.cov.Matern52(1, 0.2)
573+
with pytest.raises(TypeError):
574+
pm.gp.cov.ScaledCov(1, cov_m52, "str is not callable")
575+
with pytest.raises(TypeError):
576+
pm.gp.cov.ScaledCov(1, "str is not Covariance object", lambda x: x)
577+
578+
555579
class TestHandleArgs(object):
556580
def test_handleargs(self):
557581
def func_noargs(x):

0 commit comments

Comments
 (0)