Skip to content

Commit 100f0b7

Browse files
committed
BUG: DataFrame.where does not respect axis parameter when shape is symmetric (GH pandas-dev#9736)
1 parent c97238c commit 100f0b7

File tree

4 files changed

+185
-57
lines changed

4 files changed

+185
-57
lines changed

doc/source/whatsnew/v0.17.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ Bug Fixes
8888

8989
- Bug in ``ExcelReader`` when worksheet is empty (:issue:`6403`)
9090

91+
- Bug causing ``DataFrame.where`` to not respect the ``axis`` parameter when the frame has a symmetric shape. (:issue:`9736`)
9192

9293
- Bug in ``Table.select_column`` where name is not preserved (:issue:`10392`)
9394

pandas/core/generic.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -3634,19 +3634,31 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
36343634
else:
36353635
other = self._constructor(other, **self._construct_axes_dict())
36363636

3637+
if axis is None:
3638+
axis = 0
3639+
3640+
if self.ndim == getattr(other, 'ndim', 0):
3641+
align = True
3642+
else:
3643+
align = (self._get_axis_number(axis) == 1)
3644+
3645+
block_axis = self._get_block_manager_axis(axis)
3646+
36373647
if inplace:
36383648
# we may have different type blocks come out of putmask, so
36393649
# reconstruct the block manager
36403650

36413651
self._check_inplace_setting(other)
3642-
new_data = self._data.putmask(mask=cond, new=other, align=axis is None,
3643-
inplace=True)
3652+
new_data = self._data.putmask(mask=cond, new=other, align=align,
3653+
inplace=True, axis=block_axis,
3654+
transpose=self._AXIS_REVERSED)
36443655
self._update_inplace(new_data)
36453656

36463657
else:
3647-
new_data = self._data.where(other=other, cond=cond, align=axis is None,
3658+
new_data = self._data.where(other=other, cond=cond, align=align,
36483659
raise_on_error=raise_on_error,
3649-
try_cast=try_cast)
3660+
try_cast=try_cast, axis=block_axis,
3661+
transpose=self._AXIS_REVERSED)
36503662

36513663
return self._constructor(new_data).__finalize__(self)
36523664

pandas/core/internals.py

+64-53
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,8 @@ def _is_empty_indexer(indexer):
632632

633633
return [self]
634634

635-
def putmask(self, mask, new, align=True, inplace=False):
635+
def putmask(self, mask, new, align=True, inplace=False,
636+
axis=0, transpose=False):
636637
""" putmask the data to the block; it is possible that we may create a
637638
new dtype of block
638639
@@ -644,37 +645,55 @@ def putmask(self, mask, new, align=True, inplace=False):
644645
new : a ndarray/object
645646
align : boolean, perform alignment on other/cond, default is True
646647
inplace : perform inplace modification, default is False
648+
axis : int
649+
transpose : boolean
650+
Set to True if self is stored with axes reversed
647651
648652
Returns
649653
-------
650-
a new block(s), the result of the putmask
654+
a list of new blocks, the result of the putmask
651655
"""
652656

653657
new_values = self.values if inplace else self.values.copy()
654658

655-
# may need to align the new
656659
if hasattr(new, 'reindex_axis'):
657-
new = new.values.T
660+
new = new.values
658661

659-
# may need to align the mask
660662
if hasattr(mask, 'reindex_axis'):
661-
mask = mask.values.T
663+
mask = mask.values
662664

663665
# if we are passed a scalar None, convert it here
664666
if not is_list_like(new) and isnull(new) and not self.is_object:
665667
new = self.fill_value
666668

667669
if self._can_hold_element(new):
670+
if transpose:
671+
new_values = new_values.T
672+
668673
new = self._try_cast(new)
669674

670-
# pseudo-broadcast
671-
if isinstance(new, np.ndarray) and new.ndim == self.ndim - 1:
672-
new = np.repeat(new, self.shape[-1]).reshape(self.shape)
675+
# If the default repeat behavior in np.putmask would go in the wrong
676+
# direction, then explictly repeat and reshape new instead
677+
if getattr(new, 'ndim', 0) >= 1:
678+
if self.ndim - 1 == new.ndim and axis == 1:
679+
new = np.repeat(new, new_values.shape[-1]).reshape(self.shape)
673680

674681
np.putmask(new_values, mask, new)
675682

676683
# maybe upcast me
677684
elif mask.any():
685+
if transpose:
686+
mask = mask.T
687+
if isinstance(new, np.ndarray):
688+
new = new.T
689+
axis = new_values.ndim - axis - 1
690+
691+
# Pseudo-broadcast
692+
if getattr(new, 'ndim', 0) >= 1:
693+
if self.ndim - 1 == new.ndim:
694+
new_shape = list(new.shape)
695+
new_shape.insert(axis, 1)
696+
new = new.reshape(tuple(new_shape))
678697

679698
# need to go column by column
680699
new_blocks = []
@@ -685,14 +704,15 @@ def putmask(self, mask, new, align=True, inplace=False):
685704

686705
# need a new block
687706
if m.any():
688-
689-
n = new[i] if isinstance(
690-
new, np.ndarray) else np.array(new)
707+
if isinstance(new, np.ndarray):
708+
n = np.squeeze(new[i % new.shape[0]])
709+
else:
710+
n = np.array(new)
691711

692712
# type of the new block
693713
dtype, _ = com._maybe_promote(n.dtype)
694714

695-
# we need to exiplicty astype here to make a copy
715+
# we need to explicitly astype here to make a copy
696716
n = n.astype(dtype)
697717

698718
nv = _putmask_smart(v, m, n)
@@ -718,8 +738,10 @@ def putmask(self, mask, new, align=True, inplace=False):
718738
if inplace:
719739
return [self]
720740

721-
return [make_block(new_values,
722-
placement=self.mgr_locs, fastpath=True)]
741+
if transpose:
742+
new_values = new_values.T
743+
744+
return [make_block(new_values, placement=self.mgr_locs, fastpath=True)]
723745

724746
def interpolate(self, method='pad', axis=0, index=None,
725747
values=None, inplace=False, limit=None,
@@ -1003,7 +1025,7 @@ def handle_error():
10031025
fastpath=True, placement=self.mgr_locs)]
10041026

