diff --git a/pandas/tests/extension/base/base.py b/pandas/tests/extension/base/base.py index 144b0825b39a2..97d8e7c66dbdb 100644 --- a/pandas/tests/extension/base/base.py +++ b/pandas/tests/extension/base/base.py @@ -2,8 +2,20 @@ class BaseExtensionTests: + # classmethod and different signature is needed + # to make inheritance compliant with mypy + @classmethod + def assert_equal(cls, left, right, **kwargs): + return tm.assert_equal(left, right, **kwargs) - assert_equal = staticmethod(tm.assert_equal) - assert_series_equal = staticmethod(tm.assert_series_equal) - assert_frame_equal = staticmethod(tm.assert_frame_equal) - assert_extension_array_equal = staticmethod(tm.assert_extension_array_equal) + @classmethod + def assert_series_equal(cls, left, right, *args, **kwargs): + return tm.assert_series_equal(left, right, *args, **kwargs) + + @classmethod + def assert_frame_equal(cls, left, right, *args, **kwargs): + return tm.assert_frame_equal(left, right, *args, **kwargs) + + @classmethod + def assert_extension_array_equal(cls, left, right, *args, **kwargs): + return tm.assert_extension_array_equal(left, right, *args, **kwargs) diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index de7c98ab96571..bd9b77a2bc419 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -66,7 +66,8 @@ def data_for_grouping(): class BaseDecimal: - def assert_series_equal(self, left, right, *args, **kwargs): + @classmethod + def assert_series_equal(cls, left, right, *args, **kwargs): def convert(x): # need to convert array([Decimal(NaN)], dtype='object') to np.NaN # because Series[object].isnan doesn't recognize decimal(NaN) as @@ -88,7 +89,8 @@ def convert(x): tm.assert_series_equal(left_na, right_na) return tm.assert_series_equal(left[~left_na], right[~right_na], *args, **kwargs) - def assert_frame_equal(self, left, right, *args, **kwargs): + @classmethod + def assert_frame_equal(cls, left, right, *args, **kwargs): # TODO(EA): select_dtypes tm.assert_index_equal( left.columns, @@ -103,7 +105,7 @@ def assert_frame_equal(self, left, right, *args, **kwargs): decimals = (left.dtypes == "decimal").index for col in decimals: - self.assert_series_equal(left[col], right[col], *args, **kwargs) + cls.assert_series_equal(left[col], right[col], *args, **kwargs) left = left.drop(columns=decimals) right = right.drop(columns=decimals) diff --git a/setup.cfg b/setup.cfg index c298aa652824c..9be09ae1076bb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -135,9 +135,6 @@ ignore_errors=True [mypy-pandas.tests.arithmetic.test_datetime64] ignore_errors=True -[mypy-pandas.tests.extension.decimal.test_decimal] -ignore_errors=True - [mypy-pandas.tests.extension.json.test_json] ignore_errors=True