Skip to content

Commit fb1f008

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 0da881b commit fb1f008

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

linear_programming/interior_point_method.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@ class InteriorPointMethod:
2323
max_iter (int): Maximum number of iterations.
2424
"""
2525

26-
def __init__(self, c: np.ndarray, A: np.ndarray, b: np.ndarray, tol: float = 1e-8, max_iter: int = 100) -> None:
26+
def __init__(
27+
self,
28+
c: np.ndarray,
29+
A: np.ndarray,
30+
b: np.ndarray,
31+
tol: float = 1e-8,
32+
max_iter: int = 100,
33+
) -> None:
2734
self.c = c
2835
self.A = A
2936
self.b = b
@@ -35,7 +42,9 @@ def __init__(self, c: np.ndarray, A: np.ndarray, b: np.ndarray, tol: float = 1e-
3542

3643
def _is_valid_input(self) -> bool:
3744
"""Validate the input for the linear programming problem."""
38-
return (self.A.shape[0] == self.b.shape[0]) and (self.A.shape[1] == self.c.shape[0])
45+
return (self.A.shape[0] == self.b.shape[0]) and (
46+
self.A.shape[1] == self.c.shape[0]
47+
)
3948

4049
def _convert_to_standard_form(self):
4150
"""Convert constraints to standard form by adding slack and surplus variables."""
@@ -67,17 +76,23 @@ def solve(self) -> tuple[np.ndarray, float]:
6776
r2 = A.T @ y + s - c
6877
r3 = x * s
6978

70-
if np.linalg.norm(r1) < self.tol and np.linalg.norm(r2) < self.tol and np.linalg.norm(r3) < self.tol:
79+
if (
80+
np.linalg.norm(r1) < self.tol
81+
and np.linalg.norm(r2) < self.tol
82+
and np.linalg.norm(r3) < self.tol
83+
):
7184
break
7285

7386
mu = np.dot(x, s) / n
7487

7588
# Form the KKT matrix
76-
KKT = np.block([
77-
[np.zeros((n, n)), A.T, np.eye(n)],
78-
[A, np.zeros((m, m)), np.zeros((m, n))],
79-
[S, np.zeros((n, m)), X]
80-
])
89+
KKT = np.block(
90+
[
91+
[np.zeros((n, n)), A.T, np.eye(n)],
92+
[A, np.zeros((m, m)), np.zeros((m, n))],
93+
[S, np.zeros((n, m)), X],
94+
]
95+
)
8196

8297
# Right-hand side
8398
r = np.hstack([-r2, -r1, -r3 + mu * np.ones(n)])
@@ -86,8 +101,8 @@ def solve(self) -> tuple[np.ndarray, float]:
86101
delta = np.linalg.solve(KKT, r)
87102

88103
dx = delta[:n]
89-
dy = delta[n:n + m]
90-
ds = delta[n + m:]
104+
dy = delta[n : n + m]
105+
ds = delta[n + m :]
91106

92107
# Step size
93108
alpha_primal = min(1, 0.99 * min(-x[dx < 0] / dx[dx < 0], default=1))

0 commit comments

Comments
 (0)