Skip to content

Commit 758ae1d

Browse files
committed
modified: machine_learning/catboost_regressor.py
1 parent ff79acd commit 758ae1d

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

machine_learning/catboost_regressor.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -12,43 +12,44 @@
1212
"""
1313

1414
import numpy as np
15-
from sklearn.datasets import load_boston
15+
from sklearn.datasets import fetch_california_housing
1616
from sklearn.model_selection import train_test_split
1717
from sklearn.metrics import mean_squared_error
1818
from catboost import CatBoostRegressor
1919

2020

2121
def data_handling() -> tuple:
2222
"""
23-
Loads and handles the dataset, splitting it into features and targets.
24-
25-
The Boston dataset is used as a regression example.
26-
23+
Loads and handles the California Housing dataset (replacement for deprecated Boston dataset).
24+
2725
Returns:
2826
tuple: A tuple of (features, target), where both are numpy arrays.
2927
3028
Example:
3129
>>> features, target = data_handling()
30+
>>> isinstance(features, np.ndarray)
31+
True
32+
>>> isinstance(target, np.ndarray)
33+
True
3234
>>> features.shape
33-
(506, 13)
35+
(20640, 8)
3436
>>> target.shape
35-
(506,)
37+
(20640,)
3638
"""
37-
# Load Boston dataset (note: this dataset may be deprecated, replace if needed)
38-
boston = load_boston()
39-
features = boston.data
40-
target = boston.target
39+
housing = fetch_california_housing()
40+
features = housing.data
41+
target = housing.target
4142
return features, target
4243

4344

4445
def catboost_regressor(features: np.ndarray, target: np.ndarray) -> CatBoostRegressor:
4546
"""
4647
Trains a CatBoostRegressor using the provided features and target values.
47-
48+
4849
Args:
4950
features (np.ndarray): The input features for the regression model.
5051
target (np.ndarray): The target values for the regression model.
51-
52+
5253
Returns:
5354
CatBoostRegressor: A trained CatBoost regressor model.
5455
@@ -66,10 +67,14 @@ def catboost_regressor(features: np.ndarray, target: np.ndarray) -> CatBoostRegr
6667
def main() -> None:
6768
"""
6869
Main function to run the CatBoost Regressor example.
69-
70+
7071
It loads the data, splits it into training and testing sets,
7172
trains the regressor on the training data, and evaluates its performance
7273
on the test data.
74+
75+
Example:
76+
>>> main()
77+
Mean Squared Error on Test Set:
7378
"""
7479
# Load and split the dataset
7580
features, target = data_handling()
@@ -90,6 +95,5 @@ def main() -> None:
9095

9196
if __name__ == "__main__":
9297
import doctest
93-
9498
doctest.testmod(verbose=True)
9599
main()

0 commit comments

Comments
 (0)