Skip to content

Commit ef10baf

Browse files
committed
BUG: DataFrame.where does not respect axis parameter when shape is symmetric (GH #9736)
1 parent 9e4e447 commit ef10baf

File tree

4 files changed

+116
-53
lines changed

4 files changed

+116
-53
lines changed

doc/source/whatsnew/v0.16.1.txt

+2
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,5 @@ Bug Fixes
9696

9797
- Fixed bug where ``DataFrame.plot()`` raised an error when both ``color`` and ``style`` keywords were passed and there was no color symbol in the style strings (:issue:`9671`)
9898
- Bug in ``read_csv`` and ``read_table`` when using ``skip_rows`` parameter if blank lines are present. (:issue:`9832`)
99+
100+
- Bug causing ``DataFrame.where`` to not respect the ``axis`` parameter when the frame has a symmetric shape. (:issue:`9736`)

pandas/core/generic.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -3393,19 +3393,29 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
33933393
else:
33943394
other = self._constructor(other, **self._construct_axes_dict())
33953395

3396+
if axis is None:
3397+
axis = 0
3398+
align = True
3399+
else:
3400+
align = False
3401+
3402+
block_axis = self._get_block_manager_axis(axis)
3403+
33963404
if inplace:
33973405
# we may have different type blocks come out of putmask, so
33983406
# reconstruct the block manager
33993407

34003408
self._check_inplace_setting(other)
3401-
new_data = self._data.putmask(mask=cond, new=other, align=axis is None,
3402-
inplace=True)
3409+
new_data = self._data.putmask(mask=cond, new=other, align=align,
3410+
inplace=True, axis=block_axis,
3411+
transpose=self._AXIS_REVERSED)
34033412
self._update_inplace(new_data)
34043413

34053414
else:
3406-
new_data = self._data.where(other=other, cond=cond, align=axis is None,
3415+
new_data = self._data.where(other=other, cond=cond, align=align,
34073416
raise_on_error=raise_on_error,
3408-
try_cast=try_cast)
3417+
try_cast=try_cast, axis=block_axis,
3418+
transpose=self._AXIS_REVERSED)
34093419

34103420
return self._constructor(new_data).__finalize__(self)
34113421

pandas/core/internals.py

+54-49
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,8 @@ def _is_empty_indexer(indexer):
626626

627627
return [self]
628628

629-
def putmask(self, mask, new, align=True, inplace=False):
629+
def putmask(self, mask, new, align=True, inplace=False,
630+
axis=0, transpose=False):
630631
""" putmask the data to the block; it is possible that we may create a
631632
new dtype of block
632633
@@ -638,37 +639,55 @@ def putmask(self, mask, new, align=True, inplace=False):
638639
new : a ndarray/object
639640
align : boolean, perform alignment on other/cond, default is True
640641
inplace : perform inplace modification, default is False
642+
axis : int
643+
transpose : boolean
644+
Set to True if self is stored with axes reversed
641645
642646
Returns
643647
-------
644-
a new block(s), the result of the putmask
648+
a list of new blocks, the result of the putmask
645649
"""
646650

647651
new_values = self.values if inplace else self.values.copy()
648652

649-
# may need to align the new
650653
if hasattr(new, 'reindex_axis'):
651-
new = new.values.T
654+
new = new.values
652655

653-
# may need to align the mask
654656
if hasattr(mask, 'reindex_axis'):
655-
mask = mask.values.T
657+
mask = mask.values
656658

657659
# if we are passed a scalar None, convert it here
658660
if not is_list_like(new) and isnull(new) and not self.is_object:
659661
new = self.fill_value
660662

661663
if self._can_hold_element(new):
664+
if transpose:
665+
new_values = new_values.T
666+
662667
new = self._try_cast(new)
663668

664-
# pseudo-broadcast
665-
if isinstance(new, np.ndarray) and new.ndim == self.ndim - 1:
666-
new = np.repeat(new, self.shape[-1]).reshape(self.shape)
669+
# If the default repeat behavior in np.putmask would go in the wrong
670+
# direction, then explictly repeat and reshape new instead
671+
if getattr(new, 'ndim', 0) >= 1:
672+
if self.ndim - 1 == new.ndim and axis == 1:
673+
new = np.repeat(new, self.shape[-1]).reshape(self.shape)
667674

668675
np.putmask(new_values, mask, new)
669676

670677
# maybe upcast me
671678
elif mask.any():
679+
if transpose:
680+
mask = mask.T
681+
if isinstance(new, np.ndarray):
682+
new = new.T
683+
axis = new_values.ndim - axis - 1
684+
685+
# Pseudo-broadcast
686+
if getattr(new, 'ndim', 0) >= 1:
687+
if self.ndim - 1 == new.ndim:
688+
new_shape = list(new.shape)
689+
new_shape.insert(axis, 1)
690+
new = new.reshape(tuple(new_shape))
672691

673692
# need to go column by column
674693
new_blocks = []
@@ -679,14 +698,15 @@ def putmask(self, mask, new, align=True, inplace=False):
679698

680699
# need a new block
681700
if m.any():
682-
683-
n = new[i] if isinstance(
684-
new, np.ndarray) else np.array(new)
701+
if isinstance(new, np.ndarray):
702+
n = np.squeeze(new[i % new.shape[0]])
703+
else:
704+
n = np.array(new)
685705

686706
# type of the new block
687707
dtype, _ = com._maybe_promote(n.dtype)
688708

689-
# we need to exiplicty astype here to make a copy
709+
# we need to explicitly astype here to make a copy
690710
n = n.astype(dtype)
691711

692712
nv = _putmask_smart(v, m, n)
@@ -712,8 +732,10 @@ def putmask(self, mask, new, align=True, inplace=False):
712732
if inplace:
713733
return [self]
714734

715-
return [make_block(new_values,
716-
placement=self.mgr_locs, fastpath=True)]
735+
if transpose:
736+
new_values = new_values.T
737+
738+
return [make_block(new_values, placement=self.mgr_locs, fastpath=True)]
717739

718740
def interpolate(self, method='pad', axis=0, index=None,
719741
values=None, inplace=False, limit=None,
@@ -997,7 +1019,7 @@ def handle_error():
9971019
fastpath=True, placement=self.mgr_locs)]
9981020

