1
+ """
2
+ Implementation of the XGBoost Classifier.
3
+ Dataset - Iris Dataset
4
+ Know more about XGBoost - https://en.wikipedia.org/wiki/XGBoost
5
+ """
6
+
7
+ import pandas as pd
8
+ import matplotlib .pyplot as plt
9
+ from sklearn .datasets import load_iris
10
+ from sklearn .model_selection import train_test_split
11
+ from sklearn .metrics import accuracy_score , confusion_matrix , plot_confusion_matrix
12
+ from xgboost import XGBClassifier # might have to do `!pip install xgboost`
13
+ import warnings
14
+ warnings .filterwarnings ("ignore" )
15
+
16
+ def main ():
17
+ # load the iris dataset
18
+ iris = load_iris ()
19
+
20
+ print (type (iris )) # Currently the type is sklearn.utils.Bunch
21
+
22
+ # convert it to a dataframe
23
+ df = pd .DataFrame (iris .data , columns = iris .feature_names )
24
+
25
+ # add the target column
26
+ df ['target' ] = iris .target
27
+
28
+ print (df ['target' ].unique ()) # 0 -> sentosa, 1 -> versicolor, 2 -> virginica
29
+
30
+ # how the dataset looks
31
+ print (df .head ())
32
+
33
+ X = df .drop ('target' , axis = 1 )
34
+ y = df ['target' ]
35
+
36
+ # split dataset into training and testing sets
37
+ X_train , X_test , y_train , y_test = train_test_split (X , y , test_size = 0.25 , random_state = 0 )
38
+
39
+ # initialize the model
40
+ xgb = XGBClassifier (random_state = 0 )
41
+
42
+ # start the training on training set
43
+ xgb .fit (X_train , y_train )
44
+
45
+ # get the prediction on testing set
46
+ y_pred = xgb .predict (X_test )
47
+
48
+ # calculate the accuracy
49
+ acc = accuracy_score (y_test , y_pred )
50
+
51
+ # print the accuracy
52
+ print (f"Accuracy: { round (100 * acc , 2 )} %" )
53
+
54
+ # print the confusion matrix
55
+ print (confusion_matrix (y_test , y_pred ))
56
+
57
+ # plot the confusion matrix
58
+ plot_confusion_matrix (
59
+ xgb ,
60
+ X_test ,
61
+ y_test ,
62
+ display_labels = iris ["target_names" ],
63
+ cmap = "Blues" ,
64
+ normalize = "true" ,
65
+ )
66
+ plt .title ("Normalized Confusion Matrix - IRIS Dataset" )
67
+ plt .show ()
68
+
69
+ if __name__ == "__main__" :
70
+ main ()
0 commit comments