Skip to content

Commit 1fdb7a2

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 8e97492 commit 1fdb7a2

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

machine_learning/lasso_regression.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Lasso regression is a type of linear regression that adds a regularization term to the
3-
ordinary least squares (OLS) objective function. This regularization term is the
4-
L1 norm of the coefficients, which encourages sparsity in the model parameters. The
3+
ordinary least squares (OLS) objective function. This regularization term is the
4+
L1 norm of the coefficients, which encourages sparsity in the model parameters. The
55
objective function for Lasso regression is given by:
66
77
minimize ||y - Xβ||² + λ||β||₁
@@ -12,7 +12,7 @@
1212
- β is the vector of coefficients,
1313
- λ (lambda) is the regularization parameter controlling the strength of the penalty.
1414
15-
Lasso regression can be solved using coordinate descent or other optimization techniques.
15+
Lasso regression can be solved using coordinate descent or other optimization techniques.
1616
1717
References:
1818
- https://en.wikipedia.org/wiki/Lasso_(statistics)
@@ -25,7 +25,9 @@
2525
class LassoRegression:
2626
__slots__ = "alpha", "params"
2727

28-
def __init__(self, alpha: float = 1.0, tol: float = 1e-4, max_iter: int = 1000) -> None:
28+
def __init__(
29+
self, alpha: float = 1.0, tol: float = 1e-4, max_iter: int = 1000
30+
) -> None:
2931
"""
3032
Initializes the Lasso regression model.
3133
@@ -36,7 +38,7 @@ def __init__(self, alpha: float = 1.0, tol: float = 1e-4, max_iter: int = 1000)
3638
"""
3739
if alpha <= 0:
3840
raise ValueError("Regularization strength must be positive")
39-
41+
4042
self.alpha = alpha
4143
self.tol = tol
4244
self.max_iter = max_iter
@@ -70,7 +72,9 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> None:
7072
# Compute the residual
7173
residual = y - X @ self.params + X[:, j] * self.params[j]
7274
# Update the j-th coefficient using soft thresholding
73-
self.params[j] = self._soft_thresholding(X[:, j].T @ residual / n_samples, self.alpha / n_samples)
75+
self.params[j] = self._soft_thresholding(
76+
X[:, j].T @ residual / n_samples, self.alpha / n_samples
77+
)
7478

7579
# Check for convergence
7680
if np.linalg.norm(self.params - params_old, ord=1) < self.tol:
@@ -109,7 +113,7 @@ def main() -> None:
109113
plt.xlabel("True Values")
110114
plt.ylabel("Predicted Values")
111115
plt.title("Lasso Regression Predictions")
112-
plt.plot([y.min(), y.max()], [y.min(), y.max()], color='red', linewidth=2)
116+
plt.plot([y.min(), y.max()], [y.min(), y.max()], color="red", linewidth=2)
113117
plt.show()
114118

115119

0 commit comments

Comments
 (0)