From f19d5536d8abd8bdb0d39bbd7f9332eccba76059 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Mon, 27 Dec 2021 20:40:48 -0800 Subject: [PATCH 1/3] More pytest idioms in tests/generic --- pandas/tests/generic/test_frame.py | 13 +- pandas/tests/generic/test_generic.py | 239 +++++++++++++-------------- pandas/tests/generic/test_series.py | 17 +- 3 files changed, 128 insertions(+), 141 deletions(-) diff --git a/pandas/tests/generic/test_frame.py b/pandas/tests/generic/test_frame.py index 3588cd56d1060..ec1eed3f13ae4 100644 --- a/pandas/tests/generic/test_frame.py +++ b/pandas/tests/generic/test_frame.py @@ -12,13 +12,10 @@ date_range, ) import pandas._testing as tm -from pandas.tests.generic.test_generic import Generic +from pandas.tests.generic.test_generic import check_metadata -class TestDataFrame(Generic): - _typ = DataFrame - _comparator = lambda self, x, y: tm.assert_frame_equal(x, y) - +class TestDataFrame: @pytest.mark.parametrize("func", ["_set_axis_name", "rename_axis"]) def test_set_axis_name(self, func): df = DataFrame([[1, 2], [3, 4]]) @@ -76,7 +73,7 @@ def test_metadata_propagation_indiv_groupby(self): } ) result = df.groupby("A").sum() - self.check_metadata(df, result) + check_metadata(df, result) def test_metadata_propagation_indiv_resample(self): # resample @@ -85,7 +82,7 @@ def test_metadata_propagation_indiv_resample(self): index=date_range("20130101", periods=1000, freq="s"), ) result = df.resample("1T") - self.check_metadata(df, result) + check_metadata(df, result) def test_metadata_propagation_indiv(self, monkeypatch): # merging with override @@ -148,7 +145,7 @@ def test_deepcopy_empty(self): empty_frame = DataFrame(data=[], index=[], columns=["A"]) empty_frame_copy = deepcopy(empty_frame) - self._compare(empty_frame_copy, empty_frame) + tm.assert_frame_equal(empty_frame_copy, empty_frame) # formerly in Generic but only test DataFrame diff --git a/pandas/tests/generic/test_generic.py b/pandas/tests/generic/test_generic.py index 0eaf5fc6d6e1a..4b00b0894cd48 100644 --- a/pandas/tests/generic/test_generic.py +++ b/pandas/tests/generic/test_generic.py @@ -18,108 +18,109 @@ # Generic types test cases -class Generic: - @property - def _ndim(self): - return self._typ._AXIS_LEN - - def _axes(self): - """return the axes for my object typ""" - return self._typ._AXIS_ORDERS - - def _construct(self, shape, value=None, dtype=None, **kwargs): - """ - construct an object for the given shape - if value is specified use that if its a scalar - if value is an array, repeat it as needed - """ - if isinstance(shape, int): - shape = tuple([shape] * self._ndim) - if value is not None: - if is_scalar(value): - if value == "empty": - arr = None - dtype = np.float64 - - # remove the info axis - kwargs.pop(self._typ._info_axis_name, None) - else: - arr = np.empty(shape, dtype=dtype) - arr.fill(value) - else: - fshape = np.prod(shape) - arr = value.ravel() - new_shape = fshape / arr.shape[0] - if fshape % arr.shape[0] != 0: - raise Exception("invalid value passed in _construct") +@pytest.fixture(params=[DataFrame, Series]) +def constructor(request): + return request.param + - arr = np.repeat(arr, new_shape).reshape(shape) +def check_metadata(x, y=None): + for m in x._metadata: + v = getattr(x, m, None) + if y is None: + assert v is None + else: + assert v == getattr(y, m, None) + + +def construct(box, shape, value=None, dtype=None, **kwargs): + """ + construct an object for the given shape + if value is specified use that if its a scalar + if value is an array, repeat it as needed + """ + if isinstance(shape, int): + shape = tuple([shape] * box._AXIS_LEN) + if value is not None: + if is_scalar(value): + if value == "empty": + arr = None + dtype = np.float64 + + # remove the info axis + kwargs.pop(box._info_axis_name, None) + else: + arr = np.empty(shape, dtype=dtype) + arr.fill(value) else: - arr = np.random.randn(*shape) - return self._typ(arr, dtype=dtype, **kwargs) + fshape = np.prod(shape) + arr = value.ravel() + new_shape = fshape / arr.shape[0] + if fshape % arr.shape[0] != 0: + raise Exception("invalid value passed in construct") - def _compare(self, result, expected): - self._comparator(result, expected) + arr = np.repeat(arr, new_shape).reshape(shape) + else: + arr = np.random.randn(*shape) + return box(arr, dtype=dtype, **kwargs) - def test_rename(self): + +class Generic: + @pytest.mark.parametrize( + "func", + [ + str.lower, + {x: x.lower() for x in list("ABCD")}, + Series({x: x.lower() for x in list("ABCD")}), + ], + ) + def test_rename(self, constructor, func): # single axis idx = list("ABCD") - # relabeling values passed into self.rename - args = [ - str.lower, - {x: x.lower() for x in idx}, - Series({x: x.lower() for x in idx}), - ] - for axis in self._axes(): + for axis in constructor._AXIS_ORDERS: kwargs = {axis: idx} - obj = self._construct(4, **kwargs) - - for arg in args: - # rename a single axis - result = obj.rename(**{axis: arg}) - expected = obj.copy() - setattr(expected, axis, list("abcd")) - self._compare(result, expected) + obj = construct(4, **kwargs) - # multiple axes at once + # rename a single axis + result = obj.rename(**{axis: func}) + expected = obj.copy() + setattr(expected, axis, list("abcd")) + tm.assert_equal(result, expected) - def test_get_numeric_data(self): + def test_get_numeric_data(self, constructor): n = 4 kwargs = { - self._typ._get_axis_name(i): list(range(n)) for i in range(self._ndim) + constructor._get_axis_name(i): list(range(n)) + for i in range(constructor._AXIS_LEN) } # get the numeric data - o = self._construct(n, **kwargs) + o = construct(n, **kwargs) result = o._get_numeric_data() - self._compare(result, o) + tm.assert_equal(result, o) # non-inclusion result = o._get_bool_data() - expected = self._construct(n, value="empty", **kwargs) + expected = construct(n, value="empty", **kwargs) if isinstance(o, DataFrame): # preserve columns dtype expected.columns = o.columns[:0] - self._compare(result, expected) + tm.assert_equal(result, expected) # get the bool data arr = np.array([True, True, False, True]) - o = self._construct(n, value=arr, **kwargs) + o = construct(n, value=arr, **kwargs) result = o._get_numeric_data() - self._compare(result, o) + tm.assert_equal(result, o) - # _get_numeric_data is includes _get_bool_data, so can't test for - # non-inclusion - - def test_nonzero(self): + def test_nonzero(self, constructor): # GH 4633 # look at the boolean/nonzero behavior for objects - obj = self._construct(shape=4) - msg = f"The truth value of a {self._typ.__name__} is ambiguous" + obj = construct(constructor, shape=4) + msg = f"The truth value of a {constructor.__name__} is ambiguous" with pytest.raises(ValueError, match=msg): bool(obj == 0) with pytest.raises(ValueError, match=msg): @@ -127,7 +128,7 @@ def test_nonzero(self): with pytest.raises(ValueError, match=msg): bool(obj) - obj = self._construct(shape=4, value=1) + obj = construct(constructor, shape=4, value=1) with pytest.raises(ValueError, match=msg): bool(obj == 0) with pytest.raises(ValueError, match=msg): @@ -135,7 +136,7 @@ def test_nonzero(self): with pytest.raises(ValueError, match=msg): bool(obj) - obj = self._construct(shape=4, value=np.nan) + obj = construct(constructor, shape=4, value=np.nan) with pytest.raises(ValueError, match=msg): bool(obj == 0) with pytest.raises(ValueError, match=msg): @@ -144,14 +145,14 @@ def test_nonzero(self): bool(obj) # empty - obj = self._construct(shape=0) + obj = construct(constructor, shape=0) with pytest.raises(ValueError, match=msg): bool(obj) # invalid behaviors - obj1 = self._construct(shape=4, value=1) - obj2 = self._construct(shape=4, value=1) + obj1 = construct(constructor, shape=4, value=1) + obj2 = construct(constructor, shape=4, value=1) with pytest.raises(ValueError, match=msg): if obj1: @@ -164,16 +165,16 @@ def test_nonzero(self): with pytest.raises(ValueError, match=msg): not obj1 - def test_constructor_compound_dtypes(self): + def test_constructor_compound_dtypes(self, constructor): # see gh-5191 # Compound dtypes should raise NotImplementedError. def f(dtype): - return self._construct(shape=3, value=1, dtype=dtype) + return construct(constructor, shape=3, value=1, dtype=dtype) msg = ( "compound dtypes are not implemented " - f"in the {self._typ.__name__} constructor" + f"in the {constructor.__name__} constructor" ) with pytest.raises(NotImplementedError, match=msg): @@ -184,20 +185,12 @@ def f(dtype): f("float64") f("M8[ns]") - def check_metadata(self, x, y=None): - for m in x._metadata: - v = getattr(x, m, None) - if y is None: - assert v is None - else: - assert v == getattr(y, m, None) - - def test_metadata_propagation(self): + def test_metadata_propagation(self, constructor): # check that the metadata matches up on the resulting ops - o = self._construct(shape=3) + o = construct(constructor, shape=3) o.name = "foo" - o2 = self._construct(shape=3) + o2 = construct(constructor, shape=3) o2.name = "bar" # ---------- @@ -207,23 +200,23 @@ def test_metadata_propagation(self): # simple ops with scalars for op in ["__add__", "__sub__", "__truediv__", "__mul__"]: result = getattr(o, op)(1) - self.check_metadata(o, result) + check_metadata(o, result) # ops with like for op in ["__add__", "__sub__", "__truediv__", "__mul__"]: result = getattr(o, op)(o) - self.check_metadata(o, result) + check_metadata(o, result) # simple boolean for op in ["__eq__", "__le__", "__ge__"]: v1 = getattr(o, op)(o) - self.check_metadata(o, v1) - self.check_metadata(o, v1 & v1) - self.check_metadata(o, v1 | v1) + check_metadata(o, v1) + check_metadata(o, v1 & v1) + check_metadata(o, v1 | v1) # combine_first result = o.combine_first(o2) - self.check_metadata(o, result) + check_metadata(o, result) # --------------------------- # non-preserving (by default) @@ -231,7 +224,7 @@ def test_metadata_propagation(self): # add non-like result = o + o2 - self.check_metadata(result) + check_metadata(result) # simple boolean for op in ["__eq__", "__le__", "__ge__"]: @@ -239,27 +232,27 @@ def test_metadata_propagation(self): # this is a name matching op v1 = getattr(o, op)(o) v2 = getattr(o, op)(o2) - self.check_metadata(v2) - self.check_metadata(v1 & v2) - self.check_metadata(v1 | v2) + check_metadata(v2) + check_metadata(v1 & v2) + check_metadata(v1 | v2) - def test_size_compat(self): + def test_size_compat(self, constructor): # GH8846 # size property should be defined - o = self._construct(shape=10) + o = construct(constructor, shape=10) assert o.size == np.prod(o.shape) assert o.size == 10 ** len(o.axes) - def test_split_compat(self): + def test_split_compat(self, constructor): # xref GH8846 - o = self._construct(shape=10) + o = construct(constructor, shape=10) assert len(np.array_split(o, 5)) == 5 assert len(np.array_split(o, 2)) == 2 # See gh-12301 - def test_stat_unexpected_keyword(self): - obj = self._construct(5) + def test_stat_unexpected_keyword(self, constructor): + obj = construct(constructor, 5) starwars = "Star Wars" errmsg = "unexpected keyword" @@ -273,18 +266,18 @@ def test_stat_unexpected_keyword(self): obj.any(epic=starwars) # logical_function @pytest.mark.parametrize("func", ["sum", "cumsum", "any", "var"]) - def test_api_compat(self, func): + def test_api_compat(self, func, constructor): # GH 12021 # compat for __name__, __qualname__ - obj = self._construct(5) + obj = (constructor, 5) f = getattr(obj, func) assert f.__name__ == func assert f.__qualname__.endswith(func) - def test_stat_non_defaults_args(self): - obj = self._construct(5) + def test_stat_non_defaults_args(self, constructor): + obj = construct(constructor, 5) out = np.array([0]) errmsg = "the 'out' parameter is not supported" @@ -297,34 +290,34 @@ def test_stat_non_defaults_args(self): with pytest.raises(ValueError, match=errmsg): obj.any(out=out) # logical_function - def test_truncate_out_of_bounds(self): + def test_truncate_out_of_bounds(self, constructor): # GH11382 # small - shape = [2000] + ([1] * (self._ndim - 1)) - small = self._construct(shape, dtype="int8", value=1) - self._compare(small.truncate(), small) - self._compare(small.truncate(before=0, after=3e3), small) - self._compare(small.truncate(before=-1, after=2e3), small) + shape = [2000] + ([1] * (constructor._AXIS_LEN - 1)) + small = construct(constructor, shape, dtype="int8", value=1) + tm.assert_equal(small.truncate(), small) + tm.assert_equal(small.truncate(before=0, after=3e3), small) + tm.assert_equal(small.truncate(before=-1, after=2e3), small) # big - shape = [2_000_000] + ([1] * (self._ndim - 1)) - big = self._construct(shape, dtype="int8", value=1) - self._compare(big.truncate(), big) - self._compare(big.truncate(before=0, after=3e6), big) - self._compare(big.truncate(before=-1, after=2e6), big) + shape = [2_000_000] + ([1] * (constructor._AXIS_LEN - 1)) + big = construct(constructor, shape, dtype="int8", value=1) + tm.assert_equal(big.truncate(), big) + tm.assert_equal(big.truncate(before=0, after=3e6), big) + tm.assert_equal(big.truncate(before=-1, after=2e6), big) @pytest.mark.parametrize( "func", [copy, deepcopy, lambda x: x.copy(deep=False), lambda x: x.copy(deep=True)], ) @pytest.mark.parametrize("shape", [0, 1, 2]) - def test_copy_and_deepcopy(self, shape, func): + def test_copy_and_deepcopy(self, constructor, shape, func): # GH 15444 - obj = self._construct(shape) + obj = construct(constructor, shape) obj_copy = func(obj) assert obj_copy is not obj - self._compare(obj_copy, obj) + tm.assert_equal(obj_copy, obj) class TestNDFrame: diff --git a/pandas/tests/generic/test_series.py b/pandas/tests/generic/test_series.py index 784ced96286a6..2ad3e95708d17 100644 --- a/pandas/tests/generic/test_series.py +++ b/pandas/tests/generic/test_series.py @@ -10,13 +10,10 @@ date_range, ) import pandas._testing as tm -from pandas.tests.generic.test_generic import Generic +from pandas.tests.generic.test_generic import check_metadata -class TestSeries(Generic): - _typ = Series - _comparator = lambda self, x, y: tm.assert_series_equal(x, y) - +class TestSeries: @pytest.mark.parametrize("func", ["rename_axis", "_set_axis_name"]) def test_set_axis_name_mi(self, func): ser = Series( @@ -41,7 +38,7 @@ def test_set_axis_name_raises(self): def test_get_bool_data_preserve_dtype(self): ser = Series([True, False, True]) result = ser._get_bool_data() - self._compare(result, ser) + tm.assert_series_equal(result, ser) def test_nonzero_single_element(self): @@ -101,13 +98,13 @@ def test_metadata_propagation_indiv_resample(self): name="foo", ) result = ts.resample("1T").mean() - self.check_metadata(ts, result) + check_metadata(ts, result) result = ts.resample("1T").min() - self.check_metadata(ts, result) + check_metadata(ts, result) result = ts.resample("1T").apply(lambda x: x.sum()) - self.check_metadata(ts, result) + check_metadata(ts, result) def test_metadata_propagation_indiv(self, monkeypatch): # check that the metadata matches up on the resulting ops @@ -118,7 +115,7 @@ def test_metadata_propagation_indiv(self, monkeypatch): ser2.name = "bar" result = ser.T - self.check_metadata(ser, result) + check_metadata(ser, result) def finalize(self, other, method=None, **kwargs): for name in self._metadata: From cb0ce7e8f4ff7ce48f5ae6747008e740c3ea27dd Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Mon, 27 Dec 2021 20:42:52 -0800 Subject: [PATCH 2/3] Use frame_or_series --- pandas/tests/generic/test_generic.py | 77 +++++++++++++--------------- 1 file changed, 36 insertions(+), 41 deletions(-) diff --git a/pandas/tests/generic/test_generic.py b/pandas/tests/generic/test_generic.py index 4b00b0894cd48..803641e63e718 100644 --- a/pandas/tests/generic/test_generic.py +++ b/pandas/tests/generic/test_generic.py @@ -18,11 +18,6 @@ # Generic types test cases -@pytest.fixture(params=[DataFrame, Series]) -def constructor(request): - return request.param - - def check_metadata(x, y=None): for m in x._metadata: v = getattr(x, m, None) @@ -73,12 +68,12 @@ class Generic: Series({x: x.lower() for x in list("ABCD")}), ], ) - def test_rename(self, constructor, func): + def test_rename(self, frame_or_series, func): # single axis idx = list("ABCD") - for axis in constructor._AXIS_ORDERS: + for axis in frame_or_series._AXIS_ORDERS: kwargs = {axis: idx} obj = construct(4, **kwargs) @@ -88,12 +83,12 @@ def test_rename(self, constructor, func): setattr(expected, axis, list("abcd")) tm.assert_equal(result, expected) - def test_get_numeric_data(self, constructor): + def test_get_numeric_data(self, frame_or_series): n = 4 kwargs = { - constructor._get_axis_name(i): list(range(n)) - for i in range(constructor._AXIS_LEN) + frame_or_series._get_axis_name(i): list(range(n)) + for i in range(frame_or_series._AXIS_LEN) } # get the numeric data @@ -115,12 +110,12 @@ def test_get_numeric_data(self, constructor): result = o._get_numeric_data() tm.assert_equal(result, o) - def test_nonzero(self, constructor): + def test_nonzero(self, frame_or_series): # GH 4633 # look at the boolean/nonzero behavior for objects - obj = construct(constructor, shape=4) - msg = f"The truth value of a {constructor.__name__} is ambiguous" + obj = construct(frame_or_series, shape=4) + msg = f"The truth value of a {frame_or_series.__name__} is ambiguous" with pytest.raises(ValueError, match=msg): bool(obj == 0) with pytest.raises(ValueError, match=msg): @@ -128,7 +123,7 @@ def test_nonzero(self, constructor): with pytest.raises(ValueError, match=msg): bool(obj) - obj = construct(constructor, shape=4, value=1) + obj = construct(frame_or_series, shape=4, value=1) with pytest.raises(ValueError, match=msg): bool(obj == 0) with pytest.raises(ValueError, match=msg): @@ -136,7 +131,7 @@ def test_nonzero(self, constructor): with pytest.raises(ValueError, match=msg): bool(obj) - obj = construct(constructor, shape=4, value=np.nan) + obj = construct(frame_or_series, shape=4, value=np.nan) with pytest.raises(ValueError, match=msg): bool(obj == 0) with pytest.raises(ValueError, match=msg): @@ -145,14 +140,14 @@ def test_nonzero(self, constructor): bool(obj) # empty - obj = construct(constructor, shape=0) + obj = construct(frame_or_series, shape=0) with pytest.raises(ValueError, match=msg): bool(obj) # invalid behaviors - obj1 = construct(constructor, shape=4, value=1) - obj2 = construct(constructor, shape=4, value=1) + obj1 = construct(frame_or_series, shape=4, value=1) + obj2 = construct(frame_or_series, shape=4, value=1) with pytest.raises(ValueError, match=msg): if obj1: @@ -165,16 +160,16 @@ def test_nonzero(self, constructor): with pytest.raises(ValueError, match=msg): not obj1 - def test_constructor_compound_dtypes(self, constructor): + def test_frame_or_series_compound_dtypes(self, frame_or_series): # see gh-5191 # Compound dtypes should raise NotImplementedError. def f(dtype): - return construct(constructor, shape=3, value=1, dtype=dtype) + return construct(frame_or_series, shape=3, value=1, dtype=dtype) msg = ( "compound dtypes are not implemented " - f"in the {constructor.__name__} constructor" + f"in the {frame_or_series.__name__} frame_or_series" ) with pytest.raises(NotImplementedError, match=msg): @@ -185,12 +180,12 @@ def f(dtype): f("float64") f("M8[ns]") - def test_metadata_propagation(self, constructor): + def test_metadata_propagation(self, frame_or_series): # check that the metadata matches up on the resulting ops - o = construct(constructor, shape=3) + o = construct(frame_or_series, shape=3) o.name = "foo" - o2 = construct(constructor, shape=3) + o2 = construct(frame_or_series, shape=3) o2.name = "bar" # ---------- @@ -236,23 +231,23 @@ def test_metadata_propagation(self, constructor): check_metadata(v1 & v2) check_metadata(v1 | v2) - def test_size_compat(self, constructor): + def test_size_compat(self, frame_or_series): # GH8846 # size property should be defined - o = construct(constructor, shape=10) + o = construct(frame_or_series, shape=10) assert o.size == np.prod(o.shape) assert o.size == 10 ** len(o.axes) - def test_split_compat(self, constructor): + def test_split_compat(self, frame_or_series): # xref GH8846 - o = construct(constructor, shape=10) + o = construct(frame_or_series, shape=10) assert len(np.array_split(o, 5)) == 5 assert len(np.array_split(o, 2)) == 2 # See gh-12301 - def test_stat_unexpected_keyword(self, constructor): - obj = construct(constructor, 5) + def test_stat_unexpected_keyword(self, frame_or_series): + obj = construct(frame_or_series, 5) starwars = "Star Wars" errmsg = "unexpected keyword" @@ -266,18 +261,18 @@ def test_stat_unexpected_keyword(self, constructor): obj.any(epic=starwars) # logical_function @pytest.mark.parametrize("func", ["sum", "cumsum", "any", "var"]) - def test_api_compat(self, func, constructor): + def test_api_compat(self, func, frame_or_series): # GH 12021 # compat for __name__, __qualname__ - obj = (constructor, 5) + obj = (frame_or_series, 5) f = getattr(obj, func) assert f.__name__ == func assert f.__qualname__.endswith(func) - def test_stat_non_defaults_args(self, constructor): - obj = construct(constructor, 5) + def test_stat_non_defaults_args(self, frame_or_series): + obj = construct(frame_or_series, 5) out = np.array([0]) errmsg = "the 'out' parameter is not supported" @@ -290,19 +285,19 @@ def test_stat_non_defaults_args(self, constructor): with pytest.raises(ValueError, match=errmsg): obj.any(out=out) # logical_function - def test_truncate_out_of_bounds(self, constructor): + def test_truncate_out_of_bounds(self, frame_or_series): # GH11382 # small - shape = [2000] + ([1] * (constructor._AXIS_LEN - 1)) - small = construct(constructor, shape, dtype="int8", value=1) + shape = [2000] + ([1] * (frame_or_series._AXIS_LEN - 1)) + small = construct(frame_or_series, shape, dtype="int8", value=1) tm.assert_equal(small.truncate(), small) tm.assert_equal(small.truncate(before=0, after=3e3), small) tm.assert_equal(small.truncate(before=-1, after=2e3), small) # big - shape = [2_000_000] + ([1] * (constructor._AXIS_LEN - 1)) - big = construct(constructor, shape, dtype="int8", value=1) + shape = [2_000_000] + ([1] * (frame_or_series._AXIS_LEN - 1)) + big = construct(frame_or_series, shape, dtype="int8", value=1) tm.assert_equal(big.truncate(), big) tm.assert_equal(big.truncate(before=0, after=3e6), big) tm.assert_equal(big.truncate(before=-1, after=2e6), big) @@ -312,9 +307,9 @@ def test_truncate_out_of_bounds(self, constructor): [copy, deepcopy, lambda x: x.copy(deep=False), lambda x: x.copy(deep=True)], ) @pytest.mark.parametrize("shape", [0, 1, 2]) - def test_copy_and_deepcopy(self, constructor, shape, func): + def test_copy_and_deepcopy(self, frame_or_series, shape, func): # GH 15444 - obj = construct(constructor, shape) + obj = construct(frame_or_series, shape) obj_copy = func(obj) assert obj_copy is not obj tm.assert_equal(obj_copy, obj) From 22e98e1d6ddfd6ae4dc7f81b5f52b9c9e6e5beef Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Tue, 28 Dec 2021 10:40:20 -0800 Subject: [PATCH 3/3] Make tm.assert_metadata_equivalent --- pandas/_testing/__init__.py | 1 + pandas/_testing/asserters.py | 12 ++++++++++++ pandas/tests/generic/test_frame.py | 5 ++--- pandas/tests/generic/test_generic.py | 29 ++++++++++------------------ pandas/tests/generic/test_series.py | 9 ++++----- 5 files changed, 29 insertions(+), 27 deletions(-) diff --git a/pandas/_testing/__init__.py b/pandas/_testing/__init__.py index 2e9f92ebc7cb7..b3fcff21f0f1f 100644 --- a/pandas/_testing/__init__.py +++ b/pandas/_testing/__init__.py @@ -82,6 +82,7 @@ assert_interval_array_equal, assert_is_sorted, assert_is_valid_plot_return_object, + assert_metadata_equivalent, assert_numpy_array_equal, assert_period_array_equal, assert_series_equal, diff --git a/pandas/_testing/asserters.py b/pandas/_testing/asserters.py index d59ad72d74d73..77c477d3f9229 100644 --- a/pandas/_testing/asserters.py +++ b/pandas/_testing/asserters.py @@ -1470,3 +1470,15 @@ def assert_indexing_slices_equivalent(ser: Series, l_slc: slice, i_slc: slice): if not ser.index.is_integer(): # For integer indices, .loc and plain getitem are position-based. assert_series_equal(ser[l_slc], expected) + + +def assert_metadata_equivalent(left, right): + """ + Check that ._metadata attributes are equivalent. + """ + for attr in left._metadata: + val = getattr(left, attr, None) + if right is None: + assert val is None + else: + assert val == getattr(right, attr, None) diff --git a/pandas/tests/generic/test_frame.py b/pandas/tests/generic/test_frame.py index ec1eed3f13ae4..2b248afb42057 100644 --- a/pandas/tests/generic/test_frame.py +++ b/pandas/tests/generic/test_frame.py @@ -12,7 +12,6 @@ date_range, ) import pandas._testing as tm -from pandas.tests.generic.test_generic import check_metadata class TestDataFrame: @@ -73,7 +72,7 @@ def test_metadata_propagation_indiv_groupby(self): } ) result = df.groupby("A").sum() - check_metadata(df, result) + tm.assert_metadata_equivalent(df, result) def test_metadata_propagation_indiv_resample(self): # resample @@ -82,7 +81,7 @@ def test_metadata_propagation_indiv_resample(self): index=date_range("20130101", periods=1000, freq="s"), ) result = df.resample("1T") - check_metadata(df, result) + tm.assert_metadata_equivalent(df, result) def test_metadata_propagation_indiv(self, monkeypatch): # merging with override diff --git a/pandas/tests/generic/test_generic.py b/pandas/tests/generic/test_generic.py index 803641e63e718..5c1eb28ed6099 100644 --- a/pandas/tests/generic/test_generic.py +++ b/pandas/tests/generic/test_generic.py @@ -18,15 +18,6 @@ # Generic types test cases -def check_metadata(x, y=None): - for m in x._metadata: - v = getattr(x, m, None) - if y is None: - assert v is None - else: - assert v == getattr(y, m, None) - - def construct(box, shape, value=None, dtype=None, **kwargs): """ construct an object for the given shape @@ -195,23 +186,23 @@ def test_metadata_propagation(self, frame_or_series): # simple ops with scalars for op in ["__add__", "__sub__", "__truediv__", "__mul__"]: result = getattr(o, op)(1) - check_metadata(o, result) + tm.assert_metadata_equivalent(o, result) # ops with like for op in ["__add__", "__sub__", "__truediv__", "__mul__"]: result = getattr(o, op)(o) - check_metadata(o, result) + tm.assert_metadata_equivalent(o, result) # simple boolean for op in ["__eq__", "__le__", "__ge__"]: v1 = getattr(o, op)(o) - check_metadata(o, v1) - check_metadata(o, v1 & v1) - check_metadata(o, v1 | v1) + tm.assert_metadata_equivalent(o, v1) + tm.assert_metadata_equivalent(o, v1 & v1) + tm.assert_metadata_equivalent(o, v1 | v1) # combine_first result = o.combine_first(o2) - check_metadata(o, result) + tm.assert_metadata_equivalent(o, result) # --------------------------- # non-preserving (by default) @@ -219,7 +210,7 @@ def test_metadata_propagation(self, frame_or_series): # add non-like result = o + o2 - check_metadata(result) + tm.assert_metadata_equivalent(result) # simple boolean for op in ["__eq__", "__le__", "__ge__"]: @@ -227,9 +218,9 @@ def test_metadata_propagation(self, frame_or_series): # this is a name matching op v1 = getattr(o, op)(o) v2 = getattr(o, op)(o2) - check_metadata(v2) - check_metadata(v1 & v2) - check_metadata(v1 | v2) + tm.assert_metadata_equivalent(v2) + tm.assert_metadata_equivalent(v1 & v2) + tm.assert_metadata_equivalent(v1 | v2) def test_size_compat(self, frame_or_series): # GH8846 diff --git a/pandas/tests/generic/test_series.py b/pandas/tests/generic/test_series.py index 2ad3e95708d17..dd2380e2647d3 100644 --- a/pandas/tests/generic/test_series.py +++ b/pandas/tests/generic/test_series.py @@ -10,7 +10,6 @@ date_range, ) import pandas._testing as tm -from pandas.tests.generic.test_generic import check_metadata class TestSeries: @@ -98,13 +97,13 @@ def test_metadata_propagation_indiv_resample(self): name="foo", ) result = ts.resample("1T").mean() - check_metadata(ts, result) + tm.assert_metadata_equivalent(ts, result) result = ts.resample("1T").min() - check_metadata(ts, result) + tm.assert_metadata_equivalent(ts, result) result = ts.resample("1T").apply(lambda x: x.sum()) - check_metadata(ts, result) + tm.assert_metadata_equivalent(ts, result) def test_metadata_propagation_indiv(self, monkeypatch): # check that the metadata matches up on the resulting ops @@ -115,7 +114,7 @@ def test_metadata_propagation_indiv(self, monkeypatch): ser2.name = "bar" result = ser.T - check_metadata(ser, result) + tm.assert_metadata_equivalent(ser, result) def finalize(self, other, method=None, **kwargs): for name in self._metadata: