Skip to content

Commit aae512c

Browse files
mroeschkeTLouf
authored andcommitted
ENH: Improve error message in corr/cov for Rolling/Expanding/EWM when other isn't a DataFrame/Series (pandas-dev#41741)
1 parent abfc78f commit aae512c

File tree

5 files changed

+15
-45
lines changed

5 files changed

+15
-45
lines changed

doc/source/whatsnew/v1.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ Other enhancements
233233
- Add keyword ``sort`` to :func:`pivot_table` to allow non-sorting of the result (:issue:`39143`)
234234
- Add keyword ``dropna`` to :meth:`DataFrame.value_counts` to allow counting rows that include ``NA`` values (:issue:`41325`)
235235
- :meth:`Series.replace` will now cast results to ``PeriodDtype`` where possible instead of ``object`` dtype (:issue:`41526`)
236+
- Improved error message in ``corr` and ``cov`` methods on :class:`.Rolling`, :class:`.Expanding`, and :class:`.ExponentialMovingWindow` when ``other`` is not a :class:`DataFrame` or :class:`Series` (:issue:`41741`)
236237

237238
.. ---------------------------------------------------------------------------
238239

pandas/core/window/common.py

+10-34
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Common utility functions for rolling operations"""
22
from collections import defaultdict
33
from typing import cast
4-
import warnings
54

65
import numpy as np
76

@@ -15,17 +14,7 @@
1514

1615
def flex_binary_moment(arg1, arg2, f, pairwise=False):
1716

18-
if not (
19-
isinstance(arg1, (np.ndarray, ABCSeries, ABCDataFrame))
20-
and isinstance(arg2, (np.ndarray, ABCSeries, ABCDataFrame))
21-
):
22-
raise TypeError(
23-
"arguments to moment function must be of type np.ndarray/Series/DataFrame"
24-
)
25-
26-
if isinstance(arg1, (np.ndarray, ABCSeries)) and isinstance(
27-
arg2, (np.ndarray, ABCSeries)
28-
):
17+
if isinstance(arg1, ABCSeries) and isinstance(arg2, ABCSeries):
2918
X, Y = prep_binary(arg1, arg2)
3019
return f(X, Y)
3120

@@ -43,31 +32,25 @@ def dataframe_from_int_dict(data, frame_template):
4332
if pairwise is False:
4433
if arg1 is arg2:
4534
# special case in order to handle duplicate column names
46-
for i, col in enumerate(arg1.columns):
35+
for i in range(len(arg1.columns)):
4736
results[i] = f(arg1.iloc[:, i], arg2.iloc[:, i])
4837
return dataframe_from_int_dict(results, arg1)
4938
else:
5039
if not arg1.columns.is_unique:
5140
raise ValueError("'arg1' columns are not unique")
5241
if not arg2.columns.is_unique:
5342
raise ValueError("'arg2' columns are not unique")
54-
with warnings.catch_warnings(record=True):
55-
warnings.simplefilter("ignore", RuntimeWarning)
56-
X, Y = arg1.align(arg2, join="outer")
57-
X = X + 0 * Y
58-
Y = Y + 0 * X
59-
60-
with warnings.catch_warnings(record=True):
61-
warnings.simplefilter("ignore", RuntimeWarning)
62-
res_columns = arg1.columns.union(arg2.columns)
43+
X, Y = arg1.align(arg2, join="outer")
44+
X, Y = prep_binary(X, Y)
45+
res_columns = arg1.columns.union(arg2.columns)
6346
for col in res_columns:
6447
if col in X and col in Y:
6548
results[col] = f(X[col], Y[col])
6649
return DataFrame(results, index=X.index, columns=res_columns)
6750
elif pairwise is True:
6851
results = defaultdict(dict)
69-
for i, k1 in enumerate(arg1.columns):
70-
for j, k2 in enumerate(arg2.columns):
52+
for i in range(len(arg1.columns)):
53+
for j in range(len(arg2.columns)):
7154
if j < i and arg2 is arg1:
7255
# Symmetric case
7356
results[i][j] = results[j][i]
@@ -85,10 +68,10 @@ def dataframe_from_int_dict(data, frame_template):
8568
result = concat(
8669
[
8770
concat(
88-
[results[i][j] for j, c in enumerate(arg2.columns)],
71+
[results[i][j] for j in range(len(arg2.columns))],
8972
ignore_index=True,
9073
)
91-
for i, c in enumerate(arg1.columns)
74+
for i in range(len(arg1.columns))
9275
],
9376
ignore_index=True,
9477
axis=1,
@@ -135,13 +118,10 @@ def dataframe_from_int_dict(data, frame_template):
135118
)
136119

137120
return result
138-
139-
else:
140-
raise ValueError("'pairwise' is not True/False")
141121
else:
142122
results = {
143123
i: f(*prep_binary(arg1.iloc[:, i], arg2))
144-
for i, col in enumerate(arg1.columns)
124+
for i in range(len(arg1.columns))
145125
}
146126
return dataframe_from_int_dict(results, arg1)
147127

@@ -165,11 +145,7 @@ def zsqrt(x):
165145

166146

167147
def prep_binary(arg1, arg2):
168-
if not isinstance(arg2, type(arg1)):
169-
raise Exception("Input arrays must be of the same type!")
170-
171148
# mask out values, this also makes a common index...
172149
X = arg1 + 0 * arg2
173150
Y = arg2 + 0 * arg1
174-
175151
return X, Y

pandas/core/window/rolling.py

+2
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,8 @@ def _apply_pairwise(
472472
other = target
473473
# only default unset
474474
pairwise = True if pairwise is None else pairwise
475+
elif not isinstance(other, (ABCDataFrame, ABCSeries)):
476+
raise ValueError("other must be a DataFrame or Series")
475477

476478
return flex_binary_moment(target, other, func, pairwise=bool(pairwise))
477479

pandas/tests/window/moments/test_moments_consistency_ewm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def test_different_input_array_raise_exception(name):
6464
A = Series(np.random.randn(50), index=np.arange(50))
6565
A[:10] = np.NaN
6666

67-
msg = "Input arrays must be of the same type!"
67+
msg = "other must be a DataFrame or Series"
6868
# exception raised is Exception
69-
with pytest.raises(Exception, match=msg):
69+
with pytest.raises(ValueError, match=msg):
7070
getattr(A.ewm(com=20, min_periods=5), name)(np.random.randn(50))
7171

7272

pandas/tests/window/moments/test_moments_consistency_rolling.py

-9
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
Series,
1414
)
1515
import pandas._testing as tm
16-
from pandas.core.window.common import flex_binary_moment
1716

1817

1918
def _rolling_consistency_cases():
@@ -133,14 +132,6 @@ def test_rolling_corr_with_zero_variance(window):
133132
assert s.rolling(window=window).corr(other=other).isna().all()
134133

135134

136-
def test_flex_binary_moment():
137-
# GH3155
138-
# don't blow the stack
139-
msg = "arguments to moment function must be of type np.ndarray/Series/DataFrame"
140-
with pytest.raises(TypeError, match=msg):
141-
flex_binary_moment(5, 6, None)
142-
143-
144135
def test_corr_sanity():
145136
# GH 3155
146137
df = DataFrame(

0 commit comments

Comments
 (0)