Skip to content

Commit 9c68059

Browse files
committed
Split up test_ode.py
1 parent 2cd6a25 commit 9c68059

File tree

4 files changed

+60
-43
lines changed

4 files changed

+60
-43
lines changed

.github/workflows/tests.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ jobs:
8080
pymc/tests/gp/test_util.py
8181
pymc/tests/test_model.py
8282
pymc/tests/test_model_graph.py
83-
pymc/tests/test_ode.py
83+
pymc/tests/ode/test_ode.py
84+
pymc/tests/ode/test_utils.py
8485
pymc/tests/test_profile.py
8586
pymc/tests/test_quadpotential.py
8687
@@ -151,7 +152,7 @@ jobs:
151152
test-subset:
152153
- pymc/tests/test_variational_inference.py pymc/tests/test_initial_point.py
153154
- pymc/tests/test_pickling.py pymc/tests/test_profile.py pymc/tests/test_step.py
154-
- pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/test_ode.py pymc/tests/test_smc.py pymc/tests/test_parallel_sampling.py
155+
- pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py pymc/tests/test_smc.py pymc/tests/test_parallel_sampling.py
155156
- pymc/tests/test_sampling.py pymc/tests/test_posteriors.py
156157

157158
fail-fast: false
@@ -364,7 +365,7 @@ jobs:
364365
floatx: [float32]
365366
python-version: ["3.10"]
366367
test-subset:
367-
- pymc/tests/test_sampling.py pymc/tests/test_ode.py
368+
- pymc/tests/test_sampling.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py
368369
fail-fast: false
369370
runs-on: ${{ matrix.os }}
370371
env:

pymc/tests/ode/__init__.py

Whitespace-only changes.

pymc/tests/test_ode.py renamed to pymc/tests/ode/test_ode.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,54 +18,14 @@
1818
import numpy as np
1919
import pytest
2020

21-
from scipy.integrate import odeint
2221
from scipy.stats import norm
2322

2423
import pymc as pm
2524

2625
from pymc.ode import DifferentialEquation
27-
from pymc.ode.utils import augment_system
2826
from pymc.tests.helpers import fast_unstable_sampling_mode
2927

3028

31-
def test_gradients():
32-
"""Tests the computation of the sensitivities from the Aesara computation graph"""
33-
34-
# ODE system for which to compute gradients
35-
def ode_func(y, t, p):
36-
return np.exp(-t) - p[0] * y[0]
37-
38-
# Computation of graidients with Aesara
39-
augmented_ode_func = augment_system(ode_func, 1, 1 + 1)
40-
41-
# This is the new system, ODE + Sensitivities, which will be integrated
42-
def augmented_system(Y, t, p):
43-
dydt, ddt_dydp = augmented_ode_func(Y[:1], t, p, Y[1:])
44-
derivatives = np.concatenate([dydt, ddt_dydp])
45-
return derivatives
46-
47-
# Create real sensitivities
48-
y0 = 0.0
49-
t = np.arange(0, 12, 0.25).reshape(-1, 1)
50-
a = 0.472
51-
p = np.array([y0, a])
52-
53-
# Derivatives of the analytic solution with respect to y0 and alpha
54-
# Treat y0 like a parameter and solve analytically. Then differentiate.
55-
# I used CAS to get these derivatives
56-
y0_sensitivity = np.exp(-a * t)
57-
a_sensitivity = (
58-
-(np.exp(t * (a - 1)) - 1 + (a - 1) * (y0 * a - y0 - 1) * t) * np.exp(-a * t) / (a - 1) ** 2
59-
)
60-
61-
sensitivity = np.c_[y0_sensitivity, a_sensitivity]
62-
63-
integrated_solutions = odeint(func=augmented_system, y0=[y0, 1, 0], t=t.ravel(), args=(p,))
64-
simulated_sensitivity = integrated_solutions[:, 1:]
65-
66-
np.testing.assert_allclose(sensitivity, simulated_sensitivity, rtol=1e-5)
67-
68-
6929
def test_simulate():
7030
"""Tests the integration in DifferentialEquation"""
7131

pymc/tests/ode/test_utils.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2020 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import scipy.integrate as ode
17+
18+
from pymc.ode.utils import augment_system
19+
20+
21+
def test_gradients():
22+
"""Tests the computation of the sensitivities from the Aesara computation graph"""
23+
24+
# ODE system for which to compute gradients
25+
def ode_func(y, t, p):
26+
return np.exp(-t) - p[0] * y[0]
27+
28+
# Computation of graidients with Aesara
29+
augmented_ode_func = augment_system(ode_func, 1, 1 + 1)
30+
31+
# This is the new system, ODE + Sensitivities, which will be integrated
32+
def augmented_system(Y, t, p):
33+
dydt, ddt_dydp = augmented_ode_func(Y[:1], t, p, Y[1:])
34+
derivatives = np.concatenate([dydt, ddt_dydp])
35+
return derivatives
36+
37+
# Create real sensitivities
38+
y0 = 0.0
39+
t = np.arange(0, 12, 0.25).reshape(-1, 1)
40+
a = 0.472
41+
p = np.array([y0, a])
42+
43+
# Derivatives of the analytic solution with respect to y0 and alpha
44+
# Treat y0 like a parameter and solve analytically. Then differentiate.
45+
# I used CAS to get these derivatives
46+
y0_sensitivity = np.exp(-a * t)
47+
a_sensitivity = (
48+
-(np.exp(t * (a - 1)) - 1 + (a - 1) * (y0 * a - y0 - 1) * t) * np.exp(-a * t) / (a - 1) ** 2
49+
)
50+
51+
sensitivity = np.c_[y0_sensitivity, a_sensitivity]
52+
53+
integrated_solutions = ode.odeint(func=augmented_system, y0=[y0, 1, 0], t=t.ravel(), args=(p,))
54+
simulated_sensitivity = integrated_solutions[:, 1:]
55+
56+
np.testing.assert_allclose(sensitivity, simulated_sensitivity, rtol=1e-5)

0 commit comments

Comments
 (0)