Skip to content

Commit 0a231c5

Browse files
authored
Created lasso_regression.py
1 parent 1fdb7a2 commit 0a231c5

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

machine_learning/lasso_regression.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
class LassoRegression:
26-
__slots__ = "alpha", "params"
26+
__slots__ = "alpha", "params", "tol", "max_iter"
2727

2828
def __init__(
2929
self, alpha: float = 1.0, tol: float = 1e-4, max_iter: int = 1000
@@ -55,32 +55,32 @@ def _soft_thresholding(rho: float, alpha: float) -> float:
5555
"""
5656
return np.sign(rho) * max(0, abs(rho) - alpha)
5757

58-
def fit(self, X: np.ndarray, y: np.ndarray) -> None:
58+
def fit(self, x: np.ndarray, y: np.ndarray) -> None:
5959
"""
6060
Fits the Lasso regression model to the data.
6161
62-
@param X: the design matrix (features)
62+
@param x: the design matrix (features)
6363
@param y: the response vector (target)
64-
@raises ArithmeticError: if X isn't full rank, can't compute coefficients
64+
@raises ArithmeticError: if x isn't full rank, can't compute coefficients
6565
"""
66-
n_samples, n_features = X.shape
66+
n_samples, n_features = x.shape
6767
self.params = np.zeros(n_features)
6868

6969
for _ in range(self.max_iter):
7070
params_old = self.params.copy()
7171
for j in range(n_features):
7272
# Compute the residual
73-
residual = y - X @ self.params + X[:, j] * self.params[j]
73+
residual = y - x @ self.params + x[:, j] * self.params[j]
7474
# Update the j-th coefficient using soft thresholding
7575
self.params[j] = self._soft_thresholding(
76-
X[:, j].T @ residual / n_samples, self.alpha / n_samples
76+
x[:, j].T @ residual / n_samples, self.alpha / n_samples
7777
)
7878

7979
# Check for convergence
8080
if np.linalg.norm(self.params - params_old, ord=1) < self.tol:
8181
break
8282

83-
def predict(self, X: np.ndarray) -> np.ndarray:
83+
def predict(self, x: np.ndarray) -> np.ndarray:
8484
"""
8585
Predicts the response values for the given input data.
8686
@@ -91,7 +91,7 @@ def predict(self, X: np.ndarray) -> np.ndarray:
9191
if self.params is None:
9292
raise ArithmeticError("Predictor hasn't been fit yet")
9393

94-
return X @ self.params
94+
return x @ self.params
9595

9696

9797
def main() -> None:

0 commit comments

Comments
 (0)