Skip to content

Commit c467051

Browse files
committed
Merge pull request #4458 from hayd/get_dummies
Get dummies
2 parents 4f46814 + 8765288 commit c467051

File tree

5 files changed

+131
-8
lines changed

5 files changed

+131
-8
lines changed

doc/source/api.rst

+7
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,13 @@ Data manipulations
126126
merge
127127
concat
128128

129+
.. currentmodule:: pandas.core.reshape
130+
131+
.. autosummary::
132+
:toctree: generated/
133+
134+
get_dummies
135+
129136
Top-level missing data
130137
~~~~~~~~~~~~~~~~~~~~~~
131138

doc/source/release.rst

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ pandas 0.13
4444
``ValueError`` (:issue:`4303`, :issue:`4305`)
4545
- ``read_excel`` now supports an integer in its ``sheetname`` argument giving
4646
the index of the sheet to read in (:issue:`4301`).
47+
- ``get_dummies`` works with NaN (:issue:`4446`)
4748
- Added a test for ``read_clipboard()`` and ``to_clipboard()`` (:issue:`4282`)
4849
- Text parser now treats anything that reads like inf ("inf", "Inf", "-Inf",
4950
"iNf", etc.) to infinity. (:issue:`4220`, :issue:`4219`), affecting

doc/source/v0.13.0.txt

+11
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,17 @@ Enhancements
129129
- Added a more informative error message when plot arguments contain
130130
overlapping color and style arguments (:issue:`4402`)
131131

132+
- NaN handing in get_dummies (:issue:`4446`) with `dummy_na`
133+
134+
.. ipython:: python
135+
# previously, nan was erroneously counted as 2 here
136+
# now it is not counted at all
137+
get_dummies([1, 2, np.nan])
138+
139+
# unless requested
140+
get_dummies([1, 2, np.nan], dummy_na=True)
141+
142+
132143
- ``timedelta64[ns]`` operations
133144

134145
- A Series of dtype ``timedelta64[ns]`` can now be divided by another

pandas/core/reshape.py

+54-6
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pandas.core.common as com
1919
import pandas.algos as algos
2020

21-
from pandas.core.index import MultiIndex
21+
from pandas.core.index import Index, MultiIndex
2222

2323

2424
class ReshapeError(Exception):
@@ -805,7 +805,7 @@ def convert_dummies(data, cat_variables, prefix_sep='_'):
805805
return result
806806

807807

808-
def get_dummies(data, prefix=None, prefix_sep='_'):
808+
def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False):
809809
"""
810810
Convert categorical variable into dummy/indicator variables
811811
@@ -816,19 +816,67 @@ def get_dummies(data, prefix=None, prefix_sep='_'):
816816
String to append DataFrame column names
817817
prefix_sep : string, default '_'
818818
If appending prefix, separator/delimiter to use
819+
dummy_na : bool, default False
820+
Add a column to indicate NaNs, if False NaNs are ignored.
819821
820822
Returns
821823
-------
822824
dummies : DataFrame
825+
826+
Examples
827+
--------
828+
>>> s = pd.Series(list('abca'))
829+
830+
>>> get_dummies(s)
831+
a b c
832+
0 1 0 0
833+
1 0 1 0
834+
2 0 0 1
835+
3 1 0 0
836+
837+
>>> s1 = ['a', 'b', np.nan]
838+
839+
>>> get_dummies(s1)
840+
a b
841+
0 1 0
842+
1 0 1
843+
2 0 0
844+
845+
>>> get_dummies(s1, dummy_na=True)
846+
a b NaN
847+
0 1 0 0
848+
1 0 1 0
849+
2 0 0 1
850+
823851
"""
824-
cat = Categorical.from_array(np.asarray(data))
825-
dummy_mat = np.eye(len(cat.levels)).take(cat.labels, axis=0)
852+
cat = Categorical.from_array(Series(data)) # Series avoids inconsistent NaN handling
853+
levels = cat.levels
854+
855+
# if all NaN
856+
if not dummy_na and len(levels) == 0:
857+
if isinstance(data, Series):
858+
index = data.index
859+
else:
860+
index = np.arange(len(data))
861+
return DataFrame(index=index)
862+
863+
number_of_cols = len(levels)
864+
if dummy_na:
865+
number_of_cols += 1
866+
867+
dummy_mat = np.eye(number_of_cols).take(cat.labels, axis=0)
868+
869+
if dummy_na:
870+
levels = np.append(cat.levels, np.nan)
871+
else:
872+
# reset NaN GH4446
873+
dummy_mat[cat.labels == -1] = 0
826874

827875
if prefix is not None:
828876
dummy_cols = ['%s%s%s' % (prefix, prefix_sep, str(v))
829-
for v in cat.levels]
877+
for v in levels]
830878
else:
831-
dummy_cols = cat.levels
879+
dummy_cols = levels
832880

833881
if isinstance(data, Series):
834882
index = data.index

pandas/tests/test_reshape.py

+58-2
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77

88
import nose
99

10-
from pandas import DataFrame
10+
from pandas import DataFrame, Series
1111
import pandas as pd
1212

1313
from numpy import nan
1414
import numpy as np
1515

16-
from pandas.core.reshape import melt, convert_dummies, lreshape
16+
from pandas.util.testing import assert_frame_equal
17+
18+
from pandas.core.reshape import melt, convert_dummies, lreshape, get_dummies
1719
import pandas.util.testing as tm
1820
from pandas.compat import StringIO, cPickle, range
1921

@@ -145,6 +147,60 @@ def test_multiindex(self):
145147
self.assertEqual(res.columns.tolist(), ['CAP', 'low', 'value'])
146148

147149

150+
class TestGetDummies(unittest.TestCase):
151+
def test_basic(self):
152+
s_list = list('abc')
153+
s_series = Series(s_list)
154+
s_series_index = Series(s_list, list('ABC'))
155+
156+
expected = DataFrame({'a': {0: 1.0, 1: 0.0, 2: 0.0},
157+
'b': {0: 0.0, 1: 1.0, 2: 0.0},
158+
'c': {0: 0.0, 1: 0.0, 2: 1.0}})
159+
assert_frame_equal(get_dummies(s_list), expected)
160+
assert_frame_equal(get_dummies(s_series), expected)
161+
162+
expected.index = list('ABC')
163+
assert_frame_equal(get_dummies(s_series_index), expected)
164+
165+
def test_just_na(self):
166+
just_na_list = [np.nan]
167+
just_na_series = Series(just_na_list)
168+
just_na_series_index = Series(just_na_list, index = ['A'])
169+
170+
res_list = get_dummies(just_na_list)
171+
res_series = get_dummies(just_na_series)
172+
res_series_index = get_dummies(just_na_series_index)
173+
174+
self.assertEqual(res_list.empty, True)
175+
self.assertEqual(res_series.empty, True)
176+
self.assertEqual(res_series_index.empty, True)
177+
178+
self.assertEqual(res_list.index.tolist(), [0])
179+
self.assertEqual(res_series.index.tolist(), [0])
180+
self.assertEqual(res_series_index.index.tolist(), ['A'])
181+
182+
def test_include_na(self):
183+
s = ['a', 'b', np.nan]
184+
res = get_dummies(s)
185+
exp = DataFrame({'a': {0: 1.0, 1: 0.0, 2: 0.0},
186+
'b': {0: 0.0, 1: 1.0, 2: 0.0}})
187+
assert_frame_equal(res, exp)
188+
189+
res_na = get_dummies(s, dummy_na=True)
190+
exp_na = DataFrame({nan: {0: 0.0, 1: 0.0, 2: 1.0},
191+
'a': {0: 1.0, 1: 0.0, 2: 0.0},
192+
'b': {0: 0.0, 1: 1.0, 2: 0.0}}).iloc[:, [1, 2, 0]]
193+
# hack (NaN handling in assert_index_equal)
194+
exp_na.columns = res_na.columns
195+
assert_frame_equal(res_na, exp_na)
196+
197+
res_just_na = get_dummies([nan], dummy_na=True)
198+
exp_just_na = DataFrame({nan: {0: 1.0}})
199+
# hack (NaN handling in assert_index_equal)
200+
exp_just_na.columns = res_just_na.columns
201+
assert_frame_equal(res_just_na, exp_just_na)
202+
203+
148204
class TestConvertDummies(unittest.TestCase):
149205
def test_convert_dummies(self):
150206
df = DataFrame({'A': ['foo', 'bar', 'foo', 'bar',

0 commit comments

Comments
 (0)