Skip to content

REF: remove overriding assert_extension_array_equal #54337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pandas/tests/extension/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,3 @@ def assert_series_equal(cls, left, right, *args, **kwargs):
@classmethod
def assert_frame_equal(cls, left, right, *args, **kwargs):
return tm.assert_frame_equal(left, right, *args, **kwargs)

@classmethod
def assert_extension_array_equal(cls, left, right, *args, **kwargs):
return tm.assert_extension_array_equal(left, right, *args, **kwargs)
3 changes: 2 additions & 1 deletion pandas/tests/extension/base/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas.util._test_decorators as td

import pandas as pd
import pandas._testing as tm
from pandas.core.internals.blocks import NumpyBlock
from pandas.tests.extension.base.base import BaseExtensionTests

Expand Down Expand Up @@ -84,4 +85,4 @@ def test_astype_own_type(self, data, copy):
# https://github.com/pandas-dev/pandas/issues/28488
result = data.astype(data.dtype, copy=copy)
assert (result is data) is (not copy)
self.assert_extension_array_equal(result, data)
tm.assert_extension_array_equal(result, data)
7 changes: 4 additions & 3 deletions pandas/tests/extension/base/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest

import pandas as pd
import pandas._testing as tm
from pandas.api.extensions import ExtensionArray
from pandas.core.internals.blocks import EABackedBlock
from pandas.tests.extension.base.base import BaseExtensionTests
Expand All @@ -10,11 +11,11 @@
class BaseConstructorsTests(BaseExtensionTests):
def test_from_sequence_from_cls(self, data):
result = type(data)._from_sequence(data, dtype=data.dtype)
self.assert_extension_array_equal(result, data)
tm.assert_extension_array_equal(result, data)

data = data[:0]
result = type(data)._from_sequence(data, dtype=data.dtype)
self.assert_extension_array_equal(result, data)
tm.assert_extension_array_equal(result, data)

def test_array_from_scalars(self, data):
scalars = [data[0], data[1], data[2]]
Expand Down Expand Up @@ -107,7 +108,7 @@ def test_from_dtype(self, data):
def test_pandas_array(self, data):
# pd.array(extension_array) should be idempotent...
result = pd.array(data)
self.assert_extension_array_equal(result, data)
tm.assert_extension_array_equal(result, data)

def test_pandas_array_dtype(self, data):
# ... but specifying dtype will override idempotency
Expand Down
37 changes: 19 additions & 18 deletions pandas/tests/extension/base/dim2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)

import pandas as pd
import pandas._testing as tm
from pandas.core.arrays.integer import NUMPY_INT_TO_DTYPE
from pandas.tests.extension.base.base import BaseExtensionTests

Expand Down Expand Up @@ -39,28 +40,28 @@ def test_swapaxes(self, data):

result = arr2d.swapaxes(0, 1)
expected = arr2d.T
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

def test_delete_2d(self, data):
arr2d = data.repeat(3).reshape(-1, 3)

# axis = 0
result = arr2d.delete(1, axis=0)
expected = data.delete(1).repeat(3).reshape(-1, 3)
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

# axis = 1
result = arr2d.delete(1, axis=1)
expected = data.repeat(2).reshape(-1, 2)
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

def test_take_2d(self, data):
arr2d = data.reshape(-1, 1)

result = arr2d.take([0, 0, -1], axis=0)

expected = data.take([0, 0, -1]).reshape(-1, 1)
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

def test_repr_2d(self, data):
# this could fail in a corner case where an element contained the name
Expand Down Expand Up @@ -88,7 +89,7 @@ def test_getitem_2d(self, data):
arr2d = data.reshape(1, -1)

result = arr2d[0]
self.assert_extension_array_equal(result, data)
tm.assert_extension_array_equal(result, data)

with pytest.raises(IndexError):
arr2d[1]
Expand All @@ -97,18 +98,18 @@ def test_getitem_2d(self, data):
arr2d[-2]

result = arr2d[:]
self.assert_extension_array_equal(result, arr2d)
tm.assert_extension_array_equal(result, arr2d)

result = arr2d[:, :]
self.assert_extension_array_equal(result, arr2d)
tm.assert_extension_array_equal(result, arr2d)

result = arr2d[:, 0]
expected = data[[0]]
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

# dimension-expanding getitem on 1D
result = data[:, np.newaxis]
self.assert_extension_array_equal(result, arr2d.T)
tm.assert_extension_array_equal(result, arr2d.T)

def test_iter_2d(self, data):
arr2d = data.reshape(1, -1)
Expand Down Expand Up @@ -140,13 +141,13 @@ def test_concat_2d(self, data):
# axis=0
result = left._concat_same_type([left, right], axis=0)
expected = data._concat_same_type([data] * 4).reshape(-1, 2)
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

