2
2
try :
3
3
from sklearn .model_selection import cross_val_score as sk_cross_val_score
4
4
from sklearn .model_selection import GridSearchCV as SKGridSearchCV
5
- from sklearn .model_selection import RandomizedSearchCV as SKRandomizedSearchCV
5
+ from sklearn .model_selection import RandomizedSearchCV as \
6
+ SKRandomizedSearchCV
6
7
except ImportError :
7
8
from sklearn .cross_validation import cross_val_score as sk_cross_val_score
8
9
from sklearn .grid_search import GridSearchCV as SKGridSearchCV
@@ -21,33 +22,40 @@ def cross_val_score(model, X, *args, **kwargs):
21
22
22
23
23
24
class GridSearchCV (SKGridSearchCV ):
25
+
24
26
def __init__ (self , * args , ** kwargs ):
25
27
warnings .warn (DEPRECATION_MSG , DeprecationWarning )
26
28
super (GridSearchCV , self ).__init__ (* args , ** kwargs )
27
29
28
30
def fit (self , X , * params , ** kwparams ):
29
- return super (GridSearchCV , self ).fit (DataWrapper (X ), * params , ** kwparams )
31
+ return super (GridSearchCV , self ).fit (
32
+ DataWrapper (X ), * params , ** kwparams )
30
33
31
34
def predict (self , X , * params , ** kwparams ):
32
- return super (GridSearchCV , self ).predict (DataWrapper (X ), * params , ** kwparams )
35
+ return super (GridSearchCV , self ).predict (
36
+ DataWrapper (X ), * params , ** kwparams )
33
37
34
38
35
39
try :
36
40
class RandomizedSearchCV (SKRandomizedSearchCV ):
41
+
37
42
def __init__ (self , * args , ** kwargs ):
38
43
warnings .warn (DEPRECATION_MSG , DeprecationWarning )
39
44
super (RandomizedSearchCV , self ).__init__ (* args , ** kwargs )
40
45
41
46
def fit (self , X , * params , ** kwparams ):
42
- return super (RandomizedSearchCV , self ).fit (DataWrapper (X ), * params , ** kwparams )
47
+ return super (RandomizedSearchCV , self ).fit (
48
+ DataWrapper (X ), * params , ** kwparams )
43
49
44
50
def predict (self , X , * params , ** kwparams ):
45
- return super (RandomizedSearchCV , self ).predict (DataWrapper (X ), * params , ** kwparams )
51
+ return super (RandomizedSearchCV , self ).predict (
52
+ DataWrapper (X ), * params , ** kwparams )
46
53
except AttributeError :
47
54
pass
48
55
49
56
50
57
class DataWrapper (object ):
58
+
51
59
def __init__ (self , df ):
52
60
self .df = df
53
61
0 commit comments