Skip to content

Commit e3086a8

Browse files
Create lgbm_regressor.py
1 parent 3d46dd0 commit e3086a8

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

machine_learning/lgbm_regressor.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# LGBM Regressor Example using Bank Marketing Dataset
2+
import numpy as np
3+
from sklearn.datasets import fetch_openml
4+
from sklearn.metrics import mean_absolute_error, mean_squared_error
5+
from sklearn.model_selection import train_test_split
6+
from lightgbm import LGBMRegressor
7+
8+
9+
def data_handling(data: dict) -> tuple:
10+
# Split dataset into features and target. Data is features.
11+
"""
12+
>>> data_handling((
13+
... {'data':'[0.12, 0.02, 0.01, 0.25, 0.09]',
14+
... 'target':([1])}))
15+
('[0.12, 0.02, 0.01, 0.25, 0.09]', [1])
16+
"""
17+
return (data["data"], data["target"])
18+
19+
20+
def lgbm_regressor(
21+
features: np.ndarray, target: np.ndarray, test_features: np.ndarray
22+
) -> np.ndarray:
23+
"""
24+
>>> lgbm_regressor(np.array([[ 0.12, 0.02, 0.01, 0.25, 0.09]]), np.array([1]),
25+
... np.array([[0.11, 0.03, 0.02, 0.28, 0.08]]))
26+
array([[0.98]], dtype=float32)
27+
"""
28+
lgbm = LGBMRegressor(
29+
verbosity=0, random_state=42
30+
)
31+
lgbm.fit(features, target)
32+
# Predict target for test data
33+
predictions = lgbm.predict(test_features)
34+
predictions = predictions.reshape(len(predictions), 1)
35+
return predictions
36+
37+
38+
def main() -> None:
39+
"""
40+
The URL for this algorithm:
41+
https://lightgbm.readthedocs.io/en/latest/
42+
Bank Marketing Dataset is used to demonstrate the algorithm.
43+
44+
Expected error values:
45+
Mean Absolute Error: 0.2 (approx.)
46+
Mean Square Error: 0.15 (approx.)
47+
"""
48+
# Load Bank Marketing dataset
49+
bank_data = fetch_openml(name='bank-marketing', version=1, as_frame=False)
50+
data, target = data_handling(bank_data)
51+
x_train, x_test, y_train, y_test = train_test_split(
52+
data, target, test_size=0.25, random_state=1
53+
)
54+
predictions = lgbm_regressor(x_train, y_train, x_test)
55+
# Error printing
56+
print(f"Mean Absolute Error: {mean_absolute_error(y_test, predictions)}")
57+
print(f"Mean Square Error: {mean_squared_error(y_test, predictions)}")
58+
59+
60+
if __name__ == "__main__":
61+
import doctest
62+
63+
doctest.testmod(verbose=True)
64+
main()

0 commit comments

Comments
 (0)