Skip to content

Import from sklearn.model_selection to avoid deprecation warnings #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 16, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ To get around this, sklearn-pandas provides a wrapper on sklearn's ``cross_val_s
>>> pipe = sklearn.pipeline.Pipeline([
... ('featurize', mapper),
... ('lm', sklearn.linear_model.LinearRegression())])
>>> np.round(cross_val_score(pipe, data.copy(), data.salary, 'r2'), 2)
>>> np.round(cross_val_score(pipe, X=data.copy(), y=data.salary, scoring='r2'), 2)
array([ -1.09, -5.3 , -15.38])

Sklearn-pandas' ``cross_val_score`` function provides exactly the same interface as sklearn's function of the same name.
Expand Down
16 changes: 11 additions & 5 deletions sklearn_pandas/cross_validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import warnings
from sklearn import cross_validation
from sklearn import grid_search
try:
from sklearn.model_selection import cross_val_score as sk_cross_val_score
from sklearn.model_selection import GridSearchCV as SKGridSearchCV
from sklearn.model_selection import RandomizedSearchCV as SKRandomizedSearchCV
except ImportError:
from sklearn.cross_validation import cross_val_score as sk_cross_val_score
from sklearn.grid_search import GridSearchCV as SKGridSearchCV
from sklearn.grid_search import RandomizedSearchCV as SKRandomizedSearchCV

DEPRECATION_MSG = '''
Custom cross-validation compatibility shims are no longer needed for
Expand All @@ -11,10 +17,10 @@
def cross_val_score(model, X, *args, **kwargs):
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
X = DataWrapper(X)
return cross_validation.cross_val_score(model, X, *args, **kwargs)
return sk_cross_val_score(model, X, *args, **kwargs)


class GridSearchCV(grid_search.GridSearchCV):
class GridSearchCV(SKGridSearchCV):
def __init__(self, *args, **kwargs):
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
super(GridSearchCV, self).__init__(*args, **kwargs)
Expand All @@ -27,7 +33,7 @@ def predict(self, X, *params, **kwparams):


try:
class RandomizedSearchCV(grid_search.RandomizedSearchCV):
class RandomizedSearchCV(SKRandomizedSearchCV):
def __init__(self, *args, **kwargs):
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
super(RandomizedSearchCV, self).__init__(*args, **kwargs)
Expand Down
14 changes: 7 additions & 7 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
[tox]
envlist = py27, py34
envlist = {py27,py35}-sklearn{17,18}

[testenv]
deps =
pip==7.0.1
pytest==2.7.1
pytest==3.0.5
setuptools==16.0
wheel==0.24.0
flake8==2.4.1
numpy==1.11.3
scipy==0.18.1
pandas==0.19.2
sklearn17: scikit-learn==0.17.1
sklearn18: scikit-learn==0.18.1
py27: mock==1.3.0

commands =
pip install numpy --no-index
pip install scipy --no-index
pip install pandas --no-index
pip install scikit-learn --no-index
flake8 tests
py.test