10051027
def where(self, other, cond, align=True, raise_on_error=True,
1006-
try_cast=False):
1028+
try_cast=False, axis=0, transpose=False):
10071029
"""
10081030
evaluate the block; return result block(s) from the result
10091031
@@ -1014,50 +1036,33 @@ def where(self, other, cond, align=True, raise_on_error=True,
10141036
align : boolean, perform alignment on other/cond
10151037
raise_on_error : if True, raise when I can't perform the function,
10161038
False by default (and just return the data that we had coming in)
1039+
axis : int
1040+
transpose : boolean
1041+
Set to True if self is stored with axes reversed
10171042
10181043
Returns
10191044
-------
10201045
a new block(s), the result of the func
10211046
"""
10221047

10231048
values = self.values
1049+
if transpose:
1050+
values = values.T
10241051

1025-
# see if we can align other
10261052
if hasattr(other, 'reindex_axis'):
10271053
other = other.values
10281054

1029-
# make sure that we can broadcast
1030-
is_transposed = False
1031-
if hasattr(other, 'ndim') and hasattr(values, 'ndim'):
1032-
if values.ndim != other.ndim or values.shape == other.shape[::-1]:
1033-
1034-
# if its symmetric are ok, no reshaping needed (GH 7506)
1035-
if (values.shape[0] == np.array(values.shape)).all():
1036-
pass
1037-
1038-
# pseodo broadcast (its a 2d vs 1d say and where needs it in a
1039-
# specific direction)
1040-
elif (other.ndim >= 1 and values.ndim - 1 == other.ndim and
1041-
values.shape[0] != other.shape[0]):
1042-
other = _block_shape(other).T
1043-
else:
1044-
values = values.T
1045-
is_transposed = True
1046-
1047-
# see if we can align cond
1048-
if not hasattr(cond, 'shape'):
1049-
raise ValueError(
1050-
"where must have a condition that is ndarray like")
1051-
10521055
if hasattr(cond, 'reindex_axis'):
10531056
cond = cond.values
10541057

