From 74ad89e43cb1878e3e835d706713b39eed477baa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Israel=20Saeta=20P=C3=A9rez?= Date: Sat, 14 Jan 2017 21:01:46 +0100 Subject: [PATCH 1/2] Refactor tox.ini to test with sklearn 0.17 and 0.18 --- tox.ini | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tox.ini b/tox.ini index 7258d73..b30e99f 100644 --- a/tox.ini +++ b/tox.ini @@ -1,19 +1,19 @@ [tox] -envlist = py27, py34 +envlist = {py27,py35}-sklearn{17,18} [testenv] deps = - pip==7.0.1 - pytest==2.7.1 + pytest==3.0.5 setuptools==16.0 wheel==0.24.0 flake8==2.4.1 + numpy==1.11.3 + scipy==0.18.1 + pandas==0.19.2 + sklearn17: scikit-learn==0.17.1 + sklearn18: scikit-learn==0.18.1 py27: mock==1.3.0 commands = - pip install numpy --no-index - pip install scipy --no-index - pip install pandas --no-index - pip install scikit-learn --no-index flake8 tests py.test From 3dbb369ae928a29d77424ea1d0559896b93de3af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Israel=20Saeta=20P=C3=A9rez?= Date: Sat, 14 Jan 2017 21:03:24 +0100 Subject: [PATCH 2/2] Import from sklearn 0.18 new module sklearn.model_selection to avoid DeprecationWarnings --- README.rst | 2 +- sklearn_pandas/cross_validation.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/README.rst b/README.rst index 507eed4..242876b 100644 --- a/README.rst +++ b/README.rst @@ -220,7 +220,7 @@ To get around this, sklearn-pandas provides a wrapper on sklearn's ``cross_val_s >>> pipe = sklearn.pipeline.Pipeline([ ... ('featurize', mapper), ... ('lm', sklearn.linear_model.LinearRegression())]) - >>> np.round(cross_val_score(pipe, data.copy(), data.salary, 'r2'), 2) + >>> np.round(cross_val_score(pipe, X=data.copy(), y=data.salary, scoring='r2'), 2) array([ -1.09, -5.3 , -15.38]) Sklearn-pandas' ``cross_val_score`` function provides exactly the same interface as sklearn's function of the same name. diff --git a/sklearn_pandas/cross_validation.py b/sklearn_pandas/cross_validation.py index 2e5d6f9..911b4c0 100644 --- a/sklearn_pandas/cross_validation.py +++ b/sklearn_pandas/cross_validation.py @@ -1,6 +1,12 @@ import warnings -from sklearn import cross_validation -from sklearn import grid_search +try: + from sklearn.model_selection import cross_val_score as sk_cross_val_score + from sklearn.model_selection import GridSearchCV as SKGridSearchCV + from sklearn.model_selection import RandomizedSearchCV as SKRandomizedSearchCV +except ImportError: + from sklearn.cross_validation import cross_val_score as sk_cross_val_score + from sklearn.grid_search import GridSearchCV as SKGridSearchCV + from sklearn.grid_search import RandomizedSearchCV as SKRandomizedSearchCV DEPRECATION_MSG = ''' Custom cross-validation compatibility shims are no longer needed for @@ -11,10 +17,10 @@ def cross_val_score(model, X, *args, **kwargs): warnings.warn(DEPRECATION_MSG, DeprecationWarning) X = DataWrapper(X) - return cross_validation.cross_val_score(model, X, *args, **kwargs) + return sk_cross_val_score(model, X, *args, **kwargs) -class GridSearchCV(grid_search.GridSearchCV): +class GridSearchCV(SKGridSearchCV): def __init__(self, *args, **kwargs): warnings.warn(DEPRECATION_MSG, DeprecationWarning) super(GridSearchCV, self).__init__(*args, **kwargs) @@ -27,7 +33,7 @@ def predict(self, X, *params, **kwparams): try: - class RandomizedSearchCV(grid_search.RandomizedSearchCV): + class RandomizedSearchCV(SKRandomizedSearchCV): def __init__(self, *args, **kwargs): warnings.warn(DEPRECATION_MSG, DeprecationWarning) super(RandomizedSearchCV, self).__init__(*args, **kwargs)