Skip to content

Commit 931e0e5

Browse files
evanpwjreback
authored andcommitted
BUG: DataFrame.where does not respect axis parameter when shape is symmetric (GH pandas-dev#9736)
1 parent 5314a5f commit 931e0e5

File tree

4 files changed

+188
-57
lines changed

4 files changed

+188
-57
lines changed

doc/source/whatsnew/v0.17.0.txt

+4
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,10 @@ Bug Fixes
602602
- Bug in ``Index`` construction with a mixed list of tuples (:issue:`10697`)
603603
- Bug in ``DataFrame.reset_index`` when index contains `NaT`. (:issue:`10388`)
604604
- Bug in ``ExcelReader`` when worksheet is empty (:issue:`6403`)
605+
606+
607+
- Bug causing ``DataFrame.where`` to not respect the ``axis`` parameter when the frame has a symmetric shape. (:issue:`9736`)
608+
605609
- Bug in ``Table.select_column`` where name is not preserved (:issue:`10392`)
606610
- Bug in ``offsets.generate_range`` where ``start`` and ``end`` have finer precision than ``offset`` (:issue:`9907`)
607611
- Bug in ``pd.rolling_*`` where ``Series.name`` would be lost in the output (:issue:`10565`)

pandas/core/generic.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -3665,19 +3665,31 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
36653665
else:
36663666
other = self._constructor(other, **self._construct_axes_dict())
36673667

3668+
if axis is None:
3669+
axis = 0
3670+
3671+
if self.ndim == getattr(other, 'ndim', 0):
3672+
align = True
3673+
else:
3674+
align = (self._get_axis_number(axis) == 1)
3675+
3676+
block_axis = self._get_block_manager_axis(axis)
3677+
36683678
if inplace:
36693679
# we may have different type blocks come out of putmask, so
36703680
# reconstruct the block manager
36713681

36723682
self._check_inplace_setting(other)
3673-
new_data = self._data.putmask(mask=cond, new=other, align=axis is None,
3674-
inplace=True)
3683+
new_data = self._data.putmask(mask=cond, new=other, align=align,
3684+
inplace=True, axis=block_axis,
3685+
transpose=self._AXIS_REVERSED)
36753686
self._update_inplace(new_data)
36763687

36773688
else:
3678-
new_data = self._data.where(other=other, cond=cond, align=axis is None,
3689+
new_data = self._data.where(other=other, cond=cond, align=align,
36793690
raise_on_error=raise_on_error,
3680-
try_cast=try_cast)
3691+
try_cast=try_cast, axis=block_axis,
3692+
transpose=self._AXIS_REVERSED)
36813693

36823694
return self._constructor(new_data).__finalize__(self)
36833695

pandas/core/internals.py

+64-53
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,8 @@ def _is_empty_indexer(indexer):
632632

633633
return [self]
634634

635-
def putmask(self, mask, new, align=True, inplace=False):
635+
def putmask(self, mask, new, align=True, inplace=False,
636+
axis=0, transpose=False):
636637
""" putmask the data to the block; it is possible that we may create a
637638
new dtype of block
638639
@@ -644,37 +645,55 @@ def putmask(self, mask, new, align=True, inplace=False):
644645
new : a ndarray/object
645646
align : boolean, perform alignment on other/cond, default is True
646647
inplace : perform inplace modification, default is False
648+
axis : int
649+
transpose : boolean
650+
Set to True if self is stored with axes reversed
647651
648652
Returns
649653
-------
650-
a new block(s), the result of the putmask
654+
a list of new blocks, the result of the putmask
651655
"""
652656

653657
new_values = self.values if inplace else self.values.copy()
654658

655-
# may need to align the new
656659
if hasattr(new, 'reindex_axis'):
657-
new = new.values.T
660+
new = new.values
658661

659-
# may need to align the mask
660662
if hasattr(mask, 'reindex_axis'):
661-
mask = mask.values.T
663+
mask = mask.values
662664

663665
# if we are passed a scalar None, convert it here
664666
if not is_list_like(new) and isnull(new) and not self.is_object:
665667
new = self.fill_value
666668

667669
if self._can_hold_element(new):
670+
if transpose:
671+
new_values = new_values.T
672+
668673
new = self._try_cast(new)
669674

670-
# pseudo-broadcast
671-
if isinstance(new, np.ndarray) and new.ndim == self.ndim - 1:
672-
new = np.repeat(new, self.shape[-1]).reshape(self.shape)
675+
# If the default repeat behavior in np.putmask would go in the wrong
676+
# direction, then explictly repeat and reshape new instead
677+
if getattr(new, 'ndim', 0) >= 1:
678+
if self.ndim - 1 == new.ndim and axis == 1:
679+
new = np.repeat(new, new_values.shape[-1]).reshape(self.shape)
673680

674681
np.putmask(new_values, mask, new)
675682

676683
# maybe upcast me
677684
elif mask.any():
685+
if transpose:
686+
mask = mask.T
687+
if isinstance(new, np.ndarray):
688+
new = new.T
689+
axis = new_values.ndim - axis - 1
690+
691+
# Pseudo-broadcast
692+
if getattr(new, 'ndim', 0) >= 1:
693+
if self.ndim - 1 == new.ndim:
694+
new_shape = list(new.shape)
695+
new_shape.insert(axis, 1)
696+
new = new.reshape(tuple(new_shape))
678697

679698
# need to go column by column
680699
new_blocks = []
@@ -685,14 +704,15 @@ def putmask(self, mask, new, align=True, inplace=False):
685704

686705
# need a new block
687706
if m.any():
688-
689-
n = new[i] if isinstance(
690-
new, np.ndarray) else np.array(new)
707+
if isinstance(new, np.ndarray):
708+
n = np.squeeze(new[i % new.shape[0]])
709+
else:
710+
n = np.array(new)
691711

692712
# type of the new block
693713
dtype, _ = com._maybe_promote(n.dtype)
694714

695-
# we need to exiplicty astype here to make a copy
715+
# we need to explicitly astype here to make a copy
696716
n = n.astype(dtype)
697717

698718
nv = _putmask_smart(v, m, n)
@@ -718,8 +738,10 @@ def putmask(self, mask, new, align=True, inplace=False):
718738
if inplace:
719739
return [self]
720740

721-
return [make_block(new_values,
722-
placement=self.mgr_locs, fastpath=True)]
741+
if transpose:
742+
new_values = new_values.T
743+
744+
return [make_block(new_values, placement=self.mgr_locs, fastpath=True)]
723745

724746
def interpolate(self, method='pad', axis=0, index=None,
725747
values=None, inplace=False, limit=None,
@@ -1003,7 +1025,7 @@ def handle_error():
10031025
fastpath=True, placement=self.mgr_locs)]
10041026

