Skip to content

Commit 7e36845

Browse files
committed
BUG: DataFrame.where does not respect axis parameter when shape is symmetric (GH pandas-dev#9736)
1 parent 990972b commit 7e36845

File tree

4 files changed

+117
-53
lines changed

4 files changed

+117
-53
lines changed

doc/source/whatsnew/v0.16.1.txt

+3
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,6 @@ Bug Fixes
247247

248248

249249
- Bug in hiding ticklabels with subplots and shared axes when adding a new plot to an existing grid of axes (:issue:`9158`)
250+
251+
- Bug causing ``DataFrame.where`` to not respect the ``axis`` parameter when the frame has a symmetric shape. (:issue:`9736`)
252+

pandas/core/generic.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -3409,19 +3409,29 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
34093409
else:
34103410
other = self._constructor(other, **self._construct_axes_dict())
34113411

3412+
if axis is None:
3413+
axis = 0
3414+
align = True
3415+
else:
3416+
align = False
3417+
3418+
block_axis = self._get_block_manager_axis(axis)
3419+
34123420
if inplace:
34133421
# we may have different type blocks come out of putmask, so
34143422
# reconstruct the block manager
34153423

34163424
self._check_inplace_setting(other)
3417-
new_data = self._data.putmask(mask=cond, new=other, align=axis is None,
3418-
inplace=True)
3425+
new_data = self._data.putmask(mask=cond, new=other, align=align,
3426+
inplace=True, axis=block_axis,
3427+
transpose=self._AXIS_REVERSED)
34193428
self._update_inplace(new_data)
34203429

34213430
else:
3422-
new_data = self._data.where(other=other, cond=cond, align=axis is None,
3431+
new_data = self._data.where(other=other, cond=cond, align=align,
34233432
raise_on_error=raise_on_error,
3424-
try_cast=try_cast)
3433+
try_cast=try_cast, axis=block_axis,
3434+
transpose=self._AXIS_REVERSED)
34253435

34263436
return self._constructor(new_data).__finalize__(self)
34273437

pandas/core/internals.py

+54-49
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,8 @@ def _is_empty_indexer(indexer):
632632

633633
return [self]
634634

635-
def putmask(self, mask, new, align=True, inplace=False):
635+
def putmask(self, mask, new, align=True, inplace=False,
636+
axis=0, transpose=False):
636637
""" putmask the data to the block; it is possible that we may create a
637638
new dtype of block
638639
@@ -644,37 +645,55 @@ def putmask(self, mask, new, align=True, inplace=False):
644645
new : a ndarray/object
645646
align : boolean, perform alignment on other/cond, default is True
646647
inplace : perform inplace modification, default is False
648+
axis : int
649+
transpose : boolean
650+
Set to True if self is stored with axes reversed
647651
648652
Returns
649653
-------
650-
a new block(s), the result of the putmask
654+
a list of new blocks, the result of the putmask
651655
"""
652656

653657
new_values = self.values if inplace else self.values.copy()
654658

655-
# may need to align the new
656659
if hasattr(new, 'reindex_axis'):
657-
new = new.values.T
660+
new = new.values
658661

659-
# may need to align the mask
660662
if hasattr(mask, 'reindex_axis'):
661-
mask = mask.values.T
663+
mask = mask.values
662664

663665
# if we are passed a scalar None, convert it here
664666
if not is_list_like(new) and isnull(new) and not self.is_object:
665667
new = self.fill_value
666668

667669
if self._can_hold_element(new):
670+
if transpose:
671+
new_values = new_values.T
672+
668673
new = self._try_cast(new)
669674

670-
# pseudo-broadcast
671-
if isinstance(new, np.ndarray) and new.ndim == self.ndim - 1:
672-
new = np.repeat(new, self.shape[-1]).reshape(self.shape)
675+
# If the default repeat behavior in np.putmask would go in the wrong
676+
# direction, then explictly repeat and reshape new instead
677+
if getattr(new, 'ndim', 0) >= 1:
678+
if self.ndim - 1 == new.ndim and axis == 1:
679+
new = np.repeat(new, self.shape[-1]).reshape(self.shape)
673680

674681
np.putmask(new_values, mask, new)
675682

676683
# maybe upcast me
677684
elif mask.any():
685+
if transpose:
686+
mask = mask.T
687+
if isinstance(new, np.ndarray):
688+
new = new.T
689+
axis = new_values.ndim - axis - 1
690+
691+
# Pseudo-broadcast
692+
if getattr(new, 'ndim', 0) >= 1:
693+
if self.ndim - 1 == new.ndim:
694+
new_shape = list(new.shape)
695+
new_shape.insert(axis, 1)
696+
new = new.reshape(tuple(new_shape))
678697

679698
# need to go column by column
680699
new_blocks = []
@@ -685,14 +704,15 @@ def putmask(self, mask, new, align=True, inplace=False):
685704

686705
# need a new block
687706
if m.any():
688-
689-
n = new[i] if isinstance(
690-
new, np.ndarray) else np.array(new)
707+
if isinstance(new, np.ndarray):
708+
n = np.squeeze(new[i % new.shape[0]])
709+
else:
710+
n = np.array(new)
691711

692712
# type of the new block
693713
dtype, _ = com._maybe_promote(n.dtype)
694714

695-
# we need to exiplicty astype here to make a copy
715+
# we need to explicitly astype here to make a copy
696716
n = n.astype(dtype)
697717

698718
nv = _putmask_smart(v, m, n)
@@ -718,8 +738,10 @@ def putmask(self, mask, new, align=True, inplace=False):
718738
if inplace:
719739
return [self]
720740

721-
return [make_block(new_values,
722-
placement=self.mgr_locs, fastpath=True)]
741+
if transpose:
742+
new_values = new_values.T
743+
744+
return [make_block(new_values, placement=self.mgr_locs, fastpath=True)]
723745

724746
def interpolate(self, method='pad', axis=0, index=None,
725747
values=None, inplace=False, limit=None,
@@ -1003,7 +1025,7 @@ def handle_error():
10031025
fastpath=True, placement=self.mgr_locs)]
10041026

