Skip to content

Commit fedc503

Browse files
Scorpiljreback
authored andcommitted
ENH: allow get_dummies to accept dtype argument (#18330)
1 parent bd145c8 commit fedc503

File tree

5 files changed

+243
-191
lines changed

5 files changed

+243
-191
lines changed

doc/source/reshaping.rst

+12-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ values will be set to ``NaN``.
240240
df3
241241
df3.unstack()
242242
243-
.. versionadded: 0.18.0
243+
.. versionadded:: 0.18.0
244244

245245
Alternatively, unstack takes an optional ``fill_value`` argument, for specifying
246246
the value of missing data.
@@ -634,6 +634,17 @@ When a column contains only one level, it will be omitted in the result.
634634
635635
pd.get_dummies(df, drop_first=True)
636636
637+
By default new columns will have ``np.uint8`` dtype. To choose another dtype use ``dtype`` argument:
638+
639+
.. ipython:: python
640+
641+
df = pd.DataFrame({'A': list('abc'), 'B': [1.1, 2.2, 3.3]})
642+
643+
pd.get_dummies(df, dtype=bool).dtypes
644+
645+
.. versionadded:: 0.22.0
646+
647+
637648
.. _reshaping.factorize:
638649

639650
Factorizing values

doc/source/whatsnew/v0.22.0.txt

+15
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,21 @@ New features
1717
-
1818
-
1919

20+
21+
.. _whatsnew_0210.enhancements.get_dummies_dtype:
22+
23+
``get_dummies`` now supports ``dtype`` argument
24+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
25+
26+
The :func:`get_dummies` now accepts a ``dtype`` argument, which specifies a dtype for the new columns. The default remains uint8. (:issue:`18330`)
27+
28+
.. ipython:: python
29+
30+
df = pd.DataFrame({'a': [1, 2], 'b': [3, 4], 'c': [5, 6]})
31+
pd.get_dummies(df, columns=['c']).dtypes
32+
pd.get_dummies(df, columns=['c'], dtype=bool).dtypes
33+
34+
2035
.. _whatsnew_0220.enhancements.other:
2136

2237
Other Enhancements

pandas/core/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ def _set_axis_name(self, name, axis=0, inplace=False):
965965
inplace : bool
966966
whether to modify `self` directly or return a copy
967967
968-
.. versionadded: 0.21.0
968+
.. versionadded:: 0.21.0
969969
970970
Returns
971971
-------

pandas/core/reshape/reshape.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pandas.core.dtypes.common import (
1111
_ensure_platform_int,
1212
is_list_like, is_bool_dtype,
13-
needs_i8_conversion, is_sparse)
13+
needs_i8_conversion, is_sparse, is_object_dtype)
1414
from pandas.core.dtypes.cast import maybe_promote
1515
from pandas.core.dtypes.missing import notna
1616

@@ -697,7 +697,7 @@ def _convert_level_number(level_num, columns):
697697

698698

699699
def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False,
700-
columns=None, sparse=False, drop_first=False):
700+
columns=None, sparse=False, drop_first=False, dtype=None):
701701
"""
702702
Convert categorical variable into dummy/indicator variables
703703
@@ -728,6 +728,11 @@ def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False,
728728
729729
.. versionadded:: 0.18.0
730730
731+
dtype : dtype, default np.uint8
732+
Data type for new columns. Only a single dtype is allowed.
733+
734+
.. versionadded:: 0.22.0
735+
731736
Returns
732737
-------
733738
dummies : DataFrame or SparseDataFrame
@@ -783,6 +788,12 @@ def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False,
783788
3 0 0
784789
4 0 0
785790
791+
>>> pd.get_dummies(pd.Series(list('abc')), dtype=float)
792+
a b c
793+
0 1.0 0.0 0.0
794+
1 0.0 1.0 0.0
795+
2 0.0 0.0 1.0
796+
786797
See Also
787798
--------
788799
Series.str.get_dummies
@@ -835,20 +846,29 @@ def check_len(item, name):
835846

836847
dummy = _get_dummies_1d(data[col], prefix=pre, prefix_sep=sep,
837848
dummy_na=dummy_na, sparse=sparse,
838-
drop_first=drop_first)
849+
drop_first=drop_first, dtype=dtype)
839850
with_dummies.append(dummy)
840851
result = concat(with_dummies, axis=1)
841852
else:
842853
result = _get_dummies_1d(data, prefix, prefix_sep, dummy_na,
843-
sparse=sparse, drop_first=drop_first)
854+
sparse=sparse,
855+
drop_first=drop_first,
856+
dtype=dtype)
844857
return result
845858

846859

847860
def _get_dummies_1d(data, prefix, prefix_sep='_', dummy_na=False,
848-
sparse=False, drop_first=False):
861+
sparse=False, drop_first=False, dtype=None):
849862
# Series avoids inconsistent NaN handling
850863
codes, levels = _factorize_from_iterable(Series(data))
851864

865+
if dtype is None:
866+
dtype = np.uint8
867+
dtype = np.dtype(dtype)
868+
869+
if is_object_dtype(dtype):
870+
raise ValueError("dtype=object is not a valid dtype for get_dummies")
871+
852872
def get_empty_Frame(data, sparse):
853873
if isinstance(data, Series):
854874
index = data.index
@@ -903,18 +923,18 @@ def get_empty_Frame(data, sparse):
903923
sp_indices = sp_indices[1:]
904924
dummy_cols = dummy_cols[1:]
905925
for col, ixs in zip(dummy_cols, sp_indices):
906-
sarr = SparseArray(np.ones(len(ixs), dtype=np.uint8),
926+
sarr = SparseArray(np.ones(len(ixs), dtype=dtype),
907927
sparse_index=IntIndex(N, ixs), fill_value=0,
908-
dtype=np.uint8)
928+
dtype=dtype)
909929
sparse_series[col] = SparseSeries(data=sarr, index=index)
910930

911931
out = SparseDataFrame(sparse_series, index=index, columns=dummy_cols,
912932
default_fill_value=0,
913-
dtype=np.uint8)
933+
dtype=dtype)
914934
return out
915935

916936
else:
917-
dummy_mat = np.eye(number_of_cols, dtype=np.uint8).take(codes, axis=0)
937+
dummy_mat = np.eye(number_of_cols, dtype=dtype).take(codes, axis=0)
918938

919939
if not dummy_na:
920940
# reset NaN GH4446

0 commit comments

Comments
 (0)