10051027
def where(self, other, cond, align=True, raise_on_error=True,
1006-
try_cast=False):
1028+
try_cast=False, axis=0, transpose=False):
10071029
"""
10081030
evaluate the block; return result block(s) from the result
10091031
@@ -1014,50 +1036,33 @@ def where(self, other, cond, align=True, raise_on_error=True,
10141036
align : boolean, perform alignment on other/cond
10151037
raise_on_error : if True, raise when I can't perform the function,
10161038
False by default (and just return the data that we had coming in)
1039+
axis : int
1040+
transpose : boolean
1041+
Set to True if self is stored with axes reversed
10171042
10181043
Returns
10191044
-------
10201045
a new block(s), the result of the func
10211046
"""
10221047

10231048
values = self.values
1049+
if transpose:
1050+
values = values.T
10241051

1025-
# see if we can align other
10261052
if hasattr(other, 'reindex_axis'):
10271053
other = other.values
10281054

1029-
# make sure that we can broadcast
1030-
is_transposed = False
1031-
if hasattr(other, 'ndim') and hasattr(values, 'ndim'):
1032-
if values.ndim != other.ndim or values.shape == other.shape[::-1]:
1033-
1034-
# if its symmetric are ok, no reshaping needed (GH 7506)
1035-
if (values.shape[0] == np.array(values.shape)).all():
1036-
pass
1037-
1038-
# pseodo broadcast (its a 2d vs 1d say and where needs it in a
1039-
# specific direction)
1040-
elif (other.ndim >= 1 and values.ndim - 1 == other.ndim and
1041-
values.shape[0] != other.shape[0]):
1042-
other = _block_shape(other).T
1043-
else:
1044-
values = values.T
1045-
is_transposed = True
1046-
1047-
# see if we can align cond
1048-
if not hasattr(cond, 'shape'):
1049-
raise ValueError(
1050-
"where must have a condition that is ndarray like")
1051-
10521055
if hasattr(cond, 'reindex_axis'):
10531056
cond = cond.values
10541057

