Skip to content

Commit 8cb5708

Browse files
glemaitreviclafargue
authored andcommitted
FIX/ENH CheckingClassifier support parameters and sparse matrices (scikit-learn#17259)
1 parent d3063a0 commit 8cb5708

File tree

2 files changed

+216
-27
lines changed

2 files changed

+216
-27
lines changed

sklearn/utils/_mocking.py

+108-27
Original file line numberDiff line numberDiff line change
@@ -51,31 +51,46 @@ def __ne__(self, other):
5151
class CheckingClassifier(ClassifierMixin, BaseEstimator):
5252
"""Dummy classifier to test pipelining and meta-estimators.
5353
54-
Checks some property of X and y in fit / predict.
54+
Checks some property of `X` and `y`in fit / predict.
5555
This allows testing whether pipelines / cross-validation or metaestimators
5656
changed the input.
5757
5858
Parameters
5959
----------
60-
check_y
61-
check_X
62-
foo_param
63-
expected_fit_params
60+
check_y, check_X : callable, default=None
61+
The callable used to validate `X` and `y`. These callable should return
62+
a bool where `False` will trigger an `AssertionError`.
63+
64+
check_y_params, check_X_params : dict, default=None
65+
The optional parameters to pass to `check_X` and `check_y`.
66+
67+
foo_param : int, default=0
68+
A `foo` param. When `foo > 1`, the output of :meth:`score` will be 1
69+
otherwise it is 0.
70+
71+
expected_fit_params : list of str, default=None
72+
A list of the expected parameters given when calling `fit`.
6473
6574
Attributes
6675
----------
67-
classes_
76+
classes_ : int
77+
The classes seen during `fit`.
78+
79+
n_features_in_ : int
80+
The number of features seen during `fit`.
6881
"""
69-
def __init__(self, check_y=None, check_X=None, foo_param=0,
82+
def __init__(self, *, check_y=None, check_y_params=None,
83+
check_X=None, check_X_params=None, foo_param=0,
7084
expected_fit_params=None):
7185
self.check_y = check_y
86+
self.check_y_params = check_y_params
7287
self.check_X = check_X
88+
self.check_X_params = check_X_params
7389
self.foo_param = foo_param
7490
self.expected_fit_params = expected_fit_params
7591

7692
def fit(self, X, y, **fit_params):
77-
"""
78-
Fit classifier
93+
"""Fit classifier.
7994
8095
Parameters
8196
----------
@@ -89,48 +104,114 @@ def fit(self, X, y, **fit_params):
89104
90105
**fit_params : dict of string -> object
91106
Parameters passed to the ``fit`` method of the estimator
107+
108+
Returns
109+
-------
110+
self
92111
"""
93-
assert len(X) == len(y)
112+
assert _num_samples(X) == _num_samples(y)
94113
if self.check_X is not None:
95-
assert self.check_X(X)
114+
params = {} if self.check_X_params is None else self.check_X_params
115+
assert self.check_X(X, **params)
96116
if self.check_y is not None:
117+
params = {} if self.check_y_params is None else self.check_y_params
97118
assert self.check_y(y)
98-
self.n_features_in_ = len(X)
99-
self.classes_ = np.unique(check_array(y, ensure_2d=False,
100-
allow_nd=True))
119+
self.n_features_in_ = np.shape(X)[1]
120+
self.classes_ = np.unique(
121+
check_array(y, ensure_2d=False, allow_nd=True)
122+
)
101123
if self.expected_fit_params:
102124
missing = set(self.expected_fit_params) - set(fit_params)
103-
assert len(missing) == 0, 'Expected fit parameter(s) %s not ' \
104-
'seen.' % list(missing)
125+
if missing:
126+
raise AssertionError(
127+
f'Expected fit parameter(s) {list(missing)} not seen.'
128+
)
105129
for key, value in fit_params.items():
106-
assert len(value) == len(X), (
107-
'Fit parameter %s has length %d; '
108-
'expected %d.'
109-
% (key, len(value), len(X)))
130+
if _num_samples(value) != _num_samples(X):
131+
raise AssertionError(
132+
f'Fit parameter {key} has length {_num_samples(value)}'
133+
f'; expected {_num_samples(X)}.'
134+
)
110135

111136
return self
112137

113-
def predict(self, T):
114-
"""
138+
def predict(self, X):
139+
"""Predict the first class seen in `classes_`.
140+
115141
Parameters
116142
----------
117-
T : indexable, length n_samples
143+
X : array-like of shape (n_samples, n_features)
144+
The input data.
145+
146+
Returns
147+
-------
148+
preds : ndarray of shape (n_samples,)
149+
Predictions of the first class seens in `classes_`.
118150
"""
119151
if self.check_X is not None:
120-
assert self.check_X(T)
121-
return self.classes_[np.zeros(_num_samples(T), dtype=np.int)]
152+
params = {} if self.check_X_params is None else self.check_X_params
153+
assert self.check_X(X, **params)
154+
return self.classes_[np.zeros(_num_samples(X), dtype=np.int)]
122155

123-
def score(self, X=None, Y=None):
156+
def predict_proba(self, X):
157+
"""Predict probabilities for each class.
158+
159+
Here, the dummy classifier will provide a probability of 1 for the
160+
first class of `classes_` and 0 otherwise.
161+
162+
Parameters
163+
----------
164+
X : array-like of shape (n_samples, n_features)
165+
The input data.
166+
167+
Returns
168+
-------
169+
proba : ndarray of shape (n_samples, n_classes)
170+
The probabilities for each sample and class.
124171
"""
172+
proba = np.zeros((_num_samples(X), len(self.classes_)))
173+
proba[:, 0] = 1
174+
return proba
175+
176+
def decision_function(self, X):
177+
"""Confidence score.
178+
179+
Parameters
180+
----------
181+
X : array-like of shape (n_samples, n_features)
182+
The input data.
183+
184+
Returns
185+
-------
186+
decision : ndarray of shape (n_samples,) if n_classes == 2\
187+
else (n_samples, n_classes)
188+
Confidence score.
189+
"""
190+
if len(self.classes_) == 2:
191+
# for binary classifier, the confidence score is related to
192+
# classes_[1] and therefore should be null.
193+
return np.zeros(_num_samples(X))
194+
else:
195+
return self.predict_proba(X)
196+
197+
def score(self, X=None, Y=None):
198+
"""Fake score.
199+
125200
Parameters
126201
----------
127202
X : array-like of shape (n_samples, n_features)
128203
Input data, where n_samples is the number of samples and
129204
n_features is the number of features.
130205
131-
Y : array-like of shape (n_samples, n_output) or (n_samples,), optional
206+
Y : array-like of shape (n_samples, n_output) or (n_samples,)
132207
Target relative to X for classification or regression;
133208
None for unsupervised learning.
209+
210+
Returns
211+
-------
212+
score : float
213+
Either 0 or 1 depending of `foo_param` (i.e. `foo_param > 1 =>
214+
score=1` otherwise `score=0`).
134215
"""
135216
if self.foo_param > 1:
136217
score = 1.

sklearn/utils/tests/test_mocking.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import numpy as np
2+
import pytest
3+
from scipy import sparse
4+
5+
from numpy.testing import assert_array_equal
6+
from numpy.testing import assert_allclose
7+
8+
from sklearn.datasets import load_iris
9+
from sklearn.utils import check_array
10+
from sklearn.utils import _safe_indexing
11+
from sklearn.utils._testing import _convert_container
12+
13+
from sklearn.utils._mocking import CheckingClassifier
14+
15+
16+
@pytest.fixture
17+
def iris():
18+
return load_iris(return_X_y=True)
19+
20+
21+
@pytest.mark.parametrize(
22+
"input_type", ["list", "array", "sparse", "dataframe"]
23+
)
24+
def test_checking_classifier(iris, input_type):
25+
# Check that the CheckingClassifier outputs what we expect
26+
X, y = iris
27+
X = _convert_container(X, input_type)
28+
clf = CheckingClassifier()
29+
clf.fit(X, y)
30+
31+
assert_array_equal(clf.classes_, np.unique(y))
32+
assert len(clf.classes_) == 3
33+
assert clf.n_features_in_ == 4
34+
35+
y_pred = clf.predict(X)
36+
assert_array_equal(y_pred, np.zeros(y_pred.size, dtype=np.int))
37+
38+
assert clf.score(X) == pytest.approx(0)
39+
clf.set_params(foo_param=10)
40+
assert clf.fit(X, y).score(X) == pytest.approx(1)
41+
42+
y_proba = clf.predict_proba(X)
43+
assert y_proba.shape == (150, 3)
44+
assert_allclose(y_proba[:, 0], 1)
45+
assert_allclose(y_proba[:, 1:], 0)
46+
47+
y_decision = clf.decision_function(X)
48+
assert y_decision.shape == (150, 3)
49+
assert_allclose(y_decision[:, 0], 1)
50+
assert_allclose(y_decision[:, 1:], 0)
51+
52+
# check the shape in case of binary classification
53+
first_2_classes = np.logical_or(y == 0, y == 1)
54+
X = _safe_indexing(X, first_2_classes)
55+
y = _safe_indexing(y, first_2_classes)
56+
clf.fit(X, y)
57+
58+
y_proba = clf.predict_proba(X)
59+
assert y_proba.shape == (100, 2)
60+
assert_allclose(y_proba[:, 0], 1)
61+
assert_allclose(y_proba[:, 1], 0)
62+
63+
y_decision = clf.decision_function(X)
64+
assert y_decision.shape == (100,)
65+
assert_allclose(y_decision, 0)
66+
67+
68+
def test_checking_classifier_with_params(iris):
69+
X, y = iris
70+
X_sparse = sparse.csr_matrix(X)
71+
72+
def check_X_is_sparse(X):
73+
if not sparse.issparse(X):
74+
raise ValueError("X is not sparse")
75+
return True
76+
77+
clf = CheckingClassifier(check_X=check_X_is_sparse)
78+
with pytest.raises(ValueError, match="X is not sparse"):
79+
clf.fit(X, y)
80+
clf.fit(X_sparse, y)
81+
82+
def _check_array(X, **params):
83+
check_array(X, **params)
84+
return True
85+
86+
clf = CheckingClassifier(
87+
check_X=_check_array, check_X_params={"accept_sparse": False}
88+
)
89+
clf.fit(X, y)
90+
with pytest.raises(TypeError, match="A sparse matrix was passed"):
91+
clf.fit(X_sparse, y)
92+
93+
94+
def test_checking_classifier_fit_params(iris):
95+
# check the error raised when the number of samples is not the one expected
96+
X, y = iris
97+
clf = CheckingClassifier(expected_fit_params=["sample_weight"])
98+
sample_weight = np.ones(len(X) // 2)
99+
100+
with pytest.raises(AssertionError, match="Fit parameter sample_weight"):
101+
clf.fit(X, y, sample_weight=sample_weight)
102+
103+
104+
def test_checking_classifier_missing_fit_params(iris):
105+
X, y = iris
106+
clf = CheckingClassifier(expected_fit_params=["sample_weight"])
107+
with pytest.raises(AssertionError, match="Expected fit parameter"):
108+
clf.fit(X, y)

0 commit comments

Comments
 (0)