Skip to content

Commit bc690dd

Browse files
authored
Update xgboostclassifier.py
1 parent 0519814 commit bc690dd

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

machine_learning/xgboostclassifier.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
# XGBoost Classifier Example
2-
from doctest import testmod
3-
42
from matplotlib import pyplot as plt
53
from sklearn.datasets import load_iris
64
from sklearn.metrics import plot_confusion_matrix
75
from sklearn.model_selection import train_test_split
86
from xgboost import XGBClassifier
97

108

11-
def data_handling(data) -> list:
9+
def data_handling(data: list) -> tuple:
1210
# Split dataset into train and test data
1311
x = data["data"] # features
1412
y = data["target"]
1513
return x, y
1614

1715

18-
def xgboost(features, target):
16+
def xgboost(features: list, target: list):
1917
classifier = XGBClassifier()
2018
classifier.fit(features, target)
2119
return classifier
@@ -54,4 +52,4 @@ def main() -> None:
5452

5553

5654
if __name__ == "__main__":
57-
testmod(name="main", verbose=True)
55+
main()

0 commit comments

Comments
 (0)