Skip to content

Commit 42c0b2a

Browse files
aphc14fonnesbeck
authored andcommitted
Fix LBFGS iteration conditions and status handling (#461)
- Added entry_condition_met method that combines the logic for determining if LBFGS iterations should be stored. - Added LBFGS statuses: LOW_UPDATE_MASK_RATIO to inform user when majority of LBFGS iters aren't being used. - Added LBFGS statuses: INIT_FAILED_LOW_UPDATE_MASK to inform user when LBFGS failed to initialised due to failing to meet the update conditions. - Added status messages for LOW_UPDATE_MASK_RATIO and INIT_FAILED_LOW_UPDATE_MASK. - Renamed LBFGSStatus.DIVERGED to LBFGSStatus.NON_FINITE. - Implemented a test for unstable LBFGS update mask scenarios, adding robustness against rejected iterations.
1 parent 5d0908c commit 42c0b2a

File tree

3 files changed

+138
-20
lines changed

3 files changed

+138
-20
lines changed

pymc_extras/inference/pathfinder/lbfgs.py

+44-11
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __post_init__(self) -> None:
5252
self.count = 0
5353

5454
value, grad = self.value_grad_fn(self.x0)
55-
if np.all(np.isfinite(grad)) and np.isfinite(value):
55+
if self.entry_condition_met(self.x0, value, grad):
5656
self.add_entry(self.x0, grad)
5757

5858
def add_entry(self, x: NDArray[np.float64], g: NDArray[np.float64]) -> None:
@@ -75,18 +75,40 @@ def get_history(self) -> LBFGSHistory:
7575
x=self.x_history[: self.count], g=self.g_history[: self.count], count=self.count
7676
)
7777

78+
def entry_condition_met(self, x, value, grad) -> bool:
79+
"""Checks if the LBFGS iteration should continue."""
80+
81+
if np.all(np.isfinite(grad)) and np.isfinite(value) and (self.count < self.maxiter + 1):
82+
if self.count == 0:
83+
return True
84+
else:
85+
s = x - self.x_history[self.count - 1]
86+
z = grad - self.g_history[self.count - 1]
87+
sz = (s * z).sum(axis=-1)
88+
epsilon = 1e-8
89+
update_mask = sz > epsilon * np.sqrt(np.sum(z**2, axis=-1))
90+
91+
if update_mask:
92+
return True
93+
else:
94+
return False
95+
else:
96+
return False
97+
7898
def __call__(self, x: NDArray[np.float64]) -> None:
7999
value, grad = self.value_grad_fn(x)
80-
if np.all(np.isfinite(grad)) and np.isfinite(value) and self.count < self.maxiter + 1:
100+
if self.entry_condition_met(x, value, grad):
81101
self.add_entry(x, grad)
82102

83103

84104
class LBFGSStatus(Enum):
85105
CONVERGED = auto()
86106
MAX_ITER_REACHED = auto()
87-
DIVERGED = auto()
107+
NON_FINITE = auto()
108+
LOW_UPDATE_MASK_RATIO = auto()
88109
# Statuses that lead to Exceptions:
89110
INIT_FAILED = auto()
111+
INIT_FAILED_LOW_UPDATE_MASK = auto()
90112
LBFGS_FAILED = auto()
91113

92114

@@ -101,8 +123,8 @@ def __init__(self, message=None, status: LBFGSStatus = LBFGSStatus.LBFGS_FAILED)
101123
class LBFGSInitFailed(LBFGSException):
102124
DEFAULT_MESSAGE = "LBFGS failed to initialise."
103125

104-
def __init__(self, message=None):
105-
super().__init__(message or self.DEFAULT_MESSAGE, LBFGSStatus.INIT_FAILED)
126+
def __init__(self, status: LBFGSStatus, message=None):
127+
super().__init__(message or self.DEFAULT_MESSAGE, status)
106128

107129

108130
class LBFGS:
@@ -177,13 +199,24 @@ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
177199
history = history_manager.get_history()
178200