# axis=1
result = left._concat_same_type([left, right], axis=1)
assert result.shape == (len(data), 4)
self.assert_extension_array_equal(result[:, :2], left)
self.assert_extension_array_equal(result[:, 2:], right)
tm.assert_extension_array_equal(result[:, :2], left)
tm.assert_extension_array_equal(result[:, 2:], right)

# axis > 1 -> invalid
msg = "axis 2 is out of bounds for array of dimension 2"
Expand All @@ -163,7 +164,7 @@ def test_fillna_2d_method(self, data_missing, method):
result = arr.pad_or_backfill(method=method, limit=None)

expected = data_missing.pad_or_backfill(method=method).repeat(2).reshape(2, 2)
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

# Reverse so that backfill is not a no-op.
arr2 = arr[::-1]
Expand All @@ -175,7 +176,7 @@ def test_fillna_2d_method(self, data_missing, method):
expected2 = (
data_missing[::-1].pad_or_backfill(method=method).repeat(2).reshape(2, 2)
)
self.assert_extension_array_equal(result2, expected2)
tm.assert_extension_array_equal(result2, expected2)

@pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"])
def test_reductions_2d_axis_none(self, data, method):
Expand Down Expand Up @@ -251,18 +252,18 @@ def get_reduction_result_dtype(dtype):
fill_value = 1 if method == "prod" else 0
expected = expected.fillna(fill_value)

self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)
elif method == "median":
# std and var are not dtype-preserving
expected = data
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)
elif method in ["mean", "std", "var"]:
if is_integer_dtype(data) or is_bool_dtype(data):
data = data.astype("Float64")
if method == "mean":
self.assert_extension_array_equal(result, data)
tm.assert_extension_array_equal(result, data)
else:
self.assert_extension_array_equal(result, data - data)
tm.assert_extension_array_equal(result, data - data)

@pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"])
def test_reductions_2d_axis1(self, data, method):
Expand Down
24 changes: 12 additions & 12 deletions pandas/tests/extension/base/getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_getitem_empty(self, data):
assert isinstance(result, type(data))

expected = data[np.array([], dtype="int64")]
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

def test_getitem_mask(self, data):
# Empty mask, raw array
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_getitem_boolean_array_mask(self, data):
mask[:5] = True
expected = data.take([0, 1, 2, 3, 4])
result = data[mask]
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

expected = pd.Series(expected)
result = pd.Series(data)[mask]
Expand All @@ -224,7 +224,7 @@ def test_getitem_boolean_na_treated_as_false(self, data):
result = data[mask]
expected = data[mask.fillna(False)]

self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

s = pd.Series(data)

Expand All @@ -243,7 +243,7 @@ def test_getitem_integer_array(self, data, idx):
assert len(result) == 3
assert isinstance(result, type(data))
expected = data.take([0, 1, 2])
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

expected = pd.Series(expected)
result = pd.Series(data)[idx]
Expand Down Expand Up @@ -287,22 +287,22 @@ def test_getitem_slice(self, data):
def test_getitem_ellipsis_and_slice(self, data):
# GH#40353 this is called from slice_block_rows
result = data[..., :]
self.assert_extension_array_equal(result, data)
tm.assert_extension_array_equal(result, data)

result = data[:, ...]
self.assert_extension_array_equal(result, data)
tm.assert_extension_array_equal(result, data)

result = data[..., :3]
self.assert_extension_array_equal(result, data[:3])
tm.assert_extension_array_equal(result, data[:3])

result = data[:3, ...]
self.assert_extension_array_equal(result, data[:3])
tm.assert_extension_array_equal(result, data[:3])

result = data[..., ::2]
self.assert_extension_array_equal(result, data[::2])
tm.assert_extension_array_equal(result, data[::2])

result = data[::2, ...]
self.assert_extension_array_equal(result, data[::2])
tm.assert_extension_array_equal(result, data[::2])

def test_get(self, data):
# GH 20882
Expand Down Expand Up @@ -381,7 +381,7 @@ def test_take_negative(self, data):
n = len(data)
result = data.take([0, -n, n - 1, -1])
expected = data.take([0, 0, n - 1, n - 1])
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

def test_take_non_na_fill_value(self, data_missing):
fill_value = data_missing[1] # valid
Expand All @@ -392,7 +392,7 @@ def test_take_non_na_fill_value(self, data_missing):
)
result = arr.take([-1, 1], fill_value=fill_value, allow_fill=True)
expected = arr.take([1, 1])
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

def test_take_pandas_style_negative_raises(self, data, na_value):
with pytest.raises(ValueError, match=""):
Expand Down
24 changes: 12 additions & 12 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,14 @@ def test_factorize(self, data_for_grouping):
expected_uniques = data_for_grouping.take([0, 4, 7])

