Skip to content

Commit ae885db

Browse files
committed
If any of the extracted features is sparse, make the hstacked result sparse as well.
1 parent 70a224a commit ae885db

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

sklearn_pandas/__init__.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import pandas as pd
5+
from scipy import sparse
56
from sklearn.base import BaseEstimator, TransformerMixin
67
from sklearn import cross_validation
78
from sklearn import grid_search
@@ -55,11 +56,7 @@ def transform(self, X):
5556

5657

5758
def _handle_feature(fea):
58-
if hasattr(fea, 'toarray'):
59-
# sparse arrays should be converted to regular arrays
60-
# for hstack.
61-
fea = fea.toarray()
62-
59+
# convert 1-dimensional arrays to 2-dimensional column vectors
6360
if len(fea.shape) == 1:
6461
fea = np.array([fea]).T
6562

@@ -156,4 +153,11 @@ def transform(self, X):
156153
# at this point we lose track of which features
157154
# were created from which input columns, so it's
158155
# assumed that that doesn't matter to the model.
159-
return np.hstack(extracted)
156+
157+
# If any of the extracted features is sparse, combine to produce a
158+
# sparse matrix. Otherwise, produce a dense one.
159+
if any(sparse.issparse(fea) for fea in extracted):
160+
stacked = sparse.hstack(extracted).tocsr()
161+
else:
162+
stacked = np.hstack(extracted)
163+
return stacked

tests/test_dataframe_mapper.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99

1010
from pandas import DataFrame
1111
import pandas as pd
12+
from scipy import sparse
1213
from sklearn.datasets import load_iris
1314
from sklearn.pipeline import Pipeline
1415
from sklearn.svm import SVC
1516
from sklearn.feature_extraction.text import CountVectorizer
16-
from sklearn.preprocessing import Imputer, StandardScaler
17+
from sklearn.preprocessing import Imputer, StandardScaler, LabelBinarizer
1718
import numpy as np
1819

1920
from sklearn_pandas import (
@@ -140,3 +141,17 @@ def test_list_transformers():
140141
# all features have mean 0 and std deviation 1 (standardized)
141142
assert (abs(dmatrix.mean(axis=0) - 0) <= 1e-6).all()
142143
assert (abs(dmatrix.std(axis=0) - 1) <= 1e-6).all()
144+
145+
146+
def test_sparse_features(cars_dataframe):
147+
"""
148+
If any of the extracted features is sparse, the hstacked
149+
is also sparse.
150+
"""
151+
mapper = DataFrameMapper([
152+
("description", CountVectorizer()), # sparse feature
153+
("model", LabelBinarizer()), # dense feature
154+
])
155+
dmatrix = mapper.fit_transform(cars_dataframe)
156+
157+
assert type(dmatrix) == sparse.csr.csr_matrix

0 commit comments

Comments
 (0)