9991021
def where(self, other, cond, align=True, raise_on_error=True,
1000-
try_cast=False):
1022+
try_cast=False, axis=0, transpose=False):
10011023
"""
10021024
evaluate the block; return result block(s) from the result
10031025
@@ -1008,50 +1030,33 @@ def where(self, other, cond, align=True, raise_on_error=True,
10081030
align : boolean, perform alignment on other/cond
10091031
raise_on_error : if True, raise when I can't perform the function,
10101032
False by default (and just return the data that we had coming in)
1033+
axis : int
1034+
transpose : boolean
1035+
Set to True if self is stored with axes reversed
10111036
10121037
Returns
10131038
-------
10141039
a new block(s), the result of the func
10151040
"""
10161041

10171042
values = self.values
1043+
if transpose:
1044+
values = values.T
10181045

1019-
# see if we can align other
10201046
if hasattr(other, 'reindex_axis'):
10211047
other = other.values
10221048

1023-
# make sure that we can broadcast
1024-
is_transposed = False
1025-
if hasattr(other, 'ndim') and hasattr(values, 'ndim'):
1026-
if values.ndim != other.ndim or values.shape == other.shape[::-1]:
1027-
1028-
# if its symmetric are ok, no reshaping needed (GH 7506)
1029-
if (values.shape[0] == np.array(values.shape)).all():
1030-
pass
1031-
1032-
# pseodo broadcast (its a 2d vs 1d say and where needs it in a
1033-
# specific direction)
1034-
elif (other.ndim >= 1 and values.ndim - 1 == other.ndim and
1035-
values.shape[0] != other.shape[0]):
1036-
other = _block_shape(other).T
1037-
else:
1038-
values = values.T
1039-
is_transposed = True
1040-
1041-
# see if we can align cond
1042-
if not hasattr(cond, 'shape'):
1043-
raise ValueError(
1044-
"where must have a condition that is ndarray like")
1045-
10461049
if hasattr(cond, 'reindex_axis'):
10471050
cond = cond.values
10481051

