Skip to content

Commit a25a664

Browse files
committed
BUG: DataFrame.where does not handle Series slice correctly (fixes pandas-dev#10218)
1 parent 42ca8cd commit a25a664

File tree

6 files changed

+90
-7
lines changed

6 files changed

+90
-7
lines changed

doc/source/whatsnew/v0.17.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -785,3 +785,4 @@ Bug Fixes
785785
- Bug in ``read_msgpack`` where encoding is not respected (:issue:`10580`)
786786
- Bug preventing access to the first index when using ``iloc`` with a list containing the appropriate negative integer (:issue:`10547`, :issue:`10779`)
787787
- Bug in ``TimedeltaIndex`` formatter causing error while trying to save ``DataFrame`` with ``TimedeltaIndex`` using ``to_csv`` (:issue:`10833`)
788+
- BUG in ``DataFrame.where`` when handling Series slicing (:issue:`10218`, :issue:`9558`)

pandas/core/frame.py

+8
Original file line numberDiff line numberDiff line change
@@ -2596,6 +2596,14 @@ def _reindex_multi(self, axes, copy, fill_value):
25962596
copy=copy,
25972597
fill_value=fill_value)
25982598

2599+
@Appender(_shared_docs['align'] % _shared_doc_kwargs)
2600+
def align(self, other, join='outer', axis=None, level=None, copy=True,
2601+
fill_value=None, method=None, limit=None, fill_axis=0,
2602+
broadcast_axis=None):
2603+
return super(DataFrame, self).align(other, join=join, axis=axis, level=level, copy=copy,
2604+
fill_value=fill_value, method=method, limit=limit,
2605+
fill_axis=fill_axis, broadcast_axis=broadcast_axis)
2606+
25992607
@Appender(_shared_docs['reindex'] % _shared_doc_kwargs)
26002608
def reindex(self, index=None, columns=None, **kwargs):
26012609
return super(DataFrame, self).reindex(index=index, columns=columns,

pandas/core/generic.py

+35-7
Original file line numberDiff line numberDiff line change
@@ -3447,8 +3447,7 @@ def last(self, offset):
34473447
start = self.index.searchsorted(start_date, side='right')
34483448
return self.ix[start:]
34493449

3450-
def align(self, other, join='outer', axis=None, level=None, copy=True,
3451-
fill_value=None, method=None, limit=None, fill_axis=0):
3450+
_shared_docs['align'] = (
34523451
"""
34533452
Align two object on their axes with the
34543453
specified join method for each axis Index
@@ -3470,17 +3469,46 @@ def align(self, other, join='outer', axis=None, level=None, copy=True,
34703469
"compatible" value
34713470
method : str, default None
34723471
limit : int, default None
3473-
fill_axis : {0, 1}, default 0
3472+
fill_axis : %(axes_single_arg)s, default 0
34743473
Filling axis, method and limit
3474+
broadcast_axis : %(axes_single_arg)s, default None
3475+
Broadcast values along this axis, if aligning two objects of
3476+
different dimensions
3477+
3478+
.. versionadded:: 0.17.0
34753479
34763480
Returns
34773481
-------
3478-
(left, right) : (type of input, type of other)
3482+
(left, right) : (%(klass)s, type of other)
34793483
Aligned objects
34803484
"""
3485+
)
3486+
3487+
@Appender(_shared_docs['align'] % _shared_doc_kwargs)
3488+
def align(self, other, join='outer', axis=None, level=None, copy=True,
3489+
fill_value=None, method=None, limit=None, fill_axis=0,
3490+
broadcast_axis=None):
34813491
from pandas import DataFrame, Series
34823492
method = com._clean_fill_method(method)
34833493

3494+
if broadcast_axis == 1 and self.ndim != other.ndim:
3495+
if isinstance(self, Series):
3496+
# this means other is a DataFrame, and we need to broadcast self
3497+
df = DataFrame(dict((c, self) for c in other.columns),
3498+
**other._construct_axes_dict())
3499+
return df._align_frame(other, join=join, axis=axis, level=level,
3500+
copy=copy, fill_value=fill_value,
3501+
method=method, limit=limit,
3502+
fill_axis=fill_axis)
3503+
elif isinstance(other, Series):
3504+
# this means self is a DataFrame, and we need to broadcast other
3505+
df = DataFrame(dict((c, other) for c in self.columns),
3506+
**self._construct_axes_dict())
3507+
return self._align_frame(df, join=join, axis=axis, level=level,
3508+
copy=copy, fill_value=fill_value,
3509+
method=method, limit=limit,
3510+
fill_axis=fill_axis)
3511+
34843512
if axis is not None:
34853513
axis = self._get_axis_number(axis)
34863514
if isinstance(other, DataFrame):
@@ -3516,11 +3544,11 @@ def _align_frame(self, other, join='outer', axis=None, level=None,
35163544
self.columns.join(other.columns, how=join, level=level,
35173545
return_indexers=True)
35183546

3519-
left = self._reindex_with_indexers({0: [join_index, ilidx],
3547+
left = self._reindex_with_indexers({0: [join_index, ilidx],
35203548
1: [join_columns, clidx]},
35213549
copy=copy, fill_value=fill_value,
35223550
allow_dups=True)
3523-
right = other._reindex_with_indexers({0: [join_index, iridx],
3551+
right = other._reindex_with_indexers({0: [join_index, iridx],
35243552
1: [join_columns, cridx]},
35253553
copy=copy, fill_value=fill_value,
35263554
allow_dups=True)
@@ -3624,7 +3652,7 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
36243652
try_cast=False, raise_on_error=True):
36253653

36263654
if isinstance(cond, NDFrame):
3627-
cond = cond.reindex(**self._construct_axes_dict())
3655+
cond, _ = cond.align(self, join='right', broadcast_axis=1)
36283656
else:
36293657
if not hasattr(cond, 'shape'):
36303658
raise ValueError('where requires an ndarray like object for '

pandas/core/panel.py

+3
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,9 @@ def _needs_reindex_multi(self, axes, method, level):
628628
""" don't allow a multi reindex on Panel or above ndim """
629629
return False
630630

631+
def align(self, other, **kwargs):
632+
raise NotImplementedError
633+
631634
def dropna(self, axis=0, how='any', inplace=False):
632635
"""
633636
Drop 2D from panel, holding passed axis constant

pandas/core/series.py

+8
Original file line numberDiff line numberDiff line change
@@ -2164,6 +2164,14 @@ def _needs_reindex_multi(self, axes, method, level):
21642164
"""
21652165
return False
21662166

2167+
@Appender(generic._shared_docs['align'] % _shared_doc_kwargs)
2168+
def align(self, other, join='outer', axis=None, level=None, copy=True,
2169+
fill_value=None, method=None, limit=None, fill_axis=0,
2170+
broadcast_axis=None):
2171+
return super(Series, self).align(other, join=join, axis=axis, level=level, copy=copy,
2172+
fill_value=fill_value, method=method, limit=limit,
2173+
fill_axis=fill_axis, broadcast_axis=broadcast_axis)
2174+
21672175
@Appender(generic._shared_docs['rename'] % _shared_doc_kwargs)
21682176
def rename(self, index=None, **kwargs):
21692177
return super(Series, self).rename(index=index, **kwargs)

pandas/tests/test_frame.py

+35
Original file line numberDiff line numberDiff line change
@@ -10065,6 +10065,34 @@ def test_align(self):
1006510065
self.assertRaises(ValueError, self.frame.align, af.ix[0, :3],
1006610066
join='inner', axis=2)
1006710067

10068+
# align dataframe to series with broadcast or not
10069+
idx = self.frame.index
10070+
s = Series(range(len(idx)), index=idx)
10071+
10072+
left, right = self.frame.align(s, axis=0)
10073+
tm.assert_index_equal(left.index, self.frame.index)
10074+
tm.assert_index_equal(right.index, self.frame.index)
10075+
self.assertTrue(isinstance(right, Series))
10076+
10077+
left, right = self.frame.align(s, broadcast_axis=1)
10078+
tm.assert_index_equal(left.index, self.frame.index)
10079+
expected = {}
10080+
for c in self.frame.columns:
10081+
expected[c] = s
10082+
expected = DataFrame(expected, index=self.frame.index,
10083+
columns=self.frame.columns)
10084+
assert_frame_equal(right, expected)
10085+
10086+
# GH 9558
10087+
df = DataFrame({'a':[1,2,3], 'b':[4,5,6]})
10088+
result = df[df['a'] == 2]
10089+
expected = DataFrame([[2, 5]], index=[1], columns=['a', 'b'])
10090+
assert_frame_equal(result, expected)
10091+
10092+
result = df.where(df['a'] == 2, 0)
10093+
expected = DataFrame({'a':[0, 2, 0], 'b':[0, 5, 0]})
10094+
assert_frame_equal(result, expected)
10095+
1006810096
def _check_align(self, a, b, axis, fill_axis, how, method, limit=None):
1006910097
aa, ab = a.align(b, axis=axis, join=how, method=method, limit=limit,
1007010098
fill_axis=fill_axis)
@@ -10310,6 +10338,13 @@ def _check_set(df, cond, check_dtypes = True):
1031010338
cond = (df >= 0)[1:]
1031110339
_check_set(df, cond)
1031210340

10341+
# GH 10218
10342+
# test DataFrame.where with Series slicing
10343+
df = DataFrame({'a': range(3), 'b': range(4, 7)})
10344+
result = df.where(df['a'] == 1)
10345+
expected = df[df['a'] == 1].reindex(df.index)
10346+
assert_frame_equal(result, expected)
10347+
1031310348
def test_where_bug(self):
1031410349

1031510350
# GH 2793

0 commit comments

Comments
 (0)