Skip to content

Commit 32ee973

Browse files
toobazjreback
authored andcommitted
BUG: fix DataFrame.__getitem__ and .loc with non-list listlikes (#21313)
* BUG: fix DataFrame.__getitem__ and .loc with non-list listlikes close #21294 close #21428
1 parent 9ad8aa7 commit 32ee973

File tree

5 files changed

+122
-106
lines changed

5 files changed

+122
-106
lines changed

doc/source/whatsnew/v0.24.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,8 @@ Indexing
355355
- When ``.ix`` is asked for a missing integer label in a :class:`MultiIndex` with a first level of integer type, it now raises a ``KeyError``, consistently with the case of a flat :class:`Int64Index, rather than falling back to positional indexing (:issue:`21593`)
356356
- Bug in :meth:`DatetimeIndex.reindex` when reindexing a tz-naive and tz-aware :class:`DatetimeIndex` (:issue:`8306`)
357357
- Bug in :class:`DataFrame` when setting values with ``.loc`` and a timezone aware :class:`DatetimeIndex` (:issue:`11365`)
358+
- ``DataFrame.__getitem__`` now accepts dictionaries and dictionary keys as list-likes of labels, consistently with ``Series.__getitem__`` (:issue:`21294`)
359+
- Fixed ``DataFrame[np.nan]`` when columns are non-unique (:issue:`21428`)
358360
- Bug when indexing :class:`DatetimeIndex` with nanosecond resolution dates and timezones (:issue:`11679`)
359361

360362
-

pandas/core/frame.py

+60-48
Original file line numberDiff line numberDiff line change
@@ -2672,68 +2672,80 @@ def _ixs(self, i, axis=0):
26722672
def __getitem__(self, key):
26732673
key = com._apply_if_callable(key, self)
26742674

2675-
# shortcut if we are an actual column
2676-
is_mi_columns = isinstance(self.columns, MultiIndex)
2675+
# shortcut if the key is in columns
26772676
try:
2678-
if key in self.columns and not is_mi_columns:
2679-
return self._getitem_column(key)
2680-
except:
2677+
if self.columns.is_unique and key in self.columns:
2678+
if self.columns.nlevels > 1:
2679+
return self._getitem_multilevel(key)
2680+
return self._get_item_cache(key)
2681+
except (TypeError, ValueError):
2682+
# The TypeError correctly catches non hashable "key" (e.g. list)
2683+
# The ValueError can be removed once GH #21729 is fixed
26812684
pass
26822685

2683-
# see if we can slice the rows
2686+
# Do we have a slicer (on rows)?
26842687
indexer = convert_to_index_sliceable(self, key)
26852688
if indexer is not None:
2686-
return self._getitem_slice(indexer)
2689+
return self._slice(indexer, axis=0)
26872690

2688-
if isinstance(key, (Series, np.ndarray, Index, list)):
2689-
# either boolean or fancy integer index
2690-
return self._getitem_array(key)
2691-
elif isinstance(key, DataFrame):
2691+
# Do we have a (boolean) DataFrame?
2692+
if isinstance(key, DataFrame):
26922693
return self._getitem_frame(key)
2693-
elif is_mi_columns:
2694-
return self._getitem_multilevel(key)
2694+
2695+
# Do we have a (boolean) 1d indexer?
2696+
if com.is_bool_indexer(key):
2697+
return self._getitem_bool_array(key)
2698+
2699+
# We are left with two options: a single key, and a collection of keys,
2700+
# We interpret tuples as collections only for non-MultiIndex
2701+
is_single_key = isinstance(key, tuple) or not is_list_like(key)
2702+
2703+
if is_single_key:
2704+
if self.columns.nlevels > 1:
2705+
return self._getitem_multilevel(key)
2706+
indexer = self.columns.get_loc(key)
2707+
if is_integer(indexer):
2708+
indexer = [indexer]
26952709
else:
2696-
return self._getitem_column(key)
2710+
if is_iterator(key):
2711+
key = list(key)
2712+
indexer = self.loc._convert_to_indexer(key, axis=1,
2713+
raise_missing=True)
26972714

2698-
def _getitem_column(self, key):
2699-
""" return the actual column """
2715+
# take() does not accept boolean indexers
2716+
if getattr(indexer, "dtype", None) == bool:
2717+
indexer = np.where(indexer)[0]
27002718

2701-
# get column
2702-
if self.columns.is_unique:
2703-
return self._get_item_cache(key)
2719+
data = self._take(indexer, axis=1)
27042720

2705-
# duplicate columns & possible reduce dimensionality
2706-
result = self._constructor(self._data.get(key))
2707-
if result.columns.is_unique:
2708-
result = result[key]
2721+
if is_single_key:
2722+
# What does looking for a single key in a non-unique index return?
2723+
# The behavior is inconsistent. It returns a Series, except when
2724+
# - the key itself is repeated (test on data.shape, #9519), or
2725+
# - we have a MultiIndex on columns (test on self.columns, #21309)
2726+
if data.shape[1] == 1 and not isinstance(self.columns, MultiIndex):
2727+
data = data[key]
27092728

2710-
return result
2711-
2712-
def _getitem_slice(self, key):
2713-
return self._slice(key, axis=0)
2729+
return data
27142730

2715-
def _getitem_array(self, key):
2731+
def _getitem_bool_array(self, key):
27162732
# also raises Exception if object array with NA values
2717-
if com.is_bool_indexer(key):
2718-
# warning here just in case -- previously __setitem__ was
2719-
# reindexing but __getitem__ was not; it seems more reasonable to
2720-
# go with the __setitem__ behavior since that is more consistent
2721-
# with all other indexing behavior
2722-
if isinstance(key, Series) and not key.index.equals(self.index):
2723-
warnings.warn("Boolean Series key will be reindexed to match "
2724-
"DataFrame index.", UserWarning, stacklevel=3)
2725-
elif len(key) != len(self.index):
2726-
raise ValueError('Item wrong length %d instead of %d.' %
2727-
(len(key), len(self.index)))
2728-
# check_bool_indexer will throw exception if Series key cannot
2729-
# be reindexed to match DataFrame rows
2730-
key = check_bool_indexer(self.index, key)
2731-
indexer = key.nonzero()[0]
2732-
return self._take(indexer, axis=0)
2733-
else:
2734-
indexer = self.loc._convert_to_indexer(key, axis=1,
2735-
raise_missing=True)
2736-
return self._take(indexer, axis=1)
2733+
# warning here just in case -- previously __setitem__ was
2734+
# reindexing but __getitem__ was not; it seems more reasonable to
2735+
# go with the __setitem__ behavior since that is more consistent
2736+
# with all other indexing behavior
2737+
if isinstance(key, Series) and not key.index.equals(self.index):
2738+
warnings.warn("Boolean Series key will be reindexed to match "
2739+
"DataFrame index.", UserWarning, stacklevel=3)
2740+
elif len(key) != len(self.index):
2741+
raise ValueError('Item wrong length %d instead of %d.' %
2742+
(len(key), len(self.index)))
2743+
2744+
# check_bool_indexer will throw exception if Series key cannot
2745+
# be reindexed to match DataFrame rows
2746+
key = check_bool_indexer(self.index, key)
2747+
indexer = key.nonzero()[0]
2748+
return self._take(indexer, axis=0)
27372749

27382750
def _getitem_multilevel(self, key):
27392751
loc = self.columns.get_loc(key)

pandas/core/sparse/frame.py

+8-18
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,10 @@ def _init_dict(self, data, index, columns, dtype=None):
148148
if index is None:
149149
index = extract_index(list(data.values()))
150150

151-
sp_maker = lambda x: SparseArray(x, kind=self._default_kind,
152-
fill_value=self._default_fill_value,
153-
copy=True, dtype=dtype)
151+
def sp_maker(x):
152+
return SparseArray(x, kind=self._default_kind,
153+
fill_value=self._default_fill_value,
154+
copy=True, dtype=dtype)
154155
sdict = {}
155156
for k, v in compat.iteritems(data):
156157
if isinstance(v, Series):
@@ -397,9 +398,10 @@ def _sanitize_column(self, key, value, **kwargs):
397398
sanitized_column : SparseArray
398399
399400
"""
400-
sp_maker = lambda x, index=None: SparseArray(
401-
x, index=index, fill_value=self._default_fill_value,
402-
kind=self._default_kind)
401+
def sp_maker(x, index=None):
402+
return SparseArray(x, index=index,
403+
fill_value=self._default_fill_value,
404+
kind=self._default_kind)
403405
if isinstance(value, SparseSeries):
404406
clean = value.reindex(self.index).as_sparse_array(
405407
fill_value=self._default_fill_value, kind=self._default_kind)
@@ -428,18 +430,6 @@ def _sanitize_column(self, key, value, **kwargs):
428430
# always return a SparseArray!
429431
return clean
430432

431-
def __getitem__(self, key):
432-
"""
433-
Retrieve column or slice from DataFrame
434-
"""
435-
if isinstance(key, slice):
436-
date_rng = self.index[key]
437-
return self.reindex(date_rng)
438-
elif isinstance(key, (np.ndarray, list, Series)):
439-
return self._getitem_array(key)
440-
else:
441-
return self._get_item_cache(key)
442-
443433
def get_value(self, index, col, takeable=False):
444434
"""
445435
Quickly retrieve single value at passed column and index

pandas/tests/frame/test_constructors.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,11 @@ def test_constructor_dict_of_tuples(self):
501501
tm.assert_frame_equal(result, expected, check_dtype=False)
502502

503503
def test_constructor_dict_multiindex(self):
504-
check = lambda result, expected: tm.assert_frame_equal(
505-
result, expected, check_dtype=True, check_index_type=True,
506-
check_column_type=True, check_names=True)
504+
def check(result, expected):
505+
return tm.assert_frame_equal(result, expected, check_dtype=True,
506+
check_index_type=True,
507+
check_column_type=True,
508+
check_names=True)
507509
d = {('a', 'a'): {('i', 'i'): 0, ('i', 'j'): 1, ('j', 'i'): 2},
508510
('b', 'a'): {('i', 'i'): 6, ('i', 'j'): 5, ('j', 'i'): 4},
509511
('b', 'c'): {('i', 'i'): 7, ('i', 'j'): 8, ('j', 'i'): 9}}
@@ -1655,19 +1657,21 @@ def check(df):
16551657
for i in range(len(df.columns)):
16561658
df.iloc[:, i]
16571659

1658-
# allow single nans to succeed
16591660
indexer = np.arange(len(df.columns))[isna(df.columns)]
16601661

1661-
if len(indexer) == 1:
1662-
tm.assert_series_equal(df.iloc[:, indexer[0]],
1663-
df.loc[:, np.nan])
1664-
1665-
# multiple nans should fail
1666-
else:
1667-
1662+
# No NaN found -> error
1663+
if len(indexer) == 0:
16681664
def f():
16691665
df.loc[:, np.nan]
16701666
pytest.raises(TypeError, f)
1667+
# single nan should result in Series
1668+
elif len(indexer) == 1:
1669+
tm.assert_series_equal(df.iloc[:, indexer[0]],
1670+
df.loc[:, np.nan])
1671+
# multiple nans should result in DataFrame
1672+
else:
1673+
tm.assert_frame_equal(df.iloc[:, indexer],
1674+
df.loc[:, np.nan])
16711675

16721676
df = DataFrame([[1, 2, 3], [4, 5, 6]], index=[1, np.nan])
16731677
check(df)
@@ -1683,6 +1687,11 @@ def f():
16831687
columns=[np.nan, 1.1, 2.2, np.nan])
16841688
check(df)
16851689

1690+
# GH 21428 (non-unique columns)
1691+
df = DataFrame([[0.0, 1, 2, 3.0], [4, 5, 6, 7]],
1692+
columns=[np.nan, 1, 2, 2])
1693+
check(df)
1694+
16861695
def test_constructor_lists_to_object_dtype(self):
16871696
# from #1074
16881697
d = DataFrame({'a': [np.nan, False]})

pandas/tests/frame/test_indexing.py

+32-29
Original file line numberDiff line numberDiff line change
@@ -92,45 +92,46 @@ def test_get(self):
9292
result = df.get(None)
9393
assert result is None
9494

95-
def test_getitem_iterator(self):
95+
def test_loc_iterable(self):
9696
idx = iter(['A', 'B', 'C'])
9797
result = self.frame.loc[:, idx]
9898
expected = self.frame.loc[:, ['A', 'B', 'C']]
9999
assert_frame_equal(result, expected)
100100

101-
idx = iter(['A', 'B', 'C'])
102-
result = self.frame.loc[:, idx]
103-
expected = self.frame.loc[:, ['A', 'B', 'C']]
104-
assert_frame_equal(result, expected)
101+
@pytest.mark.parametrize(
102+
"idx_type",
103+
[list, iter, Index, set,
104+
lambda l: dict(zip(l, range(len(l)))),
105+
lambda l: dict(zip(l, range(len(l)))).keys()],
106+
ids=["list", "iter", "Index", "set", "dict", "dict_keys"])
107+
@pytest.mark.parametrize("levels", [1, 2])
108+
def test_getitem_listlike(self, idx_type, levels):
109+
# GH 21294
110+
111+
if levels == 1:
112+
frame, missing = self.frame, 'food'
113+
else:
114+
# MultiIndex columns
115+
frame = DataFrame(randn(8, 3),
116+
columns=Index([('foo', 'bar'), ('baz', 'qux'),
117+
('peek', 'aboo')],
118+
name=('sth', 'sth2')))
119+
missing = ('good', 'food')
105120

106-
def test_getitem_list(self):
107-
self.frame.columns.name = 'foo'
121+
keys = [frame.columns[1], frame.columns[0]]
122+
idx = idx_type(keys)
123+
idx_check = list(idx_type(keys))
108124

109-
result = self.frame[['B', 'A']]
110-
result2 = self.frame[Index(['B', 'A'])]
125+
result = frame[idx]
111126

112-
expected = self.frame.loc[:, ['B', 'A']]
113-
expected.columns.name = 'foo'
127+
expected = frame.loc[:, idx_check]
128+
expected.columns.names = frame.columns.names
114129

115130
assert_frame_equal(result, expected)
116-
assert_frame_equal(result2, expected)
117131

118-
assert result.columns.name == 'foo'
119-
120-
with tm.assert_raises_regex(KeyError, 'not in index'):
121-
self.frame[['B', 'A', 'food']]
132+
idx = idx_type(keys + [missing])
122133
with tm.assert_raises_regex(KeyError, 'not in index'):
123-
self.frame[Index(['B', 'A', 'foo'])]
124-
125-
# tuples
126-
df = DataFrame(randn(8, 3),
127-
columns=Index([('foo', 'bar'), ('baz', 'qux'),
128-
('peek', 'aboo')], name=('sth', 'sth2')))
129-
130-
result = df[[('foo', 'bar'), ('baz', 'qux')]]
131-
expected = df.iloc[:, :2]
132-
assert_frame_equal(result, expected)
133-
assert result.columns.names == ('sth', 'sth2')
134+
frame[idx]
134135

135136
def test_getitem_callable(self):
136137
# GH 12533
@@ -223,7 +224,8 @@ def test_setitem_callable(self):
223224

224225
def test_setitem_other_callable(self):
225226
# GH 13299
226-
inc = lambda x: x + 1
227+
def inc(x):
228+
return x + 1
227229

228230
df = pd.DataFrame([[-1, 1], [1, -1]])
229231
df[df > 0] = inc
@@ -2082,7 +2084,8 @@ def test_reindex_level(self):
20822084
icol = ['jim', 'joe', 'jolie']
20832085

20842086
def verify_first_level(df, level, idx, check_index_type=True):
2085-
f = lambda val: np.nonzero(df[level] == val)[0]
2087+
def f(val):
2088+
return np.nonzero(df[level] == val)[0]
20862089
i = np.concatenate(list(map(f, idx)))
20872090
left = df.set_index(icol).reindex(idx, level=level)
20882091
right = df.iloc[i].set_index(icol)

0 commit comments

Comments
 (0)