1
+ import logging
1
2
import numpy as np
2
3
import scipy
3
4
import theano
4
5
import theano .tensor as tt
5
- from ..ode .utils import augment_system , ODEGradop
6
+ from ..ode .utils import augment_system
7
+ from ..exceptions import ShapeError
8
+
9
+ _log = logging .getLogger ('pymc3' )
6
10
7
11
8
12
class DifferentialEquation (theano .Op ):
@@ -17,16 +21,16 @@ class DifferentialEquation(theano.Op):
17
21
18
22
func : callable
19
23
Function specifying the differential equation
20
- t0 : float
21
- Time corresponding to the initial condition
22
24
times : array
23
25
Array of times at which to evaluate the solution of the differential equation.
24
26
n_states : int
25
27
Dimension of the differential equation. For scalar differential equations, n_states=1.
26
28
For vector valued differential equations, n_states = number of differential equations in the system.
27
- n_odeparams : int
29
+ n_theta : int
28
30
Number of parameters in the differential equation.
29
-
31
+ t0 : float
32
+ Time corresponding to the initial condition
33
+
30
34
.. code-block:: python
31
35
32
36
def odefunc(y, t, p):
@@ -35,45 +39,49 @@ def odefunc(y, t, p):
35
39
36
40
times = np.arange(0.5, 5, 0.5)
37
41
38
- ode_model = DifferentialEquation(func=odefunc, t0=0, times=times, n_states=1, n_odeparams=1 )
42
+ ode_model = DifferentialEquation(func=odefunc, times=times, n_states=1, n_theta=1, t0=0 )
39
43
"""
40
-
41
- __props__ = ("func" , "t0" , "times" , "n_states" , "n_odeparams" )
42
-
43
- def __init__ (self , func , times , n_states , n_odeparams , t0 = 0 ):
44
+ _itypes = [
45
+ tt .TensorType (theano .config .floatX , (False ,)), # y0 as 1D floatX vector
46
+ tt .TensorType (theano .config .floatX , (False ,)) # theta as 1D floatX vector
47
+ ]
48
+ _otypes = [
49
+ tt .TensorType (theano .config .floatX , (False , False )), # model states as floatX of shape (T, S)
50
+ tt .TensorType (theano .config .floatX , (False , False , False )), # sensitivities as floatX of shape (T, S, len(y0) + len(theta))
51
+ ]
52
+ __props__ = ("func" , "times" , "n_states" , "n_theta" , "t0" )
53
+
54
+ def __init__ (self , func , times , * , n_states , n_theta , t0 = 0 ):
44
55
if not callable (func ):
45
56
raise ValueError ("Argument func must be callable." )
46
57
if n_states < 1 :
47
58
raise ValueError ("Argument n_states must be at least 1." )
48
- if n_odeparams <= 0 :
49
- raise ValueError ("Argument n_odeparams must be positive." )
59
+ if n_theta <= 0 :
60
+ raise ValueError ("Argument n_theta must be positive." )
50
61
51
62
# Public
52
63
self .func = func
53
64
self .t0 = t0
54
65
self .times = tuple (times )
66
+ self .n_times = len (times )
55
67
self .n_states = n_states
56
- self .n_odeparams = n_odeparams
68
+ self .n_theta = n_theta
69
+ self .n_p = n_states + n_theta
57
70
58
71
# Private
59
- self ._n = n_states
60
- self ._m = n_odeparams + n_states
61
-
62
72
self ._augmented_times = np .insert (times , 0 , t0 )
63
- self ._augmented_func = augment_system (func , self ._n , self ._m )
73
+ self ._augmented_func = augment_system (func , self .n_states , self .n_p )
64
74
self ._sens_ic = self ._make_sens_ic ()
65
75
66
- self ._cached_y = None
67
- self ._cached_sens = None
68
- self ._cached_parameters = None
69
-
70
- self ._grad_op = ODEGradop (self ._numpy_vsp )
71
-
76
+ # Cache symbolic sensitivities by the hash of inputs
77
+ self ._apply_nodes = {}
78
+ self ._output_sensitivities = {}
79
+
72
80
def _make_sens_ic (self ):
73
81
"""
74
82
The sensitivity matrix will always have consistent form.
75
- If the first n_odeparams entries of the parameters vector in the simulate call
76
- correspond to ode paramaters, then the first n_odeparams columns in
83
+ If the first n_theta entries of the parameters vector in the simulate call
84
+ correspond to ode paramaters, then the first n_theta columns in
77
85
the sensitivity matrix will be 0
78
86
79
87
If the last n_states entries of the paramters vector in the simulate call
@@ -83,7 +91,7 @@ def _make_sens_ic(self):
83
91
"""
84
92
85
93
# Initialize the sensitivity matrix to be 0 everywhere
86
- sens_matrix = np .zeros ((self ._n , self ._m ))
94
+ sens_matrix = np .zeros ((self .n_states , self .n_p ))
87
95
88
96
# Slip in the identity matrix in the appropirate place
89
97
sens_matrix [:, - self .n_states :] = np .eye (self .n_states )
@@ -95,89 +103,109 @@ def _make_sens_ic(self):
95
103
return dydp
96
104
97
105
def _system (self , Y , t , p ):
98
- """This is the function that will be passed to odeint. Solves both ODE and sensitivities
99
-
106
+ """This is the function that will be passed to odeint. Solves both ODE and sensitivities.
100
107
"""
101
-
102
- dydt , ddt_dydp = self ._augmented_func (Y [: self ._n ], t , p , Y [self ._n :])
108
+ dydt , ddt_dydp = self ._augmented_func (Y [:self .n_states ], t , p , Y [self .n_states :])
103
109
derivatives = np .concatenate ([dydt , ddt_dydp ])
104
110
return derivatives
105
111
106
- def _simulate (self , parameters ):
107
- # Initial condition comprised of state initial conditions and raveled
108
- # sensitivity matrix
109
- y0 = np .concatenate ([parameters [self .n_odeparams :], self ._sens_ic ])
112
+ def _simulate (self , y0 , theta ):
113
+ # Initial condition comprised of state initial conditions and raveled sensitivity matrix
114
+ s0 = np .concatenate ([y0 , self ._sens_ic ])
110
115
111
116
# perform the integration
112
117
sol = scipy .integrate .odeint (
113
- func = self ._system , y0 = y0 , t = self ._augmented_times , args = (parameters ,)
118
+ func = self ._system , y0 = s0 , t = self ._augmented_times , args = (np . concatenate ([ theta , y0 ]) ,)
114
119
)
115
120
# The solution
116
- y = sol [1 :, : self .n_states ]
121
+ y = sol [1 :, :self .n_states ]
117
122
118
123
# The sensitivities, reshaped to be a sequence of matrices
119
- sens = sol [1 :, self .n_states :].reshape (len ( self .times ) , self ._n , self ._m )
124
+ sens = sol [1 :, self .n_states :].reshape (self .n_times , self .n_states , self .n_p )
120
125
121
126
return y , sens
122
127
123
- def _cached_simulate (self , parameters ):
124
- if np .array_equal (np .array (parameters ), self ._cached_parameters ):
125
-
126
- return self ._cached_y , self ._cached_sens
127
-
128
- return self ._simulate (np .array (parameters ))
129
-
130
- def _state (self , parameters ):
131
- y , sens = self ._cached_simulate (np .array (parameters ))
132
- self ._cached_y , self ._cached_sens , self ._cached_parameters = y , sens , parameters
133
- return y .ravel ()
134
-
135
- def _numpy_vsp (self , parameters , g ):
136
- _ , sens = self ._cached_simulate (np .array (parameters ))
137
-
138
- # Each element of sens is an nxm sensitivity matrix
139
- # There is one sensitivity matrix per time step, making sens a (len(times), n_states, len(parameter))
140
- # dimensional array. Reshaping the sens array in this way is like stacking each of the elements of sens on top
141
- # of one another.
142
- numpy_sens = sens .reshape ((self .n_states * len (self .times ), len (parameters )))
143
- # The dot product here is equivalent to np.einsum('ijk,jk', sens, g)
144
- # if sens was not reshaped and if g had the same shape as yobs
145
- return numpy_sens .T .dot (g )
146
-
147
- def make_node (self , odeparams , y0 ):
148
- if len (odeparams ) != self .n_odeparams :
149
- raise ValueError (
150
- "odeparams has too many or too few parameters. Expected {a} parameter(s) but got {b}" .format (
151
- a = self .n_odeparams , b = len (odeparams )
152
- )
153
- )
154
- if len (y0 ) != self .n_states :
155
- raise ValueError (
156
- "y0 has too many or too few parameters. Expected {a} parameter(s) but got {b}" .format (
157
- a = self .n_states , b = len (y0 )
158
- )
128
+ def make_node (self , y0 , theta ):
129
+ inputs = (y0 , theta )
130
+ _log .debug ('make_node for inputs {}' .format (hash (inputs )))
131
+ states = self ._otypes [0 ]()
132
+ sens = self ._otypes [1 ]()
133
+
134
+ # store symbolic output in dictionary such that it can be accessed in the grad method
135
+ self ._output_sensitivities [hash (inputs )] = sens
136
+ return theano .Apply (self , inputs , (states , sens ))
137
+
138
+ def __call__ (self , y0 , theta , return_sens = False , ** kwargs ):
139
+ # convert inputs to tensors (and check their types)
140
+ y0 = tt .cast (tt .unbroadcast (tt .as_tensor_variable (y0 ), 0 ), theano .config .floatX )
141
+ theta = tt .cast (tt .unbroadcast (tt .as_tensor_variable (theta ), 0 ), theano .config .floatX )
142
+ inputs = [y0 , theta ]
143
+ for i , (input , itype ) in enumerate (zip (inputs , self ._itypes )):
144
+ if not input .type == itype :
145
+ raise ValueError ('Input {} of type {} does not have the expected type of {}' .format (i , input .type , itype ))
146
+
147
+ # use default implementation to prepare symbolic outputs (via make_node)
148
+ states , sens = super (theano .Op , self ).__call__ (y0 , theta , ** kwargs )
149
+
150
+ if theano .config .compute_test_value != 'off' :
151
+ # compute test values from input test values
152
+ test_states , test_sens = self ._simulate (
153
+ y0 = self ._get_test_value (y0 ),
154
+ theta = self ._get_test_value (theta )
159
155
)
160
156
161
- if np .ndim (odeparams ) > 1 :
162
- odeparams = np .ravel (odeparams )
163
- if np .ndim (y0 ) > 1 :
164
- y0 = np .ravel (y0 )
165
-
166
- odeparams = tt .as_tensor_variable (odeparams )
167
- y0 = tt .as_tensor_variable (y0 )
168
- parameters = tt .concatenate ([odeparams , y0 ])
169
- return theano .Apply (self , [parameters ], [parameters .type ()])
157
+ # check types of simulation result
158
+ if not test_states .dtype == self ._otypes [0 ].dtype :
159
+ raise TypeError ('Simulated states have the wrong type' )
160
+ if not test_sens .dtype == self ._otypes [1 ].dtype :
161
+ raise TypeError ('Simulated sensitivities have the wrong type' )
162
+
163
+ # check shapes of simulation result
164
+ expected_states_shape = (self .n_times , self .n_states )
165
+ expected_sens_shape = (self .n_times , self .n_states , self .n_p )
166
+ if not test_states .shape == expected_states_shape :
167
+ raise ShapeError ('Simulated states have the wrong shape.' , test_states .shape , expected_states_shape )
168
+ if not test_sens .shape == expected_sens_shape :
169
+ raise ShapeError ('Simulated sensitivities have the wrong shape.' , test_sens .shape , expected_sens_shape )
170
+
171
+ # attach results as test values to the outputs
172
+ states .tag .test_value = test_states
173
+ sens .tag .test_value = test_sens
174
+
175
+ if return_sens :
176
+ return states , sens
177
+ return states
170
178
171
179
def perform (self , node , inputs_storage , output_storage ):
172
- parameters = inputs_storage [0 ]
173
- out = output_storage [0 ]
174
- # get the numerical solution of ODE states
175
- out [0 ] = self ._state (parameters )
180
+ y0 , theta = inputs_storage [0 ], inputs_storage [1 ]
181
+ # simulate states and sensitivities in one forward pass
182
+ output_storage [0 ][0 ], output_storage [1 ][0 ] = self ._simulate (y0 , theta )
176
183
177
- def grad (self , inputs , output_grads ):
178
- x = inputs [0 ]
179
- g = output_grads [0 ]
180
- # pass the VSP when asked for gradient
181
- grad_op_apply = self ._grad_op (x , g )
184
+ def infer_shape (self , node , input_shapes ):
185
+ s_y0 , s_theta = input_shapes
186
+ output_shapes = [(self .n_times , self .n_states ), (self .n_times , self .n_states , self .n_p )]
187
+ return output_shapes
182
188
183
- return [grad_op_apply ]
189
+ def grad (self , inputs , output_grads ):
190
+ _log .debug ('grad w.r.t. inputs {}' .format (hash (tuple (inputs ))))
191
+
192
+ # fetch symbolic sensitivity output node from cache
193
+ ihash = hash (tuple (inputs ))
194
+ if ihash in self ._output_sensitivities :
195
+ sens = self ._output_sensitivities [ihash ]
196
+ else :
197
+ _log .debug ('No cached sensitivities found!' )
198
+ _ , sens = self .__call__ (* inputs , return_sens = True )
199
+ ograds = output_grads [0 ]
200
+
201
+ # for each parameter, multiply sensitivities with the output gradient and sum the result
202
+ # sens is (n_times, n_states, n_p)
203
+ # ograds is (n_times, n_states)
204
+ grads = [
205
+ tt .sum (sens [:,:,p ] * ograds )
206
+ for p in range (self .n_p )
207
+ ]
208
+
209
+ # return separate gradient tensors for y0 and theta inputs
210
+ result = tt .stack (grads [:self .n_states ]), tt .stack (grads [self .n_states :])
211
+ return result
0 commit comments