diff --git a/README.rst b/README.rst index 110556c..70b1095 100644 --- a/README.rst +++ b/README.rst @@ -44,6 +44,7 @@ For these examples, we'll also use pandas, numpy, and sklearn:: >>> import numpy as np >>> import sklearn.preprocessing, sklearn.decomposition, \ ... sklearn.linear_model, sklearn.pipeline, sklearn.metrics + >>> from sklearn.feature_extraction.text import CountVectorizer Load some Data ************** @@ -156,6 +157,20 @@ Only columns that are listed in the DataFrameMapper are kept. To keep a column b [ 1., 0., 0., 5.], [ 0., 0., 1., 4.]]) + +Working with sparse features +**************************** + +`DataFrameMapper`s will return a dense feature array by default. Setting `sparse=True` in the mapper will return a sparse array whenever any of the extracted features is sparse. Example: + + >>> mapper4 = DataFrameMapper([ + ... ('pet', CountVectorizer()), + ... ], sparse=True) + >>> type(mapper4.fit_transform(data)) + + +The stacking of the sparse features is done without ever densifying them. + Cross-Validation ---------------- @@ -179,6 +194,7 @@ Changelog ******************** * Raise ``KeyError`` when selecting unexistent columns in the dataframe. Fixes #30. +* Return sparse feature array if any of the features is sparse and `sparse` argument is `True`. Defaults to `False` to avoid potential breaking of existing code. Resolves #34. 0.0.12 (2015-11-07) diff --git a/sklearn_pandas/__init__.py b/sklearn_pandas/__init__.py index edfc0bf..4324684 100644 --- a/sklearn_pandas/__init__.py +++ b/sklearn_pandas/__init__.py @@ -2,6 +2,7 @@ import numpy as np import pandas as pd +from scipy import sparse from sklearn.base import BaseEstimator, TransformerMixin from sklearn import cross_validation from sklearn import grid_search @@ -55,11 +56,7 @@ def transform(self, X): def _handle_feature(fea): - if hasattr(fea, 'toarray'): - # sparse arrays should be converted to regular arrays - # for hstack. - fea = fea.toarray() - + # convert 1-dimensional arrays to 2-dimensional column vectors if len(fea.shape) == 1: fea = np.array([fea]).T @@ -72,7 +69,7 @@ class DataFrameMapper(BaseEstimator, TransformerMixin): sklearn transformation. """ - def __init__(self, features): + def __init__(self, features, sparse=False): """ Params: @@ -80,8 +77,11 @@ def __init__(self, features): selector. This can be a string (for one column) or a list of strings. The second element is an object that supports sklearn's transform interface. + sparse will return sparse matrix if set True and any of the + extracted features is sparse. Defaults to False. """ self.features = features + self.sparse = sparse def _get_col_subset(self, X, cols): """ @@ -156,4 +156,16 @@ def transform(self, X): # at this point we lose track of which features # were created from which input columns, so it's # assumed that that doesn't matter to the model. - return np.hstack(extracted) + + # If any of the extracted features is sparse, combine sparsely. + # Otherwise, combine as normal arrays. + if any(sparse.issparse(fea) for fea in extracted): + stacked = sparse.hstack(extracted).tocsr() + # return a sparse matrix only if the mapper was initialized + # with sparse=True + if not self.sparse: + stacked = stacked.toarray() + else: + stacked = np.hstack(extracted) + + return stacked diff --git a/tests/test_dataframe_mapper.py b/tests/test_dataframe_mapper.py index 5d90dae..4588369 100644 --- a/tests/test_dataframe_mapper.py +++ b/tests/test_dataframe_mapper.py @@ -9,11 +9,13 @@ from pandas import DataFrame import pandas as pd +from scipy import sparse from sklearn.datasets import load_iris from sklearn.pipeline import Pipeline from sklearn.svm import SVC from sklearn.feature_extraction.text import CountVectorizer from sklearn.preprocessing import Imputer, StandardScaler +from sklearn.base import BaseEstimator, TransformerMixin import numpy as np from sklearn_pandas import ( @@ -23,6 +25,17 @@ ) +class ToSparseTransformer(BaseEstimator, TransformerMixin): + """ + Transforms numpy matrix to sparse format. + """ + def fit(self, X): + return self + + def transform(self, X): + return sparse.csr_matrix(X) + + @pytest.fixture def iris_dataframe(): iris = load_iris() @@ -42,6 +55,11 @@ def cars_dataframe(): return pd.read_csv("tests/test_data/cars.csv.gz", compression='gzip') +@pytest.fixture +def simple_dataframe(): + return pd.DataFrame({'a': [1, 2, 3]}) + + def test_nonexistent_columns_explicit_fail(iris_dataframe): """ If a nonexistent column is selected, KeyError is raised. @@ -92,32 +110,32 @@ def test_with_car_dataframe(cars_dataframe): assert scores.mean() > 0.30 -def test_cols_string_array(): +def test_cols_string_array(simple_dataframe): """ If an string specified as the columns, the transformer is called with a 1-d array as input. """ - dataframe = pd.DataFrame({"a": [1, 2, 3]}) + df = simple_dataframe mock_transformer = Mock() mock_transformer.transform.return_value = np.array([1, 2, 3]) # do nothing mapper = DataFrameMapper([("a", mock_transformer)]) - mapper.fit_transform(dataframe) + mapper.fit_transform(df) args, kwargs = mock_transformer.fit.call_args assert args[0].shape == (3,) -def test_cols_list_column_vector(): +def test_cols_list_column_vector(simple_dataframe): """ If a one-element list is specified as the columns, the transformer is called with a column vector as input. """ - dataframe = pd.DataFrame({"a": [1, 2, 3]}) + df = simple_dataframe mock_transformer = Mock() mock_transformer.transform.return_value = np.array([1, 2, 3]) # do nothing mapper = DataFrameMapper([(["a"], mock_transformer)]) - mapper.fit_transform(dataframe) + mapper.fit_transform(df) args, kwargs = mock_transformer.fit.call_args assert args[0].shape == (3, 1) @@ -140,3 +158,31 @@ def test_list_transformers(): # all features have mean 0 and std deviation 1 (standardized) assert (abs(dmatrix.mean(axis=0) - 0) <= 1e-6).all() assert (abs(dmatrix.std(axis=0) - 1) <= 1e-6).all() + + +def test_sparse_features(simple_dataframe): + """ + If any of the extracted features is sparse and "sparse" argument + is true, the hstacked result is also sparse. + """ + df = simple_dataframe + mapper = DataFrameMapper([ + ("a", ToSparseTransformer()) + ], sparse=True) + dmatrix = mapper.fit_transform(df) + + assert type(dmatrix) == sparse.csr.csr_matrix + + +def test_sparse_off(simple_dataframe): + """ + If the resulting features are sparse but the "sparse" argument + of the mapper is False, return a non-sparse matrix. + """ + df = simple_dataframe + mapper = DataFrameMapper([ + ("a", ToSparseTransformer()) + ], sparse=False) + + dmatrix = mapper.fit_transform(df) + assert type(dmatrix) != sparse.csr.csr_matrix