Skip to content

Commit e4866ac

Browse files
committed
BUG: DataFrame.where does not respect axis parameter when shape is symmetric (GH pandas-dev#9736)
1 parent 7516ec7 commit e4866ac

File tree

4 files changed

+186
-57
lines changed

4 files changed

+186
-57
lines changed

doc/source/whatsnew/v0.16.2.txt

+2
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ Bug Fixes
6868
- Bug in getting timezone data with ``dateutil`` on various platforms ( :issue:`9059`, :issue:`8639`, :issue:`9663`, :issue:`10121`)
6969
- Bug in display datetimes with mixed frequencies uniformly; display 'ms' datetimes to the proper precision. (:issue:`10170`)
7070

71+
- Bug causing ``DataFrame.where`` to not respect the ``axis`` parameter when the frame has a symmetric shape. (:issue:`9736`)
72+
7173
- Bung in ``Series`` arithmetic methods may incorrectly hold names (:issue:`10068`)
7274

7375
- Bug in ``DatetimeIndex`` and ``TimedeltaIndex`` names are lost after timedelta arithmetics ( :issue:`9926`)

pandas/core/generic.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -3557,19 +3557,31 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
35573557
else:
35583558
other = self._constructor(other, **self._construct_axes_dict())
35593559

3560+
if axis is None:
3561+
axis = 0
3562+
3563+
if self.ndim == getattr(other, 'ndim', 0):
3564+
align = True
3565+
else:
3566+
align = (self._get_axis_number(axis) == 1)
3567+
3568+
block_axis = self._get_block_manager_axis(axis)
3569+
35603570
if inplace:
35613571
# we may have different type blocks come out of putmask, so
35623572
# reconstruct the block manager
35633573

35643574
self._check_inplace_setting(other)
3565-
new_data = self._data.putmask(mask=cond, new=other, align=axis is None,
3566-
inplace=True)
3575+
new_data = self._data.putmask(mask=cond, new=other, align=align,
3576+
inplace=True, axis=block_axis,
3577+
transpose=self._AXIS_REVERSED)
35673578
self._update_inplace(new_data)
35683579

35693580
else:
3570-
new_data = self._data.where(other=other, cond=cond, align=axis is None,
3581+
new_data = self._data.where(other=other, cond=cond, align=align,
35713582
raise_on_error=raise_on_error,
3572-
try_cast=try_cast)
3583+
try_cast=try_cast, axis=block_axis,
3584+
transpose=self._AXIS_REVERSED)
35733585

35743586
return self._constructor(new_data).__finalize__(self)
35753587

pandas/core/internals.py

+64-53
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,8 @@ def _is_empty_indexer(indexer):
632632

633633
return [self]
634634

635-
def putmask(self, mask, new, align=True, inplace=False):
635+
def putmask(self, mask, new, align=True, inplace=False,
636+
axis=0, transpose=False):
636637
""" putmask the data to the block; it is possible that we may create a
637638
new dtype of block
638639
@@ -644,37 +645,55 @@ def putmask(self, mask, new, align=True, inplace=False):
644645
new : a ndarray/object
645646
align : boolean, perform alignment on other/cond, default is True
646647
inplace : perform inplace modification, default is False
648+
axis : int
649+
transpose : boolean
650+
Set to True if self is stored with axes reversed
647651
648652
Returns
649653
-------
650-
a new block(s), the result of the putmask
654+
a list of new blocks, the result of the putmask
651655
"""
652656

653657
new_values = self.values if inplace else self.values.copy()
654658

655-
# may need to align the new
656659
if hasattr(new, 'reindex_axis'):
657-
new = new.values.T
660+
new = new.values
658661

659-
# may need to align the mask
660662
if hasattr(mask, 'reindex_axis'):
661-
mask = mask.values.T
663+
mask = mask.values
662664

663665
# if we are passed a scalar None, convert it here
664666
if not is_list_like(new) and isnull(new) and not self.is_object:
665667
new = self.fill_value
666668

667669
if self._can_hold_element(new):
670+
if transpose:
671+
new_values = new_values.T
672+
668673
new = self._try_cast(new)
669674

670-
# pseudo-broadcast
671-
if isinstance(new, np.ndarray) and new.ndim == self.ndim - 1:
672-
new = np.repeat(new, self.shape[-1]).reshape(self.shape)
675+
# If the default repeat behavior in np.putmask would go in the wrong
676+
# direction, then explictly repeat and reshape new instead
677+
if getattr(new, 'ndim', 0) >= 1:
678+
if self.ndim - 1 == new.ndim and axis == 1:
679+
new = np.repeat(new, new_values.shape[-1]).reshape(self.shape)
673680

674681
np.putmask(new_values, mask, new)
675682

676683
# maybe upcast me
677684
elif mask.any():
685+
if transpose:
686+
mask = mask.T
687+
if isinstance(new, np.ndarray):
688+
new = new.T
689+
axis = new_values.ndim - axis - 1
690+
691+
# Pseudo-broadcast
692+
if getattr(new, 'ndim', 0) >= 1:
693+
if self.ndim - 1 == new.ndim:
694+
new_shape = list(new.shape)
695+
new_shape.insert(axis, 1)
696+
new = new.reshape(tuple(new_shape))
678697

679698
# need to go column by column
680699
new_blocks = []
@@ -685,14 +704,15 @@ def putmask(self, mask, new, align=True, inplace=False):
685704

686705
# need a new block
687706
if m.any():
688-
689-
n = new[i] if isinstance(
690-
new, np.ndarray) else np.array(new)
707+
if isinstance(new, np.ndarray):
708+
n = np.squeeze(new[i % new.shape[0]])
709+
else:
710+
n = np.array(new)
691711

692712
# type of the new block
693713
dtype, _ = com._maybe_promote(n.dtype)
694714

695-
# we need to exiplicty astype here to make a copy
715+
# we need to explicitly astype here to make a copy
696716
n = n.astype(dtype)
697717

698718
nv = _putmask_smart(v, m, n)
@@ -718,8 +738,10 @@ def putmask(self, mask, new, align=True, inplace=False):
718738
if inplace:
719739
return [self]
720740

721-
return [make_block(new_values,
722-
placement=self.mgr_locs, fastpath=True)]
741+
if transpose:
742+
new_values = new_values.T
743+
744+
return [make_block(new_values, placement=self.mgr_locs, fastpath=True)]
723745

724746
def interpolate(self, method='pad', axis=0, index=None,
725747
values=None, inplace=False, limit=None,
@@ -1003,7 +1025,7 @@ def handle_error():
10031025
fastpath=True, placement=self.mgr_locs)]
10041026