179201
# warnings and suggestions for LBFGSStatus are displayed at the end
180-
if result.status == 1:
181-
lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
182-
elif (result.status == 2) or (history.count <= 1):
183-
if result.nit <= 1:
202+
# threshold determining if the number of update mask is low compared to iterations
203+
low_update_threshold = 3
204+
205+
logging.warning(f"LBFGS status: {result} \n nit: {result.nit} \n count: {history.count}")
206+
207+
if history.count <= 1: # triggers LBFGSInitFailed
208+
if result.nit < low_update_threshold:
184209
lbfgs_status = LBFGSStatus.INIT_FAILED
185-
elif result.fun == np.inf:
186-
lbfgs_status = LBFGSStatus.DIVERGED
210+
else:
211+
lbfgs_status = LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK
212+
elif result.status == 1:
213+
# (result.nit > maxiter) or (result.nit > maxls)
214+
lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
215+
elif result.status == 2:
216+
# precision loss resulting to inf or nan
217+
lbfgs_status = LBFGSStatus.NON_FINITE
218+
elif history.count < low_update_threshold * result.nit:
219+
lbfgs_status = LBFGSStatus.LOW_UPDATE_MASK_RATIO
187220
else:
188221
lbfgs_status = LBFGSStatus.CONVERGED
189222

pymc_extras/inference/pathfinder/pathfinder.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -957,8 +957,8 @@ def single_pathfinder_fn(random_seed: int) -> PathfinderResult:
957957
x0 = x_base + jitter_value
958958
x, g, lbfgs_niter, lbfgs_status = lbfgs.minimize(x0)
959959

960-
if lbfgs_status == LBFGSStatus.INIT_FAILED:
961-
raise LBFGSInitFailed()
960+
if lbfgs_status in {LBFGSStatus.INIT_FAILED, LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK}:
961+
raise LBFGSInitFailed(lbfgs_status)
962962
elif lbfgs_status == LBFGSStatus.LBFGS_FAILED:
963963
raise LBFGSException()
964964

@@ -1396,15 +1396,16 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:
13961396
warnings = []
13971397

13981398
lbfgs_status_message = {
1399-
LBFGSStatus.MAX_ITER_REACHED: "LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.",
1400-
LBFGSStatus.INIT_FAILED: "LBFGS failed to initialise. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.",
1401-
LBFGSStatus.DIVERGED: "LBFGS diverged to infinity. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1399+
LBFGSStatus.MAX_ITER_REACHED: "MAX_ITER_REACHED: LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.",
1400+
LBFGSStatus.INIT_FAILED: "INIT_FAILED: LBFGS failed to initialize. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.",
1401+
LBFGSStatus.NON_FINITE: "NON_FINITE: LBFGS objective function produced inf or nan at the last iteration. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1402+
LBFGSStatus.LOW_UPDATE_MASK_RATIO: "LOW_UPDATE_MASK_RATIO: Majority of LBFGS iterations were not accepted due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
1403+
LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK: "INIT_FAILED_LOW_UPDATE_MASK: LBFGS failed to initialize due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
14021404
}
14031405

14041406
path_status_message = {
1405-
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter is may be too close to the mean posterior and a poor exploration of the parameter space. Consider increasing jitter if this occurence is high relative to the number of paths.",
1406-
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.",
1407-
PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1407+
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO_ARGMAX_AT_ZERO: ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.",
1408+
PathStatus.INVALID_LOGQ: "INVALID_LOGQ: Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
14081409
}
14091410

