Skip to content

Commit 2e537fd

Browse files
committed
ENH: enable df[bool_df] += value for compatible shapes, close #1366
1 parent 3ad0f0a commit 2e537fd

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

pandas/core/frame.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,6 +1653,12 @@ def __getitem__(self, key):
16531653
return self._getitem_array(key)
16541654
elif isinstance(self.columns, MultiIndex):
16551655
return self._getitem_multilevel(key)
1656+
elif isinstance(key, DataFrame):
1657+
values = key.values
1658+
if values.dtype == bool:
1659+
return self.values[values]
1660+
else:
1661+
raise ValueError('Cannot index using non-boolean DataFrame')
16561662
else:
16571663
return self._get_item_cache(key)
16581664

pandas/core/generic.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -470,16 +470,10 @@ def _get_item_cache(self, item):
470470
try:
471471
return cache[item]
472472
except Exception:
473-
try:
474-
values = self._data.get(item)
475-
res = self._box_item_values(item, values)
476-
cache[item] = res
477-
return res
478-
except Exception: # pragma: no cover
479-
from pandas.core.frame import DataFrame
480-
if isinstance(item, DataFrame):
481-
raise ValueError('Cannot index using (boolean) dataframe')
482-
raise
473+
values = self._data.get(item)
474+
res = self._box_item_values(item, values)
475+
cache[item] = res
476+
return res
483477

484478
def _box_item_values(self, key, values):
485479
raise NotImplementedError

pandas/tests/test_frame.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,16 @@ def _checkit(lst):
151151
_checkit([True, True, True])
152152
_checkit([False, False, False])
153153

154+
def test_getitem_boolean_iadd(self):
155+
arr = randn(5, 5)
156+
157+
df = DataFrame(arr.copy())
158+
df[df < 0] += 1
159+
160+
arr[arr < 0] += 1
161+
162+
assert_almost_equal(df.values, arr)
163+
154164
def test_getattr(self):
155165
tm.assert_series_equal(self.frame.A, self.frame['A'])
156166
self.assertRaises(AttributeError, getattr, self.frame,

0 commit comments

Comments
 (0)