Skip to content

Commit 225ae82

Browse files
michaelosthegeColCarroll
authored andcommitted
DifferentialEquation Op refactor (#3634)
* addition of test for equality checking of ODE Ops (not yet implemented) * WIP: refactoring the DifferentialEquation Op + full support for test_values + explicit input/output types + 2D return shape + optional return of sensitivities + gradient without helper Op * fully replace DifferentialEquation Op with the refactored implementation * align tests with refactored API + whitespace & condensed formatting + test for equality of identical Ops * use tt.stack as suggested by DeprecationWarning * always cast y0 and theta to floatX * allow some tests to fail on float32 (due to downcast exception) * don't use f-strings to maintain 3.5 support * link ODE refactor PR * renamed ODE notebooks + add notebooks to examples index * use (custom) errors instead of asserts * move ShapeError to exceptions.py
1 parent cc55279 commit 225ae82

11 files changed

+979
-814
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
### New features
66
- Implemented robust u turn check in NUTS (similar to stan-dev/stan#2800). See PR [#3605]
7-
- Add capabilities to do inference on parameters in a differential equation with `DifferentialEquation`. See [#3590](https://github.com/pymc-devs/pymc3/pull/3590).
7+
- Add capabilities to do inference on parameters in a differential equation with `DifferentialEquation`. See [#3590](https://github.com/pymc-devs/pymc3/pull/3590) and [#3634](https://github.com/pymc-devs/pymc3/pull/3634).
88
- Distinguish between `Data` and `Deterministic` variables when graphing models with graphviz. PR [#3491](https://github.com/pymc-devs/pymc3/pull/3491).
99
- Sequential Monte Carlo - Approximate Bayesian Computation step method is now available. The implementation is in an experimental stage and will be further improved.
1010
- Added `Matern12` covariance function for Gaussian processes. This is the Matern kernel with nu=1/2.

docs/source/notebooks/ODE_API_introduction.ipynb

Lines changed: 410 additions & 0 deletions
Large diffs are not rendered by default.

docs/source/notebooks/ODE_API_parameter_estimation.ipynb

Lines changed: 0 additions & 570 deletions
This file was deleted.

docs/source/notebooks/ODE_API_shapes_and_benchmarking.ipynb

Lines changed: 317 additions & 0 deletions
Large diffs are not rendered by default.

docs/source/notebooks/ODE_parameter_estimation.ipynb renamed to docs/source/notebooks/ODE_with_manual_gradients.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"cell_type": "markdown",
2121
"metadata": {},
2222
"source": [
23-
"# Bayesian inference in non-linear ODEs using PyMC3\n",
23+
"# Lotka-Volterra with manual gradients\n",
2424
"\n",
2525
"by [Sanmitra Ghosh](https://www.mrc-bsu.cam.ac.uk/people/in-alphabetical-order/a-to-g/sanmitra-ghosh/)"
2626
]

docs/source/notebooks/table_of_contents_examples.js

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Gallery.contents = {
5454
"normalizing_flows_overview": "Variational Inference",
5555
"gaussian-mixture-model-advi": "Variational Inference",
5656
"GLM-hierarchical-advi-minibatch": "Variational Inference",
57-
"ODE_parameter_estimation": "Inference in ODE models",
58-
"ODE_API_parameter_estimation": "Inference in ODE models"
57+
"ODE_with_manual_gradients": "Inference in ODE models",
58+
"ODE_API_introduction": "Inference in ODE models",
59+
"ODE_API_shapes_and_benchmarking": "Inference in ODE models"
5960
}

pymc3/exceptions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"IncorrectArgumentsError",
44
"TraceDirectoryError",
55
"ImputationWarning",
6+
"ShapeError"
67
]
78

89

@@ -24,3 +25,12 @@ class ImputationWarning(UserWarning):
2425
"""Warning that there are missing values that will be imputed."""
2526

2627
pass
28+
29+
30+
class ShapeError(Exception):
31+
"""Error that the shape of a variable is incorrect."""
32+
def __init__(self, message, actual=None, expected=None):
33+
if expected and actual:
34+
super().__init__('{} (actual {} != expected {})'.format(message, actual, expected))
35+
else:
36+
super().__init__(message)

pymc3/ode/ode.py

Lines changed: 120 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
import logging
12
import numpy as np
23
import scipy
34
import theano
45
import theano.tensor as tt
5-
from ..ode.utils import augment_system, ODEGradop
6+
from ..ode.utils import augment_system
7+
from ..exceptions import ShapeError
8+
9+
_log = logging.getLogger('pymc3')
610

711

812
class DifferentialEquation(theano.Op):
@@ -17,16 +21,16 @@ class DifferentialEquation(theano.Op):
1721
1822
func : callable
1923
Function specifying the differential equation
20-
t0 : float
21-
Time corresponding to the initial condition
2224
times : array
2325
Array of times at which to evaluate the solution of the differential equation.
2426
n_states : int
2527
Dimension of the differential equation. For scalar differential equations, n_states=1.
2628
For vector valued differential equations, n_states = number of differential equations in the system.
27-
n_odeparams : int
29+
n_theta : int
2830
Number of parameters in the differential equation.
29-
31+
t0 : float
32+
Time corresponding to the initial condition
33+
3034
.. code-block:: python
3135
3236
def odefunc(y, t, p):
@@ -35,45 +39,49 @@ def odefunc(y, t, p):
3539
3640
times = np.arange(0.5, 5, 0.5)
3741
38-
ode_model = DifferentialEquation(func=odefunc, t0=0, times=times, n_states=1, n_odeparams=1)
42+
ode_model = DifferentialEquation(func=odefunc, times=times, n_states=1, n_theta=1, t0=0)
3943
"""
40-
41-
__props__ = ("func", "t0", "times", "n_states", "n_odeparams")
42-
43-
def __init__(self, func, times, n_states, n_odeparams, t0=0):
44+
_itypes = [
45+
tt.TensorType(theano.config.floatX, (False,)), # y0 as 1D floatX vector
46+
tt.TensorType(theano.config.floatX, (False,)) # theta as 1D floatX vector
47+
]
48+
_otypes = [
49+
tt.TensorType(theano.config.floatX, (False, False)), # model states as floatX of shape (T, S)
50+
tt.TensorType(theano.config.floatX, (False, False, False)), # sensitivities as floatX of shape (T, S, len(y0) + len(theta))
51+
]
52+
__props__ = ("func", "times", "n_states", "n_theta", "t0")
53+
54+
def __init__(self, func, times, *, n_states, n_theta, t0=0):
4455
if not callable(func):
4556
raise ValueError("Argument func must be callable.")
4657
if n_states < 1:
4758
raise ValueError("Argument n_states must be at least 1.")
48-
if n_odeparams <= 0:
49-
raise ValueError("Argument n_odeparams must be positive.")
59+
if n_theta <= 0:
60+
raise ValueError("Argument n_theta must be positive.")
5061

5162
# Public
5263
self.func = func
5364
self.t0 = t0
5465
self.times = tuple(times)
66+
self.n_times = len(times)
5567
self.n_states = n_states
56-
self.n_odeparams = n_odeparams
68+
self.n_theta = n_theta
69+
self.n_p = n_states + n_theta
5770

5871
# Private
59-
self._n = n_states
60-
self._m = n_odeparams + n_states
61-
6272
self._augmented_times = np.insert(times, 0, t0)
63-
self._augmented_func = augment_system(func, self._n, self._m)
73+
self._augmented_func = augment_system(func, self.n_states, self.n_p)
6474
self._sens_ic = self._make_sens_ic()
6575

66-
self._cached_y = None
67-
self._cached_sens = None
68-
self._cached_parameters = None
69-
70-
self._grad_op = ODEGradop(self._numpy_vsp)
71-
76+
# Cache symbolic sensitivities by the hash of inputs
77+
self._apply_nodes = {}
78+
self._output_sensitivities = {}
79+
7280
def _make_sens_ic(self):
7381
"""
7482
The sensitivity matrix will always have consistent form.
75-
If the first n_odeparams entries of the parameters vector in the simulate call
76-
correspond to ode paramaters, then the first n_odeparams columns in
83+
If the first n_theta entries of the parameters vector in the simulate call
84+
correspond to ode paramaters, then the first n_theta columns in
7785
the sensitivity matrix will be 0
7886
7987
If the last n_states entries of the paramters vector in the simulate call
@@ -83,7 +91,7 @@ def _make_sens_ic(self):
8391
"""
8492

8593
# Initialize the sensitivity matrix to be 0 everywhere
86-
sens_matrix = np.zeros((self._n, self._m))
94+
sens_matrix = np.zeros((self.n_states, self.n_p))
8795

8896
# Slip in the identity matrix in the appropirate place
8997
sens_matrix[:, -self.n_states :] = np.eye(self.n_states)
@@ -95,89 +103,109 @@ def _make_sens_ic(self):
95103
return dydp
96104

97105
def _system(self, Y, t, p):
98-
"""This is the function that will be passed to odeint. Solves both ODE and sensitivities
99-
106+
"""This is the function that will be passed to odeint. Solves both ODE and sensitivities.
100107
"""
101-
102-
dydt, ddt_dydp = self._augmented_func(Y[: self._n], t, p, Y[self._n :])
108+
dydt, ddt_dydp = self._augmented_func(Y[:self.n_states], t, p, Y[self.n_states:])
103109
derivatives = np.concatenate([dydt, ddt_dydp])
104110
return derivatives
105111

106-
def _simulate(self, parameters):
107-
# Initial condition comprised of state initial conditions and raveled
108-
# sensitivity matrix
109-
y0 = np.concatenate([parameters[self.n_odeparams :], self._sens_ic])
112+
def _simulate(self, y0, theta):
113+
# Initial condition comprised of state initial conditions and raveled sensitivity matrix
114+
s0 = np.concatenate([y0, self._sens_ic])
110115

111116
# perform the integration
112117
sol = scipy.integrate.odeint(
113-
func=self._system, y0=y0, t=self._augmented_times, args=(parameters,)
118+
func=self._system, y0=s0, t=self._augmented_times, args=(np.concatenate([theta, y0]),)
114119
)
115120
# The solution
116-
y = sol[1:, : self.n_states]
121+
y = sol[1:, :self.n_states]
117122

118123
# The sensitivities, reshaped to be a sequence of matrices
119-
sens = sol[1:, self.n_states :].reshape(len(self.times), self._n, self._m)
124+
sens = sol[1:, self.n_states:].reshape(self.n_times, self.n_states, self.n_p)
120125

121126
return y, sens
122127

123-
def _cached_simulate(self, parameters):
124-
if np.array_equal(np.array(parameters), self._cached_parameters):
125-
126-
return self._cached_y, self._cached_sens
127-
128-
return self._simulate(np.array(parameters))
129-
130-
def _state(self, parameters):
131-
y, sens = self._cached_simulate(np.array(parameters))
132-
self._cached_y, self._cached_sens, self._cached_parameters = y, sens, parameters
133-
return y.ravel()
134-
135-
def _numpy_vsp(self, parameters, g):
136-
_, sens = self._cached_simulate(np.array(parameters))
137-
138-
# Each element of sens is an nxm sensitivity matrix
139-
# There is one sensitivity matrix per time step, making sens a (len(times), n_states, len(parameter))
140-
# dimensional array. Reshaping the sens array in this way is like stacking each of the elements of sens on top
141-
# of one another.
142-
numpy_sens = sens.reshape((self.n_states * len(self.times), len(parameters)))
143-
# The dot product here is equivalent to np.einsum('ijk,jk', sens, g)
144-
# if sens was not reshaped and if g had the same shape as yobs
145-
return numpy_sens.T.dot(g)
146-
147-
def make_node(self, odeparams, y0):
148-
if len(odeparams) != self.n_odeparams:
149-
raise ValueError(
150-
"odeparams has too many or too few parameters. Expected {a} parameter(s) but got {b}".format(
151-
a=self.n_odeparams, b=len(odeparams)
152-
)
153-
)
154-
if len(y0) != self.n_states:
155-
raise ValueError(
156-
"y0 has too many or too few parameters. Expected {a} parameter(s) but got {b}".format(
157-
a=self.n_states, b=len(y0)
158-
)
128+
def make_node(self, y0, theta):
129+
inputs = (y0, theta)
130+
_log.debug('make_node for inputs {}'.format(hash(inputs)))
131+
states = self._otypes[0]()
132+
sens = self._otypes[1]()
133+
134+
# store symbolic output in dictionary such that it can be accessed in the grad method
135+
self._output_sensitivities[hash(inputs)] = sens
136+
return theano.Apply(self, inputs, (states, sens))
137+
138+
def __call__(self, y0, theta, return_sens=False, **kwargs):
139+
# convert inputs to tensors (and check their types)
140+
y0 = tt.cast(tt.unbroadcast(tt.as_tensor_variable(y0), 0), theano.config.floatX)
141+
theta = tt.cast(tt.unbroadcast(tt.as_tensor_variable(theta), 0), theano.config.floatX)
142+
inputs = [y0, theta]
143+
for i, (input, itype) in enumerate(zip(inputs, self._itypes)):
144+
if not input.type == itype:
145+
raise ValueError('Input {} of type {} does not have the expected type of {}'.format(i, input.type, itype))
146+
147+
# use default implementation to prepare symbolic outputs (via make_node)
148+
states, sens = super(theano.Op, self).__call__(y0, theta, **kwargs)
149+
150+
if theano.config.compute_test_value != 'off':
151+
# compute test values from input test values
152+
test_states, test_sens = self._simulate(
153+
y0=self._get_test_value(y0),
154+
theta=self._get_test_value(theta)
159155
)
160156

161-
if np.ndim(odeparams) > 1:
162-
odeparams = np.ravel(odeparams)
163-
if np.ndim(y0) > 1:
164-
y0 = np.ravel(y0)
165-
166-
odeparams = tt.as_tensor_variable(odeparams)
167-
y0 = tt.as_tensor_variable(y0)
168-
parameters = tt.concatenate([odeparams, y0])
169-
return theano.Apply(self, [parameters], [parameters.type()])
157+
# check types of simulation result
158+
if not test_states.dtype == self._otypes[0].dtype:
159+
raise TypeError('Simulated states have the wrong type')
160+
if not test_sens.dtype == self._otypes[1].dtype:
161+
raise TypeError('Simulated sensitivities have the wrong type')
162+
163+
# check shapes of simulation result
164+
expected_states_shape = (self.n_times, self.n_states)
165+
expected_sens_shape = (self.n_times, self.n_states, self.n_p)
166+
if not test_states.shape == expected_states_shape:
167+
raise ShapeError('Simulated states have the wrong shape.', test_states.shape, expected_states_shape)
168+
if not test_sens.shape == expected_sens_shape:
169+
raise ShapeError('Simulated sensitivities have the wrong shape.', test_sens.shape, expected_sens_shape)
170+
171+
# attach results as test values to the outputs
172+
states.tag.test_value = test_states
173+
sens.tag.test_value = test_sens
174+
175+
if return_sens:
176+
return states, sens
177+
return states
170178

171179
def perform(self, node, inputs_storage, output_storage):
172-
parameters = inputs_storage[0]
173-
out = output_storage[0]
174-
# get the numerical solution of ODE states
175-
out[0] = self._state(parameters)
180+
y0, theta = inputs_storage[0], inputs_storage[1]
181+
# simulate states and sensitivities in one forward pass
182+
output_storage[0][0], output_storage[1][0] = self._simulate(y0, theta)
176183

177-
def grad(self, inputs, output_grads):
178-
x = inputs[0]
179-
g = output_grads[0]
180-
# pass the VSP when asked for gradient
181-
grad_op_apply = self._grad_op(x, g)
184+
def infer_shape(self, node, input_shapes):
185+
s_y0, s_theta = input_shapes
186+
output_shapes = [(self.n_times, self.n_states), (self.n_times, self.n_states, self.n_p)]
187+
return output_shapes
182188

183-
return [grad_op_apply]
189+
def grad(self, inputs, output_grads):
190+
_log.debug('grad w.r.t. inputs {}'.format(hash(tuple(inputs))))
191+
192+
# fetch symbolic sensitivity output node from cache
193+
ihash = hash(tuple(inputs))
194+
if ihash in self._output_sensitivities:
195+
sens = self._output_sensitivities[ihash]
196+
else:
197+
_log.debug('No cached sensitivities found!')
198+
_, sens = self.__call__(*inputs, return_sens=True)
199+
ograds = output_grads[0]
200+
201+
# for each parameter, multiply sensitivities with the output gradient and sum the result
202+
# sens is (n_times, n_states, n_p)
203+
# ograds is (n_times, n_states)
204+
grads = [
205+
tt.sum(sens[:,:,p] * ograds)
206+
for p in range(self.n_p)
207+
]
208+
209+
# return separate gradient tensors for y0 and theta inputs
210+
result = tt.stack(grads[:self.n_states]), tt.stack(grads[self.n_states:])
211+
return result

pymc3/ode/utils.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,41 +45,26 @@ def augment_system(ode_func, n, m):
4545

4646
dydp = dydp_vec.reshape((n, m))
4747

48-
# Stack the results of the ode_func
49-
f_tensor = tt.stack(ode_func(t_y, t_t, t_p))
48+
# Stack the results of the ode_func into a single tensor variable
49+
yhat = ode_func(t_y, t_t, t_p)
50+
if not isinstance(yhat, (list, tuple)):
51+
yhat = (yhat,)
52+
t_yhat = tt.stack(yhat, axis=0)
5053

5154
# Now compute gradients
52-
J = tt.jacobian(f_tensor, t_y)
55+
J = tt.jacobian(t_yhat, t_y)
5356

5457
Jdfdy = tt.dot(J, dydp)
5558

56-
grad_f = tt.jacobian(f_tensor, t_p)
59+
grad_f = tt.jacobian(t_yhat, t_p)
5760

5861
# This is the time derivative of dydp
5962
ddt_dydp = (Jdfdy + grad_f).flatten()
6063

6164
system = theano.function(
6265
inputs=[t_y, t_t, t_p, dydp_vec],
63-
outputs=[f_tensor, ddt_dydp],
66+
outputs=[t_yhat, ddt_dydp],
6467
on_unused_input="ignore",
6568
)
6669

6770
return system
68-
69-
70-
class ODEGradop(theano.Op):
71-
def __init__(self, numpy_vsp):
72-
self._numpy_vsp = numpy_vsp
73-
74-
def make_node(self, x, g):
75-
76-
x = theano.tensor.as_tensor_variable(x)
77-
g = theano.tensor.as_tensor_variable(g)
78-
node = theano.Apply(self, [x, g], [g.type()])
79-
return node
80-
81-
def perform(self, node, inputs_storage, output_storage):
82-
x = inputs_storage[0]
83-
g = inputs_storage[1]
84-
out = output_storage[0]
85-
out[0] = self._numpy_vsp(x, g) # get the numerical VSP

0 commit comments

Comments
 (0)