Skip to content

Commit b55ca5c

Browse files
committed
Merge pull request pandas-dev#10283 from mortada/df_where
BUG: DataFrame.where does not handle Series slice correctly (pandas-dev#10218)
2 parents 3e83459 + a25a664 commit b55ca5c

File tree

6 files changed

+90
-7
lines changed

6 files changed

+90
-7
lines changed

doc/source/whatsnew/v0.17.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -829,3 +829,4 @@ Bug Fixes
829829
- Bug in ``read_msgpack`` where encoding is not respected (:issue:`10580`)
830830
- Bug preventing access to the first index when using ``iloc`` with a list containing the appropriate negative integer (:issue:`10547`, :issue:`10779`)
831831
- Bug in ``TimedeltaIndex`` formatter causing error while trying to save ``DataFrame`` with ``TimedeltaIndex`` using ``to_csv`` (:issue:`10833`)
832+
- BUG in ``DataFrame.where`` when handling Series slicing (:issue:`10218`, :issue:`9558`)

pandas/core/frame.py

+8
Original file line numberDiff line numberDiff line change
@@ -2591,6 +2591,14 @@ def _reindex_multi(self, axes, copy, fill_value):
25912591
copy=copy,
25922592
fill_value=fill_value)
25932593

2594+
@Appender(_shared_docs['align'] % _shared_doc_kwargs)
2595+
def align(self, other, join='outer', axis=None, level=None, copy=True,
2596+
fill_value=None, method=None, limit=None, fill_axis=0,
2597+
broadcast_axis=None):
2598+
return super(DataFrame, self).align(other, join=join, axis=axis, level=level, copy=copy,
2599+
fill_value=fill_value, method=method, limit=limit,
2600+
fill_axis=fill_axis, broadcast_axis=broadcast_axis)
2601+
25942602
@Appender(_shared_docs['reindex'] % _shared_doc_kwargs)
25952603
def reindex(self, index=None, columns=None, **kwargs):
25962604
return super(DataFrame, self).reindex(index=index, columns=columns,

pandas/core/generic.py

+35-7
Original file line numberDiff line numberDiff line change
@@ -3433,8 +3433,7 @@ def last(self, offset):
34333433
start = self.index.searchsorted(start_date, side='right')
34343434
return self.ix[start:]
34353435

3436-
def align(self, other, join='outer', axis=None, level=None, copy=True,
3437-
fill_value=None, method=None, limit=None, fill_axis=0):
3436+
_shared_docs['align'] = (
34383437
"""
34393438
Align two object on their axes with the
34403439
specified join method for each axis Index
@@ -3456,17 +3455,46 @@ def align(self, other, join='outer', axis=None, level=None, copy=True,
34563455
"compatible" value
34573456
method : str, default None
34583457
limit : int, default None
3459-
fill_axis : {0, 1}, default 0
3458+
fill_axis : %(axes_single_arg)s, default 0
34603459
Filling axis, method and limit
3460+
broadcast_axis : %(axes_single_arg)s, default None
3461+
Broadcast values along this axis, if aligning two objects of
3462+
different dimensions
3463+
3464+
.. versionadded:: 0.17.0
34613465
34623466
Returns
34633467
-------
3464-
(left, right) : (type of input, type of other)
3468+
(left, right) : (%(klass)s, type of other)
34653469
Aligned objects
34663470
"""
3471+
)
3472+
3473+
@Appender(_shared_docs['align'] % _shared_doc_kwargs)
3474+
def align(self, other, join='outer', axis=None, level=None, copy=True,
3475+
fill_value=None, method=None, limit=None, fill_axis=0,
3476+
broadcast_axis=None):
34673477
from pandas import DataFrame, Series
34683478
method = com._clean_fill_method(method)
34693479

3480+
if broadcast_axis == 1 and self.ndim != other.ndim:
3481+
if isinstance(self, Series):
3482+
# this means other is a DataFrame, and we need to broadcast self
3483+
df = DataFrame(dict((c, self) for c in other.columns),
3484+
**other._construct_axes_dict())
3485+
return df._align_frame(other, join=join, axis=axis, level=level,
3486+
copy=copy, fill_value=fill_value,
3487+
method=method, limit=limit,
3488+
fill_axis=fill_axis)
3489+
elif isinstance(other, Series):
3490+
# this means self is a DataFrame, and we need to broadcast other
3491+
df = DataFrame(dict((c, other) for c in self.columns),
3492+
**self._construct_axes_dict())
3493+
return self._align_frame(df, join=join, axis=axis, level=level,
3494+
copy=copy, fill_value=fill_value,
3495+
method=method, limit=limit,
3496+
fill_axis=fill_axis)
3497+
34703498
if axis is not None:
34713499
axis = self._get_axis_number(axis)
34723500
if isinstance(other, DataFrame):
@@ -3502,11 +3530,11 @@ def _align_frame(self, other, join='outer', axis=None, level=None,
35023530
self.columns.join(other.columns, how=join, level=level,
35033531
return_indexers=True)
35043532

3505-
left = self._reindex_with_indexers({0: [join_index, ilidx],
3533+
left = self._reindex_with_indexers({0: [join_index, ilidx],
35063534
1: [join_columns, clidx]},
35073535
copy=copy, fill_value=fill_value,
35083536
allow_dups=True)
3509-
right = other._reindex_with_indexers({0: [join_index, iridx],
3537+
right = other._reindex_with_indexers({0: [join_index, iridx],
35103538
1: [join_columns, cridx]},
35113539
copy=copy, fill_value=fill_value,
35123540
allow_dups=True)
@@ -3610,7 +3638,7 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
36103638
try_cast=False, raise_on_error=True):
36113639

36123640
if isinstance(cond, NDFrame):
3613-
cond = cond.reindex(**self._construct_axes_dict())
3641+
cond, _ = cond.align(self, join='right', broadcast_axis=1)
36143642
else:
36153643
if not hasattr(cond, 'shape'):
36163644
raise ValueError('where requires an ndarray like object for '

pandas/core/panel.py

+3
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,9 @@ def _needs_reindex_multi(self, axes, method, level):
628628
""" don't allow a multi reindex on Panel or above ndim """
629629
return False
630630

631+
def align(self, other, **kwargs):
632+
raise NotImplementedError
633+
631634
def dropna(self, axis=0, how='any', inplace=False):
632635
"""
633636
Drop 2D from panel, holding passed axis constant

pandas/core/series.py

+8
Original file line numberDiff line numberDiff line change
@@ -2165,6 +2165,14 @@ def _needs_reindex_multi(self, axes, method, level):
21652165
"""
21662166
return False
21672167

2168+
@Appender(generic._shared_docs['align'] % _shared_doc_kwargs)
2169+
def align(self, other, join='outer', axis=None, level=None, copy=True,
2170+
fill_value=None, method=None, limit=None, fill_axis=0,
2171+
broadcast_axis=None):
2172+
return super(Series, self).align(other, join=join, axis=axis, level=level, copy=copy,
2173+
fill_value=fill_value, method=method, limit=limit,
2174+
fill_axis=fill_axis, broadcast_axis=broadcast_axis)
2175+
21682176
@Appender(generic._shared_docs['rename'] % _shared_doc_kwargs)
21692177
def rename(self, index=None, **kwargs):
21702178
return super(Series, self).rename(index=index, **kwargs)

pandas/tests/test_frame.py

+35
Original file line numberDiff line numberDiff line change
@@ -10066,6 +10066,34 @@ def test_align(self):
1006610066
self.assertRaises(ValueError, self.frame.align, af.ix[0, :3],
1006710067
join='inner', axis=2)
1006810068

10069+
# align dataframe to series with broadcast or not
10070+
idx = self.frame.index
10071+
s = Series(range(len(idx)), index=idx)
10072+
10073+
left, right = self.frame.align(s, axis=0)
10074+
tm.assert_index_equal(left.index, self.frame.index)
10075+
tm.assert_index_equal(right.index, self.frame.index)
10076+
self.assertTrue(isinstance(right, Series))
10077+
10078+
left, right = self.frame.align(s, broadcast_axis=1)
10079+
tm.assert_index_equal(left.index, self.frame.index)
10080+
expected = {}
10081+
for c in self.frame.columns:
10082+
expected[c] = s
10083+
expected = DataFrame(expected, index=self.frame.index,
10084+
columns=self.frame.columns)
10085+
assert_frame_equal(right, expected)
10086+
10087+
# GH 9558
10088+
df = DataFrame({'a':[1,2,3], 'b':[4,5,6]})
10089+
result = df[df['a'] == 2]
10090+
expected = DataFrame([[2, 5]], index=[1], columns=['a', 'b'])
10091+
assert_frame_equal(result, expected)
10092+
10093+
result = df.where(df['a'] == 2, 0)
10094+
expected = DataFrame({'a':[0, 2, 0], 'b':[0, 5, 0]})
10095+
assert_frame_equal(result, expected)
10096+
1006910097
def _check_align(self, a, b, axis, fill_axis, how, method, limit=None):
1007010098
aa, ab = a.align(b, axis=axis, join=how, method=method, limit=limit,
1007110099
fill_axis=fill_axis)
@@ -10311,6 +10339,13 @@ def _check_set(df, cond, check_dtypes = True):
1031110339
cond = (df >= 0)[1:]
1031210340
_check_set(df, cond)
1031310341

10342+
# GH 10218
10343+
# test DataFrame.where with Series slicing
10344+
df = DataFrame({'a': range(3), 'b': range(4, 7)})
10345+
result = df.where(df['a'] == 1)
10346+
expected = df[df['a'] == 1].reindex(df.index)
10347+
assert_frame_equal(result, expected)
10348+
1031410349
def test_where_bug(self):
1031510350

1031610351
# GH 2793

0 commit comments

Comments
 (0)