Skip to content

Commit daf0f8b

Browse files
Update lgbm_classifier.py
1 parent 3e93a6c commit daf0f8b

File tree

1 file changed

+23
-33
lines changed

1 file changed

+23
-33
lines changed

machine_learning/lgbm_classifier.py

+23-33
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,58 @@
1-
# LGBM Classifier Example
1+
# LGBM Classifier Example using Bank Marketing Dataset
22
import numpy as np
33
from matplotlib import pyplot as plt
4-
from sklearn.datasets import load_iris
4+
from sklearn.datasets import fetch_openml
55
from sklearn.metrics import ConfusionMatrixDisplay
66
from sklearn.model_selection import train_test_split
77
from lightgbm import LGBMClassifier
88

99

1010
def data_handling(data: dict) -> tuple:
11+
# Split dataset into features and target. Data is features.
1112
"""
12-
Splits dataset into features and target labels.
13-
14-
>>> data_handling({'data': '[5.1, 3.5, 1.4, 0.2]', 'target': [0]})
15-
('[5.1, 3.5, 1.4, 0.2]', [0])
16-
>>> data_handling({'data': '[4.9, 3.0, 1.4, 0.2], [4.7, 3.2, 1.3, 0.2]', 'target': [0, 0]})
17-
('[4.9, 3.0, 1.4, 0.2], [4.7, 3.2, 1.3, 0.2]', [0, 0])
13+
>>> data_handling((
14+
... {'data':'[0.12, 0.02, 0.01, 0.25, 0.09]',
15+
... 'target':([1])}))
16+
('[0.12, 0.02, 0.01, 0.25, 0.09]', [1])
1817
"""
19-
return data["data"], data["target"]
18+
return (data["data"], data["target"])
2019

2120

2221
def lgbm_classifier(features: np.ndarray, target: np.ndarray) -> LGBMClassifier:
2322
"""
24-
Trains an LGBM Classifier on the given features and target labels.
25-
26-
>>> lgbm_classifier(np.array([[5.1, 3.6, 1.4, 0.2]]), np.array([0]))
27-
LGBMClassifier()
23+
>>> lgbm_classifier(np.array([[0.12, 0.02, 0.01, 0.25, 0.09]]), np.array([1]))
24+
LGBMClassifier(...)
2825
"""
29-
classifier = LGBMClassifier()
26+
classifier = LGBMClassifier(random_state=42)
3027
classifier.fit(features, target)
3128
return classifier
3229

3330

3431
def main() -> None:
3532
"""
36-
Main function to demonstrate LGBM classification on the Iris dataset.
37-
38-
URL for LightGBM documentation:
33+
The URL for this algorithm:
3934
https://lightgbm.readthedocs.io/en/latest/
35+
Bank Marketing Dataset is used to demonstrate the algorithm.
4036
"""
41-
# Load the Iris dataset
42-
iris = load_iris()
43-
features, targets = data_handling(iris)
44-
45-
# Split the dataset into training and testing sets
37+
# Load Bank Marketing dataset
38+
bank_data = fetch_openml(name='bank-marketing', version=1, as_frame=False)
39+
data, target = data_handling(bank_data)
4640
x_train, x_test, y_train, y_test = train_test_split(
47-
features, targets, test_size=0.25, random_state=42
41+
data, target, test_size=0.25, random_state=1
4842
)
43+
# Create an LGBM Classifier from the training data
44+
lgbm_classifier_model = lgbm_classifier(x_train, y_train)
4945

50-
# Class names for display
51-
names = iris["target_names"]
52-
53-
# Train the LGBM classifier
54-
lgbm_clf = lgbm_classifier(x_train, y_train)
55-
56-
# Display the confusion matrix for the classifier
46+
# Display the confusion matrix of the classifier
5747
ConfusionMatrixDisplay.from_estimator(
58-
lgbm_clf,
48+
lgbm_classifier_model,
5949
x_test,
6050
y_test,
61-
display_labels=names,
51+
display_labels=['No', 'Yes'],
6252
cmap="Blues",
6353
normalize="true",
6454
)
65-
plt.title("Normalized Confusion Matrix - IRIS Dataset")
55+
plt.title("Normalized Confusion Matrix - Bank Marketing Dataset")
6656
plt.show()
6757

6858

0 commit comments

Comments
 (0)