Skip to content

Commit e486852

Browse files
committed
added doctests
1 parent 12c29ce commit e486852

File tree

1 file changed

+79
-23
lines changed

1 file changed

+79
-23
lines changed

linear_programming/interior_point_method.py

+79-23
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
1-
"""
2-
Python implementation of the Primal-Dual Interior-Point Method for solving linear
3-
programs with - `>=`, `<=`, and `=` constraints and - each variable `x1, x2, ... >= 0`.
4-
5-
Resources:
6-
https://en.wikipedia.org/wiki/Interior-point_method
7-
"""
8-
91
import numpy as np
102

113

@@ -15,11 +7,14 @@ class InteriorPointMethod:
157
168
Attributes:
179
objective_coefficients (np.ndarray): Coefficient matrix for the objective
18-
function.
10+
function.
1911
constraint_matrix (np.ndarray): Constraint matrix.
2012
constraint_bounds (np.ndarray): Constraint bounds.
2113
tol (float): Tolerance for stopping criterion.
2214
max_iter (int): Maximum number of iterations.
15+
16+
Methods:
17+
solve: Solve the linear programming problem.
2318
"""
2419

2520
def __init__(
@@ -40,18 +35,53 @@ def __init__(
4035
raise ValueError("Invalid input for the linear programming problem.")
4136

4237
def _is_valid_input(self) -> bool:
43-
"""Validate the input for the linear programming problem."""
38+
"""
39+
Validate the input for the linear programming problem.
40+
41+
Returns:
42+
bool: True if input is valid, False otherwise.
43+
44+
>>> objective_coefficients = np.array([1, 2])
45+
>>> constraint_matrix = np.array([[1, 1], [1, -1]])
46+
>>> constraint_bounds = np.array([2, 0])
47+
>>> ipm = InteriorPointMethod(objective_coefficients, constraint_matrix,
48+
constraint_bounds)
49+
>>> ipm._is_valid_input()
50+
True
51+
>>> constraint_bounds = np.array([2, 0, 1])
52+
>>> ipm = InteriorPointMethod(objective_coefficients, constraint_matrix,
53+
constraint_bounds)
54+
>>> ipm._is_valid_input()
55+
False
56+
"""
4457
return (
4558
self.constraint_matrix.shape[0] == self.constraint_bounds.shape[0]
4659
and self.constraint_matrix.shape[1] == self.objective_coefficients.shape[0]
4760
)
4861

4962
def _convert_to_standard_form(self) -> tuple[np.ndarray, np.ndarray]:
50-
"""Convert constraints to standard form by adding slack and surplus
51-
variables."""
63+
"""
64+
Convert constraints to standard form by adding slack variables.
65+
66+
Returns:
67+
tuple: A tuple of the standard form constraint matrix and objective
68+
coefficients.
69+
70+
>>> objective_coefficients = np.array([1, 2])
71+
>>> constraint_matrix = np.array([[1, 1], [1, -1]])
72+
>>> constraint_bounds = np.array([2, 0])
73+
>>> ipm = InteriorPointMethod(objective_coefficients, constraint_matrix,
74+
constraint_bounds)
75+
>>> a_standard, c_standard = ipm._convert_to_standard_form()
76+
>>> a_standard
77+
array([[ 1., 1., 1., 0.],
78+
[ 1., -1., 0., 1.]])
79+
>>> c_standard
80+
array([1., 2., 0., 0.])
81+
"""
5282
(m, n) = self.constraint_matrix.shape
53-
slack_surplus = np.eye(m)
54-
a_standard = np.hstack([self.constraint_matrix, slack_surplus])
83+
slack = np.eye(m)
84+
a_standard = np.hstack([self.constraint_matrix, slack])
5585
c_standard = np.hstack([self.objective_coefficients, np.zeros(m)])
5686
return a_standard, c_standard
5787

@@ -60,15 +90,24 @@ def solve(self) -> tuple[np.ndarray, float]:
6090
Solve problem with Primal-Dual Interior-Point Method.
6191
6292
Returns:
63-
tuple: A tuple with optimal soln and the optimal value.
93+
tuple: A tuple with optimal solution and the optimal value.
94+
95+
>>> objective_coefficients = np.array([1, 2])
96+
>>> constraint_matrix = np.array([[1, 1], [1, -1]])
97+
>>> constraint_bounds = np.array([2, 0])
98+
>>> ipm = InteriorPointMethod(objective_coefficients, constraint_matrix,
99+
constraint_bounds)
100+
>>> solution, value = ipm.solve()
101+
>>> np.isclose(value, np.dot(objective_coefficients, solution))
102+
True
64103
"""
65104
a, c = self._convert_to_standard_form()
66105
m, n = a.shape
67106
x = np.ones(n)
68107
s = np.ones(n)
69108
y = np.ones(m)
70109

71-
for _ in range(self.max_iter):
110+
for iteration in range(self.max_iter):
72111
x_diag = np.diag(x)
73112
s_diag = np.diag(s)
74113

@@ -99,23 +138,35 @@ def solve(self) -> tuple[np.ndarray, float]:
99138
r = np.hstack([-r2, -r1, -r3 + mu * np.ones(n)])
100139

101140
# Solve the KKT system
102-
delta = np.linalg.solve(kkt, r)
141+
try:
142+
delta = np.linalg.solve(kkt, r)
143+
except np.linalg.LinAlgError:
144+
print("KKT matrix is singular, switching to least squares solution")
145+
delta = np.linalg.lstsq(kkt, r, rcond=None)[0]
103146

104147
dx = delta[:n]
105148
dy = delta[n : n + m]
106149
ds = delta[n + m :]
107150

108151
# Step size
109-
alpha_primal = min(1, 0.99 * min(-x[dx < 0] / dx[dx < 0], default=1))
110-
alpha_dual = min(1, 0.99 * min(-s[ds < 0] / ds[ds < 0], default=1))
152+
alpha_primal = min(
153+
1, 0.99 * min([-x[i] / dx[i] for i in range(n) if dx[i] < 0], default=1)
154+
)
155+
alpha_dual = min(
156+
1, 0.99 * min([-s[i] / ds[i] for i in range(n) if ds[i] < 0], default=1)
157+
)
111158

112159
# Update variables
113160
x += alpha_primal * dx
114161
y += alpha_dual * dy
115162
s += alpha_dual * ds
116163

117-
optimal_value = np.dot(c, x)
118-
return x, optimal_value
164+
print(f"Iteration {iteration}: x = {x}, s = {s}, y = {y}")
165+
166+
# Extract the solution (remove slack variables)
167+
original_vars = x[: self.objective_coefficients.shape[0]]
168+
optimal_value = np.dot(self.objective_coefficients, original_vars)
169+
return original_vars, optimal_value
119170

120171

121172
if __name__ == "__main__":
@@ -127,5 +178,10 @@ def solve(self) -> tuple[np.ndarray, float]:
127178
objective_coefficients, constraint_matrix, constraint_bounds
128179
)
129180
solution, value = ipm.solve()
130-
print("Optimal solution:", solution)
131-
print("Optimal value:", value)
181+
print("Final optimal solution:", solution)
182+
print("Final optimal value:", value)
183+
184+
# Verify the solution
185+
print("\nVerification:")
186+
print("Objective value calculation matches final optimal value:")
187+
print(np.isclose(value, np.dot(objective_coefficients, solution)))

0 commit comments

Comments
 (0)