23
23
24
24
25
25
class LassoRegression :
26
- __slots__ = "alpha" , "params"
26
+ __slots__ = "alpha" , "params" , "tol" , "max_iter"
27
27
28
28
def __init__ (
29
29
self , alpha : float = 1.0 , tol : float = 1e-4 , max_iter : int = 1000
@@ -55,32 +55,32 @@ def _soft_thresholding(rho: float, alpha: float) -> float:
55
55
"""
56
56
return np .sign (rho ) * max (0 , abs (rho ) - alpha )
57
57
58
- def fit (self , X : np .ndarray , y : np .ndarray ) -> None :
58
+ def fit (self , x : np .ndarray , y : np .ndarray ) -> None :
59
59
"""
60
60
Fits the Lasso regression model to the data.
61
61
62
- @param X : the design matrix (features)
62
+ @param x : the design matrix (features)
63
63
@param y: the response vector (target)
64
- @raises ArithmeticError: if X isn't full rank, can't compute coefficients
64
+ @raises ArithmeticError: if x isn't full rank, can't compute coefficients
65
65
"""
66
- n_samples , n_features = X .shape
66
+ n_samples , n_features = x .shape
67
67
self .params = np .zeros (n_features )
68
68
69
69
for _ in range (self .max_iter ):
70
70
params_old = self .params .copy ()
71
71
for j in range (n_features ):
72
72
# Compute the residual
73
- residual = y - X @ self .params + X [:, j ] * self .params [j ]
73
+ residual = y - x @ self .params + x [:, j ] * self .params [j ]
74
74
# Update the j-th coefficient using soft thresholding
75
75
self .params [j ] = self ._soft_thresholding (
76
- X [:, j ].T @ residual / n_samples , self .alpha / n_samples
76
+ x [:, j ].T @ residual / n_samples , self .alpha / n_samples
77
77
)
78
78
79
79
# Check for convergence
80
80
if np .linalg .norm (self .params - params_old , ord = 1 ) < self .tol :
81
81
break
82
82
83
- def predict (self , X : np .ndarray ) -> np .ndarray :
83
+ def predict (self , x : np .ndarray ) -> np .ndarray :
84
84
"""
85
85
Predicts the response values for the given input data.
86
86
@@ -91,7 +91,7 @@ def predict(self, X: np.ndarray) -> np.ndarray:
91
91
if self .params is None :
92
92
raise ArithmeticError ("Predictor hasn't been fit yet" )
93
93
94
- return X @ self .params
94
+ return x @ self .params
95
95
96
96
97
97
def main () -> None :
0 commit comments