Skip to content

Commit ce56390

Browse files
authored
Create catboost_regressor.py
1 parent 2d671df commit ce56390

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed
+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# CatBoost Classifier Example
2+
import numpy as np
3+
from matplotlib import pyplot as plt
4+
from sklearn.datasets import load_iris
5+
from sklearn.metrics import ConfusionMatrixDisplay
6+
from sklearn.model_selection import train_test_split
7+
from catboost import CatBoostClassifier
8+
9+
10+
def data_handling(data: dict) -> tuple:
11+
"""
12+
Extracts the features and target values from the provided dataset.
13+
14+
Args:
15+
data (dict): A dictionary containing the dataset's features and targets.
16+
17+
Returns:
18+
tuple: A tuple with features and targets.
19+
20+
Example:
21+
>>> data_handling({'data':'[5.1, 3.5, 1.4, 0.2]', 'target': [0]})
22+
('[5.1, 3.5, 1.4, 0.2]', [0])
23+
"""
24+
return data["data"], data["target"]
25+
26+
27+
def catboost(features: np.ndarray, target: np.ndarray) -> CatBoostClassifier:
28+
"""
29+
Trains a CatBoostClassifier using the provided features and target.
30+
31+
Args:
32+
features (np.ndarray): The input features for training the classifier.
33+
target (np.ndarray): The target labels corresponding to the features.
34+
35+
Returns:
36+
CatBoostClassifier: A trained CatBoost classifier.
37+
38+
Example:
39+
>>> catboost(np.array([[5.1, 3.6, 1.4, 0.2]]), np.array([0]))
40+
CatBoostClassifier(...)
41+
"""
42+
classifier = CatBoostClassifier(verbose=0) # Suppressing verbose output
43+
classifier.fit(features, target)
44+
return classifier
45+
46+
47+
def main() -> None:
48+
"""
49+
Demonstrates the training and evaluation of a CatBoost classifier
50+
on the Iris dataset, displaying a confusion matrix of the results.
51+
52+
The dataset is split into training and testing sets, the model is
53+
trained on the training data, and then evaluated on the test data.
54+
A normalized confusion matrix is displayed.
55+
"""
56+
57+
# Load the Iris dataset
58+
iris = load_iris()
59+
features, targets = data_handling(iris)
60+
x_train, x_test, y_train, y_test = train_test_split(
61+
features, targets, test_size=0.25
62+
)
63+
64+
# Train a CatBoost classifier
65+
catboost_classifier = catboost(x_train, y_train)
66+
67+
# Display the confusion matrix for the test data
68+
ConfusionMatrixDisplay.from_estimator(
69+
catboost_classifier,
70+
x_test,
71+
y_test,
72+
display_labels=iris["target_names"],
73+
cmap="Blues",
74+
normalize="true",
75+
)
76+
plt.title("Normalized Confusion Matrix - IRIS Dataset")
77+
plt.show()
78+
79+
80+
if __name__ == "__main__":
81+
import doctest
82+
doctest.testmod(verbose=True)
83+
main()

0 commit comments

Comments
 (0)