1055-
# may need to undo transpose of values
1056-
if hasattr(values, 'ndim'):
1057-
if values.ndim != cond.ndim or values.shape == cond.shape[::-1]:
1058+
# If the default broadcasting would go in the wrong direction, then
1059+
# explictly reshape other instead
1060+
if getattr(other, 'ndim', 0) >= 1:
1061+
if values.ndim - 1 == other.ndim and axis == 1:
1062+
other = other.reshape(tuple(other.shape + (1,)))
10581063

1059-
values = values.T
1060-
is_transposed = not is_transposed
1064+
if not hasattr(cond, 'shape'):
1065+
raise ValueError("where must have a condition that is ndarray like")
10611066

10621067
other = _maybe_convert_string_to_object(other)
10631068

@@ -1090,15 +1095,14 @@ def func(c, v, o):
10901095
raise TypeError('Could not compare [%s] with block values'
10911096
% repr(other))
10921097

1093-
if is_transposed:
1098+
if transpose:
10941099
result = result.T
10951100

10961101
# try to cast if requested
10971102
if try_cast:
10981103
result = self._try_cast_result(result)
10991104

1100-
return make_block(result,
1101-
ndim=self.ndim, placement=self.mgr_locs)
1105+
return make_block(result, ndim=self.ndim, placement=self.mgr_locs)
11021106

11031107
# might need to separate out blocks
11041108
axis = cond.ndim - 1
@@ -1733,7 +1737,8 @@ def take_nd(self, indexer, axis=0, new_mgr_locs=None, fill_tuple=None):
17331737

17341738
return self.make_block_same_class(new_values, new_mgr_locs)
17351739