14101411
for lbfgs_status in mpr.lbfgs_status:
@@ -1626,7 +1627,7 @@ def fit_pathfinder(
16261627
maxiter: int = 1000, # L^max
16271628
ftol: float = 1e-5,
16281629
gtol: float = 1e-8,
1629-
maxls=1000,
1630+
maxls: int = 1000,
16301631
num_elbo_draws: int = 10, # K
16311632
jitter: float = 2.0,
16321633
epsilon: float = 1e-8,

tests/test_pathfinder.py

+84
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import re
1516
import sys
1617

1718
import numpy as np
1819
import pymc as pm
20+
import pytensor.tensor as pt
1921
import pytest
2022

2123
import pymc_extras as pmx
@@ -50,6 +52,88 @@ def reference_idata():
5052
return idata
5153

5254

55+
def unstable_lbfgs_update_mask_model() -> pm.Model:
56+
# data and model from: https://github.com/pymc-devs/pymc-extras/issues/445
57+
# this scenario made LBFGS struggle leading to a lot of rejected iterations, (result.nit being moderate, but only history.count <= 1).
58+
# this scenario is used to test that the LBFGS history manager is rejecting iterations as expected and PF can run to completion.
59+
60+
# fmt: off
61+
inp = np.array([0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 2, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 1, 1, 2, 0, 1, 2, 1, 0, 1, 0, 1, 0, 1, 0])
62+
63+
res = np.array([[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[0,1,0,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,1,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[1,0,0,0,0],[0,1,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,0,1,0]])
64+
# fmt: on
65+
66+
n_ordered = res.shape[1]
67+
coords = {
68+
"obs": np.arange(len(inp)),
69+
"inp": np.arange(max(inp) + 1),
70+
"outp": np.arange(res.shape[1]),
71+
}
72+
with pm.Model(coords=coords) as mdl:
73+
mu = pm.Normal("intercept", sigma=3.5)[None]
74+
75+
offset = pm.Normal(
76+
"offset", dims=("inp"), transform=pm.distributions.transforms.ZeroSumTransform([0])
77+
)
78+
79+
scale = 3.5 * pm.HalfStudentT("scale", nu=5)
80+
mu += (scale * offset)[inp]
81+
82+
phi_delta = pm.Dirichlet("phi_diffs", [1.0] * (n_ordered - 1))
83+
phi = pt.concatenate([[0], pt.cumsum(phi_delta)])
84+
s_mu = pm.Normal(
85+
"stereotype_intercept",
86+
size=n_ordered,
87+
transform=pm.distributions.transforms.ZeroSumTransform([-1]),
88+
)
89+
fprobs = pm.math.softmax(s_mu[None, :] + phi[None, :] * mu[:, None], axis=-1)
90+
91+
pm.Multinomial("y_res", p=fprobs, n=np.ones(len(inp)), observed=res, dims=("obs", "outp"))
92+
93+
return mdl
94+
95+
96+
@pytest.mark.parametrize("jitter", [12.0, 500.0, 1000.0])
97+
def test_unstable_lbfgs_update_mask(capsys, jitter):
98+
model = unstable_lbfgs_update_mask_model()
99+
100+
if jitter < 1000:
101+
with model:
102+
idata = pmx.fit(
103+
method="pathfinder",
104+
jitter=jitter,
105+
random_seed=4,
106+
)
107+
out, err = capsys.readouterr()
108+
status_pattern = [
109+
r"INIT_FAILED_LOW_UPDATE_MASK\s+\d+",
110+
r"LOW_UPDATE_MASK_RATIO\s+\d+",
111+
r"LBFGS_FAILED\s+\d+",
112+
r"SUCCESS\s+\d+",
113+
]
114+
for pattern in status_pattern:
115+
assert re.search(pattern, out) is not None
116+
117+
else:
118+
with pytest.raises(ValueError, match="All paths failed"):
119+
with model:
120+
idata = pmx.fit(
121+
method="pathfinder",
122+
jitter=1000,
123+
random_seed=2,
124+
num_paths=4,
125+
)
126+
out, err = capsys.readouterr()
127+
128+
status_pattern = [
129+
r"INIT_FAILED_LOW_UPDATE_MASK\s+2",
130+
r"LOW_UPDATE_MASK_RATIO\s+2",
131+
r"LBFGS_FAILED\s+4",
132+
]
133+
for pattern in status_pattern:
134+
assert re.search(pattern, out) is not None
135+
136+
53137
@pytest.mark.parametrize("inference_backend", ["pymc", "blackjax"])
54138
@pytest.mark.filterwarnings("ignore:JAXopt is no longer maintained.:DeprecationWarning")
55139
def test_pathfinder(inference_backend, reference_idata):

0 commit comments

Comments
 (0)