Skip to content

Commit 31f5662

Browse files
committed
ENH: rolling_corr/rolling_cov can take DF/DF and DF/Series combos for easy multi-application, GH #462
1 parent 6aa80f9 commit 31f5662

File tree

3 files changed

+108
-40
lines changed

3 files changed

+108
-40
lines changed

RELEASE.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ pandas 0.6.1
6161
#453)
6262
- Optimize scalar value lookups in the general case by 25% or more in Series
6363
and DataFrame
64+
- Can pass DataFrame/DataFrame and DataFrame/Series to
65+
rolling_corr/rolling_cov (GH #462)
6466

6567
**Bug fixes**
6668

@@ -80,6 +82,9 @@ pandas 0.6.1
8082
- Bug fix in left join Cython code with duplicate monotonic labels
8183
- Fix bug when unstacking multiple levels described in #451
8284
- Exclude NA values in dtype=object arrays, regression from 0.5.0 (GH #469)
85+
- Use Cython map_infer function in DataFrame.applymap to properly infer
86+
output type, handle tuple return values and other things that were breaking
87+
(GH #465)
8388

8489
Thanks
8590
------

pandas/stats/moments.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,47 @@ def rolling_count(arg, window, time_rule=None):
4646
return return_hook(result)
4747

4848
def rolling_cov(arg1, arg2, window, min_periods=None, time_rule=None):
49-
X, Y = _prep_binary(arg1, arg2)
50-
mean = lambda x: rolling_mean(x, window, min_periods, time_rule)
51-
count = rolling_count(X + Y, window, time_rule)
52-
bias_adj = count / (count - 1)
53-
return (mean(X * Y) - mean(X) * mean(Y)) * bias_adj
49+
def _get_cov(X, Y):
50+
mean = lambda x: rolling_mean(x, window, min_periods, time_rule)
51+
count = rolling_count(X + Y, window, time_rule)
52+
bias_adj = count / (count - 1)
53+
return (mean(X * Y) - mean(X) * mean(Y)) * bias_adj
54+
return _flex_binary_moment(arg1, arg2, _get_cov)
5455

5556
def rolling_corr(arg1, arg2, window, min_periods=None, time_rule=None):
56-
X, Y = _prep_binary(arg1, arg2)
57-
num = rolling_cov(X, Y, window, min_periods, time_rule)
58-
den = (rolling_std(X, window, min_periods, time_rule) *
59-
rolling_std(Y, window, min_periods, time_rule))
60-
return num / den
57+
def _get_corr(a, b):
58+
num = rolling_cov(a, b, window, min_periods, time_rule)
59+
den = (rolling_std(a, window, min_periods, time_rule) *
60+
rolling_std(b, window, min_periods, time_rule))
61+
return num / den
62+
return _flex_binary_moment(arg1, arg2, _get_corr)
63+
64+
def _flex_binary_moment(arg1, arg2, f):
65+
if isinstance(arg1, np.ndarray) and isinstance(arg2, np.ndarray):
66+
X, Y = _prep_binary(arg1, arg2)
67+
return f(X, Y)
68+
elif isinstance(arg1, DataFrame):
69+
results = {}
70+
if isinstance(arg2, DataFrame):
71+
X, Y = arg1.align(arg2, join='outer')
72+
X = X + 0 * Y
73+
Y = Y + 0 * X
74+
res_columns = arg1.columns.union(arg2.columns)
75+
for col in res_columns:
76+
if col in X and col in Y:
77+
results[col] = f(X[col], Y[col])
78+
else:
79+
res_columns = arg1.columns
80+
X, Y = arg1.align(arg2, axis=0, join='outer')
81+
results = {}
82+
83+
for col in res_columns:
84+
results[col] = f(X[col], Y)
85+
86+
return DataFrame(results, index=X.index, columns=res_columns)
87+
else:
88+
return _flex_binary_moment(arg2, arg1, f)
89+
6190

6291
def _rolling_moment(arg, window, func, minp, axis=0, time_rule=None):
6392
"""
@@ -219,7 +248,7 @@ def _prep_binary(arg1, arg2):
219248
220249
Returns
221250
-------
222-
y : type of input argument
251+
%s
223252
"""
224253

225254

@@ -256,18 +285,29 @@ def _prep_binary(arg1, arg2):
256285
y : type of input argument
257286
"""
258287

288+
_type_of_input = "y : type of input argument"
289+
290+
_flex_retval = """y : type depends on inputs
291+
DataFrame / DataFrame -> DataFrame (matches on columns)
292+
DataFrame / Series -> Computes result for each column
293+
Series / Series -> Series"""
294+
259295
_unary_arg = "arg : Series, DataFrame"
296+
297+
_binary_arg_flex = """arg1 : Series, DataFrame, or ndarray
298+
arg2 : Series, DataFrame, or ndarray"""
299+
260300
_binary_arg = """arg1 : Series, DataFrame, or ndarray
261-
arg2 : type of arg1"""
301+
arg2 : Series, DataFrame, or ndarray"""
262302

263303
_bias_doc = r"""bias : boolean, default False
264304
Use a standard estimation bias correction
265305
"""
266306

267307
rolling_cov.__doc__ = _doc_template % ("Unbiased moving covariance",
268-
_binary_arg)
308+
_binary_arg_flex, _flex_retval)
269309
rolling_corr.__doc__ = _doc_template % ("Moving sample correlation",
270-
_binary_arg)
310+
_binary_arg_flex, _flex_retval)
271311

272312
ewma.__doc__ = _ewm_doc % ("Exponentially-weighted moving average",
273313
_unary_arg, "")
@@ -314,7 +354,7 @@ def call_cython(arg, window, minp):
314354
return _rolling_moment(arg, window, call_cython, min_periods,
315355
time_rule=time_rule)
316356

317-
f.__doc__ = _doc_template % (desc, _unary_arg)
357+
f.__doc__ = _doc_template % (desc, _unary_arg, _type_of_input)
318358

319359
return f
320360

pandas/stats/tests/test_moments.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pandas.core.api import Series, DataFrame, DateRange
99
from pandas.util.testing import assert_almost_equal
1010
import pandas.core.datetools as datetools
11-
import pandas.stats.moments as moments
11+
import pandas.stats.moments as mom
1212
import pandas.util.testing as tm
1313

1414
N, K = 100, 10
@@ -31,25 +31,25 @@ def setUp(self):
3131
columns=np.arange(K))
3232