10051027
def where(self, other, cond, align=True, raise_on_error=True,
1006-
try_cast=False):
1028+
try_cast=False, axis=0, transpose=False):
10071029
"""
10081030
evaluate the block; return result block(s) from the result
10091031
@@ -1014,50 +1036,33 @@ def where(self, other, cond, align=True, raise_on_error=True,
10141036
align : boolean, perform alignment on other/cond
10151037
raise_on_error : if True, raise when I can't perform the function,
10161038
False by default (and just return the data that we had coming in)
1039+
axis : int
1040+
transpose : boolean
1041+
Set to True if self is stored with axes reversed
10171042
10181043
Returns
10191044
-------
10201045
a new block(s), the result of the func
10211046
"""
10221047

10231048
values = self.values
1049+
if transpose:
1050+
values = values.T
10241051

1025-
# see if we can align other
10261052
if hasattr(other, 'reindex_axis'):
10271053
other = other.values
10281054

1029-
# make sure that we can broadcast
1030-
is_transposed = False
1031-
if hasattr(other, 'ndim') and hasattr(values, 'ndim'):
1032-
if values.ndim != other.ndim or values.shape == other.shape[::-1]:
1033-
1034-
# if its symmetric are ok, no reshaping needed (GH 7506)
1035-
if (values.shape[0] == np.array(values.shape)).all():
1036-
pass
1037-
1038-
# pseodo broadcast (its a 2d vs 1d say and where needs it in a
1039-
# specific direction)
1040-
elif (other.ndim >= 1 and values.ndim - 1 == other.ndim and
1041-
values.shape[0] != other.shape[0]):
1042-
other = _block_shape(other).T
1043-
else:
1044-
values = values.T
1045-
is_transposed = True
1046-
1047-
# see if we can align cond
1048-
if not hasattr(cond, 'shape'):
1049-
raise ValueError(
1050-
"where must have a condition that is ndarray like")
1051-
10521055
if hasattr(cond, 'reindex_axis'):
10531056
cond = cond.values
10541057

1055-
# may need to undo transpose of values
1056-
if hasattr(values, 'ndim'):
1057-
if values.ndim != cond.ndim or values.shape == cond.shape[::-1]:
1058+
# If the default broadcasting would go in the wrong direction, then
1059+
# explictly reshape other instead
1060+
if getattr(other, 'ndim', 0) >= 1:
1061+
if values.ndim - 1 == other.ndim and axis == 1:
1062+
other = other.reshape(tuple(other.shape + (1,)))
10581063

1059-
values = values.T
1060-
is_transposed = not is_transposed
1064+
if not hasattr(cond, 'shape'):
1065+
raise ValueError("where must have a condition that is ndarray like")
10611066

10621067
other = _maybe_convert_string_to_object(other)
10631068

@@ -1090,15 +1095,14 @@ def func(c, v, o):
10901095
raise TypeError('Could not compare [%s] with block values'
10911096
% repr(other))
10921097

1093-
if is_transposed:
1098+
if transpose:
10941099
result = result.T
10951100

10961101
# try to cast if requested
10971102
if try_cast:
10981103
result = self._try_cast_result(result)
10991104

1100-
return make_block(result,
1101-
ndim=self.ndim, placement=self.mgr_locs)
1105+
return make_block(result, ndim=self.ndim, placement=self.mgr_locs)
11021106

11031107
# might need to separate out blocks
11041108
axis = cond.ndim - 1
@@ -1730,7 +1734,8 @@ def take_nd(self, indexer, axis=0, new_mgr_locs=None, fill_tuple=None):
17301734

17311735
return self.make_block_same_class(new_values, new_mgr_locs)
17321736

