diff --git a/doc/source/whatsnew/v0.17.0.txt b/doc/source/whatsnew/v0.17.0.txt index 206c5e2e22711..82fb2b4fe45e7 100644 --- a/doc/source/whatsnew/v0.17.0.txt +++ b/doc/source/whatsnew/v0.17.0.txt @@ -377,16 +377,12 @@ Bug Fixes - Bug in ``Series.plot(kind='hist')`` Y Label not informative (:issue:`10485`) - - - - - - - - - Bug in operator equal on Index not being consistent with Series (:issue:`9947`) - Reading "famafrench" data via ``DataReader`` results in HTTP 404 error because of the website url is changed (:issue:`10591`). - Bug in `read_msgpack` where DataFrame to decode has duplicate column names (:issue:`9618`) + + +- Bug in `get_dummies` with `sparse=True` not returning SparseDataFrame (:issue:`10531`) + diff --git a/pandas/core/reshape.py b/pandas/core/reshape.py index fd786fa30f842..99767ab199843 100644 --- a/pandas/core/reshape.py +++ b/pandas/core/reshape.py @@ -957,13 +957,15 @@ def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False, If `columns` is None then all the columns with `object` or `category` dtype will be converted. sparse : bool, default False - Whether the returned DataFrame should be sparse or not. + Whether the dummy columns should be sparse or not. Returns + SparseDataFrame if `data` is a Series or if all columns are included. + Otherwise returns a DataFrame with some SparseBlocks. .. versionadded:: 0.16.1 Returns ------- - dummies : DataFrame + dummies : DataFrame or SparseDataFrame Examples -------- @@ -1042,8 +1044,11 @@ def check_len(item, name): elif isinstance(prefix_sep, dict): prefix_sep = [prefix_sep[col] for col in columns_to_encode] - result = data.drop(columns_to_encode, axis=1) - with_dummies = [result] + if set(columns_to_encode) == set(data.columns): + with_dummies = [] + else: + with_dummies = [data.drop(columns_to_encode, axis=1)] + for (col, pre, sep) in zip(columns_to_encode, prefix, prefix_sep): dummy = _get_dummies_1d(data[col], prefix=pre, prefix_sep=sep, diff --git a/pandas/tests/test_reshape.py b/pandas/tests/test_reshape.py index 346c9e2598985..f1670ce885d6c 100644 --- a/pandas/tests/test_reshape.py +++ b/pandas/tests/test_reshape.py @@ -8,6 +8,7 @@ import nose from pandas import DataFrame, Series +from pandas.core.sparse import SparseDataFrame import pandas as pd from numpy import nan @@ -171,6 +172,33 @@ def test_basic(self): expected.index = list('ABC') assert_frame_equal(get_dummies(s_series_index, sparse=self.sparse), expected) + def test_basic_types(self): + # GH 10531 + s_list = list('abc') + s_series = Series(s_list) + s_df = DataFrame({'a': [0, 1, 0, 1, 2], + 'b': ['A', 'A', 'B', 'C', 'C'], + 'c': [2, 3, 3, 3, 2]}) + + if not self.sparse: + exp_df_type = DataFrame + exp_blk_type = pd.core.internals.FloatBlock + else: + exp_df_type = SparseDataFrame + exp_blk_type = pd.core.internals.SparseBlock + + self.assertEqual(type(get_dummies(s_list, sparse=self.sparse)), exp_df_type) + self.assertEqual(type(get_dummies(s_series, sparse=self.sparse)), exp_df_type) + + r = get_dummies(s_df, sparse=self.sparse, columns=s_df.columns) + self.assertEqual(type(r), exp_df_type) + + r = get_dummies(s_df, sparse=self.sparse, columns=['a']) + self.assertEqual(type(r[['a_0']]._data.blocks[0]), exp_blk_type) + self.assertEqual(type(r[['a_1']]._data.blocks[0]), exp_blk_type) + self.assertEqual(type(r[['a_2']]._data.blocks[0]), exp_blk_type) + + def test_just_na(self): just_na_list = [np.nan] just_na_series = Series(just_na_list)