3333
def test_rolling_sum(self):
34-
self._check_moment_func(moments.rolling_sum, np.sum)
34+
self._check_moment_func(mom.rolling_sum, np.sum)
3535

3636
def test_rolling_count(self):
3737
counter = lambda x: np.isfinite(x).astype(float).sum()
38-
self._check_moment_func(moments.rolling_count, counter,
38+
self._check_moment_func(mom.rolling_count, counter,
3939
has_min_periods=False,
4040
preserve_nan=False)
4141

4242
def test_rolling_mean(self):
43-
self._check_moment_func(moments.rolling_mean, np.mean)
43+
self._check_moment_func(mom.rolling_mean, np.mean)
4444

4545
def test_rolling_median(self):
46-
self._check_moment_func(moments.rolling_median, np.median)
46+
self._check_moment_func(mom.rolling_median, np.median)
4747

4848
def test_rolling_min(self):
49-
self._check_moment_func(moments.rolling_min, np.min)
49+
self._check_moment_func(mom.rolling_min, np.min)
5050

5151
def test_rolling_max(self):
52-
self._check_moment_func(moments.rolling_max, np.max)
52+
self._check_moment_func(mom.rolling_max, np.max)
5353

5454
def test_rolling_quantile(self):
5555
qs = [.1, .5, .9]
@@ -62,7 +62,7 @@ def scoreatpercentile(a, per):
6262

6363
for q in qs:
6464
def f(x, window, min_periods=None, time_rule=None):
65-
return moments.rolling_quantile(x, window, q,
65+
return mom.rolling_quantile(x, window, q,
6666
min_periods=min_periods,
6767
time_rule=time_rule)
6868
def alt(x):
@@ -72,34 +72,34 @@ def alt(x):
7272

7373
def test_rolling_apply(self):
7474
def roll_mean(x, window, min_periods=None, time_rule=None):
75-
return moments.rolling_apply(x, window,
75+
return mom.rolling_apply(x, window,
7676
lambda x: x[np.isfinite(x)].mean(),
7777
min_periods=min_periods,
7878
time_rule=time_rule)
7979
self._check_moment_func(roll_mean, np.mean)
8080

8181
def test_rolling_std(self):
82-
self._check_moment_func(moments.rolling_std,
82+
self._check_moment_func(mom.rolling_std,
8383
lambda x: np.std(x, ddof=1))
8484

8585
def test_rolling_var(self):
86-
self._check_moment_func(moments.rolling_var,
86+
self._check_moment_func(mom.rolling_var,
8787
lambda x: np.var(x, ddof=1))
8888

8989
def test_rolling_skew(self):
9090
try:
9191
from scipy.stats import skew
9292
except ImportError:
9393
raise nose.SkipTest('no scipy')
94-
self._check_moment_func(moments.rolling_skew,
94+
self._check_moment_func(mom.rolling_skew,
9595
lambda x: skew(x, bias=False))
9696

9797
def test_rolling_kurt(self):
9898
try:
9999
from scipy.stats import kurtosis
100100
except ImportError:
101101
raise nose.SkipTest('no scipy')
102-
self._check_moment_func(moments.rolling_kurt,
102+
self._check_moment_func(mom.rolling_kurt,
103103
lambda x: kurtosis(x, bias=False))
104104

105105
def _check_moment_func(self, func, static_comp, window=50,
@@ -186,21 +186,21 @@ def _check_structures(self, func, static_comp,
186186
trunc_frame.apply(static_comp))
187187

188188
def test_ewma(self):
189-
self._check_ew(moments.ewma)
189+
self._check_ew(mom.ewma)
190190

191191
def test_ewmvar(self):
192-
self._check_ew(moments.ewmvar)
192+
self._check_ew(mom.ewmvar)
193193

194194
def test_ewmvol(self):
195-
self._check_ew(moments.ewmvol)
195+
self._check_ew(mom.ewmvol)
196196

197197
def test_ewma_span_com_args(self):
198-
A = moments.ewma(self.arr, com=9.5)
199-
B = moments.ewma(self.arr, span=20)
198+
A = mom.ewma(self.arr, com=9.5)
199+
B = mom.ewma(self.arr, span=20)
200200
assert_almost_equal(A, B)
201201

202-
self.assertRaises(Exception, moments.ewma, self.arr, com=9.5, span=20)
203-
self.assertRaises(Exception, moments.ewma, self.arr)
202+
self.assertRaises(Exception, mom.ewma, self.arr, com=9.5, span=20)
203+
self.assertRaises(Exception, mom.ewma, self.arr)
204204

205205
def _check_ew(self, func):
206206
self._check_ew_ndarray(func)
@@ -233,14 +233,14 @@ def test_rolling_cov(self):
233233
A = self.series
234234
B = A + randn(len(A))
235235

236-
result = moments.rolling_cov(A, B, 50, min_periods=25)
236+
result = mom.rolling_cov(A, B, 50, min_periods=25)
237237
assert_almost_equal(result[-1], np.cov(A[-50:], B[-50:])[0, 1])
238238

239239
def test_rolling_corr(self):
240240
A = self.series
241241
B = A + randn(len(A))
242242

243-
result = moments.rolling_corr(A, B, 50, min_periods=25)
243+
result = mom.rolling_corr(A, B, 50, min_periods=25)
244244
assert_almost_equal(result[-1], np.corrcoef(A[-50:], B[-50:])[0, 1])
245245

246246
# test for correct bias correction
@@ -249,14 +249,37 @@ def test_rolling_corr(self):
249249
a[:5] = np.nan
250250
b[:10] = np.nan
251251

252-
result = moments.rolling_corr(a, b, len(a), min_periods=1)
252+
result = mom.rolling_corr(a, b, len(a), min_periods=1)
253253
assert_almost_equal(result[-1], a.corr(b))
254254

255+
def test_flex_binary_frame(self):
256+
def _check(method):
257+
series = self.frame[1]
258+
259+
res = method(series, self.frame, 10)
260+
res2 = method(self.frame, series, 10)
261+
exp = self.frame.apply(lambda x: method(series, x, 10))
262+
263+
tm.assert_frame_equal(res, exp)
264+
tm.assert_frame_equal(res2, exp)
265+
266+
frame2 = self.frame.copy()
267+
frame2.values[:] = np.random.randn(*frame2.shape)
268+
269+
res3 = method(self.frame, frame2, 10)
270+
exp = DataFrame(dict((k, method(self.frame[k], frame2[k], 10))
271+
for k in self.frame))
272+
tm.assert_frame_equal(res3, exp)
273+
274+
methods = [mom.rolling_corr, mom.rolling_cov]
275+
for meth in methods:
276+
_check(meth)
277+
255278
def test_ewmcov(self):
256-
self._check_binary_ew(moments.ewmcov)
279+
self._check_binary_ew(mom.ewmcov)
257280

258281
def test_ewmcorr(self):
259-
self._check_binary_ew(moments.ewmcorr)
282+
self._check_binary_ew(mom.ewmcorr)
260283

261284
def _check_binary_ew(self, func):
262285
A = Series(randn(50), index=np.arange(50))

0 commit comments

Comments
 (0)