1055-
# may need to undo transpose of values
1056-
if hasattr(values, 'ndim'):
1057-
if values.ndim != cond.ndim or values.shape == cond.shape[::-1]:
1058+
# If the default broadcasting would go in the wrong direction, then
1059+
# explictly reshape other instead
1060+
if getattr(other, 'ndim', 0) >= 1:
1061+
if values.ndim - 1 == other.ndim and axis == 1:
1062+
other = other.reshape(tuple(other.shape + (1,)))
10581063

1059-
values = values.T
1060-
is_transposed = not is_transposed
1064+
if not hasattr(cond, 'shape'):
1065+
raise ValueError("where must have a condition that is ndarray like")
10611066

10621067
other = _maybe_convert_string_to_object(other)
10631068

@@ -1090,15 +1095,14 @@ def func(c, v, o):
10901095
raise TypeError('Could not compare [%s] with block values'
10911096
% repr(other))
10921097

1093-
if is_transposed:
1098+
if transpose:
10941099
result = result.T
10951100

10961101
# try to cast if requested
10971102
if try_cast:
10981103
result = self._try_cast_result(result)
10991104

1100-
return make_block(result,
1101-
ndim=self.ndim, placement=self.mgr_locs)
1105+
return make_block(result, ndim=self.ndim, placement=self.mgr_locs)
11021106

11031107
# might need to separate out blocks
11041108
axis = cond.ndim - 1
@@ -1744,7 +1748,8 @@ def take_nd(self, indexer, axis=0, new_mgr_locs=None, fill_tuple=None):
17441748

17451749
return self.make_block_same_class(new_values, new_mgr_locs)
17461750