10051027
def where(self, other, cond, align=True, raise_on_error=True,
1006-
try_cast=False):
1028+
try_cast=False, axis=0, transpose=False):
10071029
"""
10081030
evaluate the block; return result block(s) from the result
10091031
@@ -1014,50 +1036,33 @@ def where(self, other, cond, align=True, raise_on_error=True,
10141036
align : boolean, perform alignment on other/cond
10151037
raise_on_error : if True, raise when I can't perform the function,
10161038
False by default (and just return the data that we had coming in)
1039+
axis : int
1040+
transpose : boolean
1041+
Set to True if self is stored with axes reversed
10171042
10181043
Returns
10191044
-------
10201045
a new block(s), the result of the func
10211046
"""
10221047

10231048
values = self.values
1049+
if transpose:
1050+
values = values.T
10241051

1025-
# see if we can align other
10261052
if hasattr(other, 'reindex_axis'):
10271053
other = other.values
10281054

1029-
# make sure that we can broadcast
1030-
is_transposed = False
1031-
if hasattr(other, 'ndim') and hasattr(values, 'ndim'):
1032-
if values.ndim != other.ndim or values.shape == other.shape[::-1]:
1033-
1034-
# if its symmetric are ok, no reshaping needed (GH 7506)
1035-
if (values.shape[0] == np.array(values.shape)).all():
1036-
pass
1037-
1038-
# pseodo broadcast (its a 2d vs 1d say and where needs it in a
1039-
# specific direction)
1040-
elif (other.ndim >= 1 and values.ndim - 1 == other.ndim and
1041-
values.shape[0] != other.shape[0]):
1042-
other = _block_shape(other).T
1043-
else:
1044-
values = values.T
1045-
is_transposed = True
1046-
1047-
# see if we can align cond
1048-
if not hasattr(cond, 'shape'):
1049-
raise ValueError(
1050-
"where must have a condition that is ndarray like")
1051-
10521055
if hasattr(cond, 'reindex_axis'):
10531056
cond = cond.values
10541057

