1
1
import warnings
2
- from sklearn import cross_validation
3
- from sklearn import grid_search
2
+ try :
3
+ from sklearn .model_selection import cross_val_score as sk_cross_val_score
4
+ from sklearn .model_selection import GridSearchCV as SKGridSearchCV
5
+ from sklearn .model_selection import RandomizedSearchCV as SKRandomizedSearchCV
6
+ except ImportError :
7
+ from sklearn .cross_validation import cross_val_score as sk_cross_val_score
8
+ from sklearn .grid_search import GridSearchCV as SKGridSearchCV
9
+ from sklearn .grid_search import RandomizedSearchCV as SKRandomizedSearchCV
4
10
5
11
DEPRECATION_MSG = '''
6
12
Custom cross-validation compatibility shims are no longer needed for
11
17
def cross_val_score (model , X , * args , ** kwargs ):
12
18
warnings .warn (DEPRECATION_MSG , DeprecationWarning )
13
19
X = DataWrapper (X )
14
- return cross_validation . cross_val_score (model , X , * args , ** kwargs )
20
+ return sk_cross_val_score (model , X , * args , ** kwargs )
15
21
16
22
17
- class GridSearchCV (grid_search . GridSearchCV ):
23
+ class GridSearchCV (SKGridSearchCV ):
18
24
def __init__ (self , * args , ** kwargs ):
19
25
warnings .warn (DEPRECATION_MSG , DeprecationWarning )
20
26
super (GridSearchCV , self ).__init__ (* args , ** kwargs )
@@ -27,7 +33,7 @@ def predict(self, X, *params, **kwparams):
27
33
28
34
29
35
try :
30
- class RandomizedSearchCV (grid_search . RandomizedSearchCV ):
36
+ class RandomizedSearchCV (SKRandomizedSearchCV ):
31
37
def __init__ (self , * args , ** kwargs ):
32
38
warnings .warn (DEPRECATION_MSG , DeprecationWarning )
33
39
super (RandomizedSearchCV , self ).__init__ (* args , ** kwargs )
0 commit comments