From 737c9f079a0ae2f11e72fe96c31c39c44fde2f3f Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Thu, 27 Aug 2020 00:17:54 -0700 Subject: [PATCH 1/2] REF: window/test_dtypes.py with pytest idioms --- pandas/tests/window/conftest.py | 30 +++ pandas/tests/window/test_dtypes.py | 315 ++++++++--------------------- 2 files changed, 117 insertions(+), 228 deletions(-) diff --git a/pandas/tests/window/conftest.py b/pandas/tests/window/conftest.py index eb8252d5731be..1d8cd636e5ede 100644 --- a/pandas/tests/window/conftest.py +++ b/pandas/tests/window/conftest.py @@ -308,3 +308,33 @@ def which(request): def halflife_with_times(request): """Halflife argument for EWM when times is specified.""" return request.param + + +@pytest.fixture( + params=[ + "object", + "category", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float16", + "float32", + "float64", + "m8[ns]", + "M8[ns]", + pytest.param( + "datetime64[ns, UTC]", + marks=pytest.mark.skip( + "direct creation of extension dtype datetime64[ns, UTC] is not supported ATM" + ), + ), + ] +) +def dtypes(request): + """Dtypes for window tests""" + return request.param diff --git a/pandas/tests/window/test_dtypes.py b/pandas/tests/window/test_dtypes.py index 0aa5bf019ff5e..245b48b351684 100644 --- a/pandas/tests/window/test_dtypes.py +++ b/pandas/tests/window/test_dtypes.py @@ -1,5 +1,3 @@ -from itertools import product - import numpy as np import pytest @@ -10,234 +8,95 @@ # gh-12373 : rolling functions error on float32 data # make sure rolling functions works for different dtypes # -# NOTE that these are yielded tests and so _create_data -# is explicitly called. -# # further note that we are only checking rolling for fully dtype # compliance (though both expanding and ewm inherit) -class Dtype: - window = 2 - - funcs = { - "count": lambda v: v.count(), - "max": lambda v: v.max(), - "min": lambda v: v.min(), - "sum": lambda v: v.sum(), - "mean": lambda v: v.mean(), - "std": lambda v: v.std(), - "var": lambda v: v.var(), - "median": lambda v: v.median(), - } - - def get_expects(self): - expects = { - "sr1": { - "count": Series([1, 2, 2, 2, 2], dtype="float64"), - "max": Series([np.nan, 1, 2, 3, 4], dtype="float64"), - "min": Series([np.nan, 0, 1, 2, 3], dtype="float64"), - "sum": Series([np.nan, 1, 3, 5, 7], dtype="float64"), - "mean": Series([np.nan, 0.5, 1.5, 2.5, 3.5], dtype="float64"), - "std": Series([np.nan] + [np.sqrt(0.5)] * 4, dtype="float64"), - "var": Series([np.nan, 0.5, 0.5, 0.5, 0.5], dtype="float64"), - "median": Series([np.nan, 0.5, 1.5, 2.5, 3.5], dtype="float64"), +def get_dtype(dtype, coerce_int=None): + if coerce_int is False and "int" in dtype: + return None + if dtype != "category": + return np.dtype(dtype) + return dtype + + +@pytest.mark.parametrize( + "method, data, expected_data, coerce_int", + [ + ("count", np.arange(5), [1, 2, 2, 2, 2], True), + ("count", np.arange(10, 0, -2), [1, 2, 2, 2, 2], True), + ("count", [0, 1, 2, np.nan, 4], [1, 2, 2, 1, 1], False), + ("max", np.arange(5), [np.nan, 1, 2, 3, 4], True), + ("max", np.arange(10, 0, -2), [np.nan, 10, 8, 6, 4], True), + ("max", [0, 1, 2, np.nan, 4], [np.nan, 1, 2, np.nan, np.nan], False), + ("min", np.arange(5), [np.nan, 0, 1, 2, 3], True), + ("min", np.arange(10, 0, -2), [np.nan, 8, 6, 4, 2], True), + ("min", [0, 1, 2, np.nan, 4], [np.nan, 0, 1, np.nan, np.nan], False), + ("sum", np.arange(5), [np.nan, 1, 3, 5, 7], True), + ("sum", np.arange(10, 0, -2), [np.nan, 18, 14, 10, 6], True), + ("sum", [0, 1, 2, np.nan, 4], [np.nan, 1, 3, np.nan, np.nan], False), + ("mean", np.arange(5), [np.nan, 0.5, 1.5, 2.5, 3.5], True), + ("mean", np.arange(10, 0, -2), [np.nan, 9, 7, 5, 3], True), + ("mean", [0, 1, 2, np.nan, 4], [np.nan, 0.5, 1.5, np.nan, np.nan], False), + ("std", np.arange(5), [np.nan] + [np.sqrt(0.5)] * 4, True), + ("std", np.arange(10, 0, -2), [np.nan] + [np.sqrt(2)] * 4, True), + ( + "std", + [0, 1, 2, np.nan, 4], + [np.nan] + [np.sqrt(0.5)] * 2 + [np.nan] * 2, + False, + ), + ("var", np.arange(5), [np.nan, 0.5, 0.5, 0.5, 0.5], True), + ("var", np.arange(10, 0, -2), [np.nan, 2, 2, 2, 2], True), + ("var", [0, 1, 2, np.nan, 4], [np.nan, 0.5, 0.5, np.nan, np.nan], False), + ("median", np.arange(5), [np.nan, 0.5, 1.5, 2.5, 3.5], True), + ("median", np.arange(10, 0, -2), [np.nan, 9, 7, 5, 3], True), + ("median", [0, 1, 2, np.nan, 4], [np.nan, 0.5, 1.5, np.nan, np.nan], False), + ], +) +def test_series_dtypes(method, data, expected_data, coerce_int, dtypes): + s = Series(data, dtype=get_dtype(dtypes, coerce_int=coerce_int)) + if dtypes in ("m8[ns]", "M8[ns]") and method != "count": + msg = "No numeric types to aggregate" + with pytest.raises(DataError, match=msg): + getattr(s.rolling(2), method)() + else: + result = getattr(s.rolling(2), method)() + expected = Series(expected_data, dtype="float64") + tm.assert_almost_equal(result, expected) + + +@pytest.mark.parametrize( + "method, expected_data", + [ + ("count", {0: Series([1, 2, 2, 2, 2]), 1: Series([1, 2, 2, 2, 2])}), + ("max", {0: Series([np.nan, 2, 4, 6, 8]), 1: Series([np.nan, 3, 5, 7, 9])}), + ("min", {0: Series([np.nan, 0, 2, 4, 6]), 1: Series([np.nan, 1, 3, 5, 7])}), + ( + "sum", + {0: Series([np.nan, 2, 6, 10, 14]), 1: Series([np.nan, 4, 8, 12, 16])}, + ), + ("mean", {0: Series([np.nan, 1, 3, 5, 7]), 1: Series([np.nan, 2, 4, 6, 8])}), + ( + "std", + { + 0: Series([np.nan] + [np.sqrt(2)] * 4), + 1: Series([np.nan] + [np.sqrt(2)] * 4), }, - "sr2": { - "count": Series([1, 2, 2, 2, 2], dtype="float64"), - "max": Series([np.nan, 10, 8, 6, 4], dtype="float64"), - "min": Series([np.nan, 8, 6, 4, 2], dtype="float64"), - "sum": Series([np.nan, 18, 14, 10, 6], dtype="float64"), - "mean": Series([np.nan, 9, 7, 5, 3], dtype="float64"), - "std": Series([np.nan] + [np.sqrt(2)] * 4, dtype="float64"), - "var": Series([np.nan, 2, 2, 2, 2], dtype="float64"), - "median": Series([np.nan, 9, 7, 5, 3], dtype="float64"), - }, - "sr3": { - "count": Series([1, 2, 2, 1, 1], dtype="float64"), - "max": Series([np.nan, 1, 2, np.nan, np.nan], dtype="float64"), - "min": Series([np.nan, 0, 1, np.nan, np.nan], dtype="float64"), - "sum": Series([np.nan, 1, 3, np.nan, np.nan], dtype="float64"), - "mean": Series([np.nan, 0.5, 1.5, np.nan, np.nan], dtype="float64"), - "std": Series( - [np.nan] + [np.sqrt(0.5)] * 2 + [np.nan] * 2, dtype="float64" - ), - "var": Series([np.nan, 0.5, 0.5, np.nan, np.nan], dtype="float64"), - "median": Series([np.nan, 0.5, 1.5, np.nan, np.nan], dtype="float64"), - }, - "df": { - "count": DataFrame( - {0: Series([1, 2, 2, 2, 2]), 1: Series([1, 2, 2, 2, 2])}, - dtype="float64", - ), - "max": DataFrame( - {0: Series([np.nan, 2, 4, 6, 8]), 1: Series([np.nan, 3, 5, 7, 9])}, - dtype="float64", - ), - "min": DataFrame( - {0: Series([np.nan, 0, 2, 4, 6]), 1: Series([np.nan, 1, 3, 5, 7])}, - dtype="float64", - ), - "sum": DataFrame( - { - 0: Series([np.nan, 2, 6, 10, 14]), - 1: Series([np.nan, 4, 8, 12, 16]), - }, - dtype="float64", - ), - "mean": DataFrame( - {0: Series([np.nan, 1, 3, 5, 7]), 1: Series([np.nan, 2, 4, 6, 8])}, - dtype="float64", - ), - "std": DataFrame( - { - 0: Series([np.nan] + [np.sqrt(2)] * 4), - 1: Series([np.nan] + [np.sqrt(2)] * 4), - }, - dtype="float64", - ), - "var": DataFrame( - {0: Series([np.nan, 2, 2, 2, 2]), 1: Series([np.nan, 2, 2, 2, 2])}, - dtype="float64", - ), - "median": DataFrame( - {0: Series([np.nan, 1, 3, 5, 7]), 1: Series([np.nan, 2, 4, 6, 8])}, - dtype="float64", - ), - }, - } - return expects - - def _create_dtype_data(self, dtype): - sr1 = Series(np.arange(5), dtype=dtype) - sr2 = Series(np.arange(10, 0, -2), dtype=dtype) - sr3 = sr1.copy() - sr3[3] = np.NaN - df = DataFrame(np.arange(10).reshape((5, 2)), dtype=dtype) - - data = {"sr1": sr1, "sr2": sr2, "sr3": sr3, "df": df} - - return data - - def _create_data(self): - self.data = self._create_dtype_data(self.dtype) - self.expects = self.get_expects() - - def test_dtypes(self): - self._create_data() - for f_name, d_name in product(self.funcs.keys(), self.data.keys()): - - f = self.funcs[f_name] - d = self.data[d_name] - exp = self.expects[d_name][f_name] - self.check_dtypes(f, f_name, d, d_name, exp) - - def check_dtypes(self, f, f_name, d, d_name, exp): - roll = d.rolling(window=self.window) - result = f(roll) - tm.assert_almost_equal(result, exp) - - -class TestDtype_object(Dtype): - dtype = object - - -class Dtype_integer(Dtype): - pass - - -class TestDtype_int8(Dtype_integer): - dtype = np.int8 - - -class TestDtype_int16(Dtype_integer): - dtype = np.int16 - - -class TestDtype_int32(Dtype_integer): - dtype = np.int32 - - -class TestDtype_int64(Dtype_integer): - dtype = np.int64 - - -class Dtype_uinteger(Dtype): - pass - - -class TestDtype_uint8(Dtype_uinteger): - dtype = np.uint8 - - -class TestDtype_uint16(Dtype_uinteger): - dtype = np.uint16 - - -class TestDtype_uint32(Dtype_uinteger): - dtype = np.uint32 - - -class TestDtype_uint64(Dtype_uinteger): - dtype = np.uint64 - - -class Dtype_float(Dtype): - pass - - -class TestDtype_float16(Dtype_float): - dtype = np.float16 - - -class TestDtype_float32(Dtype_float): - dtype = np.float32 - - -class TestDtype_float64(Dtype_float): - dtype = np.float64 - - -class TestDtype_category(Dtype): - dtype = "category" - include_df = False - - def _create_dtype_data(self, dtype): - sr1 = Series(range(5), dtype=dtype) - sr2 = Series(range(10, 0, -2), dtype=dtype) - - data = {"sr1": sr1, "sr2": sr2} - - return data - - -class DatetimeLike(Dtype): - def check_dtypes(self, f, f_name, d, d_name, exp): - - roll = d.rolling(window=self.window) - if f_name == "count": - result = f(roll) - tm.assert_almost_equal(result, exp) - - else: - msg = "No numeric types to aggregate" - with pytest.raises(DataError, match=msg): - f(roll) - - -class TestDtype_timedelta(DatetimeLike): - dtype = np.dtype("m8[ns]") - - -class TestDtype_datetime(DatetimeLike): - dtype = np.dtype("M8[ns]") - - -class TestDtype_datetime64UTC(DatetimeLike): - dtype = "datetime64[ns, UTC]" - - def _create_data(self): - pytest.skip( - "direct creation of extension dtype " - "datetime64[ns, UTC] is not supported ATM" - ) + ), + ("var", {0: Series([np.nan, 2, 2, 2, 2]), 1: Series([np.nan, 2, 2, 2, 2])}), + ("median", {0: Series([np.nan, 1, 3, 5, 7]), 1: Series([np.nan, 2, 4, 6, 8])}), + ], +) +def test_dataframe_dtypes(method, expected_data, dtypes): + if dtypes == "category": + pytest.skip("Category dataframe testing not implemented.") + df = DataFrame(np.arange(10).reshape((5, 2)), dtype=get_dtype(dtypes)) + if dtypes in ("m8[ns]", "M8[ns]") and method != "count": + msg = "No numeric types to aggregate" + with pytest.raises(DataError, match=msg): + getattr(df.rolling(2), method)() + else: + result = getattr(df.rolling(2), method)() + expected = DataFrame(expected_data, dtype="float64") + tm.assert_frame_equal(result, expected) From a3611b0e94b13c084271492810a226654168179d Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Thu, 27 Aug 2020 00:21:20 -0700 Subject: [PATCH 2/2] Lint --- pandas/tests/window/conftest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pandas/tests/window/conftest.py b/pandas/tests/window/conftest.py index 1d8cd636e5ede..7f03fa2a5ea0d 100644 --- a/pandas/tests/window/conftest.py +++ b/pandas/tests/window/conftest.py @@ -330,7 +330,8 @@ def halflife_with_times(request): pytest.param( "datetime64[ns, UTC]", marks=pytest.mark.skip( - "direct creation of extension dtype datetime64[ns, UTC] is not supported ATM" + "direct creation of extension dtype datetime64[ns, UTC] " + "is not supported ATM" ), ), ]