Skip to content

Commit bd5ce2a

Browse files
authored
REF: remove overriding assert_extension_array_equal (pandas-dev#54337)
1 parent 7cde20b commit bd5ce2a

16 files changed

+67
-66
lines changed

pandas/tests/extension/base/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,3 @@ def assert_series_equal(cls, left, right, *args, **kwargs):
1515
@classmethod
1616
def assert_frame_equal(cls, left, right, *args, **kwargs):
1717
return tm.assert_frame_equal(left, right, *args, **kwargs)
18-
19-
@classmethod
20-
def assert_extension_array_equal(cls, left, right, *args, **kwargs):
21-
return tm.assert_extension_array_equal(left, right, *args, **kwargs)

pandas/tests/extension/base/casting.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas.util._test_decorators as td
55

66
import pandas as pd
7+
import pandas._testing as tm
78
from pandas.core.internals.blocks import NumpyBlock
89
from pandas.tests.extension.base.base import BaseExtensionTests
910

@@ -84,4 +85,4 @@ def test_astype_own_type(self, data, copy):
8485
# https://github.com/pandas-dev/pandas/issues/28488
8586
result = data.astype(data.dtype, copy=copy)
8687
assert (result is data) is (not copy)
87-
self.assert_extension_array_equal(result, data)
88+
tm.assert_extension_array_equal(result, data)

pandas/tests/extension/base/constructors.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
import pandas as pd
5+
import pandas._testing as tm
56
from pandas.api.extensions import ExtensionArray
67
from pandas.core.internals.blocks import EABackedBlock
78
from pandas.tests.extension.base.base import BaseExtensionTests
@@ -10,11 +11,11 @@
1011
class BaseConstructorsTests(BaseExtensionTests):
1112
def test_from_sequence_from_cls(self, data):
1213
result = type(data)._from_sequence(data, dtype=data.dtype)
13-
self.assert_extension_array_equal(result, data)
14+
tm.assert_extension_array_equal(result, data)
1415

1516
data = data[:0]
1617
result = type(data)._from_sequence(data, dtype=data.dtype)
17-
self.assert_extension_array_equal(result, data)
18+
tm.assert_extension_array_equal(result, data)
1819

1920
def test_array_from_scalars(self, data):
2021
scalars = [data[0], data[1], data[2]]
@@ -107,7 +108,7 @@ def test_from_dtype(self, data):
107108
def test_pandas_array(self, data):
108109
# pd.array(extension_array) should be idempotent...
109110
result = pd.array(data)
110-
self.assert_extension_array_equal(result, data)
111+
tm.assert_extension_array_equal(result, data)
111112

112113
def test_pandas_array_dtype(self, data):
113114
# ... but specifying dtype will override idempotency

pandas/tests/extension/base/dim2.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313

1414
import pandas as pd
15+
import pandas._testing as tm
1516
from pandas.core.arrays.integer import NUMPY_INT_TO_DTYPE
1617
from pandas.tests.extension.base.base import BaseExtensionTests
1718

@@ -39,28 +40,28 @@ def test_swapaxes(self, data):
3940

4041
result = arr2d.swapaxes(0, 1)
4142
expected = arr2d.T
42-
self.assert_extension_array_equal(result, expected)
43+
tm.assert_extension_array_equal(result, expected)
4344

4445
def test_delete_2d(self, data):
4546
arr2d = data.repeat(3).reshape(-1, 3)
4647

4748
# axis = 0
4849
result = arr2d.delete(1, axis=0)
4950
expected = data.delete(1).repeat(3).reshape(-1, 3)
50-
self.assert_extension_array_equal(result, expected)
51+
tm.assert_extension_array_equal(result, expected)
5152

5253
# axis = 1
5354
result = arr2d.delete(1, axis=1)
5455
expected = data.repeat(2).reshape(-1, 2)
55-
self.assert_extension_array_equal(result, expected)
56+
tm.assert_extension_array_equal(result, expected)
5657

5758
def test_take_2d(self, data):
5859
arr2d = data.reshape(-1, 1)
5960

6061
result = arr2d.take([0, 0, -1], axis=0)
6162

6263
expected = data.take([0, 0, -1]).reshape(-1, 1)
63-
self.assert_extension_array_equal(result, expected)
64+
tm.assert_extension_array_equal(result, expected)
6465

6566
def test_repr_2d(self, data):
6667
# this could fail in a corner case where an element contained the name
@@ -88,7 +89,7 @@ def test_getitem_2d(self, data):
8889
arr2d = data.reshape(1, -1)
8990

9091
result = arr2d[0]
91-
self.assert_extension_array_equal(result, data)
92+
tm.assert_extension_array_equal(result, data)
9293

9394
with pytest.raises(IndexError):
9495
arr2d[1]
@@ -97,18 +98,18 @@ def test_getitem_2d(self, data):
9798
arr2d[-2]
9899

99100
result = arr2d[:]
100-
self.assert_extension_array_equal(result, arr2d)
101+
tm.assert_extension_array_equal(result, arr2d)
101102

102103
result = arr2d[:, :]
103-
self.assert_extension_array_equal(result, arr2d)
104+
tm.assert_extension_array_equal(result, arr2d)
104105

105106
result = arr2d[:, 0]
106107
expected = data[[0]]
107-
self.assert_extension_array_equal(result, expected)
108+
tm.assert_extension_array_equal(result, expected)
108109

109110
# dimension-expanding getitem on 1D
110111
result = data[:, np.newaxis]
111-
self.assert_extension_array_equal(result, arr2d.T)
112+
tm.assert_extension_array_equal(result, arr2d.T)
112113

113114
def test_iter_2d(self, data):
114115
arr2d = data.reshape(1, -1)
@@ -140,13 +141,13 @@ def test_concat_2d(self, data):
140141
# axis=0
141142
result = left._concat_same_type([left, right], axis=0)
142143
expected = data._concat_same_type([data] * 4).reshape(-1, 2)
143-
self.assert_extension_array_equal(result, expected)
144+
tm.assert_extension_array_equal(result, expected)
144145

145146
# axis=1
146147
result = left._concat_same_type([left, right], axis=1)
147148
assert result.shape == (len(data), 4)
148-
self.assert_extension_array_equal(result[:, :2], left)
149-
self.assert_extension_array_equal(result[:, 2:], right)
149+
tm.assert_extension_array_equal(result[:, :2], left)
150+
tm.assert_extension_array_equal(result[:, 2:], right)
150151

151152
# axis > 1 -> invalid
152153
msg = "axis 2 is out of bounds for array of dimension 2"
@@ -163,7 +164,7 @@ def test_fillna_2d_method(self, data_missing, method):
163164
result = arr.pad_or_backfill(method=method, limit=None)
164165

165166
expected = data_missing.pad_or_backfill(method=method).repeat(2).reshape(2, 2)
166-
self.assert_extension_array_equal(result, expected)
167+
tm.assert_extension_array_equal(result, expected)
167168

168169
# Reverse so that backfill is not a no-op.
169170
arr2 = arr[::-1]
@@ -175,7 +176,7 @@ def test_fillna_2d_method(self, data_missing, method):
175176
expected2 = (
176177
data_missing[::-1].pad_or_backfill(method=method).repeat(2).reshape(2, 2)
177178
)
178-
self.assert_extension_array_equal(result2, expected2)
179+
tm.assert_extension_array_equal(result2, expected2)
179180

180181
@pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"])
181182
def test_reductions_2d_axis_none(self, data, method):
@@ -251,18 +252,18 @@ def get_reduction_result_dtype(dtype):
251252
fill_value = 1 if method == "prod" else 0
252253
expected = expected.fillna(fill_value)
253254

254-
self.assert_extension_array_equal(result, expected)
255+
tm.assert_extension_array_equal(result, expected)
255256
elif method == "median":
256257
# std and var are not dtype-preserving
257258
expected = data
258-
self.assert_extension_array_equal(result, expected)
259+
tm.assert_extension_array_equal(result, expected)
259260
elif method in ["mean", "std", "var"]:
260261
if is_integer_dtype(data) or is_bool_dtype(data):
261262
data = data.astype("Float64")
262263
if method == "mean":
263-
self.assert_extension_array_equal(result, data)
264+
tm.assert_extension_array_equal(result, data)
264265
else:
265-
self.assert_extension_array_equal(result, data - data)
266+
tm.assert_extension_array_equal(result, data - data)
266267

267268
@pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"])
268269
def test_reductions_2d_axis1(self, data, method):

pandas/tests/extension/base/getitem.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_getitem_empty(self, data):
160160
assert isinstance(result, type(data))
161161

162162
expected = data[np.array([], dtype="int64")]
163-
self.assert_extension_array_equal(result, expected)
163+
tm.assert_extension_array_equal(result, expected)
164164

165165
def test_getitem_mask(self, data):
166166
# Empty mask, raw array
@@ -209,7 +209,7 @@ def test_getitem_boolean_array_mask(self, data):
209209
mask[:5] = True
210210
expected = data.take([0, 1, 2, 3, 4])
211211
result = data[mask]
212-
self.assert_extension_array_equal(result, expected)
212+
tm.assert_extension_array_equal(result, expected)
213213

214214
expected = pd.Series(expected)
215215
result = pd.Series(data)[mask]
@@ -224,7 +224,7 @@ def test_getitem_boolean_na_treated_as_false(self, data):
224224
result = data[mask]
225225
expected = data[mask.fillna(False)]
226226

227-
self.assert_extension_array_equal(result, expected)
227+
tm.assert_extension_array_equal(result, expected)
228228

229229
s = pd.Series(data)
230230

@@ -243,7 +243,7 @@ def test_getitem_integer_array(self, data, idx):
243243
assert len(result) == 3
244244
assert isinstance(result, type(data))
245245
expected = data.take([0, 1, 2])
246-
self.assert_extension_array_equal(result, expected)
246+
tm.assert_extension_array_equal(result, expected)
247247

248248
expected = pd.Series(expected)
249249
result = pd.Series(data)[idx]
@@ -287,22 +287,22 @@ def test_getitem_slice(self, data):
287287
def test_getitem_ellipsis_and_slice(self, data):
288288
# GH#40353 this is called from slice_block_rows
289289
result = data[..., :]
290-
self.assert_extension_array_equal(result, data)
290+
tm.assert_extension_array_equal(result, data)
291291

292292
result = data[:, ...]
293-
self.assert_extension_array_equal(result, data)
293+
tm.assert_extension_array_equal(result, data)
294294

295295
result = data[..., :3]
296-
self.assert_extension_array_equal(result, data[:3])
296+
tm.assert_extension_array_equal(result, data[:3])
297297

298298
result = data[:3, ...]
299-
self.assert_extension_array_equal(result, data[:3])
299+
tm.assert_extension_array_equal(result, data[:3])
300300

301301
result = data[..., ::2]
302-
self.assert_extension_array_equal(result, data[::2])
302+
tm.assert_extension_array_equal(result, data[::2])
303303

304304
result = data[::2, ...]
305-
self.assert_extension_array_equal(result, data[::2])
305+
tm.assert_extension_array_equal(result, data[::2])
306306

307307
def test_get(self, data):
308308
# GH 20882
@@ -381,7 +381,7 @@ def test_take_negative(self, data):
381381
n = len(data)
382382
result = data.take([0, -n, n - 1, -1])
383383
expected = data.take([0, 0, n - 1, n - 1])
384-
self.assert_extension_array_equal(result, expected)
384+
tm.assert_extension_array_equal(result, expected)
385385

386386
def test_take_non_na_fill_value(self, data_missing):
387387
fill_value = data_missing[1] # valid
@@ -392,7 +392,7 @@ def test_take_non_na_fill_value(self, data_missing):
392392
)
393393
result = arr.take([-1, 1], fill_value=fill_value, allow_fill=True)
394394
expected = arr.take([1, 1])
395-
self.assert_extension_array_equal(result, expected)
395+
tm.assert_extension_array_equal(result, expected)
396396

397397
def test_take_pandas_style_negative_raises(self, data, na_value):
398398
with pytest.raises(ValueError, match=""):

pandas/tests/extension/base/methods.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -263,14 +263,14 @@ def test_factorize(self, data_for_grouping):
263263
expected_uniques = data_for_grouping.take([0, 4, 7])
264264

265265
tm.assert_numpy_array_equal(codes, expected_codes)
266-
self.assert_extension_array_equal(uniques, expected_uniques)
266+
tm.assert_extension_array_equal(uniques, expected_uniques)
267267

268268
def test_factorize_equivalence(self, data_for_grouping):
269269
codes_1, uniques_1 = pd.factorize(data_for_grouping, use_na_sentinel=True)
270270
codes_2, uniques_2 = data_for_grouping.factorize(use_na_sentinel=True)
271271

272272
tm.assert_numpy_array_equal(codes_1, codes_2)
273-
self.assert_extension_array_equal(uniques_1, uniques_2)
273+
tm.assert_extension_array_equal(uniques_1, uniques_2)
274274
assert len(uniques_1) == len(pd.unique(uniques_1))
275275
assert uniques_1.dtype == data_for_grouping.dtype
276276

@@ -280,7 +280,7 @@ def test_factorize_empty(self, data):
280280
expected_uniques = type(data)._from_sequence([], dtype=data[:0].dtype)
281281