1733-
def putmask(self, mask, new, align=True, inplace=False):
1737+
def putmask(self, mask, new, align=True, inplace=False,
1738+
axis=0, transpose=False):
17341739
""" putmask the data to the block; it is possible that we may create a
17351740
new dtype of block
17361741
@@ -2422,12 +2427,18 @@ def apply(self, f, axes=None, filter=None, do_integrity_check=False, **kwargs):
24222427
else:
24232428
kwargs['filter'] = filter_locs
24242429

2425-
if f == 'where' and kwargs.get('align', True):
2430+
if f == 'where':
24262431
align_copy = True
2427-
align_keys = ['other', 'cond']
2428-
elif f == 'putmask' and kwargs.get('align', True):
2432+
if kwargs.get('align', True):
2433+
align_keys = ['other', 'cond']
2434+
else:
2435+
align_keys = ['cond']
2436+
elif f == 'putmask':
24292437
align_copy = False
2430-
align_keys = ['new', 'mask']
2438+
if kwargs.get('align', True):
2439+
align_keys = ['new', 'mask']
2440+
else:
2441+
align_keys = ['mask']
24312442
elif f == 'eval':
24322443
align_copy = False
24332444
align_keys = ['other']

pandas/tests/test_frame.py

+104
Original file line numberDiff line numberDiff line change
@@ -9957,6 +9957,110 @@ def test_where_complex(self):
99579957
df[df.abs() >= 5] = np.nan
99589958
assert_frame_equal(df,expected)
99599959

9960+
def test_where_axis(self):
9961+
# GH 9736
9962+
df = DataFrame(np.random.randn(2, 2))
9963+
mask = DataFrame([[False, False], [False, False]])
9964+
s = Series([0, 1])
9965+
9966+
expected = DataFrame([[0, 0], [1, 1]], dtype='float64')
9967+
result = df.where(mask, s, axis='index')
9968+
assert_frame_equal(result, expected)
9969+
9970+
result = df.copy()
9971+
result.where(mask, s, axis='index', inplace=True)
9972+
assert_frame_equal(result, expected)
9973+
9974+
expected = DataFrame([[0, 1], [0, 1]], dtype='float64')
9975+
result = df.where(mask, s, axis='columns')
9976+
assert_frame_equal(result, expected)
9977+
9978+
result = df.copy()
9979+
result.where(mask, s, axis='columns', inplace=True)
9980+
assert_frame_equal(result, expected)
9981+
9982+
# Upcast needed
9983+
df = DataFrame([[1, 2], [3, 4]], dtype='int64')
9984+
mask = DataFrame([[False, False], [False, False]])
9985+
s = Series([0, np.nan])
9986+
9987+
expected = DataFrame([[0, 0], [np.nan, np.nan]], dtype='float64')
9988+
result = df.where(mask, s, axis='index')
9989+
assert_frame_equal(result, expected)
9990+
9991+
result = df.copy()
9992+
result.where(mask, s, axis='index', inplace=True)
9993+
assert_frame_equal(result, expected)
9994+
9995+
expected = DataFrame([[0, np.nan], [0, np.nan]], dtype='float64')
9996+
result = df.where(mask, s, axis='columns')
9997+
assert_frame_equal(result, expected)
9998+
9999+
expected = DataFrame({0 : np.array([0, 0], dtype='int64'),
10000+
1 : np.array([np.nan, np.nan], dtype='float64')})
10001+
result = df.copy()
10002+
result.where(mask, s, axis='columns', inplace=True)
10003+
assert_frame_equal(result, expected)
10004+
10005+
# Multiple dtypes (=> multiple Blocks)
10006+
df = pd.concat([DataFrame(np.random.randn(10, 2)),
10007+
DataFrame(np.random.randint(0, 10, size=(10, 2)))],
10008+
ignore_index=True, axis=1)
10009+
mask = DataFrame(False, columns=df.columns, index=df.index)
10010+
s1 = Series(1, index=df.columns)
10011+
s2 = Series(2, index=df.index)
10012+
10013+
result = df.where(mask, s1, axis='columns')
10014+
expected = DataFrame(1.0, columns=df.columns, index=df.index)
10015+
expected[2] = expected[2].astype(int)
10016+
expected[3] = expected[3].astype(int)
10017+
assert_frame_equal(result, expected)
10018+
10019+
result = df.copy()
10020+
result.where(mask, s1, axis='columns', inplace=True)
10021+
assert_frame_equal(result, expected)
10022+
10023+
result = df.where(mask, s2, axis='index')
10024+
expected = DataFrame(2.0, columns=df.columns, index=df.index)
10025+
expected[2] = expected[2].astype(int)
10026+
expected[3] = expected[3].astype(int)
10027+
assert_frame_equal(result, expected)
10028+
10029+
result = df.copy()
10030+
result.where(mask, s2, axis='index', inplace=True)
10031+
assert_frame_equal(result, expected)
10032+
10033+
# DataFrame vs DataFrame
10034+
d1 = df.copy().drop(1, axis=0)
10035+
expected = df.copy()
10036+
expected.loc[1, :] = np.nan
10037+
10038+
result = df.where(mask, d1)
10039+
assert_frame_equal(result, expected)
10040+
result = df.where(mask, d1, axis='index')
10041+
assert_frame_equal(result, expected)
10042+
result = df.copy()
10043+
result.where(mask, d1, inplace=True)
10044+
assert_frame_equal(result, expected)
10045+
result = df.copy()
10046+
result.where(mask, d1, inplace=True, axis='index')
10047+
assert_frame_equal(result, expected)
10048+
10049+
d2 = df.copy().drop(1, axis=1)
10050+
expected = df.copy()
10051+
expected.loc[:, 1] = np.nan
10052+
10053+
result = df.where(mask, d2)
10054+
assert_frame_equal(result, expected)
10055+
result = df.where(mask, d2, axis='columns')
10056+
assert_frame_equal(result, expected)
10057+
result = df.copy()
10058+
result.where(mask, d2, inplace=True)
10059+
assert_frame_equal(result, expected)
10060+
result = df.copy()
10061+
result.where(mask, d2, inplace=True, axis='columns')
10062+
assert_frame_equal(result, expected)
10063+
996010064
def test_mask(self):
996110065
df = DataFrame(np.random.randn(5, 3))
996210066
cond = df > 0

0 commit comments

Comments
 (0)