File tree Expand file tree Collapse file tree 2 files changed +26
-7
lines changed Expand file tree Collapse file tree 2 files changed +26
-7
lines changed Original file line number Diff line number Diff line change 2
2
3
3
import numpy as np
4
4
import pandas as pd
5
+ from scipy import sparse
5
6
from sklearn .base import BaseEstimator , TransformerMixin
6
7
from sklearn import cross_validation
7
8
from sklearn import grid_search
@@ -55,11 +56,7 @@ def transform(self, X):
55
56
56
57
57
58
def _handle_feature (fea ):
58
- if hasattr (fea , 'toarray' ):
59
- # sparse arrays should be converted to regular arrays
60
- # for hstack.
61
- fea = fea .toarray ()
62
-
59
+ # convert 1-dimensional arrays to 2-dimensional column vectors
63
60
if len (fea .shape ) == 1 :
64
61
fea = np .array ([fea ]).T
65
62
@@ -156,4 +153,11 @@ def transform(self, X):
156
153
# at this point we lose track of which features
157
154
# were created from which input columns, so it's
158
155
# assumed that that doesn't matter to the model.
159
- return np .hstack (extracted )
156
+
157
+ # If any of the extracted features is sparse, combine to produce a
158
+ # sparse matrix. Otherwise, produce a dense one.
159
+ if any (sparse .issparse (fea ) for fea in extracted ):
160
+ stacked = sparse .hstack (extracted ).tocsr ()
161
+ else :
162
+ stacked = np .hstack (extracted )
163
+ return stacked
Original file line number Diff line number Diff line change 9
9
10
10
from pandas import DataFrame
11
11
import pandas as pd
12
+ from scipy import sparse
12
13
from sklearn .datasets import load_iris
13
14
from sklearn .pipeline import Pipeline
14
15
from sklearn .svm import SVC
15
16
from sklearn .feature_extraction .text import CountVectorizer
16
- from sklearn .preprocessing import Imputer , StandardScaler
17
+ from sklearn .preprocessing import Imputer , StandardScaler , LabelBinarizer
17
18
import numpy as np
18
19
19
20
from sklearn_pandas import (
@@ -140,3 +141,17 @@ def test_list_transformers():
140
141
# all features have mean 0 and std deviation 1 (standardized)
141
142
assert (abs (dmatrix .mean (axis = 0 ) - 0 ) <= 1e-6 ).all ()
142
143
assert (abs (dmatrix .std (axis = 0 ) - 1 ) <= 1e-6 ).all ()
144
+
145
+
146
+ def test_sparse_features (cars_dataframe ):
147
+ """
148
+ If any of the extracted features is sparse, the hstacked
149
+ is also sparse.
150
+ """
151
+ mapper = DataFrameMapper ([
152
+ ("description" , CountVectorizer ()), # sparse feature
153
+ ("model" , LabelBinarizer ()), # dense feature
154
+ ])
155
+ dmatrix = mapper .fit_transform (cars_dataframe )
156
+
157
+ assert type (dmatrix ) == sparse .csr .csr_matrix
You can’t perform that action at this time.
0 commit comments