Skip to content

Commit a4d1ea3

Browse files
committed
ENH: corrwith excludes object data by default, address GH #144
1 parent dca3c5c commit a4d1ea3

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

pandas/core/frame.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2214,20 +2214,23 @@ def corrwith(self, other, axis=0, drop=False):
22142214
-------
22152215
correls : Series
22162216
"""
2217-
com_index = self._intersect_index(other)
2218-
com_cols = self._intersect_columns(other)
2217+
this = self._get_numeric_data()
2218+
other = other._get_numeric_data()
2219+
2220+
com_index = this._intersect_index(other)
2221+
com_cols = this._intersect_columns(other)
22192222

22202223
# feels hackish
22212224
if axis == 0:
22222225
result_index = com_index
22232226
if not drop:
2224-
result_index = self.columns.union(other.columns)
2227+
result_index = this.columns.union(other.columns)
22252228
else:
22262229
result_index = com_cols
22272230
if not drop:
2228-
result_index = self.index.union(other.index)
2231+
result_index = this.index.union(other.index)
22292232

2230-
left = self.reindex(index=com_index, columns=com_cols)
2233+
left = this.reindex(index=com_index, columns=com_cols)
22312234
right = other.reindex(index=com_index, columns=com_cols)
22322235

22332236
# mask missing values
@@ -2692,6 +2695,15 @@ def _get_numeric_columns(self):
26922695

26932696
return cols
26942697

2698+
def _get_numeric_data(self):
2699+
if self._is_mixed_type:
2700+
return self.ix[:, self._get_numeric_columns()]
2701+
else:
2702+
if self.values.dtype != np.object_:
2703+
return self
2704+
else:
2705+
return self.ix[:, []]
2706+
26952707
def clip(self, upper=None, lower=None):
26962708
"""
26972709
Trim values at input threshold(s)

pandas/tests/test_frame.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,6 +1704,22 @@ def test_corrwith(self):
17041704
for row in index[:4]:
17051705
assert_almost_equal(correls[row], df1.ix[row].corr(df2.ix[row]))
17061706

1707+
def test_corrwith_with_objects(self):
1708+
df1 = tm.makeTimeDataFrame()
1709+
df2 = tm.makeTimeDataFrame()
1710+
cols = ['A', 'B', 'C', 'D']
1711+
1712+
df1['obj'] = 'foo'
1713+
df2['obj'] = 'bar'
1714+
1715+
result = df1.corrwith(df2)
1716+
expected = df1.ix[:, cols].corrwith(df2.ix[:, cols])
1717+
assert_series_equal(result, expected)
1718+
1719+
result = df1.corrwith(df2, axis=1)
1720+
expected = df1.ix[:, cols].corrwith(df2.ix[:, cols], axis=1)
1721+
assert_series_equal(result, expected)
1722+
17071723
def test_dropEmptyRows(self):
17081724
N = len(self.frame.index)
17091725
mat = randn(N)

0 commit comments

Comments
 (0)