1736-
def putmask(self, mask, new, align=True, inplace=False):
1740+
def putmask(self, mask, new, align=True, inplace=False,
1741+
axis=0, transpose=False):
17371742
""" putmask the data to the block; it is possible that we may create a
17381743
new dtype of block
17391744
@@ -2425,12 +2430,18 @@ def apply(self, f, axes=None, filter=None, do_integrity_check=False, **kwargs):
24252430
else:
24262431
kwargs['filter'] = filter_locs
24272432

2428-
if f == 'where' and kwargs.get('align', True):
2433+
if f == 'where':
24292434
align_copy = True
2430-
align_keys = ['other', 'cond']
2431-
elif f == 'putmask' and kwargs.get('align', True):
2435+
if kwargs.get('align', True):
2436+
align_keys = ['other', 'cond']
2437+
else:
2438+
align_keys = ['cond']
2439+
elif f == 'putmask':
24322440
align_copy = False
2433-
align_keys = ['new', 'mask']
2441+
if kwargs.get('align', True):
2442+
align_keys = ['new', 'mask']
2443+
else:
2444+
align_keys = ['mask']
24342445
elif f == 'eval':
24352446
align_copy = False
24362447
align_keys = ['other']

pandas/tests/test_frame.py

+104
Original file line numberDiff line numberDiff line change
@@ -10046,6 +10046,110 @@ def test_where_complex(self):
1004610046
df[df.abs() >= 5] = np.nan
1004710047
assert_frame_equal(df,expected)
1004810048

10049+
def test_where_axis(self):
10050+
# GH 9736
10051+
df = DataFrame(np.random.randn(2, 2))
10052+
mask = DataFrame([[False, False], [False, False]])
10053+
s = Series([0, 1])
10054+
10055+
expected = DataFrame([[0, 0], [1, 1]], dtype='float64')
10056+
result = df.where(mask, s, axis='index')
10057+
assert_frame_equal(result, expected)
10058+
10059+
result = df.copy()
10060+
result.where(mask, s, axis='index', inplace=True)
10061+
assert_frame_equal(result, expected)
10062+
10063+
expected = DataFrame([[0, 1], [0, 1]], dtype='float64')
10064+
result = df.where(mask, s, axis='columns')
10065+
assert_frame_equal(result, expected)
10066+
10067+
result = df.copy()
10068+
result.where(mask, s, axis='columns', inplace=True)
10069+
assert_frame_equal(result, expected)
10070+
10071+
# Upcast needed
10072+
df = DataFrame([[1, 2], [3, 4]], dtype='int64')
10073+
mask = DataFrame([[False, False], [False, False]])
10074+
s = Series([0, np.nan])
10075+
10076+
expected = DataFrame([[0, 0], [np.nan, np.nan]], dtype='float64')
10077+
result = df.where(mask, s, axis='index')
10078+
assert_frame_equal(result, expected)
10079+
10080+
result = df.copy()
10081+
result.where(mask, s, axis='index', inplace=True)
10082+
assert_frame_equal(result, expected)
10083+
10084+
expected = DataFrame([[0, np.nan], [0, np.nan]], dtype='float64')
10085+
result = df.where(mask, s, axis='columns')
10086+
assert_frame_equal(result, expected)
10087+
10088+
expected = DataFrame({0 : np.array([0, 0], dtype='int64'),
10089+
1 : np.array([np.nan, np.nan], dtype='float64')})
10090+
result = df.copy()
10091+
result.where(mask, s, axis='columns', inplace=True)
10092+
assert_frame_equal(result, expected)
10093+
10094+
# Multiple dtypes (=> multiple Blocks)
10095+
df = pd.concat([DataFrame(np.random.randn(10, 2)),
10096+
DataFrame(np.random.randint(0, 10, size=(10, 2)))],
10097+
ignore_index=True, axis=1)
10098+
mask = DataFrame(False, columns=df.columns, index=df.index)
10099+
s1 = Series(1, index=df.columns)
10100+
s2 = Series(2, index=df.index)
10101+
10102+
result = df.where(mask, s1, axis='columns')
10103+
expected = DataFrame(1.0, columns=df.columns, index=df.index)
10104+
expected[2] = expected[2].astype(int)
10105+
expected[3] = expected[3].astype(int)
10106+
assert_frame_equal(result, expected)
10107+
10108+
result = df.copy()
10109+
result.where(mask, s1, axis='columns', inplace=True)
10110+
assert_frame_equal(result, expected)
10111+
10112+
result = df.where(mask, s2, axis='index')
10113+
expected = DataFrame(2.0, columns=df.columns, index=df.index)
10114+
expected[2] = expected[2].astype(int)
10115+
expected[3] = expected[3].astype(int)
10116+
assert_frame_equal(result, expected)
10117+
10118+
result = df.copy()
10119+
result.where(mask, s2, axis='index', inplace=True)
10120+
assert_frame_equal(result, expected)
10121+
10122+
# DataFrame vs DataFrame
10123+
d1 = df.copy().drop(1, axis=0)
10124+
expected = df.copy()
10125+
expected.loc[1, :] = np.nan
10126+
10127+
result = df.where(mask, d1)
10128+
assert_frame_equal(result, expected)
10129+
result = df.where(mask, d1, axis='index')
10130+
assert_frame_equal(result, expected)
10131+
result = df.copy()
10132+
result.where(mask, d1, inplace=True)
10133+
assert_frame_equal(result, expected)
10134+
result = df.copy()
10135+
result.where(mask, d1, inplace=True, axis='index')
10136+
assert_frame_equal(result, expected)
10137+
10138+
d2 = df.copy().drop(1, axis=1)
10139+
expected = df.copy()
10140+
expected.loc[:, 1] = np.nan
10141+
10142+
result = df.where(mask, d2)
10143+
assert_frame_equal(result, expected)
10144+
result = df.where(mask, d2, axis='columns')
10145+
assert_frame_equal(result, expected)
10146+
result = df.copy()
10147+
result.where(mask, d2, inplace=True)
10148+
assert_frame_equal(result, expected)
10149+
result = df.copy()
10150+
result.where(mask, d2, inplace=True, axis='columns')
10151+
assert_frame_equal(result, expected)
10152+
1004910153
def test_mask(self):
1005010154
df = DataFrame(np.random.randn(5, 3))
1005110155
cond = df > 0

0 commit comments

Comments
 (0)