1747-
def putmask(self, mask, new, align=True, inplace=False):
1751+
def putmask(self, mask, new, align=True, inplace=False,
1752+
axis=0, transpose=False):
17481753
""" putmask the data to the block; it is possible that we may create a
17491754
new dtype of block
17501755
@@ -2436,12 +2441,18 @@ def apply(self, f, axes=None, filter=None, do_integrity_check=False, **kwargs):
24362441
else:
24372442
kwargs['filter'] = filter_locs
24382443

2439-
if f == 'where' and kwargs.get('align', True):
2444+
if f == 'where':
24402445
align_copy = True
2441-
align_keys = ['other', 'cond']
2442-
elif f == 'putmask' and kwargs.get('align', True):
2446+
if kwargs.get('align', True):
2447+
align_keys = ['other', 'cond']
2448+
else:
2449+
align_keys = ['cond']
2450+
elif f == 'putmask':
24432451
align_copy = False
2444-
align_keys = ['new', 'mask']
2452+
if kwargs.get('align', True):
2453+
align_keys = ['new', 'mask']
2454+
else:
2455+
align_keys = ['mask']
24452456
elif f == 'eval':
24462457
align_copy = False
24472458
align_keys = ['other']

pandas/tests/test_frame.py

+104
Original file line numberDiff line numberDiff line change
@@ -10394,6 +10394,110 @@ def test_where_complex(self):
1039410394
df[df.abs() >= 5] = np.nan
1039510395
assert_frame_equal(df,expected)
1039610396

10397+
def test_where_axis(self):
10398+
# GH 9736
10399+
df = DataFrame(np.random.randn(2, 2))
10400+
mask = DataFrame([[False, False], [False, False]])
10401+
s = Series([0, 1])
10402+
10403+
expected = DataFrame([[0, 0], [1, 1]], dtype='float64')
10404+
result = df.where(mask, s, axis='index')
10405+
assert_frame_equal(result, expected)
10406+
10407+
result = df.copy()
10408+
result.where(mask, s, axis='index', inplace=True)
10409+
assert_frame_equal(result, expected)
10410+
10411+
expected = DataFrame([[0, 1], [0, 1]], dtype='float64')
10412+
result = df.where(mask, s, axis='columns')
10413+
assert_frame_equal(result, expected)
10414+
10415+
result = df.copy()
10416+
result.where(mask, s, axis='columns', inplace=True)
10417+
assert_frame_equal(result, expected)
10418+
10419+
# Upcast needed
10420+
df = DataFrame([[1, 2], [3, 4]], dtype='int64')
10421+
mask = DataFrame([[False, False], [False, False]])
10422+
s = Series([0, np.nan])
10423+
10424+
expected = DataFrame([[0, 0], [np.nan, np.nan]], dtype='float64')
10425+
result = df.where(mask, s, axis='index')
10426+
assert_frame_equal(result, expected)
10427+
10428+
result = df.copy()
10429+
result.where(mask, s, axis='index', inplace=True)
10430+
assert_frame_equal(result, expected)
10431+
10432+
expected = DataFrame([[0, np.nan], [0, np.nan]], dtype='float64')
10433+
result = df.where(mask, s, axis='columns')
10434+
assert_frame_equal(result, expected)
10435+
10436+
expected = DataFrame({0 : np.array([0, 0], dtype='int64'),
10437+
1 : np.array([np.nan, np.nan], dtype='float64')})
10438+
result = df.copy()
10439+
result.where(mask, s, axis='columns', inplace=True)
10440+
assert_frame_equal(result, expected)
10441+
10442+
# Multiple dtypes (=> multiple Blocks)
10443+
df = pd.concat([DataFrame(np.random.randn(10, 2)),
10444+
DataFrame(np.random.randint(0, 10, size=(10, 2)))],
10445+
ignore_index=True, axis=1)
10446+
mask = DataFrame(False, columns=df.columns, index=df.index)
10447+
s1 = Series(1, index=df.columns)
10448+
s2 = Series(2, index=df.index)
10449+
10450+
result = df.where(mask, s1, axis='columns')
10451+
expected = DataFrame(1.0, columns=df.columns, index=df.index)
10452+
expected[2] = expected[2].astype(int)
10453+
expected[3] = expected[3].astype(int)
10454+
assert_frame_equal(result, expected)
10455+
10456+
result = df.copy()
10457+
result.where(mask, s1, axis='columns', inplace=True)
10458+
assert_frame_equal(result, expected)
10459+
10460+
result = df.where(mask, s2, axis='index')
10461+
expected = DataFrame(2.0, columns=df.columns, index=df.index)
10462+
expected[2] = expected[2].astype(int)
10463+
expected[3] = expected[3].astype(int)
10464+
assert_frame_equal(result, expected)
10465+
10466+
result = df.copy()
10467+
result.where(mask, s2, axis='index', inplace=True)
10468+
assert_frame_equal(result, expected)
10469+
10470+
# DataFrame vs DataFrame
10471+
d1 = df.copy().drop(1, axis=0)
10472+
expected = df.copy()
10473+
expected.loc[1, :] = np.nan
10474+
10475+
result = df.where(mask, d1)
10476+
assert_frame_equal(result, expected)
10477+
result = df.where(mask, d1, axis='index')
10478+
assert_frame_equal(result, expected)
10479+
result = df.copy()
10480+
result.where(mask, d1, inplace=True)
10481+
assert_frame_equal(result, expected)
10482+
result = df.copy()
10483+
result.where(mask, d1, inplace=True, axis='index')
10484+
assert_frame_equal(result, expected)
10485+
10486+
d2 = df.copy().drop(1, axis=1)
10487+
expected = df.copy()
10488+
expected.loc[:, 1] = np.nan
10489+
10490+
result = df.where(mask, d2)
10491+
assert_frame_equal(result, expected)
10492+
result = df.where(mask, d2, axis='columns')
10493+
assert_frame_equal(result, expected)
10494+
result = df.copy()
10495+
result.where(mask, d2, inplace=True)
10496+
assert_frame_equal(result, expected)
10497+
result = df.copy()
10498+
result.where(mask, d2, inplace=True, axis='columns')
10499+
assert_frame_equal(result, expected)
10500+
1039710501
def test_mask(self):
1039810502
df = DataFrame(np.random.randn(5, 3))
1039910503
cond = df > 0

0 commit comments

Comments
 (0)