1
1
"""
2
2
Lasso regression is a type of linear regression that adds a regularization term to the
3
- ordinary least squares (OLS) objective function. This regularization term is the
4
- L1 norm of the coefficients, which encourages sparsity in the model parameters. The
3
+ ordinary least squares (OLS) objective function. This regularization term is the
4
+ L1 norm of the coefficients, which encourages sparsity in the model parameters. The
5
5
objective function for Lasso regression is given by:
6
6
7
7
minimize ||y - Xβ||² + λ||β||₁
12
12
- β is the vector of coefficients,
13
13
- λ (lambda) is the regularization parameter controlling the strength of the penalty.
14
14
15
- Lasso regression can be solved using coordinate descent or other optimization techniques.
15
+ Lasso regression can be solved using coordinate descent or other optimization techniques.
16
16
17
17
References:
18
18
- https://en.wikipedia.org/wiki/Lasso_(statistics)
25
25
class LassoRegression :
26
26
__slots__ = "alpha" , "params"
27
27
28
- def __init__ (self , alpha : float = 1.0 , tol : float = 1e-4 , max_iter : int = 1000 ) -> None :
28
+ def __init__ (
29
+ self , alpha : float = 1.0 , tol : float = 1e-4 , max_iter : int = 1000
30
+ ) -> None :
29
31
"""
30
32
Initializes the Lasso regression model.
31
33
@@ -36,7 +38,7 @@ def __init__(self, alpha: float = 1.0, tol: float = 1e-4, max_iter: int = 1000)
36
38
"""
37
39
if alpha <= 0 :
38
40
raise ValueError ("Regularization strength must be positive" )
39
-
41
+
40
42
self .alpha = alpha
41
43
self .tol = tol
42
44
self .max_iter = max_iter
@@ -70,7 +72,9 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> None:
70
72
# Compute the residual
71
73
residual = y - X @ self .params + X [:, j ] * self .params [j ]
72
74
# Update the j-th coefficient using soft thresholding
73
- self .params [j ] = self ._soft_thresholding (X [:, j ].T @ residual / n_samples , self .alpha / n_samples )
75
+ self .params [j ] = self ._soft_thresholding (
76
+ X [:, j ].T @ residual / n_samples , self .alpha / n_samples
77
+ )
74
78
75
79
# Check for convergence
76
80
if np .linalg .norm (self .params - params_old , ord = 1 ) < self .tol :
@@ -109,7 +113,7 @@ def main() -> None:
109
113
plt .xlabel ("True Values" )
110
114
plt .ylabel ("Predicted Values" )
111
115
plt .title ("Lasso Regression Predictions" )
112
- plt .plot ([y .min (), y .max ()], [y .min (), y .max ()], color = ' red' , linewidth = 2 )
116
+ plt .plot ([y .min (), y .max ()], [y .min (), y .max ()], color = " red" , linewidth = 2 )
113
117
plt .show ()
114
118
115
119
0 commit comments