1049-
# may need to undo transpose of values
1050-
if hasattr(values, 'ndim'):
1051-
if values.ndim != cond.ndim or values.shape == cond.shape[::-1]:
1052+
# If the default broadcasting would go in the wrong direction, then
1053+
# explictly reshape other instead
1054+
if getattr(other, 'ndim', 0) >= 1:
1055+
if values.ndim - 1 == other.ndim and axis == 1:
1056+
other = other.reshape(tuple(other.shape + (1,)))
10521057

1053-
values = values.T
1054-
is_transposed = not is_transposed
1058+
if not hasattr(cond, 'shape'):
1059+
raise ValueError("where must have a condition that is ndarray like")
10551060

10561061
other = _maybe_convert_string_to_object(other)
10571062

@@ -1084,15 +1089,14 @@ def func(c, v, o):
10841089
raise TypeError('Could not compare [%s] with block values'
10851090
% repr(other))
10861091

1087-
if is_transposed:
1092+
if transpose:
10881093
result = result.T
10891094

10901095
# try to cast if requested
10911096
if try_cast:
10921097
result = self._try_cast_result(result)
10931098

1094-
return make_block(result,
1095-
ndim=self.ndim, placement=self.mgr_locs)
1099+
return make_block(result, ndim=self.ndim, placement=self.mgr_locs)
10961100

10971101
# might need to separate out blocks
10981102
axis = cond.ndim - 1
@@ -1721,7 +1725,8 @@ def take_nd(self, indexer, axis=0, new_mgr_locs=None, fill_tuple=None):
17211725

17221726
return self.make_block_same_class(new_values, new_mgr_locs)
17231727

1724-
def putmask(self, mask, new, align=True, inplace=False):
1728+
def putmask(self, mask, new, align=True, inplace=False,
1729+
axis=0, transpose=False):
17251730
""" putmask the data to the block; it is possible that we may create a
17261731
new dtype of block
17271732

pandas/tests/test_frame.py

+46
Original file line numberDiff line numberDiff line change
@@ -9814,6 +9814,52 @@ def test_where_complex(self):
98149814
df[df.abs() >= 5] = np.nan
98159815
assert_frame_equal(df,expected)
98169816

9817+
def test_where_axis(self):
9818+
# GH 9736
9819+
df = DataFrame(np.random.randn(2, 2))
9820+
mask = DataFrame([[False, False], [False, False]])
9821+
s = Series([0, 1])
9822+
9823+
expected = DataFrame([[0, 0], [1, 1]], dtype='float64')
9824+
result = df.where(mask, s, axis='index')
9825+
assert_frame_equal(result, expected)
9826+
9827+
result = df.copy()
9828+
result.where(mask, s, axis='index', inplace=True)
9829+
assert_frame_equal(result, expected)
9830+
9831+
expected = DataFrame([[0, 1], [0, 1]], dtype='float64')
9832+
result = df.where(mask, s, axis='columns')
9833+
assert_frame_equal(result, expected)
9834+
9835+
result = df.copy()
9836+
result.where(mask, s, axis='columns', inplace=True)
9837+
assert_frame_equal(result, expected)
9838+
9839+
# Upcast needed
9840+
df = DataFrame([[1, 2], [3, 4]], dtype='int64')
9841+
mask = DataFrame([[False, False], [False, False]])
9842+
s = Series([0, np.nan])
9843+
9844+
expected = DataFrame([[0, 0], [np.nan, np.nan]], dtype='float64')
9845+
result = df.where(mask, s, axis='index')
9846+
assert_frame_equal(result, expected)
9847+
9848+
result = df.copy()
9849+
result.where(mask, s, axis='index', inplace=True)
9850+
assert_frame_equal(result, expected)
9851+
9852+
expected = DataFrame([[0, np.nan], [0, np.nan]], dtype='float64')
9853+
result = df.where(mask, s, axis='columns')
9854+
assert_frame_equal(result, expected)
9855+
9856+
expected = DataFrame({0 : np.array([0, 0], dtype='int64'),
9857+
1 : np.array([np.nan, np.nan], dtype='float64')})
9858+
result = df.copy()
9859+
result.where(mask, s, axis='columns', inplace=True)
9860+
assert_frame_equal(result, expected)
9861+
9862+
98179863
def test_mask(self):
98189864
df = DataFrame(np.random.randn(5, 3))
98199865
cond = df > 0

0 commit comments

Comments
 (0)