Skip to content

Commit 1d70ea3

Browse files
committed
BUG: DataFrame.where does not respect axis parameter when shape is symmetric (GH #9736)
1 parent 529cd3d commit 1d70ea3

File tree

4 files changed

+117
-53
lines changed

4 files changed

+117
-53
lines changed

doc/source/whatsnew/v0.16.1.txt

+3
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,6 @@ Bug Fixes
161161
- Changed caching in ``AbstractHolidayCalendar`` to be at the instance level rather than at the class level as the latter can result in unexpected behaviour. (:issue:`9552`)
162162

163163
- Fixed latex output for multi-indexed dataframes (:issue:`9778`)
164+
165+
- Bug causing ``DataFrame.where`` to not respect the ``axis`` parameter when the frame has a symmetric shape. (:issue:`9736`)
166+

pandas/core/generic.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -3397,19 +3397,29 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
33973397
else:
33983398
other = self._constructor(other, **self._construct_axes_dict())
33993399

3400+
if axis is None:
3401+
axis = 0
3402+
align = True
3403+
else:
3404+
align = False
3405+
3406+
block_axis = self._get_block_manager_axis(axis)
3407+
34003408
if inplace:
34013409
# we may have different type blocks come out of putmask, so
34023410
# reconstruct the block manager
34033411

34043412
self._check_inplace_setting(other)
3405-
new_data = self._data.putmask(mask=cond, new=other, align=axis is None,
3406-
inplace=True)
3413+
new_data = self._data.putmask(mask=cond, new=other, align=align,
3414+
inplace=True, axis=block_axis,
3415+
transpose=self._AXIS_REVERSED)
34073416
self._update_inplace(new_data)
34083417

34093418
else:
3410-
new_data = self._data.where(other=other, cond=cond, align=axis is None,
3419+
new_data = self._data.where(other=other, cond=cond, align=align,
34113420
raise_on_error=raise_on_error,
3412-
try_cast=try_cast)
3421+
try_cast=try_cast, axis=block_axis,
3422+
transpose=self._AXIS_REVERSED)
34133423

34143424
return self._constructor(new_data).__finalize__(self)
34153425

pandas/core/internals.py

+54-49
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,8 @@ def _is_empty_indexer(indexer):
627627

628628
return [self]
629629

630-
def putmask(self, mask, new, align=True, inplace=False):
630+
def putmask(self, mask, new, align=True, inplace=False,
631+
axis=0, transpose=False):
631632
""" putmask the data to the block; it is possible that we may create a
632633
new dtype of block
633634
@@ -639,37 +640,55 @@ def putmask(self, mask, new, align=True, inplace=False):
639640
new : a ndarray/object
640641
align : boolean, perform alignment on other/cond, default is True
641642
inplace : perform inplace modification, default is False
643+
axis : int
644+
transpose : boolean
645+
Set to True if self is stored with axes reversed
642646
643647
Returns
644648
-------
645-
a new block(s), the result of the putmask
649+
a list of new blocks, the result of the putmask
646650
"""
647651

648652
new_values = self.values if inplace else self.values.copy()
649653

650-
# may need to align the new
651654
if hasattr(new, 'reindex_axis'):
652-
new = new.values.T
655+
new = new.values
653656

654-
# may need to align the mask
655657
if hasattr(mask, 'reindex_axis'):
656-
mask = mask.values.T
658+
mask = mask.values
657659

658660
# if we are passed a scalar None, convert it here
659661
if not is_list_like(new) and isnull(new) and not self.is_object:
660662
new = self.fill_value
661663

662664
if self._can_hold_element(new):
665+
if transpose:
666+
new_values = new_values.T
667+
663668
new = self._try_cast(new)
664669

665-
# pseudo-broadcast
666-
if isinstance(new, np.ndarray) and new.ndim == self.ndim - 1:
667-
new = np.repeat(new, self.shape[-1]).reshape(self.shape)
670+
# If the default repeat behavior in np.putmask would go in the wrong
671+
# direction, then explictly repeat and reshape new instead
672+
if getattr(new, 'ndim', 0) >= 1:
673+
if self.ndim - 1 == new.ndim and axis == 1:
674+
new = np.repeat(new, self.shape[-1]).reshape(self.shape)
668675

669676
np.putmask(new_values, mask, new)
670677

671678
# maybe upcast me
672679
elif mask.any():
680+
if transpose:
681+
mask = mask.T
682+
if isinstance(new, np.ndarray):
683+
new = new.T
684+
axis = new_values.ndim - axis - 1
685+
686+
# Pseudo-broadcast
687+
if getattr(new, 'ndim', 0) >= 1:
688+
if self.ndim - 1 == new.ndim:
689+
new_shape = list(new.shape)
690+
new_shape.insert(axis, 1)
691+
new = new.reshape(tuple(new_shape))
673692

674693
# need to go column by column
675694
new_blocks = []
@@ -680,14 +699,15 @@ def putmask(self, mask, new, align=True, inplace=False):
680699

681700
# need a new block
682701
if m.any():
683-
684-
n = new[i] if isinstance(
685-
new, np.ndarray) else np.array(new)
702+
if isinstance(new, np.ndarray):
703+
n = np.squeeze(new[i % new.shape[0]])
704+
else:
705+
n = np.array(new)
686706

687707
# type of the new block
688708
dtype, _ = com._maybe_promote(n.dtype)
689709

690-
# we need to exiplicty astype here to make a copy
710+
# we need to explicitly astype here to make a copy
691711
n = n.astype(dtype)
692712

693713
nv = _putmask_smart(v, m, n)
@@ -713,8 +733,10 @@ def putmask(self, mask, new, align=True, inplace=False):
713733
if inplace:
714734
return [self]
715735

716-
return [make_block(new_values,
717-
placement=self.mgr_locs, fastpath=True)]
736+
if transpose:
737+
new_values = new_values.T
738+
739+
return [make_block(new_values, placement=self.mgr_locs, fastpath=True)]
718740

719741
def interpolate(self, method='pad', axis=0, index=None,
720742
values=None, inplace=False, limit=None,
@@ -998,7 +1020,7 @@ def handle_error():
9981020
fastpath=True, placement=self.mgr_locs)]
9991021

