Skip to content

Commit 226876a

Browse files
BUG: EWM silently failed float32 (pandas-dev#42650)
* BUG: EWM silently failed float32 * added tests * resolved mypy error * added constant data in test * added pytest.fixture & whatsnew * parametrized expected df; removed float16 * added test for float32 * added tests on select_dtypes
1 parent 397432c commit 226876a

File tree

5 files changed

+96
-0
lines changed

5 files changed

+96
-0
lines changed

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ Groupby/resample/rolling
270270
- Bug in :meth:`Series.rolling.apply`, :meth:`DataFrame.rolling.apply`, :meth:`Series.expanding.apply` and :meth:`DataFrame.expanding.apply` with ``engine="numba"`` where ``*args`` were being cached with the user passed function (:issue:`42287`)
271271
- Bug in :meth:`DataFrame.groupby.rolling.var` would calculate the rolling variance only on the first group (:issue:`42442`)
272272
- Bug in :meth:`GroupBy.shift` that would return the grouping columns if ``fill_value`` was not None (:issue:`41556`)
273+
- Bug in :meth:`pandas.DataFrame.ewm`, where non-float64 dtypes were silently failing (:issue:`42452`)
273274

274275
Reshaping
275276
^^^^^^^^^

pandas/core/frame.py

+5
Original file line numberDiff line numberDiff line change
@@ -4280,6 +4280,11 @@ def check_int_infer_dtype(dtypes):
42804280
# error: Argument 1 to "append" of "list" has incompatible type
42814281
# "Type[signedinteger[Any]]"; expected "Type[signedinteger[Any]]"
42824282
converted_dtypes.append(np.int64) # type: ignore[arg-type]
4283+
elif dtype == "float" or dtype is float:
4284+
# GH#42452 : np.dtype("float") coerces to np.float64 from Numpy 1.20
4285+
converted_dtypes.extend(
4286+
[np.float64, np.float32] # type: ignore[list-item]
4287+
)
42834288
else:
42844289
# error: Argument 1 to "append" of "list" has incompatible type
42854290
# "Union[dtype[Any], ExtensionDtype]"; expected

pandas/tests/frame/methods/test_select_dtypes.py

+34
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,37 @@ def test_select_dtypes_numeric_nullable_string(self, nullable_string_dtype):
407407
df = DataFrame(arr)
408408
is_selected = df.select_dtypes(np.number).shape == df.shape
409409
assert not is_selected
410+
411+
@pytest.mark.parametrize(
412+
"expected, float_dtypes",
413+
[
414+
[
415+
DataFrame(
416+
{"A": range(3), "B": range(5, 8), "C": range(10, 7, -1)}
417+
).astype(dtype={"A": float, "B": np.float64, "C": np.float32}),
418+
float,
419+
],
420+
[
421+
DataFrame(
422+
{"A": range(3), "B": range(5, 8), "C": range(10, 7, -1)}
423+
).astype(dtype={"A": float, "B": np.float64, "C": np.float32}),
424+
"float",
425+
],
426+
[DataFrame({"C": range(10, 7, -1)}, dtype=np.float32), np.float32],
427+
[
428+
DataFrame({"A": range(3), "B": range(5, 8)}).astype(
429+
dtype={"A": float, "B": np.float64}
430+
),
431+
np.float64,
432+
],
433+
],
434+
)
435+
def test_select_dtypes_float_dtype(self, expected, float_dtypes):
436+
# GH#42452
437+
dtype_dict = {"A": float, "B": np.float64, "C": np.float32}
438+
df = DataFrame(
439+
{"A": range(3), "B": range(5, 8), "C": range(10, 7, -1)},
440+
)
441+
df = df.astype(dtype_dict)
442+
result = df.select_dtypes(include=float_dtypes)
443+
tm.assert_frame_equal(result, expected)

pandas/tests/window/test_ewm.py

+48
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,51 @@ def test_ewma_times_adjust_false_raises():
181181
Series(range(1)).ewm(
182182
0.1, adjust=False, times=date_range("2000", freq="D", periods=1)
183183
)
184+
185+
186+
@pytest.mark.parametrize(
187+
"func, expected",
188+
[
189+
[
190+
"mean",
191+
DataFrame(
192+
{
193+
0: range(5),
194+
1: range(4, 9),
195+
2: [7.428571, 9, 10.571429, 12.142857, 13.714286],
196+
},
197+
dtype=float,
198+
),
199+
],
200+
[
201+
"std",
202+
DataFrame(
203+
{
204+
0: [np.nan] * 5,
205+
1: [4.242641] * 5,
206+
2: [4.6291, 5.196152, 5.781745, 6.380775, 6.989788],
207+
}
208+
),
209+
],
210+
[
211+
"var",
212+
DataFrame(
213+
{
214+
0: [np.nan] * 5,
215+
1: [18.0] * 5,
216+
2: [21.428571, 27, 33.428571, 40.714286, 48.857143],
217+
}
218+
),
219+
],
220+
],
221+
)
222+
def test_float_dtype_ewma(func, expected, float_dtype):
223+
# GH#42452
224+
225+
df = DataFrame(
226+
{0: range(5), 1: range(6, 11), 2: range(10, 20, 2)}, dtype=float_dtype
227+
)
228+
e = df.ewm(alpha=0.5, axis=1)
229+
result = getattr(e, func)()
230+
231+
tm.assert_frame_equal(result, expected)

pandas/tests/window/test_rolling.py

+8
Original file line numberDiff line numberDiff line change
@@ -1424,3 +1424,11 @@ def test_rolling_zero_window():
14241424
result = s.rolling(0).min()
14251425
expected = Series([np.nan])
14261426
tm.assert_series_equal(result, expected)
1427+
1428+
1429+
def test_rolling_float_dtype(float_dtype):
1430+
# GH#42452
1431+
df = DataFrame({"A": range(5), "B": range(10, 15)}, dtype=float_dtype)
1432+
expected = DataFrame({"A": [np.nan] * 5, "B": range(10, 20, 2)}, dtype=float_dtype)
1433+
result = df.rolling(2, axis=1).sum()
1434+
tm.assert_frame_equal(result, expected, check_dtype=False)

0 commit comments

Comments
 (0)