Skip to content

Commit 741b2fa

Browse files
committed
Merge pull request pandas-dev#7506 from jreback/where
BUG: Bug in DataFrame.where with a symmetric shaped frame and a passed other of a DataFrame
2 parents d0fa20a + adca809 commit 741b2fa

File tree

3 files changed

+43
-14
lines changed

3 files changed

+43
-14
lines changed

doc/source/v0.14.1.txt

+1
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ Experimental
165165

166166
Bug Fixes
167167
~~~~~~~~~
168+
- Bug in ``DataFrame.where`` with a symmetric shaped frame and a passed other of a DataFrame (:issue:`7506`)
168169

169170

170171
- Bug in timeops with non-aligned Series (:issue:`7500`)

pandas/core/internals.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -921,9 +921,13 @@ def where(self, other, cond, align=True, raise_on_error=True,
921921
if hasattr(other, 'ndim') and hasattr(values, 'ndim'):
922922
if values.ndim != other.ndim or values.shape == other.shape[::-1]:
923923

924+
# if its symmetric are ok, no reshaping needed (GH 7506)
925+
if (values.shape[0] == np.array(values.shape)).all():
926+
pass
927+
924928
# pseodo broadcast (its a 2d vs 1d say and where needs it in a
925929
# specific direction)
926-
if (other.ndim >= 1 and values.ndim - 1 == other.ndim and
930+
elif (other.ndim >= 1 and values.ndim - 1 == other.ndim and
927931
values.shape[0] != other.shape[0]):
928932
other = _block_shape(other).T
929933
else:
@@ -941,9 +945,11 @@ def where(self, other, cond, align=True, raise_on_error=True,
941945
# may need to undo transpose of values
942946
if hasattr(values, 'ndim'):
943947
if values.ndim != cond.ndim or values.shape == cond.shape[::-1]:
948+
944949
values = values.T
945950
is_transposed = not is_transposed
946951

952+
947953
# our where function
948954
def func(c, v, o):
949955
if c.ravel().all():

pandas/tests/test_frame.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -5564,27 +5564,27 @@ def test_to_csv_from_csv(self):
55645564
with ensure_clean(pname) as path:
55655565

55665566
self.frame['A'][:5] = nan
5567-
5567+
55685568
self.frame.to_csv(path)
55695569
self.frame.to_csv(path, columns=['A', 'B'])
55705570
self.frame.to_csv(path, header=False)
55715571
self.frame.to_csv(path, index=False)
5572-
5572+
55735573
# test roundtrip
55745574
self.tsframe.to_csv(path)
55755575
recons = DataFrame.from_csv(path)
5576-
5576+
55775577
assert_frame_equal(self.tsframe, recons)
5578-
5578+
55795579
self.tsframe.to_csv(path, index_label='index')
55805580
recons = DataFrame.from_csv(path, index_col=None)
55815581
assert(len(recons.columns) == len(self.tsframe.columns) + 1)
5582-
5582+
55835583
# no index
55845584
self.tsframe.to_csv(path, index=False)
55855585
recons = DataFrame.from_csv(path, index_col=None)
55865586
assert_almost_equal(self.tsframe.values, recons.values)
5587-
5587+
55885588
# corner case
55895589
dm = DataFrame({'s1': Series(lrange(3), lrange(3)),
55905590
's2': Series(lrange(2), lrange(2))})
@@ -5600,24 +5600,24 @@ def test_to_csv_from_csv(self):
56005600
df.to_csv(path)
56015601
result = DataFrame.from_csv(path)
56025602
assert_frame_equal(result, df)
5603-
5603+
56045604
midx = MultiIndex.from_tuples([('A', 1, 2), ('A', 1, 2), ('B', 1, 2)])
56055605
df = DataFrame(np.random.randn(3, 3), index=midx,
56065606
columns=['x', 'y', 'z'])
56075607
df.to_csv(path)
56085608
result = DataFrame.from_csv(path, index_col=[0, 1, 2],
56095609
parse_dates=False)
56105610
assert_frame_equal(result, df, check_names=False) # TODO from_csv names index ['Unnamed: 1', 'Unnamed: 2'] should it ?
5611-
5611+
56125612
# column aliases
56135613
col_aliases = Index(['AA', 'X', 'Y', 'Z'])
56145614
self.frame2.to_csv(path, header=col_aliases)
56155615
rs = DataFrame.from_csv(path)
56165616
xp = self.frame2.copy()
56175617
xp.columns = col_aliases
5618-
5618+
56195619
assert_frame_equal(xp, rs)
5620-
5620+
56215621
self.assertRaises(ValueError, self.frame2.to_csv, path,
56225622
header=['AA', 'X'])
56235623

@@ -5881,7 +5881,7 @@ def test_to_csv_from_csv_w_some_infs(self):
58815881
with ensure_clean() as path:
58825882
self.frame.to_csv(path)
58835883
recons = DataFrame.from_csv(path)
5884-
5884+
58855885
assert_frame_equal(self.frame, recons, check_names=False) # TODO to_csv drops column name
58865886
assert_frame_equal(np.isinf(self.frame), np.isinf(recons), check_names=False)
58875887

@@ -5940,11 +5940,11 @@ def test_to_csv_multiindex(self):
59405940

59415941
frame.to_csv(path, header=False)
59425942
frame.to_csv(path, columns=['A', 'B'])
5943-
5943+
59445944
# round trip
59455945
frame.to_csv(path)
59465946
df = DataFrame.from_csv(path, index_col=[0, 1], parse_dates=False)
5947-
5947+
59485948
assert_frame_equal(frame, df, check_names=False) # TODO to_csv drops column name
59495949
self.assertEqual(frame.index.names, df.index.names)
59505950
self.frame.index = old_index # needed if setUP becomes a classmethod
@@ -9155,6 +9155,28 @@ def test_where_bug(self):
91559155
result.where(result > 2, np.nan, inplace=True)
91569156
assert_frame_equal(result, expected)
91579157

9158+
# transpositional issue
9159+
# GH7506
9160+
a = DataFrame({ 0 : [1,2], 1 : [3,4], 2 : [5,6]})
9161+
b = DataFrame({ 0 : [np.nan,8], 1:[9,np.nan], 2:[np.nan,np.nan]})
9162+
do_not_replace = b.isnull() | (a > b)
9163+
9164+
expected = a.copy()
9165+
expected[~do_not_replace] = b
9166+
9167+
result = a.where(do_not_replace,b)
9168+
assert_frame_equal(result,expected)
9169+
9170+
a = DataFrame({ 0 : [4,6], 1 : [1,0]})
9171+
b = DataFrame({ 0 : [np.nan,3],1:[3,np.nan]})
9172+
do_not_replace = b.isnull() | (a > b)
9173+
9174+
expected = a.copy()
9175+
expected[~do_not_replace] = b
9176+
9177+
result = a.where(do_not_replace,b)
9178+
assert_frame_equal(result,expected)
9179+
91589180
def test_where_datetime(self):
91599181

91609182
# GH 3311

0 commit comments

Comments
 (0)