Skip to content

Commit d49f9eb

Browse files
authored
Set random state when splitting data (#762)
1 parent 0ffc235 commit d49f9eb

File tree

2 files changed

+30
-29
lines changed

2 files changed

+30
-29
lines changed

packages/scikit-learn/index.rst

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,8 @@ One good method to keep in mind is Gaussian Naive Bayes
570570
>>> from sklearn.model_selection import train_test_split
571571

572572
>>> # split the data into training and validation sets
573-
>>> X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target)
573+
>>> X_train, X_test, y_train, y_test = train_test_split(
574+
... digits.data, digits.target, random_state=42)
574575

575576
>>> # train the model
576577
>>> clf = GaussianNB()
@@ -581,9 +582,9 @@ One good method to keep in mind is Gaussian Naive Bayes
581582
>>> predicted = clf.predict(X_test)
582583
>>> expected = y_test
583584
>>> print(predicted)
584-
[5 1 7 2 8 9 4 3 9 3 6 2 3 2 6 7 4 3 5 7 5 7 0 1 2 5 9 8 1 8...]
585+
[6 9 3 7 2 2 5 8 5 2 1 1 7 0 4 8 3 7 8 8 4 3 9 7 5 6 3 5 6 3...]
585586
>>> print(expected)
586-
[5 8 7 2 8 9 4 3 7 3 6 2 3 2 6 7 4 3 5 7 5 7 0 1 2 5 3 3 1 8...]
587+
[6 9 3 7 2 1 5 2 5 2 1 9 4 0 4 2 3 7 8 8 4 3 9 7 5 6 3 5 6 3...]
587588

588589
As above, we plot the digits with the predicted labels to get an idea of
589590
how well the classification is working.
@@ -607,11 +608,11 @@ the number of matches::
607608

608609
>>> matches = (predicted == expected)
609610
>>> print(matches.sum())
610-
371
611+
385
611612
>>> print(len(matches))
612613
450
613614
>>> matches.sum() / float(len(matches))
614-
0.82444...
615+
0.8555...
615616

616617
We see that more than 80% of the 450 predictions match the input. But
617618
there are other more sophisticated metrics that can be used to judge the
@@ -625,20 +626,20 @@ combines several measures and prints a table with the results::
625626
>>> print(metrics.classification_report(expected, predicted))
626627
precision recall f1-score support
627628
<BLANKLINE>
628-
0 1.00 0.98 0.99 45
629-
1 0.91 0.66 0.76 44
630-
2 0.91 0.56 0.69 36
631-
3 0.89 0.67 0.77 49
632-
4 0.95 0.83 0.88 46
633-
5 0.93 0.93 0.93 45
634-
6 0.92 0.98 0.95 47
635-
7 0.75 0.96 0.84 50
636-
8 0.49 0.97 0.66 39
637-
9 0.85 0.67 0.75 49
629+
0 1.00 0.95 0.98 43
630+
1 0.85 0.78 0.82 37
631+
2 0.85 0.61 0.71 38
632+
3 0.97 0.83 0.89 46
633+
4 0.98 0.84 0.90 55
634+
5 0.90 0.95 0.93 59
635+
6 0.90 0.96 0.92 45
636+
7 0.71 0.98 0.82 41
637+
8 0.60 0.89 0.72 38
638+
9 0.90 0.73 0.80 48
638639
<BLANKLINE>
639-
accuracy 0.82 450
640-
macro avg 0.86 0.82 0.82 450
641-
weighted avg 0.86 0.82 0.83 450
640+
accuracy 0.86 450
641+
macro avg 0.87 0.85 0.85 450
642+
weighted avg 0.88 0.86 0.86 450
642643
<BLANKLINE>
643644

644645

@@ -647,16 +648,16 @@ is a *confusion matrix*: it helps us visualize which labels are being
647648
interchanged in the classification errors::
648649

649650
>>> print(metrics.confusion_matrix(expected, predicted))
650-
[[44 0 0 0 0 0 0 0 0 1]
651-
[ 0 29 0 0 0 0 1 6 6 2]
652-
[ 0 1 20 1 0 0 0 0 14 0]
653-
[ 0 0 0 33 0 2 0 1 11 2]
654-
[ 0 0 0 0 38 1 2 4 1 0]
655-
[ 0 0 0 0 0 42 1 0 2 0]
656-
[ 0 0 0 0 0 0 46 0 1 0]
657-
[ 0 0 0 0 1 0 0 48 0 1]
658-
[ 0 1 0 0 0 0 0 0 38 0]
659-
[ 0 1 2 3 1 0 0 5 4 33]]
651+
[[41 0 0 0 0 1 0 1 0 0]
652+
[ 0 29 2 0 0 0 0 0 4 2]
653+
[ 0 2 23 0 0 0 1 0 12 0]
654+
[ 0 0 1 38 0 1 0 0 5 1]
655+
[ 0 0 0 0 46 0 2 7 0 0]
656+
[ 0 0 0 0 0 56 1 1 0 1]
657+
[ 0 0 0 0 1 1 43 0 0 0]
658+
[ 0 0 0 0 0 1 0 40 0 0]
659+
[ 0 2 0 0 0 0 0 2 34 0]
660+
[ 0 1 1 1 0 2 1 5 2 35]]
660661

661662
We see here that in particular, the numbers 1, 2, 3, and 9 are often
662663
being labeled 8.

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ matplotlib==3.9.0
44
pandas==2.2.2
55
patsy==0.5.6
66
pyarrow==16.1.0
7-
scikit-learn==1.4.2
7+
scikit-learn==1.5.0
88
scikit-image==0.23.2
99
sympy==1.12.1
1010
statsmodels==0.14.2

0 commit comments

Comments
 (0)