@@ -570,7 +570,8 @@ One good method to keep in mind is Gaussian Naive Bayes
570
570
>>> from sklearn.model_selection import train_test_split
571
571
572
572
>>> # split the data into training and validation sets
573
- >>> X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target)
573
+ >>> X_train, X_test, y_train, y_test = train_test_split(
574
+ ... digits.data, digits.target, random_state=42)
574
575
575
576
>>> # train the model
576
577
>>> clf = GaussianNB()
@@ -581,9 +582,9 @@ One good method to keep in mind is Gaussian Naive Bayes
581
582
>>> predicted = clf.predict(X_test)
582
583
>>> expected = y_test
583
584
>>> print(predicted)
584
- [5 1 7 2 8 9 4 3 9 3 6 2 3 2 6 7 4 3 5 7 5 7 0 1 2 5 9 8 1 8 ...]
585
+ [6 9 3 7 2 2 5 8 5 2 1 1 7 0 4 8 3 7 8 8 4 3 9 7 5 6 3 5 6 3 ...]
585
586
>>> print(expected)
586
- [5 8 7 2 8 9 4 3 7 3 6 2 3 2 6 7 4 3 5 7 5 7 0 1 2 5 3 3 1 8 ...]
587
+ [6 9 3 7 2 1 5 2 5 2 1 9 4 0 4 2 3 7 8 8 4 3 9 7 5 6 3 5 6 3 ...]
587
588
588
589
As above, we plot the digits with the predicted labels to get an idea of
589
590
how well the classification is working.
@@ -607,11 +608,11 @@ the number of matches::
607
608
608
609
>>> matches = (predicted == expected)
609
610
>>> print(matches.sum())
610
- 371
611
+ 385
611
612
>>> print(len(matches))
612
613
450
613
614
>>> matches.sum() / float(len(matches))
614
- 0.82444 ...
615
+ 0.8555 ...
615
616
616
617
We see that more than 80% of the 450 predictions match the input. But
617
618
there are other more sophisticated metrics that can be used to judge the
@@ -625,20 +626,20 @@ combines several measures and prints a table with the results::
625
626
>>> print(metrics.classification_report(expected, predicted))
626
627
precision recall f1-score support
627
628
<BLANKLINE>
628
- 0 1.00 0.98 0.99 45
629
- 1 0.91 0.66 0.76 44
630
- 2 0.91 0.56 0.69 36
631
- 3 0.89 0.67 0.77 49
632
- 4 0.95 0.83 0.88 46
633
- 5 0.93 0.93 0.93 45
634
- 6 0.92 0.98 0.95 47
635
- 7 0.75 0.96 0.84 50
636
- 8 0.49 0.97 0.66 39
637
- 9 0.85 0.67 0.75 49
629
+ 0 1.00 0.95 0.98 43
630
+ 1 0.85 0.78 0.82 37
631
+ 2 0.85 0.61 0.71 38
632
+ 3 0.97 0.83 0.89 46
633
+ 4 0.98 0.84 0.90 55
634
+ 5 0.90 0.95 0.93 59
635
+ 6 0.90 0.96 0.92 45
636
+ 7 0.71 0.98 0.82 41
637
+ 8 0.60 0.89 0.72 38
638
+ 9 0.90 0.73 0.80 48
638
639
<BLANKLINE>
639
- accuracy 0.82 450
640
- macro avg 0.86 0.82 0.82 450
641
- weighted avg 0.86 0.82 0.83 450
640
+ accuracy 0.86 450
641
+ macro avg 0.87 0.85 0.85 450
642
+ weighted avg 0.88 0.86 0.86 450
642
643
<BLANKLINE>
643
644
644
645
@@ -647,16 +648,16 @@ is a *confusion matrix*: it helps us visualize which labels are being
647
648
interchanged in the classification errors::
648
649
649
650
>>> print(metrics.confusion_matrix(expected, predicted))
650
- [[44 0 0 0 0 0 0 0 0 1 ]
651
- [ 0 29 0 0 0 0 1 6 6 2]
652
- [ 0 1 20 1 0 0 0 0 14 0]
653
- [ 0 0 0 33 0 2 0 1 11 2 ]
654
- [ 0 0 0 0 38 1 2 4 1 0]
655
- [ 0 0 0 0 0 42 1 0 2 0 ]
656
- [ 0 0 0 0 0 0 46 0 1 0]
657
- [ 0 0 0 0 1 0 0 48 0 1 ]
658
- [ 0 1 0 0 0 0 0 0 38 0]
659
- [ 0 1 2 3 1 0 0 5 4 33 ]]
651
+ [[41 0 0 0 0 1 0 1 0 0 ]
652
+ [ 0 29 2 0 0 0 0 0 4 2]
653
+ [ 0 2 23 0 0 0 1 0 12 0]
654
+ [ 0 0 1 38 0 1 0 0 5 1 ]
655
+ [ 0 0 0 0 46 0 2 7 0 0]
656
+ [ 0 0 0 0 0 56 1 1 0 1 ]
657
+ [ 0 0 0 0 1 1 43 0 0 0]
658
+ [ 0 0 0 0 0 1 0 40 0 0 ]
659
+ [ 0 2 0 0 0 0 0 2 34 0]
660
+ [ 0 1 1 1 0 2 1 5 2 35 ]]
660
661
661
662
We see here that in particular, the numbers 1, 2, 3, and 9 are often
662
663
being labeled 8.
0 commit comments