tm.assert_numpy_array_equal(codes, expected_codes)
self.assert_extension_array_equal(uniques, expected_uniques)
tm.assert_extension_array_equal(uniques, expected_uniques)

def test_factorize_equivalence(self, data_for_grouping):
codes_1, uniques_1 = pd.factorize(data_for_grouping, use_na_sentinel=True)
codes_2, uniques_2 = data_for_grouping.factorize(use_na_sentinel=True)

tm.assert_numpy_array_equal(codes_1, codes_2)
self.assert_extension_array_equal(uniques_1, uniques_2)
tm.assert_extension_array_equal(uniques_1, uniques_2)
assert len(uniques_1) == len(pd.unique(uniques_1))
assert uniques_1.dtype == data_for_grouping.dtype

Expand All @@ -280,7 +280,7 @@ def test_factorize_empty(self, data):
expected_uniques = type(data)._from_sequence([], dtype=data[:0].dtype)

tm.assert_numpy_array_equal(codes, expected_codes)
self.assert_extension_array_equal(uniques, expected_uniques)
tm.assert_extension_array_equal(uniques, expected_uniques)

def test_fillna_copy_frame(self, data_missing):
arr = data_missing.take([1, 1])
Expand Down Expand Up @@ -428,15 +428,15 @@ def test_shift_non_empty_array(self, data, periods, indices):
subset = data[:2]
result = subset.shift(periods)
expected = subset.take(indices, allow_fill=True)
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

@pytest.mark.parametrize("periods", [-4, -1, 0, 1, 4])
def test_shift_empty_array(self, data, periods):
# https://github.com/pandas-dev/pandas/issues/23911
empty = data[:0]
result = empty.shift(periods)
expected = empty
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

def test_shift_zero_copies(self, data):
# GH#31502
Expand All @@ -451,11 +451,11 @@ def test_shift_fill_value(self, data):
fill_value = data[0]
result = arr.shift(1, fill_value=fill_value)
expected = data.take([0, 0, 1, 2])
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

result = arr.shift(-2, fill_value=fill_value)
expected = data.take([2, 3, 0, 0])
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

def test_not_hashable(self, data):
# We are in general mutable, so not hashable
Expand Down Expand Up @@ -602,19 +602,19 @@ def test_repeat_raises(self, data, repeats, kwargs, error, msg, use_numpy):
def test_delete(self, data):
result = data.delete(0)
expected = data[1:]
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

result = data.delete([1, 3])
expected = data._concat_same_type([data[[0]], data[[2]], data[4:]])
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

def test_insert(self, data):
# insert at the beginning
result = data[1:].insert(0, data[0])
self.assert_extension_array_equal(result, data)
tm.assert_extension_array_equal(result, data)

result = data[1:].insert(-len(data[1:]), data[0])
self.assert_extension_array_equal(result, data)
tm.assert_extension_array_equal(result, data)

# insert at the middle
result = data[:-1].insert(4, data[-1])
Expand All @@ -623,7 +623,7 @@ def test_insert(self, data):
taker[5:] = taker[4:-1]
taker[4] = len(data) - 1
expected = data.take(taker)
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

def test_insert_invalid(self, data, invalid_scalar):
item = invalid_scalar
Expand Down
8 changes: 4 additions & 4 deletions pandas/tests/extension/base/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_isna_returns_copy(self, data_missing, na_func):
def test_dropna_array(self, data_missing):
result = data_missing.dropna()
expected = data_missing[[1]]
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

def test_dropna_series(self, data_missing):
ser = pd.Series(data_missing)
Expand Down Expand Up @@ -67,7 +67,7 @@ def test_fillna_scalar(self, data_missing):
valid = data_missing[1]
result = data_missing.fillna(valid)
expected = data_missing.fillna(valid)
self.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)

@pytest.mark.filterwarnings(
"ignore:Series.fillna with 'method' is deprecated:FutureWarning"
Expand All @@ -93,11 +93,11 @@ def test_fillna_no_op_returns_copy(self, data):
valid = data[0]
result = data.fillna(valid)
assert result is not data
self.assert_extension_array_equal(result, data)
tm.assert_extension_array_equal(result, data)

result = data.pad_or_backfill(method="backfill")
assert result is not data
self.assert_extension_array_equal(result, data)
tm.assert_extension_array_equal(result, data)

def test_fillna_series(self, data_missing):
fill_value = data_missing[1]
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,4 @@ def test_unary_ufunc_dunder_equivalence(self, data, ufunc):
ufunc(data)
else:
alt = ufunc(data)
self.assert_extension_array_equal(result, alt)
tm.assert_extension_array_equal(result, alt)
Loading