Skip to content

Commit 3dbb369

Browse files
committed
Import from sklearn 0.18 new module sklearn.model_selection to avoid DeprecationWarnings
1 parent 74ad89e commit 3dbb369

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ To get around this, sklearn-pandas provides a wrapper on sklearn's ``cross_val_s
220220
>>> pipe = sklearn.pipeline.Pipeline([
221221
... ('featurize', mapper),
222222
... ('lm', sklearn.linear_model.LinearRegression())])
223-
>>> np.round(cross_val_score(pipe, data.copy(), data.salary, 'r2'), 2)
223+
>>> np.round(cross_val_score(pipe, X=data.copy(), y=data.salary, scoring='r2'), 2)
224224
array([ -1.09, -5.3 , -15.38])
225225

226226
Sklearn-pandas' ``cross_val_score`` function provides exactly the same interface as sklearn's function of the same name.

sklearn_pandas/cross_validation.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import warnings
2-
from sklearn import cross_validation
3-
from sklearn import grid_search
2+
try:
3+
from sklearn.model_selection import cross_val_score as sk_cross_val_score
4+
from sklearn.model_selection import GridSearchCV as SKGridSearchCV
5+
from sklearn.model_selection import RandomizedSearchCV as SKRandomizedSearchCV
6+
except ImportError:
7+
from sklearn.cross_validation import cross_val_score as sk_cross_val_score
8+
from sklearn.grid_search import GridSearchCV as SKGridSearchCV
9+
from sklearn.grid_search import RandomizedSearchCV as SKRandomizedSearchCV
410

511
DEPRECATION_MSG = '''
612
Custom cross-validation compatibility shims are no longer needed for
@@ -11,10 +17,10 @@
1117
def cross_val_score(model, X, *args, **kwargs):
1218
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
1319
X = DataWrapper(X)
14-
return cross_validation.cross_val_score(model, X, *args, **kwargs)
20+
return sk_cross_val_score(model, X, *args, **kwargs)
1521

1622

17-
class GridSearchCV(grid_search.GridSearchCV):
23+
class GridSearchCV(SKGridSearchCV):
1824
def __init__(self, *args, **kwargs):
1925
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
2026
super(GridSearchCV, self).__init__(*args, **kwargs)
@@ -27,7 +33,7 @@ def predict(self, X, *params, **kwparams):
2733

2834

2935
try:
30-
class RandomizedSearchCV(grid_search.RandomizedSearchCV):
36+
class RandomizedSearchCV(SKRandomizedSearchCV):
3137
def __init__(self, *args, **kwargs):
3238
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
3339
super(RandomizedSearchCV, self).__init__(*args, **kwargs)

0 commit comments

Comments
 (0)