Skip to content

Commit d150d55

Browse files
mirko-mMirko Moellermichaelosthege
authored
Add tests for the output type of ode_func (#5414)
Co-authored-by: Mirko Moeller <[email protected]> Co-authored-by: Michael Osthege <[email protected]>
1 parent 069de73 commit d150d55

File tree

2 files changed

+103
-16
lines changed

2 files changed

+103
-16
lines changed

pymc/ode/utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,19 @@ def augment_system(ode_func, n_states, n_theta):
103103

104104
# Get symbolic representation of the ODEs by passing tensors for y, t and theta
105105
yhat = ode_func(t_y, t_t, t_p[n_states:])
106-
# Stack the results of the ode_func into a single tensor variable
107-
if not isinstance(yhat, (list, tuple)):
108-
yhat = (yhat,)
109-
t_yhat = at.stack(yhat, axis=0)
106+
if isinstance(yhat, at.TensorVariable):
107+
t_yhat = at.atleast_1d(yhat)
108+
else:
109+
# Stack the results of the ode_func into a single tensor variable
110+
if not isinstance(yhat, (list, tuple)):
111+
raise TypeError(
112+
f"Unexpected type, {type(yhat)}, returned by ode_func. TensorVariable, list or tuple is expected."
113+
)
114+
t_yhat = at.stack(yhat, axis=0)
115+
if t_yhat.ndim > 1:
116+
raise ValueError(
117+
f"The odefunc returned a {t_yhat.ndim}-dimensional tensor, but 0 or 1 dimensions were expected."
118+
)
110119

111120
# Now compute gradients
112121
J = at.jacobian(t_yhat, t_y)

pymc/tests/test_ode.py

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import sys
1616

1717
import aesara
18+
import aesara.tensor as at
1819
import numpy as np
1920
import pytest
2021

@@ -154,6 +155,24 @@ def ode_func_4(y, t, p):
154155

155156
np.testing.assert_array_equal(model4_sens_ic, model4._sens_ic)
156157

158+
def test_sens_ic_vector_2_param_tensor(self):
159+
# Vector ODE 2 Param with return type at.TensorVariable
160+
def ode_func_4_t(y, t, p):
161+
# Make sure that ds and di are vectors by slicing
162+
ds = -p[0:1] * y[0:1] * y[1:]
163+
di = p[0:1] * y[0:1] * y[1:] - p[1:] * y[1:]
164+
165+
return at.concatenate([ds, di], axis=0)
166+
167+
# Instantiate ODE model
168+
model4_t = DifferentialEquation(
169+
func=ode_func_4_t, t0=0, times=self.t, n_states=2, n_theta=2
170+
)
171+
172+
model4_sens_ic_t = np.array([1, 0, 0, 0, 0, 1, 0, 0])
173+
174+
np.testing.assert_array_equal(model4_sens_ic_t, model4_t._sens_ic)
175+
157176
def test_sens_ic_vector_3_params(self):
158177
# Big System with Many Parameters
159178
def ode_func_5(y, t, p):
@@ -209,45 +228,104 @@ def system_1(y, t, p):
209228
class TestErrors:
210229
"""Test running model for a scalar ODE with 1 parameter"""
211230

212-
def system(y, t, p):
213-
return np.exp(-t) - p[0] * y[0]
214-
215-
times = np.arange(0, 9)
231+
def setup_method(self, method):
232+
def system(y, t, p):
233+
return np.exp(-t) - p[0] * y[0]
216234

217-
ode_model = DifferentialEquation(func=system, t0=0, times=times, n_states=1, n_theta=1)
235+
self.system = system
236+
self.times = np.arange(0, 9)
237+
self.ode_model = DifferentialEquation(
238+
func=system, t0=0, times=self.times, n_states=1, n_theta=1
239+
)
218240

219241
@pytest.mark.xfail(condition=(IS_FLOAT32 and IS_WINDOWS), reason="Fails on float32 on Windows")
220242
def test_too_many_params(self):
221-
with pytest.raises(pm.ShapeError):
243+
with pytest.raises(
244+
pm.ShapeError,
245+
match="Length of theta is wrong. \\(actual \\(2,\\) != expected \\(1,\\)\\)",
246+
):
222247
self.ode_model(theta=[1, 1], y0=[0])
223248

224249
@pytest.mark.xfail(condition=(IS_FLOAT32 and IS_WINDOWS), reason="Fails on float32 on Windows")
225250
def test_too_many_y0(self):
226-
with pytest.raises(pm.ShapeError):
251+
with pytest.raises(
252+
pm.ShapeError, match="Length of y0 is wrong. \\(actual \\(2,\\) != expected \\(1,\\)\\)"
253+
):
227254
self.ode_model(theta=[1], y0=[0, 0])
228255

229256
@pytest.mark.xfail(condition=(IS_FLOAT32 and IS_WINDOWS), reason="Fails on float32 on Windows")
230257
def test_too_few_params(self):
231-
with pytest.raises(pm.ShapeError):
258+
with pytest.raises(
259+
pm.ShapeError,
260+
match="Length of theta is wrong. \\(actual \\(0,\\) != expected \\(1,\\)\\)",
261+
):
232262
self.ode_model(theta=[], y0=[1])
233263

234264
@pytest.mark.xfail(condition=(IS_FLOAT32 and IS_WINDOWS), reason="Fails on float32 on Windows")
235265
def test_too_few_y0(self):
236-
with pytest.raises(pm.ShapeError):
266+
with pytest.raises(
267+
pm.ShapeError, match="Length of y0 is wrong. \\(actual \\(0,\\) != expected \\(1,\\)\\)"
268+
):
237269
self.ode_model(theta=[1], y0=[])
238270

239271
def test_func_callable(self):
240-
with pytest.raises(ValueError):
272+
with pytest.raises(ValueError, match="Argument func must be callable."):
241273
DifferentialEquation(func=1, t0=0, times=self.times, n_states=1, n_theta=1)
242274

243275
def test_number_of_states(self):
244-
with pytest.raises(ValueError):
276+
with pytest.raises(ValueError, match="Argument n_states must be at least 1."):
245277
DifferentialEquation(func=self.system, t0=0, times=self.times, n_states=0, n_theta=1)
246278

247279
def test_number_of_params(self):
248-
with pytest.raises(ValueError):
280+
with pytest.raises(ValueError, match="Argument n_theta must be positive"):
249281
DifferentialEquation(func=self.system, t0=0, times=self.times, n_states=1, n_theta=0)
250282

283+
def test_tensor_shape(self):
284+
with pytest.raises(ValueError, match="returned a 2-dimensional tensor"):
285+
286+
def system_2d_tensor(y, t, p):
287+
s0 = np.exp(-t) - p[0] * y[0]
288+
s1 = np.exp(-t) - p[0] * y[1]
289+
s2 = np.exp(-t) - p[0] * y[2]
290+
s3 = np.exp(-t) - p[0] * y[3]
291+
return at.stack((s0, s1, s2, s3)).reshape((2, 2))
292+
293+
DifferentialEquation(
294+
func=system_2d_tensor, t0=0, times=self.times, n_states=4, n_theta=1
295+
)
296+
297+
def test_list_shape(self):
298+
with pytest.raises(ValueError, match="returned a 2-dimensional tensor"):
299+
300+
def system_2d_list(y, t, p):
301+
s0 = np.exp(-t) - p[0] * y[0]
302+
s1 = np.exp(-t) - p[0] * y[1]
303+
s2 = np.exp(-t) - p[0] * y[2]
304+
s3 = np.exp(-t) - p[0] * y[3]
305+
return [[s0, s1], [s2, s3]]
306+
307+
DifferentialEquation(func=system_2d_list, t0=0, times=self.times, n_states=4, n_theta=1)
308+
309+
def test_unexpected_return_type_set(self):
310+
with pytest.raises(
311+
TypeError, match="Unexpected type, <class 'set'>, returned by ode_func."
312+
):
313+
314+
def system_set(y, t, p):
315+
return {np.exp(-t) - p[0] * y[0]}
316+
317+
DifferentialEquation(func=system_set, t0=0, times=self.times, n_states=4, n_theta=1)
318+
319+
def test_unexpected_return_type_dict(self):
320+
with pytest.raises(
321+
TypeError, match="Unexpected type, <class 'dict'>, returned by ode_func."
322+
):
323+
324+
def system_dict(y, t, p):
325+
return {"rhs": np.exp(-t) - p[0] * y[0]}
326+
327+
DifferentialEquation(func=system_dict, t0=0, times=self.times, n_states=4, n_theta=1)
328+
251329

252330
class TestDiffEqModel:
253331
def test_op_equality(self):

0 commit comments

Comments
 (0)