282282
tm.assert_numpy_array_equal(codes, expected_codes)
283-
self.assert_extension_array_equal(uniques, expected_uniques)
283+
tm.assert_extension_array_equal(uniques, expected_uniques)
284284

285285
def test_fillna_copy_frame(self, data_missing):
286286
arr = data_missing.take([1, 1])
@@ -428,15 +428,15 @@ def test_shift_non_empty_array(self, data, periods, indices):
428428
subset = data[:2]
429429
result = subset.shift(periods)
430430
expected = subset.take(indices, allow_fill=True)
431-
self.assert_extension_array_equal(result, expected)
431+
tm.assert_extension_array_equal(result, expected)
432432

433433
@pytest.mark.parametrize("periods", [-4, -1, 0, 1, 4])
434434
def test_shift_empty_array(self, data, periods):
435435
# https://github.com/pandas-dev/pandas/issues/23911
436436
empty = data[:0]
437437
result = empty.shift(periods)
438438
expected = empty
439-
self.assert_extension_array_equal(result, expected)
439+
tm.assert_extension_array_equal(result, expected)
440440

441441
def test_shift_zero_copies(self, data):
442442
# GH#31502
@@ -451,11 +451,11 @@ def test_shift_fill_value(self, data):
451451
fill_value = data[0]
452452
result = arr.shift(1, fill_value=fill_value)
453453
expected = data.take([0, 0, 1, 2])
454-
self.assert_extension_array_equal(result, expected)
454+
tm.assert_extension_array_equal(result, expected)
455455

