Skip to content

Commit 82711a4

Browse files
Create lgbm_classifier.py
Adding LGBM Classifier Script for this repository.
1 parent 9a572de commit 82711a4

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

machine_learning/lgbm_classifier.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# LGBM Classifier Example
2+
import numpy as np
3+
from matplotlib import pyplot as plt
4+
from sklearn.datasets import load_iris
5+
from sklearn.metrics import ConfusionMatrixDisplay
6+
from sklearn.model_selection import train_test_split
7+
from lightgbm import LGBMClassifier
8+
9+
10+
def data_handling(data: dict) -> tuple:
11+
"""
12+
Splits dataset into features and target labels.
13+
14+
>>> data_handling({'data': '[5.1, 3.5, 1.4, 0.2]', 'target': [0]})
15+
('[5.1, 3.5, 1.4, 0.2]', [0])
16+
>>> data_handling({'data': '[4.9, 3.0, 1.4, 0.2], [4.7, 3.2, 1.3, 0.2]', 'target': [0, 0]})
17+
('[4.9, 3.0, 1.4, 0.2], [4.7, 3.2, 1.3, 0.2]', [0, 0])
18+
"""
19+
return data["data"], data["target"]
20+
21+
22+
def lgbm_classifier(features: np.ndarray, target: np.ndarray) -> LGBMClassifier:
23+
"""
24+
Trains an LGBM Classifier on the given features and target labels.
25+
26+
>>> lgbm_classifier(np.array([[5.1, 3.6, 1.4, 0.2]]), np.array([0]))
27+
LGBMClassifier()
28+
"""
29+
classifier = LGBMClassifier()
30+
classifier.fit(features, target)
31+
return classifier
32+
33+
34+
def main() -> None:
35+
"""
36+
Main function to demonstrate LGBM classification on the Iris dataset.
37+
38+
URL for LightGBM documentation:
39+
https://lightgbm.readthedocs.io/en/latest/
40+
"""
41+
# Load the Iris dataset
42+
iris = load_iris()
43+
features, targets = data_handling(iris)
44+
45+
# Split the dataset into training and testing sets
46+
x_train, x_test, y_train, y_test = train_test_split(
47+
features, targets, test_size=0.25, random_state=42
48+
)
49+
50+
# Class names for display
51+
names = iris["target_names"]
52+
53+
# Train the LGBM classifier
54+
lgbm_clf = lgbm_classifier(x_train, y_train)
55+
56+
# Display the confusion matrix for the classifier
57+
ConfusionMatrixDisplay.from_estimator(
58+
lgbm_clf,
59+
x_test,
60+
y_test,
61+
display_labels=names,
62+
cmap="Blues",
63+
normalize="true",
64+
)
65+
plt.title("Normalized Confusion Matrix - IRIS Dataset")
66+
plt.show()
67+
68+
69+
if __name__ == "__main__":
70+
import doctest
71+
doctest.testmod(verbose=True)
72+
main()

0 commit comments

Comments
 (0)