diff --git a/README.rst b/README.rst index 507eed4..242876b 100644 --- a/README.rst +++ b/README.rst @@ -220,7 +220,7 @@ To get around this, sklearn-pandas provides a wrapper on sklearn's ``cross_val_s >>> pipe = sklearn.pipeline.Pipeline([ ... ('featurize', mapper), ... ('lm', sklearn.linear_model.LinearRegression())]) - >>> np.round(cross_val_score(pipe, data.copy(), data.salary, 'r2'), 2) + >>> np.round(cross_val_score(pipe, X=data.copy(), y=data.salary, scoring='r2'), 2) array([ -1.09, -5.3 , -15.38]) Sklearn-pandas' ``cross_val_score`` function provides exactly the same interface as sklearn's function of the same name. diff --git a/sklearn_pandas/cross_validation.py b/sklearn_pandas/cross_validation.py index 2e5d6f9..911b4c0 100644 --- a/sklearn_pandas/cross_validation.py +++ b/sklearn_pandas/cross_validation.py @@ -1,6 +1,12 @@ import warnings -from sklearn import cross_validation -from sklearn import grid_search +try: + from sklearn.model_selection import cross_val_score as sk_cross_val_score + from sklearn.model_selection import GridSearchCV as SKGridSearchCV + from sklearn.model_selection import RandomizedSearchCV as SKRandomizedSearchCV +except ImportError: + from sklearn.cross_validation import cross_val_score as sk_cross_val_score + from sklearn.grid_search import GridSearchCV as SKGridSearchCV + from sklearn.grid_search import RandomizedSearchCV as SKRandomizedSearchCV DEPRECATION_MSG = ''' Custom cross-validation compatibility shims are no longer needed for @@ -11,10 +17,10 @@ def cross_val_score(model, X, *args, **kwargs): warnings.warn(DEPRECATION_MSG, DeprecationWarning) X = DataWrapper(X) - return cross_validation.cross_val_score(model, X, *args, **kwargs) + return sk_cross_val_score(model, X, *args, **kwargs) -class GridSearchCV(grid_search.GridSearchCV): +class GridSearchCV(SKGridSearchCV): def __init__(self, *args, **kwargs): warnings.warn(DEPRECATION_MSG, DeprecationWarning) super(GridSearchCV, self).__init__(*args, **kwargs) @@ -27,7 +33,7 @@ def predict(self, X, *params, **kwparams): try: - class RandomizedSearchCV(grid_search.RandomizedSearchCV): + class RandomizedSearchCV(SKRandomizedSearchCV): def __init__(self, *args, **kwargs): warnings.warn(DEPRECATION_MSG, DeprecationWarning) super(RandomizedSearchCV, self).__init__(*args, **kwargs) diff --git a/tox.ini b/tox.ini index 7258d73..b30e99f 100644 --- a/tox.ini +++ b/tox.ini @@ -1,19 +1,19 @@ [tox] -envlist = py27, py34 +envlist = {py27,py35}-sklearn{17,18} [testenv] deps = - pip==7.0.1 - pytest==2.7.1 + pytest==3.0.5 setuptools==16.0 wheel==0.24.0 flake8==2.4.1 + numpy==1.11.3 + scipy==0.18.1 + pandas==0.19.2 + sklearn17: scikit-learn==0.17.1 + sklearn18: scikit-learn==0.18.1 py27: mock==1.3.0 commands = - pip install numpy --no-index - pip install scipy --no-index - pip install pandas --no-index - pip install scikit-learn --no-index flake8 tests py.test