10001022
def where(self, other, cond, align=True, raise_on_error=True,
1001-
try_cast=False):
1023+
try_cast=False, axis=0, transpose=False):
10021024
"""
10031025
evaluate the block; return result block(s) from the result
10041026
@@ -1009,50 +1031,33 @@ def where(self, other, cond, align=True, raise_on_error=True,
10091031
align : boolean, perform alignment on other/cond
10101032
raise_on_error : if True, raise when I can't perform the function,
10111033
False by default (and just return the data that we had coming in)
1034+
axis : int
1035+
transpose : boolean
1036+
Set to True if self is stored with axes reversed
10121037
10131038
Returns
10141039
-------
10151040
a new block(s), the result of the func
10161041
"""
10171042

10181043
values = self.values
1044+
if transpose:
1045+
values = values.T
10191046

1020-
# see if we can align other
10211047
if hasattr(other, 'reindex_axis'):
10221048
other = other.values
10231049

1024-
# make sure that we can broadcast
1025-
is_transposed = False
1026-
if hasattr(other, 'ndim') and hasattr(values, 'ndim'):
1027-
if values.ndim != other.ndim or values.shape == other.shape[::-1]:
1028-
1029-
# if its symmetric are ok, no reshaping needed (GH 7506)
1030-
if (values.shape[0] == np.array(values.shape)).all():
1031-
pass
1032-
1033-
# pseodo broadcast (its a 2d vs 1d say and where needs it in a
1034-
# specific direction)
1035-
elif (other.ndim >= 1 and values.ndim - 1 == other.ndim and
1036-
values.shape[0] != other.shape[0]):
1037-
other = _block_shape(other).T
1038-
else:
1039-
values = values.T
1040-
is_transposed = True
1041-
1042-
# see if we can align cond
1043-
if not hasattr(cond, 'shape'):
1044-
raise ValueError(
1045-
"where must have a condition that is ndarray like")
1046-
10471050
if hasattr(cond, 'reindex_axis'):
10481051
cond = cond.values
10491052

