Skip to content

Commit 404d4e9

Browse files
committed
Avoid recompiling [d]logp in every step of find_MAP
This also fixes a bug where graphs with NotImplemented gradients were not detected before deciding on optimization method
1 parent 3dc7d42 commit 404d4e9

File tree

2 files changed

+5
-14
lines changed

2 files changed

+5
-14
lines changed

pymc/tests/test_starting.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_accuracy_non_normal():
4646
close_to(newstart["x"], mu, select_by_precision(float64=1e-5, float32=1e-4))
4747

4848

49-
@pytest.mark.xfail(reason="find_MAP fails with derivatives")
49+
@pytest.mark.xfail(reason="first call to find_MAP is failing")
5050
def test_find_MAP_discrete():
5151
tol = 2.0 ** -11
5252
alpha = 4
@@ -68,15 +68,12 @@ def test_find_MAP_discrete():
6868
assert map_est2["ss"] == 14
6969

7070

71-
@pytest.mark.xfail(reason="find_MAP fails with derivatives")
7271
def test_find_MAP_no_gradient():
7372
_, model = simple_arbitrary_det()
7473
with model:
7574
find_MAP()
7675

7776

78-
@pytest.mark.skip(reason="test is slow because it's failing")
79-
@pytest.mark.xfail(reason="find_MAP fails with derivatives")
8077
def test_find_MAP():
8178
tol = 2.0 ** -11 # 16 bit machine epsilon, a low bar
8279
data = np.random.randn(100)

pymc/tuning/starting.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,15 @@ def find_MAP(
113113

114114
# TODO: If the mapping is fixed, we can simply create graphs for the
115115
# mapping and avoid all this bijection overhead
116-
def logp_func(x):
117-
return DictToArrayBijection.mapf(model.compile_logp(jacobian=False))(
118-
RaveledVars(x, x0.point_map_info)
119-
)
116+
compiled_logp_func = DictToArrayBijection.mapf(model.compile_logp(jacobian=False))
117+
logp_func = lambda x: compiled_logp_func(RaveledVars(x, x0.point_map_info))
120118

121119
rvs = [model.values_to_rvs[value] for value in vars]
122120
try:
123121
# This might be needed for calls to `dlogp_func`
124122
# start_map_info = tuple((v.name, v.shape, v.dtype) for v in vars)
125-
126-
def dlogp_func(x):
127-
return DictToArrayBijection.mapf(model.compile_dlogp(rvs, jacobian=False))(
128-
RaveledVars(x, x0.point_map_info)
129-
)
130-
123+
compiled_dlogp_func = DictToArrayBijection.mapf(model.compile_dlogp(rvs, jacobian=False))
124+
dlogp_func = lambda x: compiled_dlogp_func(RaveledVars(x, x0.point_map_info))
131125
compute_gradient = True
132126
except (AttributeError, NotImplementedError, tg.NullTypeGradError):
133127
compute_gradient = False

0 commit comments

Comments
 (0)