Skip to content

Commit 8e97492

Browse files
authored
Created lasso_regression.py
1 parent 03a4251 commit 8e97492

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

machine_learning/lasso_regression.py

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
Lasso regression is a type of linear regression that adds a regularization term to the
3+
ordinary least squares (OLS) objective function. This regularization term is the
4+
L1 norm of the coefficients, which encourages sparsity in the model parameters. The
5+
objective function for Lasso regression is given by:
6+
7+
minimize ||y - Xβ||² + λ||β||₁
8+
9+
where:
10+
- y is the response vector,
11+
- X is the design matrix,
12+
- β is the vector of coefficients,
13+
- λ (lambda) is the regularization parameter controlling the strength of the penalty.
14+
15+
Lasso regression can be solved using coordinate descent or other optimization techniques.
16+
17+
References:
18+
- https://en.wikipedia.org/wiki/Lasso_(statistics)
19+
- https://en.wikipedia.org/wiki/Regularization_(mathematics)
20+
"""
21+
22+
import numpy as np
23+
24+
25+
class LassoRegression:
26+
__slots__ = "alpha", "params"
27+
28+
def __init__(self, alpha: float = 1.0, tol: float = 1e-4, max_iter: int = 1000) -> None:
29+
"""
30+
Initializes the Lasso regression model.
31+
32+
@param alpha: regularization strength; must be a positive float
33+
@param tol: tolerance for stopping criteria
34+
@param max_iter: maximum number of iterations
35+
@raises ValueError: if alpha is not positive
36+
"""
37+
if alpha <= 0:
38+
raise ValueError("Regularization strength must be positive")
39+
40+
self.alpha = alpha
41+
self.tol = tol
42+
self.max_iter = max_iter
43+
self.params = None
44+
45+
@staticmethod
46+
def _soft_thresholding(rho: float, alpha: float) -> float:
47+
"""
48+
Applies the soft thresholding operator.
49+
50+
@param rho: the value to be thresholded
51+
@param alpha: the regularization parameter
52+
@returns: the thresholded value
53+
"""
54+
return np.sign(rho) * max(0, abs(rho) - alpha)
55+
56+
def fit(self, X: np.ndarray, y: np.ndarray) -> None:
57+
"""
58+
Fits the Lasso regression model to the data.
59+
60+
@param X: the design matrix (features)
61+
@param y: the response vector (target)
62+
@raises ArithmeticError: if X isn't full rank, can't compute coefficients
63+
"""
64+
n_samples, n_features = X.shape
65+
self.params = np.zeros(n_features)
66+
67+
for _ in range(self.max_iter):
68+
params_old = self.params.copy()
69+
for j in range(n_features):
70+
# Compute the residual
71+
residual = y - X @ self.params + X[:, j] * self.params[j]
72+
# Update the j-th coefficient using soft thresholding
73+
self.params[j] = self._soft_thresholding(X[:, j].T @ residual / n_samples, self.alpha / n_samples)
74+
75+
# Check for convergence
76+
if np.linalg.norm(self.params - params_old, ord=1) < self.tol:
77+
break
78+
79+
def predict(self, X: np.ndarray) -> np.ndarray:
80+
"""
81+
Predicts the response values for the given input data.
82+
83+
@param X: the design matrix (features) for prediction
84+
@returns: the predicted response values
85+
@raises ArithmeticError: if this function is called before the model parameters are fit
86+
"""
87+
if self.params is None:
88+
raise ArithmeticError("Predictor hasn't been fit yet")
89+
90+
return X @ self.params
91+
92+
93+
def main() -> None:
94+
"""
95+
Fit a Lasso regression model to predict a target variable using synthetic data.
96+
"""
97+
import matplotlib.pyplot as plt
98+
from sklearn.datasets import make_regression
99+
100+
# Create synthetic data
101+
X, y = make_regression(n_samples=100, n_features=10, noise=0.1)
102+
103+
lasso_reg = LassoRegression(alpha=0.1)
104+
lasso_reg.fit(X, y)
105+
106+
predictions = lasso_reg.predict(X)
107+
108+
plt.scatter(y, predictions, alpha=0.5)
109+
plt.xlabel("True Values")
110+
plt.ylabel("Predicted Values")
111+
plt.title("Lasso Regression Predictions")
112+
plt.plot([y.min(), y.max()], [y.min(), y.max()], color='red', linewidth=2)
113+
plt.show()
114+
115+
116+
if __name__ == "__main__":
117+
import doctest
118+
119+
doctest.testmod()
120+
121+
main()

0 commit comments

Comments
 (0)