Skip to content

Commit 10a43e4

Browse files
authored
Merge pull request #71 from paulgb/sklearn18
Import from sklearn.model_selection to avoid deprecation warnings
2 parents edfbe94 + 3dbb369 commit 10a43e4

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ To get around this, sklearn-pandas provides a wrapper on sklearn's ``cross_val_s
220220
>>> pipe = sklearn.pipeline.Pipeline([
221221
... ('featurize', mapper),
222222
... ('lm', sklearn.linear_model.LinearRegression())])
223-
>>> np.round(cross_val_score(pipe, data.copy(), data.salary, 'r2'), 2)
223+
>>> np.round(cross_val_score(pipe, X=data.copy(), y=data.salary, scoring='r2'), 2)
224224
array([ -1.09, -5.3 , -15.38])
225225

226226
Sklearn-pandas' ``cross_val_score`` function provides exactly the same interface as sklearn's function of the same name.

sklearn_pandas/cross_validation.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import warnings
2-
from sklearn import cross_validation
3-
from sklearn import grid_search
2+
try:
3+
from sklearn.model_selection import cross_val_score as sk_cross_val_score
4+
from sklearn.model_selection import GridSearchCV as SKGridSearchCV
5+
from sklearn.model_selection import RandomizedSearchCV as SKRandomizedSearchCV
6+
except ImportError:
7+
from sklearn.cross_validation import cross_val_score as sk_cross_val_score
8+
from sklearn.grid_search import GridSearchCV as SKGridSearchCV
9+
from sklearn.grid_search import RandomizedSearchCV as SKRandomizedSearchCV
410

511
DEPRECATION_MSG = '''
612
Custom cross-validation compatibility shims are no longer needed for
@@ -11,10 +17,10 @@
1117
def cross_val_score(model, X, *args, **kwargs):
1218
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
1319
X = DataWrapper(X)
14-
return cross_validation.cross_val_score(model, X, *args, **kwargs)
20+
return sk_cross_val_score(model, X, *args, **kwargs)
1521

1622

17-
class GridSearchCV(grid_search.GridSearchCV):
23+
class GridSearchCV(SKGridSearchCV):
1824
def __init__(self, *args, **kwargs):
1925
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
2026
super(GridSearchCV, self).__init__(*args, **kwargs)
@@ -27,7 +33,7 @@ def predict(self, X, *params, **kwparams):
2733

2834

2935
try:
30-
class RandomizedSearchCV(grid_search.RandomizedSearchCV):
36+
class RandomizedSearchCV(SKRandomizedSearchCV):
3137
def __init__(self, *args, **kwargs):
3238
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
3339
super(RandomizedSearchCV, self).__init__(*args, **kwargs)

tox.ini

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
[tox]
2-
envlist = py27, py34
2+
envlist = {py27,py35}-sklearn{17,18}
33

44
[testenv]
55
deps =
6-
pip==7.0.1
7-
pytest==2.7.1
6+
pytest==3.0.5
87
setuptools==16.0
98
wheel==0.24.0
109
flake8==2.4.1
10+
numpy==1.11.3
11+
scipy==0.18.1
12+
pandas==0.19.2
13+
sklearn17: scikit-learn==0.17.1
14+
sklearn18: scikit-learn==0.18.1
1115
py27: mock==1.3.0
1216

1317
commands =
14-
pip install numpy --no-index
15-
pip install scipy --no-index
16-
pip install pandas --no-index
17-
pip install scikit-learn --no-index
1818
flake8 tests
1919
py.test

0 commit comments

Comments
 (0)