1050-
# may need to undo transpose of values
1051-
if hasattr(values, 'ndim'):
1052-
if values.ndim != cond.ndim or values.shape == cond.shape[::-1]:
1053+
# If the default broadcasting would go in the wrong direction, then
1054+
# explictly reshape other instead
1055+
if getattr(other, 'ndim', 0) >= 1:
1056+
if values.ndim - 1 == other.ndim and axis == 1:
1057+
other = other.reshape(tuple(other.shape + (1,)))
10531058

1054-
values = values.T
1055-
is_transposed = not is_transposed
1059+
if not hasattr(cond, 'shape'):
1060+
raise ValueError("where must have a condition that is ndarray like")
10561061

10571062
other = _maybe_convert_string_to_object(other)
10581063

@@ -1085,15 +1090,14 @@ def func(c, v, o):
10851090
raise TypeError('Could not compare [%s] with block values'
10861091
% repr(other))
10871092

1088-
if is_transposed:
1093+
if transpose:
10891094
result = result.T
10901095

10911096
# try to cast if requested
10921097
if try_cast:
10931098
result = self._try_cast_result(result)
10941099

1095-
return make_block(result,
1096-
ndim=self.ndim, placement=self.mgr_locs)
1100+
return make_block(result, ndim=self.ndim, placement=self.mgr_locs)
10971101

10981102
# might need to separate out blocks
10991103
axis = cond.ndim - 1
@@ -1723,7 +1727,8 @@ def take_nd(self, indexer, axis=0, new_mgr_locs=None, fill_tuple=None):
17231727

17241728
return self.make_block_same_class(new_values, new_mgr_locs)
17251729

1726-
def putmask(self, mask, new, align=True, inplace=False):
1730+
def putmask(self, mask, new, align=True, inplace=False,
1731+
axis=0, transpose=False):
17271732
""" putmask the data to the block; it is possible that we may create a
17281733
new dtype of block
17291734

pandas/tests/test_frame.py

+46
Original file line numberDiff line numberDiff line change
@@ -9838,6 +9838,52 @@ def test_where_complex(self):
98389838
df[df.abs() >= 5] = np.nan
98399839
assert_frame_equal(df,expected)
98409840

9841+
def test_where_axis(self):
9842+
# GH 9736
9843+
df = DataFrame(np.random.randn(2, 2))
9844+
mask = DataFrame([[False, False], [False, False]])
9845+
s = Series([0, 1])
9846+
9847+
expected = DataFrame([[0, 0], [1, 1]], dtype='float64')
9848+
result = df.where(mask, s, axis='index')
9849+
assert_frame_equal(result, expected)
9850+
9851+
result = df.copy()
9852+
result.where(mask, s, axis='index', inplace=True)
9853+
assert_frame_equal(result, expected)
9854+
9855+
expected = DataFrame([[0, 1], [0, 1]], dtype='float64')
9856+
result = df.where(mask, s, axis='columns')
9857+
assert_frame_equal(result, expected)
9858+
9859+
result = df.copy()
9860+
result.where(mask, s, axis='columns', inplace=True)
9861+
assert_frame_equal(result, expected)
9862+
9863+
# Upcast needed
9864+
df = DataFrame([[1, 2], [3, 4]], dtype='int64')
9865+
mask = DataFrame([[False, False], [False, False]])
9866+
s = Series([0, np.nan])
9867+
9868+
expected = DataFrame([[0, 0], [np.nan, np.nan]], dtype='float64')
9869+
result = df.where(mask, s, axis='index')
9870+
assert_frame_equal(result, expected)
9871+
9872+
result = df.copy()
9873+
result.where(mask, s, axis='index', inplace=True)
9874+
assert_frame_equal(result, expected)
9875+
9876+
expected = DataFrame([[0, np.nan], [0, np.nan]], dtype='float64')
9877+
result = df.where(mask, s, axis='columns')
9878+
assert_frame_equal(result, expected)
9879+
9880+
expected = DataFrame({0 : np.array([0, 0], dtype='int64'),
9881+
1 : np.array([np.nan, np.nan], dtype='float64')})
9882+
result = df.copy()
9883+
result.where(mask, s, axis='columns', inplace=True)
9884+
assert_frame_equal(result, expected)
9885+
9886+
98419887
def test_mask(self):
98429888
df = DataFrame(np.random.randn(5, 3))
98439889
cond = df > 0

0 commit comments

Comments
 (0)