456456
result = arr.shift(-2, fill_value=fill_value)
457457
expected = data.take([2, 3, 0, 0])
458-
self.assert_extension_array_equal(result, expected)
458+
tm.assert_extension_array_equal(result, expected)
459459

460460
def test_not_hashable(self, data):
461461
# We are in general mutable, so not hashable
@@ -602,19 +602,19 @@ def test_repeat_raises(self, data, repeats, kwargs, error, msg, use_numpy):
602602
def test_delete(self, data):
603603
result = data.delete(0)
604604
expected = data[1:]
605-
self.assert_extension_array_equal(result, expected)
605+
tm.assert_extension_array_equal(result, expected)
606606

607607
result = data.delete([1, 3])
608608
expected = data._concat_same_type([data[[0]], data[[2]], data[4:]])
609-
self.assert_extension_array_equal(result, expected)
609+
tm.assert_extension_array_equal(result, expected)
610610

611611
def test_insert(self, data):
612612
# insert at the beginning
613613
result = data[1:].insert(0, data[0])
614-
self.assert_extension_array_equal(result, data)
614+
tm.assert_extension_array_equal(result, data)
615615

616616
result = data[1:].insert(-len(data[1:]), data[0])
617-
self.assert_extension_array_equal(result, data)
617+
tm.assert_extension_array_equal(result, data)
618618

619619
# insert at the middle
620620
result = data[:-1].insert(4, data[-1])
@@ -623,7 +623,7 @@ def test_insert(self, data):
623623
taker[5:] = taker[4:-1]
624624
taker[4] = len(data) - 1
625625
expected = data.take(taker)
626-
self.assert_extension_array_equal(result, expected)
626+
tm.assert_extension_array_equal(result, expected)
627627

628628
def test_insert_invalid(self, data, invalid_scalar):
629629
item = invalid_scalar

pandas/tests/extension/base/missing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_isna_returns_copy(self, data_missing, na_func):
3636
def test_dropna_array(self, data_missing):
3737
result = data_missing.dropna()
3838
expected = data_missing[[1]]
39-
self.assert_extension_array_equal(result, expected)
39+
tm.assert_extension_array_equal(result, expected)
4040

4141
def test_dropna_series(self, data_missing):
4242
ser = pd.Series(data_missing)
@@ -67,7 +67,7 @@ def test_fillna_scalar(self, data_missing):
6767
valid = data_missing[1]
6868
result = data_missing.fillna(valid)
6969
expected = data_missing.fillna(valid)
70-
self.assert_extension_array_equal(result, expected)
70+
tm.assert_extension_array_equal(result, expected)
7171

7272
@pytest.mark.filterwarnings(
7373
"ignore:Series.fillna with 'method' is deprecated:FutureWarning"
@@ -93,11 +93,11 @@ def test_fillna_no_op_returns_copy(self, data):
9393
valid = data[0]
9494
result = data.fillna(valid)
9595
assert result is not data
96-
self.assert_extension_array_equal(result, data)
96+
tm.assert_extension_array_equal(result, data)
9797

9898
result = data.pad_or_backfill(method="backfill")
9999
assert result is not data
100-
self.assert_extension_array_equal(result, data)
100+
tm.assert_extension_array_equal(result, data)
101101

102102
def test_fillna_series(self, data_missing):
103103
fill_value = data_missing[1]

pandas/tests/extension/base/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,4 +213,4 @@ def test_unary_ufunc_dunder_equivalence(self, data, ufunc):
213213
ufunc(data)
214214
else:
215215
alt = ufunc(data)
216-
self.assert_extension_array_equal(result, alt)
216+
tm.assert_extension_array_equal(result, alt)

0 commit comments

Comments
 (0)