1055-
# may need to undo transpose of values
1056-
if hasattr(values, 'ndim'):
1057-
if values.ndim != cond.ndim or values.shape == cond.shape[::-1]:
1058+
# If the default broadcasting would go in the wrong direction, then
1059+
# explictly reshape other instead
1060+
if getattr(other, 'ndim', 0) >= 1:
1061+
if values.ndim - 1 == other.ndim and axis == 1:
1062+
other = other.reshape(tuple(other.shape + (1,)))
10581063

1059-
values = values.T
1060-
is_transposed = not is_transposed
1064+
if not hasattr(cond, 'shape'):
1065+
raise ValueError("where must have a condition that is ndarray like")
10611066

10621067
other = _maybe_convert_string_to_object(other)
10631068

@@ -1090,15 +1095,14 @@ def func(c, v, o):
10901095
raise TypeError('Could not compare [%s] with block values'
10911096
% repr(other))
10921097

1093-
if is_transposed:
1098+
if transpose:
10941099
result = result.T
10951100

10961101
# try to cast if requested
10971102
if try_cast:
10981103
result = self._try_cast_result(result)
10991104

1100-
return make_block(result,
1101-
ndim=self.ndim, placement=self.mgr_locs)
1105+
return make_block(result, ndim=self.ndim, placement=self.mgr_locs)
11021106

11031107
# might need to separate out blocks
11041108
axis = cond.ndim - 1
@@ -1730,7 +1734,8 @@ def take_nd(self, indexer, axis=0, new_mgr_locs=None, fill_tuple=None):
17301734

17311735
return self.make_block_same_class(new_values, new_mgr_locs)
17321736

1733-
def putmask(self, mask, new, align=True, inplace=False):
1737+
def putmask(self, mask, new, align=True, inplace=False,
1738+
axis=0, transpose=False):
17341739
""" putmask the data to the block; it is possible that we may create a
17351740
new dtype of block
17361741

pandas/tests/test_frame.py

+46
Original file line numberDiff line numberDiff line change
@@ -9874,6 +9874,52 @@ def test_where_complex(self):
98749874
df[df.abs() >= 5] = np.nan
98759875
assert_frame_equal(df,expected)
98769876

9877+
def test_where_axis(self):
9878+
# GH 9736
9879+
df = DataFrame(np.random.randn(2, 2))
9880+
mask = DataFrame([[False, False], [False, False]])
9881+
s = Series([0, 1])
9882+
9883+
expected = DataFrame([[0, 0], [1, 1]], dtype='float64')
9884+
result = df.where(mask, s, axis='index')
9885+
assert_frame_equal(result, expected)
9886+
9887+
result = df.copy()
9888+
result.where(mask, s, axis='index', inplace=True)
9889+
assert_frame_equal(result, expected)
9890+
9891+
expected = DataFrame([[0, 1], [0, 1]], dtype='float64')
9892+
result = df.where(mask, s, axis='columns')
9893+
assert_frame_equal(result, expected)
9894+
9895+
result = df.copy()
9896+
result.where(mask, s, axis='columns', inplace=True)
9897+
assert_frame_equal(result, expected)
9898+
9899+
# Upcast needed
9900+
df = DataFrame([[1, 2], [3, 4]], dtype='int64')
9901+
mask = DataFrame([[False, False], [False, False]])
9902+
s = Series([0, np.nan])
9903+
9904+
expected = DataFrame([[0, 0], [np.nan, np.nan]], dtype='float64')
9905+
result = df.where(mask, s, axis='index')
9906+
assert_frame_equal(result, expected)
9907+
9908+
result = df.copy()
9909+
result.where(mask, s, axis='index', inplace=True)
9910+
assert_frame_equal(result, expected)
9911+
9912+
expected = DataFrame([[0, np.nan], [0, np.nan]], dtype='float64')
9913+
result = df.where(mask, s, axis='columns')
9914+
assert_frame_equal(result, expected)
9915+
9916+
expected = DataFrame({0 : np.array([0, 0], dtype='int64'),
9917+
1 : np.array([np.nan, np.nan], dtype='float64')})
9918+
result = df.copy()
9919+
result.where(mask, s, axis='columns', inplace=True)
9920+
assert_frame_equal(result, expected)
9921+
9922+
98779923
def test_mask(self):
98789924
df = DataFrame(np.random.randn(5, 3))
98799925
cond = df > 0

0 commit comments

Comments
 (0)