Skip to content

Commit 9db9baa

Browse files
authored
REF/TST: remove overriding tm.assert_foo pattern (pandas-dev#54355)
* REF: remove overriding assert_extension_array_equal * REF/TST: remove overriding tm.assert_foo pattern * remove the methods * update usage
1 parent c93e803 commit 9db9baa

25 files changed

+191
-217
lines changed

.pre-commit-config.yaml

-9
Original file line numberDiff line numberDiff line change
@@ -256,15 +256,6 @@ repos:
256256
|default_rng\(\)
257257
files: ^pandas/tests/
258258
types_or: [python, cython, rst]
259-
- id: unwanted-patterns-in-ea-tests
260-
name: Unwanted patterns in EA tests
261-
language: pygrep
262-
entry: |
263-
(?x)
264-
tm.assert_(series|frame)_equal
265-
files: ^pandas/tests/extension/base/
266-
exclude: ^pandas/tests/extension/base/base\.py$
267-
types_or: [python, cython, rst]
268259
- id: unwanted-patterns-in-cython
269260
name: Unwanted patterns in Cython code
270261
language: pygrep

pandas/tests/arrays/integer/test_comparison.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
import pandas as pd
4+
import pandas._testing as tm
45
from pandas.tests.arrays.masked_shared import (
56
ComparisonOps,
67
NumericOps,
@@ -25,7 +26,7 @@ def test_compare_to_int(self, dtype, comparison_op):
2526
expected = method(2).astype("boolean")
2627
expected[s2.isna()] = pd.NA
2728

28-
self.assert_series_equal(result, expected)
29+
tm.assert_series_equal(result, expected)
2930

3031

3132
def test_equals():

pandas/tests/arrays/masked_shared.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_compare_to_string(self, dtype):
110110
result = ser == "a"
111111
expected = pd.Series([False, pd.NA], dtype="boolean")
112112

113-
self.assert_series_equal(result, expected)
113+
tm.assert_series_equal(result, expected)
114114

115115
def test_ufunc_with_out(self, dtype):
116116
arr = pd.array([1, 2, 3], dtype=dtype)

pandas/tests/extension/base/__init__.py

-7
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,6 @@ class TestMyDtype(BaseDtypeTests):
3333
``BaseDtypeTests``. pytest's fixture discover will supply your ``dtype``
3434
wherever the test requires it. You're free to implement additional tests.
3535
36-
All the tests in these modules use ``self.assert_frame_equal`` or
37-
``self.assert_series_equal`` for dataframe or series comparisons. By default,
38-
they use the usual ``pandas.testing.assert_frame_equal`` and
39-
``pandas.testing.assert_series_equal``. You can override the checks used
40-
by defining the staticmethods ``assert_frame_equal`` and
41-
``assert_series_equal`` on your base test class.
42-
4336
"""
4437
from pandas.tests.extension.base.accumulate import BaseAccumulateTests # noqa: F401
4538
from pandas.tests.extension.base.casting import BaseCastingTests # noqa: F401

pandas/tests/extension/base/accumulate.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
import pandas as pd
4+
import pandas._testing as tm
45
from pandas.tests.extension.base.base import BaseExtensionTests
56

67

@@ -20,7 +21,7 @@ def check_accumulate(self, s, op_name, skipna):
2021
)
2122

2223
expected = getattr(s.astype("float64"), op_name)(skipna=skipna)
23-
self.assert_series_equal(result, expected, check_dtype=False)
24+
tm.assert_series_equal(result, expected, check_dtype=False)
2425

2526
@pytest.mark.parametrize("skipna", [True, False])
2627
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):

pandas/tests/extension/base/base.py

+1-16
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,2 @@
1-
import pandas._testing as tm
2-
3-
41
class BaseExtensionTests:
5-
# classmethod and different signature is needed
6-
# to make inheritance compliant with mypy
7-
@classmethod
8-
def assert_equal(cls, left, right, **kwargs):
9-
return tm.assert_equal(left, right, **kwargs)
10-
11-
@classmethod
12-
def assert_series_equal(cls, left, right, *args, **kwargs):
13-
return tm.assert_series_equal(left, right, *args, **kwargs)
14-
15-
@classmethod
16-
def assert_frame_equal(cls, left, right, *args, **kwargs):
17-
return tm.assert_frame_equal(left, right, *args, **kwargs)
2+
pass

pandas/tests/extension/base/casting.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_tolist(self, data):
4646
def test_astype_str(self, data):
4747
result = pd.Series(data[:5]).astype(str)
4848
expected = pd.Series([str(x) for x in data[:5]], dtype=str)
49-
self.assert_series_equal(result, expected)
49+
tm.assert_series_equal(result, expected)
5050

5151
@pytest.mark.parametrize(
5252
"nullable_string_dtype",
@@ -62,22 +62,22 @@ def test_astype_string(self, data, nullable_string_dtype):
6262
[str(x) if not isinstance(x, bytes) else x.decode() for x in data[:5]],
6363
dtype=nullable_string_dtype,
6464
)
65-
self.assert_series_equal(result, expected)
65+
tm.assert_series_equal(result, expected)
6666

6767
def test_to_numpy(self, data):
6868
expected = np.asarray(data)
6969

7070
result = data.to_numpy()
71-
self.assert_equal(result, expected)
71+
tm.assert_equal(result, expected)
7272

7373
result = pd.Series(data).to_numpy()
74-
self.assert_equal(result, expected)
74+
tm.assert_equal(result, expected)
7575

7676
def test_astype_empty_dataframe(self, dtype):
7777
# https://github.com/pandas-dev/pandas/issues/33113
7878
df = pd.DataFrame()
7979
result = df.astype(dtype)
80-
self.assert_frame_equal(result, df)
80+
tm.assert_frame_equal(result, df)
8181

8282
@pytest.mark.parametrize("copy", [True, False])
8383
def test_astype_own_type(self, data, copy):

pandas/tests/extension/base/constructors.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -39,27 +39,27 @@ def test_series_constructor(self, data):
3939
def test_series_constructor_no_data_with_index(self, dtype, na_value):
4040
result = pd.Series(index=[1, 2, 3], dtype=dtype)
4141
expected = pd.Series([na_value] * 3, index=[1, 2, 3], dtype=dtype)
42-
self.assert_series_equal(result, expected)
42+
tm.assert_series_equal(result, expected)
4343

4444
# GH 33559 - empty index
4545
result = pd.Series(index=[], dtype=dtype)
4646
expected = pd.Series([], index=pd.Index([], dtype="object"), dtype=dtype)
47-
self.assert_series_equal(result, expected)
47+
tm.assert_series_equal(result, expected)
4848

4949
def test_series_constructor_scalar_na_with_index(self, dtype, na_value):
5050
result = pd.Series(na_value, index=[1, 2, 3], dtype=dtype)
5151
expected = pd.Series([na_value] * 3, index=[1, 2, 3], dtype=dtype)
52-
self.assert_series_equal(result, expected)
52+
tm.assert_series_equal(result, expected)
5353

5454
def test_series_constructor_scalar_with_index(self, data, dtype):
5555
scalar = data[0]
5656
result = pd.Series(scalar, index=[1, 2, 3], dtype=dtype)
5757
expected = pd.Series([scalar] * 3, index=[1, 2, 3], dtype=dtype)
58-
self.assert_series_equal(result, expected)
58+
tm.assert_series_equal(result, expected)
5959

6060
result = pd.Series(scalar, index=["foo"], dtype=dtype)
6161
expected = pd.Series([scalar], index=["foo"], dtype=dtype)
62-
self.assert_series_equal(result, expected)
62+
tm.assert_series_equal(result, expected)
6363

6464
@pytest.mark.parametrize("from_series", [True, False])
6565
def test_dataframe_constructor_from_dict(self, data, from_series):
@@ -91,19 +91,19 @@ def test_from_dtype(self, data):
9191

9292
expected = pd.Series(data)
9393
result = pd.Series(list(data), dtype=dtype)
94-
self.assert_series_equal(result, expected)
94+
tm.assert_series_equal(result, expected)
9595

9696
result = pd.Series(list(data), dtype=str(dtype))
97-
self.assert_series_equal(result, expected)
97+
tm.assert_series_equal(result, expected)
9898

9999
# gh-30280
100100

101101
expected = pd.DataFrame(data).astype(dtype)
102102
result = pd.DataFrame(list(data), dtype=dtype)
103-
self.assert_frame_equal(result, expected)
103+
tm.assert_frame_equal(result, expected)
104104

105105
result = pd.DataFrame(list(data), dtype=str(dtype))
106-
self.assert_frame_equal(result, expected)
106+
tm.assert_frame_equal(result, expected)
107107

108108
def test_pandas_array(self, data):
109109
# pd.array(extension_array) should be idempotent...
@@ -114,15 +114,15 @@ def test_pandas_array_dtype(self, data):
114114
# ... but specifying dtype will override idempotency
115115
result = pd.array(data, dtype=np.dtype(object))
116116
expected = pd.arrays.NumpyExtensionArray(np.asarray(data, dtype=object))
117-
self.assert_equal(result, expected)
117+
tm.assert_equal(result, expected)
118118

119119
def test_construct_empty_dataframe(self, dtype):
120120
# GH 33623
121121
result = pd.DataFrame(columns=["a"], dtype=dtype)
122122
expected = pd.DataFrame(
123123
{"a": pd.array([], dtype=dtype)}, index=pd.RangeIndex(0)
124124
)
125-
self.assert_frame_equal(result, expected)
125+
tm.assert_frame_equal(result, expected)
126126

127127
def test_empty(self, dtype):
128128
cls = dtype.construct_array_type()

pandas/tests/extension/base/dim2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_frame_from_2d_array(self, data):
3333

3434
df = pd.DataFrame(arr2d)
3535
expected = pd.DataFrame({0: arr2d[:, 0], 1: arr2d[:, 1]})
36-
self.assert_frame_equal(df, expected)
36+
tm.assert_frame_equal(df, expected)
3737

3838
def test_swapaxes(self, data):
3939
arr2d = data.repeat(2).reshape(-1, 2)

pandas/tests/extension/base/dtype.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
import pandas as pd
5+
import pandas._testing as tm
56
from pandas.api.types import (
67
infer_dtype,
78
is_object_dtype,
@@ -66,11 +67,11 @@ def test_check_dtype(self, data):
6667

6768
expected = pd.Series([True, True, False, False], index=list("ABCD"))
6869

69-
self.assert_series_equal(result, expected)
70+
tm.assert_series_equal(result, expected)
7071

7172
expected = pd.Series([True, True, False, False], index=list("ABCD"))
7273
result = df.dtypes.apply(str) == str(dtype)
73-
self.assert_series_equal(result, expected)
74+
tm.assert_series_equal(result, expected)
7475

7576
def test_hashable(self, dtype):
7677
hash(dtype) # no error

0 commit comments

Comments
 (0)