Skip to content

Commit cace0f7

Browse files
committed
BUG: int dtype for get_dummies
Closes pandas-dev#8725 Ensures that get_dummies on a DataFrame whose output is a mix of floats / ints & dummy-encoded columns doesn't coerce the dummy-encoded cols from uint8 to ints / floats.
1 parent 58199c5 commit cace0f7

File tree

5 files changed

+163
-93
lines changed

5 files changed

+163
-93
lines changed

doc/source/whatsnew/v0.19.0.txt

+11-1
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,17 @@ Previous versions of pandas would permanently silence numpy's ufunc error handli
371371

372372
After upgrading pandas, you may see *new* ``RuntimeWarnings`` being issued from your code. These are likely legitimate, and the underlying cause likely existed in the code when using previous versions of pandas that simply silenced the warning. Use `numpy.errstate <http://docs.scipy.org/doc/numpy/reference/generated/numpy.errstate.html>`__ around the source of the ``RuntimeWarning`` to control how these conditions are handled.
373373

374+
get_dummies dtypes
375+
^^^^^^^^^^^^^^^^^^
376+
377+
The ``pd.get_dummies`` function now returns dummy-encoded columns as integers, rather than floats
378+
379+
.. ipython:: python
380+
381+
pd.get_dummies(['a', 'b', 'a', 'c']).dtypes
382+
383+
Previously, this would have been a DataFrame of float columns (:issue:`8725`).
384+
374385
.. _whatsnew_0190.enhancements.other:
375386

376387
Other enhancements
@@ -479,7 +490,6 @@ API changes
479490
- ``Series.unique()`` with datetime and timezone now returns return array of ``Timestamp`` with timezone (:issue:`13565`)
480491

481492

482-
483493
.. _whatsnew_0190.api.tolist:
484494

485495
``Series.tolist()`` will now return Python types

pandas/core/reshape.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -1161,14 +1161,17 @@ def get_empty_Frame(data, sparse):
11611161
sp_indices = sp_indices[1:]
11621162
dummy_cols = dummy_cols[1:]
11631163
for col, ixs in zip(dummy_cols, sp_indices):
1164-
sarr = SparseArray(np.ones(len(ixs)),
1165-
sparse_index=IntIndex(N, ixs), fill_value=0)
1164+
sarr = SparseArray(np.ones(len(ixs), dtype=np.uint8),
1165+
sparse_index=IntIndex(N, ixs), fill_value=0,
1166+
dtype=np.uint8)
11661167
sparse_series[col] = SparseSeries(data=sarr, index=index)
11671168

1168-
return SparseDataFrame(sparse_series, index=index, columns=dummy_cols)
1169+
out = SparseDataFrame(sparse_series, index=index, columns=dummy_cols,
1170+
dtype=np.uint8)
1171+
return out
11691172

11701173
else:
1171-
dummy_mat = np.eye(number_of_cols).take(codes, axis=0)
1174+
dummy_mat = np.eye(number_of_cols, dtype=np.uint8).take(codes, axis=0)
11721175

11731176
if not dummy_na:
11741177
# reset NaN GH4446

pandas/stats/tests/test_ols.py

+2
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,7 @@ def testWithXEffects(self):
645645
exp_x = DataFrame([[0., 0., 14., 1.], [0, 1, 17, 1], [1, 0, 48, 1]],
646646
columns=['x1_30', 'x1_9', 'x2', 'intercept'],
647647
index=res.index, dtype=float)
648+
exp_x[['x1_30', 'x1_9']] = exp_x[['x1_30', 'x1_9']].astype(np.uint8)
648649
assert_frame_equal(res, exp_x.reindex(columns=res.columns))
649650

650651
def testWithXEffectsAndDroppedDummies(self):
@@ -659,6 +660,7 @@ def testWithXEffectsAndDroppedDummies(self):
659660
exp_x = DataFrame([[1., 0., 14., 1.], [0, 1, 17, 1], [0, 0, 48, 1]],
660661
columns=['x1_6', 'x1_9', 'x2', 'intercept'],
661662
index=res.index, dtype=float)
663+
exp_x[['x1_6', 'x1_9']] = exp_x[['x1_6', 'x1_9']].astype(np.uint8)
662664

663665
assert_frame_equal(res, exp_x.reindex(columns=res.columns))
664666

pandas/tests/test_panel.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2429,18 +2429,18 @@ def test_truncate(self):
24292429
def test_axis_dummies(self):
24302430
from pandas.core.reshape import make_axis_dummies
24312431

2432-
minor_dummies = make_axis_dummies(self.panel, 'minor')
2432+
minor_dummies = make_axis_dummies(self.panel, 'minor').astype(np.uint8)
24332433
self.assertEqual(len(minor_dummies.columns),
24342434
len(self.panel.index.levels[1]))
24352435

2436-
major_dummies = make_axis_dummies(self.panel, 'major')
2436+
major_dummies = make_axis_dummies(self.panel, 'major').astype(np.uint8)
24372437
self.assertEqual(len(major_dummies.columns),
24382438
len(self.panel.index.levels[0]))
24392439

24402440
mapping = {'A': 'one', 'B': 'one', 'C': 'two', 'D': 'two'}
24412441

24422442
transformed = make_axis_dummies(self.panel, 'minor',
2443-
transform=mapping.get)
2443+
transform=mapping.get).astype(np.uint8)
24442444
self.assertEqual(len(transformed.columns), 2)
24452445
self.assert_index_equal(transformed.columns, Index(['one', 'two']))
24462446

@@ -2450,7 +2450,7 @@ def test_get_dummies(self):
24502450
from pandas.core.reshape import get_dummies, make_axis_dummies
24512451

24522452
self.panel['Label'] = self.panel.index.labels[1]
2453-
minor_dummies = make_axis_dummies(self.panel, 'minor')
2453+
minor_dummies = make_axis_dummies(self.panel, 'minor').astype(np.uint8)
24542454
dummies = get_dummies(self.panel['Label'])
24552455
self.assert_numpy_array_equal(dummies.values, minor_dummies.values)
24562456

0 commit comments

Comments
 (0)