Skip to content

Commit bd4332f

Browse files
kunalgosarjreback
authored andcommitted
Handle duplicate column names in select_dtypes and get_dummies (#20839)
1 parent e8e6e89 commit bd4332f

File tree

5 files changed

+66
-22
lines changed

5 files changed

+66
-22
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1359,6 +1359,7 @@ Reshaping
13591359
- Bug in :meth:`DataFrame.astype` where column metadata is lost when converting to categorical or a dictionary of dtypes (:issue:`19920`)
13601360
- Bug in :func:`cut` and :func:`qcut` where timezone information was dropped (:issue:`19872`)
13611361
- Bug in :class:`Series` constructor with a ``dtype=str``, previously raised in some cases (:issue:`19853`)
1362+
- Bug in :func:`get_dummies`, and :func:`select_dtypes`, where duplicate column names caused incorrect behavior (:issue:`20848`)
13621363

13631364
Other
13641365
^^^^^

pandas/core/frame.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -3076,15 +3076,15 @@ def select_dtypes(self, include=None, exclude=None):
30763076
include_these = Series(not bool(include), index=self.columns)
30773077
exclude_these = Series(not bool(exclude), index=self.columns)
30783078

3079-
def is_dtype_instance_mapper(column, dtype):
3080-
return column, functools.partial(issubclass, dtype.type)
3079+
def is_dtype_instance_mapper(idx, dtype):
3080+
return idx, functools.partial(issubclass, dtype.type)
30813081

3082-
for column, f in itertools.starmap(is_dtype_instance_mapper,
3083-
self.dtypes.iteritems()):
3082+
for idx, f in itertools.starmap(is_dtype_instance_mapper,
3083+
enumerate(self.dtypes)):
30843084
if include: # checks for the case of empty include or exclude
3085-
include_these[column] = any(map(f, include))
3085+
include_these.iloc[idx] = any(map(f, include))
30863086
if exclude:
3087-
exclude_these[column] = not any(map(f, exclude))
3087+
exclude_these.iloc[idx] = not any(map(f, exclude))
30883088

30893089
dtype_indexer = include_these & exclude_these
30903090
return self.loc[com._get_info_slice(self, dtype_indexer)]

pandas/core/reshape/reshape.py

+27-16
Original file line numberDiff line numberDiff line change
@@ -821,50 +821,61 @@ def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False,
821821
from pandas.core.reshape.concat import concat
822822
from itertools import cycle
823823

824+
dtypes_to_encode = ['object', 'category']
825+
824826
if isinstance(data, DataFrame):
825827
# determine columns being encoded
826-
827828
if columns is None:
828-
columns_to_encode = data.select_dtypes(
829-
include=['object', 'category']).columns
829+
data_to_encode = data.select_dtypes(
830+
include=dtypes_to_encode)
830831
else:
831-
columns_to_encode = columns
832+
data_to_encode = data[columns]
832833

833834
# validate prefixes and separator to avoid silently dropping cols
834835
def check_len(item, name):
835836
len_msg = ("Length of '{name}' ({len_item}) did not match the "
836837
"length of the columns being encoded ({len_enc}).")
837838

838839
if is_list_like(item):
839-
if not len(item) == len(columns_to_encode):
840-
len_msg = len_msg.format(name=name, len_item=len(item),
841-
len_enc=len(columns_to_encode))
840+
if not len(item) == data_to_encode.shape[1]:
841+
len_msg = \
842+
len_msg.format(name=name, len_item=len(item),
843+
len_enc=data_to_encode.shape[1])
842844
raise ValueError(len_msg)
843845

844846
check_len(prefix, 'prefix')
845847
check_len(prefix_sep, 'prefix_sep')
848+
846849
if isinstance(prefix, compat.string_types):
847850
prefix = cycle([prefix])
848851
if isinstance(prefix, dict):
849-
prefix = [prefix[col] for col in columns_to_encode]
852+
prefix = [prefix[col] for col in data_to_encode.columns]
850853

851854
if prefix is None:
852-
prefix = columns_to_encode
855+
prefix = data_to_encode.columns
853856

854857
# validate separators
855858
if isinstance(prefix_sep, compat.string_types):
856859
prefix_sep = cycle([prefix_sep])
857860
elif isinstance(prefix_sep, dict):
858-
prefix_sep = [prefix_sep[col] for col in columns_to_encode]
861+
prefix_sep = [prefix_sep[col] for col in data_to_encode.columns]
859862

860-
if set(columns_to_encode) == set(data.columns):
863+
if data_to_encode.shape == data.shape:
864+
# Encoding the entire df, do not prepend any dropped columns
861865
with_dummies = []
866+
elif columns is not None:
867+
# Encoding only cols specified in columns. Get all cols not in
868+
# columns to prepend to result.
869+
with_dummies = [data.drop(columns, axis=1)]
862870
else:
863-
with_dummies = [data.drop(columns_to_encode, axis=1)]
864-
865-
for (col, pre, sep) in zip(columns_to_encode, prefix, prefix_sep):
866-
867-
dummy = _get_dummies_1d(data[col], prefix=pre, prefix_sep=sep,
871+
# Encoding only object and category dtype columns. Get remaining
872+
# columns to prepend to result.
873+
with_dummies = [data.select_dtypes(exclude=dtypes_to_encode)]
874+
875+
for (col, pre, sep) in zip(data_to_encode.iteritems(), prefix,
876+
prefix_sep):
877+
# col is (column_name, column), use just column data here
878+
dummy = _get_dummies_1d(col[1], prefix=pre, prefix_sep=sep,
868879
dummy_na=dummy_na, sparse=sparse,
869880
drop_first=drop_first, dtype=dtype)
870881
with_dummies.append(dummy)

pandas/tests/frame/test_dtypes.py

+17
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,23 @@ def test_select_dtypes_include_exclude_mixed_scalars_lists(self):
287287
ei = df[['b', 'c', 'f', 'k']]
288288
assert_frame_equal(ri, ei)
289289

290+
def test_select_dtypes_duplicate_columns(self):
291+
# GH20839
292+
odict = compat.OrderedDict
293+
df = DataFrame(odict([('a', list('abc')),
294+
('b', list(range(1, 4))),
295+
('c', np.arange(3, 6).astype('u1')),
296+
('d', np.arange(4.0, 7.0, dtype='float64')),
297+
('e', [True, False, True]),
298+
('f', pd.date_range('now', periods=3).values)]))
299+
df.columns = ['a', 'a', 'b', 'b', 'b', 'c']
300+
301+
expected = DataFrame({'a': list(range(1, 4)),
302+
'b': np.arange(3, 6).astype('u1')})
303+
304+
result = df.select_dtypes(include=[np.number], exclude=['floating'])
305+
assert_frame_equal(result, expected)
306+
290307
def test_select_dtypes_not_an_attr_but_still_valid_dtype(self):
291308
df = DataFrame({'a': list('abc'),
292309
'b': list(range(1, 4)),

pandas/tests/reshape/test_reshape.py

+15
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,21 @@ def test_get_dummies_dont_sparsify_all_columns(self, sparse):
465465

466466
tm.assert_frame_equal(df[['GDP']], df2)
467467

468+
def test_get_dummies_duplicate_columns(self, df):
469+
# GH20839
470+
df.columns = ["A", "A", "A"]
471+
result = get_dummies(df).sort_index(axis=1)
472+
473+
expected = DataFrame([[1, 1, 0, 1, 0],
474+
[2, 0, 1, 1, 0],
475+
[3, 1, 0, 0, 1]],
476+
columns=['A', 'A_a', 'A_b', 'A_b', 'A_c'],
477+
dtype=np.uint8).sort_index(axis=1)
478+
479+
expected = expected.astype({"A": np.int64})
480+
481+
tm.assert_frame_equal(result, expected)
482+
468483

469484
class TestCategoricalReshape(object):
470485

0 commit comments

Comments
 (0)