Skip to content

Commit 65048fd

Browse files
authored
Update catboost_regressor.py
added description
1 parent ce56390 commit 65048fd

File tree

1 file changed

+59
-48
lines changed

1 file changed

+59
-48
lines changed

machine_learning/catboost_regressor.py

+59-48
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,91 @@
1-
# CatBoost Classifier Example
1+
"""
2+
CatBoost Regressor Example.
3+
4+
This script demonstrates the usage of the CatBoost Regressor for a simple regression task.
5+
CatBoost is a powerful gradient boosting library that handles categorical features automatically
6+
and is highly efficient.
7+
8+
Make sure to install CatBoost using:
9+
pip install catboost
10+
11+
Contributed by: @AHuzail
12+
"""
13+
214
import numpy as np
3-
from matplotlib import pyplot as plt
4-
from sklearn.datasets import load_iris
5-
from sklearn.metrics import ConfusionMatrixDisplay
15+
from sklearn.datasets import load_boston
616
from sklearn.model_selection import train_test_split
7-
from catboost import CatBoostClassifier
17+
from sklearn.metrics import mean_squared_error
18+
from catboost import CatBoostRegressor
819

920

10-
def data_handling(data: dict) -> tuple:
21+
def data_handling() -> tuple:
1122
"""
12-
Extracts the features and target values from the provided dataset.
13-
14-
Args:
15-
data (dict): A dictionary containing the dataset's features and targets.
23+
Loads and handles the dataset, splitting it into features and targets.
1624
25+
The Boston dataset is used as a regression example.
26+
1727
Returns:
18-
tuple: A tuple with features and targets.
28+
tuple: A tuple of (features, target), where both are numpy arrays.
1929
2030
Example:
21-
>>> data_handling({'data':'[5.1, 3.5, 1.4, 0.2]', 'target': [0]})
22-
('[5.1, 3.5, 1.4, 0.2]', [0])
31+
>>> features, target = data_handling()
32+
>>> features.shape
33+
(506, 13)
34+
>>> target.shape
35+
(506,)
2336
"""
24-
return data["data"], data["target"]
37+
# Load Boston dataset (note: this dataset may be deprecated, replace if needed)
38+
boston = load_boston()
39+
features = boston.data
40+
target = boston.target
41+
return features, target
2542

2643

27-
def catboost(features: np.ndarray, target: np.ndarray) -> CatBoostClassifier:
44+
def catboost_regressor(features: np.ndarray, target: np.ndarray) -> CatBoostRegressor:
2845
"""
29-
Trains a CatBoostClassifier using the provided features and target.
46+
Trains a CatBoostRegressor using the provided features and target values.
3047
3148
Args:
32-
features (np.ndarray): The input features for training the classifier.
33-
target (np.ndarray): The target labels corresponding to the features.
49+
features (np.ndarray): The input features for the regression model.
50+
target (np.ndarray): The target values for the regression model.
3451
3552
Returns:
36-
CatBoostClassifier: A trained CatBoost classifier.
53+
CatBoostRegressor: A trained CatBoost regressor model.
3754
3855
Example:
39-
>>> catboost(np.array([[5.1, 3.6, 1.4, 0.2]]), np.array([0]))
40-
CatBoostClassifier(...)
56+
>>> features, target = data_handling()
57+
>>> model = catboost_regressor(features, target)
58+
>>> isinstance(model, CatBoostRegressor)
59+
True
4160
"""
42-
classifier = CatBoostClassifier(verbose=0) # Suppressing verbose output
43-
classifier.fit(features, target)
44-
return classifier
61+
regressor = CatBoostRegressor(iterations=100, learning_rate=0.1, depth=6, verbose=0)
62+
regressor.fit(features, target)
63+
return regressor
4564

4665

4766
def main() -> None:
4867
"""
49-
Demonstrates the training and evaluation of a CatBoost classifier
50-
on the Iris dataset, displaying a confusion matrix of the results.
68+
Main function to run the CatBoost Regressor example.
5169
52-
The dataset is split into training and testing sets, the model is
53-
trained on the training data, and then evaluated on the test data.
54-
A normalized confusion matrix is displayed.
70+
It loads the data, splits it into training and testing sets,
71+
trains the regressor on the training data, and evaluates its performance
72+
on the test data.
5573
"""
56-
57-
# Load the Iris dataset
58-
iris = load_iris()
59-
features, targets = data_handling(iris)
74+
# Load and split the dataset
75+
features, target = data_handling()
6076
x_train, x_test, y_train, y_test = train_test_split(
61-
features, targets, test_size=0.25
77+
features, target, test_size=0.25, random_state=42
6278
)
6379

64-
# Train a CatBoost classifier
65-
catboost_classifier = catboost(x_train, y_train)
66-
67-
# Display the confusion matrix for the test data
68-
ConfusionMatrixDisplay.from_estimator(
69-
catboost_classifier,
70-
x_test,
71-
y_test,
72-
display_labels=iris["target_names"],
73-
cmap="Blues",
74-
normalize="true",
75-
)
76-
plt.title("Normalized Confusion Matrix - IRIS Dataset")
77-
plt.show()
80+
# Train CatBoost Regressor
81+
regressor = catboost_regressor(x_train, y_train)
82+
83+
# Predict on the test set
84+
predictions = regressor.predict(x_test)
85+
86+
# Evaluate the performance using Mean Squared Error
87+
mse = mean_squared_error(y_test, predictions)
88+
print(f"Mean Squared Error on Test Set: {mse:.4f}")
7889

7990

8091
if __name__ == "__main__":

0 commit comments

Comments
 (0)