Skip to content

Commit faf595e

Browse files
TomAugspurgerharisbal
authored and
harisbal
committed
REF: Base class for all extension tests (pandas-dev#19863)
1 parent fbc8d72 commit faf595e

File tree

11 files changed

+72
-36
lines changed

11 files changed

+72
-36
lines changed

ci/lint.sh

+9
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@ if [ "$LINT" ]; then
111111
RET=1
112112
fi
113113

114+
# Check for the following code in the extension array base tests
115+
# tm.assert_frame_equal
116+
# tm.assert_series_equal
117+
grep -r -E --include '*.py' --exclude base.py 'tm.assert_(series|frame)_equal' pandas/tests/extension/base
118+
119+
if [ $? = "0" ]; then
120+
RET=1
121+
fi
122+
114123
echo "Check for invalid testing DONE"
115124

116125
# Check for imports from pandas.core.common instead

pandas/tests/extension/base/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ class TestMyDtype(BaseDtypeTests):
3131
Your class ``TestDtype`` will inherit all the tests defined on
3232
``BaseDtypeTests``. pytest's fixture discover will supply your ``dtype``
3333
wherever the test requires it. You're free to implement additional tests.
34+
35+
All the tests in these modules use ``self.assert_frame_equal`` or
36+
``self.assert_series_equal`` for dataframe or series comparisons. By default,
37+
they use the usual ``pandas.testing.assert_frame_equal`` and
38+
``pandas.testing.assert_series_equal``. You can override the checks used
39+
by defining the staticmethods ``assert_frame_equal`` and
40+
``assert_series_equal`` on your base test class.
41+
3442
"""
3543
from .casting import BaseCastingTests # noqa
3644
from .constructors import BaseConstructorsTests # noqa

pandas/tests/extension/base/base.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import pandas.util.testing as tm
2+
3+
4+
class BaseExtensionTests(object):
5+
assert_series_equal = staticmethod(tm.assert_series_equal)
6+
assert_frame_equal = staticmethod(tm.assert_frame_equal)

pandas/tests/extension/base/casting.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import pandas as pd
22
from pandas.core.internals import ObjectBlock
33

4+
from .base import BaseExtensionTests
45

5-
class BaseCastingTests(object):
6+
7+
class BaseCastingTests(BaseExtensionTests):
68
"""Casting to and from ExtensionDtypes"""
79

810
def test_astype_object_series(self, all_data):

pandas/tests/extension/base/constructors.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import pandas.util.testing as tm
55
from pandas.core.internals import ExtensionBlock
66

7+
from .base import BaseExtensionTests
78

8-
class BaseConstructorsTests(object):
9+
10+
class BaseConstructorsTests(BaseExtensionTests):
911

1012
def test_series_constructor(self, data):
1113
result = pd.Series(data)

pandas/tests/extension/base/dtype.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
22
import pandas as pd
33

4+
from .base import BaseExtensionTests
45

5-
class BaseDtypeTests(object):
6+
7+
class BaseDtypeTests(BaseExtensionTests):
68
"""Base class for ExtensionDtype classes"""
79

810
def test_name(self, dtype):

pandas/tests/extension/base/getitem.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,74 @@
11
import numpy as np
22

33
import pandas as pd
4-
import pandas.util.testing as tm
54

5+
from .base import BaseExtensionTests
66

7-
class BaseGetitemTests(object):
7+
8+
class BaseGetitemTests(BaseExtensionTests):
89
"""Tests for ExtensionArray.__getitem__."""
910

1011
def test_iloc_series(self, data):
1112
ser = pd.Series(data)
1213
result = ser.iloc[:4]
1314
expected = pd.Series(data[:4])
14-
tm.assert_series_equal(result, expected)
15+
self.assert_series_equal(result, expected)
1516

1617
result = ser.iloc[[0, 1, 2, 3]]
17-
tm.assert_series_equal(result, expected)
18+
self.assert_series_equal(result, expected)
1819

1920
def test_iloc_frame(self, data):
2021
df = pd.DataFrame({"A": data, 'B': np.arange(len(data))})
2122
expected = pd.DataFrame({"A": data[:4]})
2223

2324
# slice -> frame
2425
result = df.iloc[:4, [0]]
25-
tm.assert_frame_equal(result, expected)
26+
self.assert_frame_equal(result, expected)
2627

2728
# sequence -> frame
2829
result = df.iloc[[0, 1, 2, 3], [0]]
29-
tm.assert_frame_equal(result, expected)
30+
self.assert_frame_equal(result, expected)
3031

3132
expected = pd.Series(data[:4], name='A')
3233

3334
# slice -> series
3435
result = df.iloc[:4, 0]
35-
tm.assert_series_equal(result, expected)
36+
self.assert_series_equal(result, expected)
3637

3738
# sequence -> series
3839
result = df.iloc[:4, 0]
39-
tm.assert_series_equal(result, expected)
40+
self.assert_series_equal(result, expected)
4041

4142
def test_loc_series(self, data):
4243
ser = pd.Series(data)
4344
result = ser.loc[:3]
4445
expected = pd.Series(data[:4])
45-
tm.assert_series_equal(result, expected)
46+
self.assert_series_equal(result, expected)
4647

4748
result = ser.loc[[0, 1, 2, 3]]
48-
tm.assert_series_equal(result, expected)
49+
self.assert_series_equal(result, expected)
4950

5051
def test_loc_frame(self, data):
5152
df = pd.DataFrame({"A": data, 'B': np.arange(len(data))})
5253
expected = pd.DataFrame({"A": data[:4]})
5354

5455
# slice -> frame
5556
result = df.loc[:3, ['A']]
56-
tm.assert_frame_equal(result, expected)
57+
self.assert_frame_equal(result, expected)
5758

5859
# sequence -> frame
5960
result = df.loc[[0, 1, 2, 3], ['A']]
60-
tm.assert_frame_equal(result, expected)
61+
self.assert_frame_equal(result, expected)
6162

6263
expected = pd.Series(data[:4], name='A')
6364

6465
# slice -> series
6566
result = df.loc[:3, 'A']
66-
tm.assert_series_equal(result, expected)
67+
self.assert_series_equal(result, expected)
6768

6869
# sequence -> series
6970
result = df.loc[:3, 'A']
70-
tm.assert_series_equal(result, expected)
71+
self.assert_series_equal(result, expected)
7172

7273
def test_getitem_scalar(self, data):
7374
result = data[0]

pandas/tests/extension/base/interface.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
from pandas.core.dtypes.common import is_extension_array_dtype
66
from pandas.core.dtypes.dtypes import ExtensionDtype
77

8+
from .base import BaseExtensionTests
89

9-
class BaseInterfaceTests(object):
10+
11+
class BaseInterfaceTests(BaseExtensionTests):
1012
"""Tests that the basic interface is satisfied."""
1113
# ------------------------------------------------------------------------
1214
# Interface

pandas/tests/extension/base/methods.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import numpy as np
33

44
import pandas as pd
5-
import pandas.util.testing as tm
65

6+
from .base import BaseExtensionTests
77

8-
class BaseMethodsTests(object):
8+
9+
class BaseMethodsTests(BaseExtensionTests):
910
"""Various Series and DataFrame methods."""
1011

1112
@pytest.mark.parametrize('dropna', [True, False])
@@ -19,13 +20,13 @@ def test_value_counts(self, all_data, dropna):
1920
result = pd.Series(all_data).value_counts(dropna=dropna).sort_index()
2021
expected = pd.Series(other).value_counts(dropna=dropna).sort_index()
2122

22-
tm.assert_series_equal(result, expected)
23+
self.assert_series_equal(result, expected)
2324

2425
def test_count(self, data_missing):
2526
df = pd.DataFrame({"A": data_missing})
2627
result = df.count(axis='columns')
2728
expected = pd.Series([0, 1])
28-
tm.assert_series_equal(result, expected)
29+
self.assert_series_equal(result, expected)
2930

3031
def test_apply_simple_series(self, data):
3132
result = pd.Series(data).apply(id)

pandas/tests/extension/base/missing.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import pandas as pd
44
import pandas.util.testing as tm
55

6+
from .base import BaseExtensionTests
67

7-
class BaseMissingTests(object):
8+
9+
class BaseMissingTests(BaseExtensionTests):
810
def test_isna(self, data_missing):
911
if data_missing._can_hold_na:
1012
expected = np.array([True, False])
@@ -16,30 +18,30 @@ def test_isna(self, data_missing):
1618

1719
result = pd.Series(data_missing).isna()
1820
expected = pd.Series(expected)
19-
tm.assert_series_equal(result, expected)
21+
self.assert_series_equal(result, expected)
2022

2123
def test_dropna_series(self, data_missing):
2224
ser = pd.Series(data_missing)
2325
result = ser.dropna()
2426
expected = ser.iloc[[1]]
25-
tm.assert_series_equal(result, expected)
27+
self.assert_series_equal(result, expected)
2628

2729
def test_dropna_frame(self, data_missing):
2830
df = pd.DataFrame({"A": data_missing})
2931

3032
# defaults
3133
result = df.dropna()
3234
expected = df.iloc[[1]]
33-
tm.assert_frame_equal(result, expected)
35+
self.assert_frame_equal(result, expected)
3436

3537
# axis = 1
3638
result = df.dropna(axis='columns')
3739
expected = pd.DataFrame(index=[0, 1])
38-
tm.assert_frame_equal(result, expected)
40+
self.assert_frame_equal(result, expected)
3941

4042
# multiple
4143
df = pd.DataFrame({"A": data_missing,
4244
"B": [1, np.nan]})
4345
result = df.dropna()
4446
expected = df.iloc[:0]
45-
tm.assert_frame_equal(result, expected)
47+
self.assert_frame_equal(result, expected)

pandas/tests/extension/base/reshaping.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import pytest
22

33
import pandas as pd
4-
import pandas.util.testing as tm
54
from pandas.core.internals import ExtensionBlock
65

6+
from .base import BaseExtensionTests
77

8-
class BaseReshapingTests(object):
8+
9+
class BaseReshapingTests(BaseExtensionTests):
910
"""Tests for reshaping and concatenation."""
1011
@pytest.mark.parametrize('in_frame', [True, False])
1112
def test_concat(self, data, in_frame):
@@ -32,8 +33,8 @@ def test_align(self, data, na_value):
3233
# Assumes that the ctor can take a list of scalars of the type
3334
e1 = pd.Series(type(data)(list(a) + [na_value]))
3435
e2 = pd.Series(type(data)([na_value] + list(b)))
35-
tm.assert_series_equal(r1, e1)
36-
tm.assert_series_equal(r2, e2)
36+
self.assert_series_equal(r1, e1)
37+
self.assert_series_equal(r2, e2)
3738

3839
def test_align_frame(self, data, na_value):
3940
a = data[:3]
@@ -45,17 +46,17 @@ def test_align_frame(self, data, na_value):
4546
# Assumes that the ctor can take a list of scalars of the type
4647
e1 = pd.DataFrame({'A': type(data)(list(a) + [na_value])})
4748
e2 = pd.DataFrame({'A': type(data)([na_value] + list(b))})
48-
tm.assert_frame_equal(r1, e1)
49-
tm.assert_frame_equal(r2, e2)
49+
self.assert_frame_equal(r1, e1)
50+
self.assert_frame_equal(r2, e2)
5051

5152
def test_set_frame_expand_regular_with_extension(self, data):
5253
df = pd.DataFrame({"A": [1] * len(data)})
5354
df['B'] = data
5455
expected = pd.DataFrame({"A": [1] * len(data), "B": data})
55-
tm.assert_frame_equal(df, expected)
56+
self.assert_frame_equal(df, expected)
5657

5758
def test_set_frame_expand_extension_with_regular(self, data):
5859
df = pd.DataFrame({'A': data})
5960
df['B'] = [1] * len(data)
6061
expected = pd.DataFrame({"A": data, "B": [1] * len(data)})
61-
tm.assert_frame_equal(df, expected)
62+
self.assert_frame_equal(df, expected)

0 commit comments

Comments
 (0)