Skip to content

Commit f0ceb04

Browse files
Update lgbm_regressor.py
1 parent e3197a2 commit f0ceb04

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

machine_learning/lgbm_regressor.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# LGBM Regressor Example using Bank Marketing Dataset
22
import numpy as np
3+
from lightgbm import LGBMRegressor
34
from sklearn.datasets import fetch_openml
45
from sklearn.metrics import mean_absolute_error, mean_squared_error
56
from sklearn.model_selection import train_test_split
6-
from lightgbm import LGBMRegressor
77

88

99
def data_handling(data: dict) -> tuple:
@@ -17,9 +17,8 @@ def data_handling(data: dict) -> tuple:
1717
return (data["data"], data["target"])
1818

1919

20-
def lgbm_regressor(
21-
features: np.ndarray, target: np.ndarray, test_features: np.ndarray
22-
) -> np.ndarray:
20+
def lgbm_regressor(features: np.ndarray, target: np.ndarray,
21+
test_features: np.ndarray) -> np.ndarray:
2322
"""
2423
>>> lgbm_regressor(np.array([[0.12, 0.02, 0.01, 0.25, 0.09]]),
2524
... np.array([1]), np.array([[0.11, 0.03, 0.02, 0.28, 0.08]]))
@@ -40,7 +39,7 @@ def main() -> None:
4039
Bank Marketing Dataset is used to demonstrate the algorithm.
4140
"""
4241
# Load Bank Marketing dataset
43-
bank_data = fetch_openml(name="bank-marketing", version=1, as_frame=False)
42+
bank_data = fetch_openml(name='bank-marketing', version=1, as_frame=False)
4443
data, target = data_handling(bank_data)
4544
x_train, x_test, y_train, y_test = train_test_split(
4645
data, target, test_size=0.25, random_state=1

0 commit comments

Comments
 (0)