Skip to content

Commit 22d8913

Browse files
authored
1 parent 5f9cae8 commit 22d8913

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

machine_learning/xgboostclassifier.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,27 @@
44
from sklearn.metrics import plot_confusion_matrix
55
from sklearn.model_selection import train_test_split
66
from xgboost import XGBClassifier
7+
import numpy as np
78

89

910
def data_handling(data: dict) -> tuple:
1011
# Split dataset into train and test data
11-
features = data["data"] # data is features
12-
target = data["target"]
13-
x = train_test_split(features, target, test_size=0.25)
12+
# data is features
13+
"""
14+
>>> data_handling(({'data':'[5.1, 3.5, 1.4, 0.2],[4.6, 3.4, 1.4, 0.3]','target':([0,1])}))
15+
('[5.1, 3.5, 1.4, 0.2],[4.6, 3.4, 1.4, 0.3]', [0, 1])
16+
>>> data_handling({'data':'[4.9, 3. , 1.4, 0.2],[4.7, 3.2, 1.3, 0.2],[4.6, 3.1, 1.5, 0.2],[5. , 3.6, 1.4, 0.2],[5.4, 3.9, 1.7, 0.4]','target':([0,0, 0, 0, 0])})
17+
('[4.9, 3. , 1.4, 0.2],[4.7, 3.2, 1.3, 0.2],[4.6, 3.1, 1.5, 0.2],[5. , 3.6, 1.4, 0.2],[5.4, 3.9, 1.7, 0.4]', [0, 0, 0, 0, 0])
18+
"""
19+
x = (data["data"],data["target"])
1420
return x
1521

1622

17-
def xgboost(features: list, target: list): # -> returns a trained model:
23+
def xgboost(features: np.ndarray, target: np.ndarray): -> XGBClassifier:
24+
"""
25+
>>> xgboost(np.array([[5.1, 3.5, 1.4, 0.2],[4.6, 3.4, 1.4, 0.3]]), np.array([1,2]))
26+
XGBClassifier()
27+
"""
1828
classifier = XGBClassifier()
1929
classifier.fit(features, target)
2030
return classifier
@@ -23,17 +33,20 @@ def xgboost(features: list, target: list): # -> returns a trained model:
2333
def main() -> None:
2434

2535
"""
36+
>>> main()
37+
2638
The Url for the algorithm
2739
https://xgboost.readthedocs.io/en/stable/
2840
Iris type dataset is used to demonstrate algorithm.
2941
"""
3042

3143
# Load Iris dataset
3244
iris = load_iris()
45+
features,targets= data_handling(iris)
46+
x_train, x_test, y_train, y_test=train_test_split(features, targets, test_size=0.25)
3347

3448
names = iris["target_names"]
3549

36-
x_train, x_test, y_train, y_test = data_handling(iris)
3750

3851
# XGBoost Classifier
3952
xgb = xgboost(x_train, y_train)

0 commit comments

Comments
 (0)