Skip to content

REF: de-duplicate extension tests #54340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@ def test_value_counts_with_normalize(self, data):
else:
expected = pd.Series(0.0, index=result.index, name="proportion")
expected[result > 0] = 1 / len(values)
if na_value_for_dtype(data.dtype) is pd.NA:

if getattr(data.dtype, "storage", "") == "pyarrow" or isinstance(
data.dtype, pd.ArrowDtype
):
# TODO: avoid special-casing
expected = expected.astype("double[pyarrow]")
elif na_value_for_dtype(data.dtype) is pd.NA:
# TODO(GH#44692): avoid special-casing
expected = expected.astype("Float64")

Expand Down Expand Up @@ -678,3 +684,7 @@ def test_equals(self, data, na_value, as_series, box):
# other types
assert data.equals(None) is False
assert data[[0]].equals(data[0]) is False

def test_equals_same_data_different_object(self, data):
# https://github.com/pandas-dev/pandas/issues/34660
assert pd.Series(data).equals(pd.Series(data))
8 changes: 8 additions & 0 deletions pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,14 @@ def test_equals(self, data, na_value, as_series):
def test_fillna_copy_frame(self, data_missing):
super().test_fillna_copy_frame(data_missing)

def test_equals_same_data_different_object(
self, data, using_copy_on_write, request
):
if using_copy_on_write:
mark = pytest.mark.xfail(reason="Fails with CoW")
request.node.add_marker(mark)
super().test_equals_same_data_different_object(data)


class TestCasting(BaseJSON, base.BaseCastingTests):
@pytest.mark.xfail(reason="failing on np.array(self, dtype=str)")
Expand Down
90 changes: 22 additions & 68 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@
from pandas.core.arrays.arrow.extension_types import ArrowPeriodType


def _require_timezone_database(request):
if is_platform_windows() and is_ci_environment():
mark = pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason=(
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
"on CI to path to the tzdata for pyarrow."
),
)
request.node.add_marker(mark)


@pytest.fixture(params=tm.ALL_PYARROW_DTYPES, ids=str)
def dtype(request):
return ArrowDtype(pyarrow_dtype=request.param)
Expand Down Expand Up @@ -314,16 +326,8 @@ def test_from_sequence_of_strings_pa_array(self, data, request):
)
)
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None:
if is_platform_windows() and is_ci_environment():
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason=(
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
"on CI to path to the tzdata for pyarrow."
),
)
)
_require_timezone_database(request)

pa_array = data._pa_array.cast(pa.string())
result = type(data)._from_sequence_of_strings(pa_array, dtype=data.dtype)
tm.assert_extension_array_equal(result, data)
Expand Down Expand Up @@ -795,20 +799,6 @@ def test_value_counts_returns_pyarrow_int64(self, data):
result = data.value_counts()
assert result.dtype == ArrowDtype(pa.int64())

def test_value_counts_with_normalize(self, data, request):
data = data[:10].unique()
values = np.array(data[~data.isna()])
ser = pd.Series(data, dtype=data.dtype)

result = ser.value_counts(normalize=True).sort_index()

expected = pd.Series(
[1 / len(values)] * len(values), index=result.index, name="proportion"
)
expected = expected.astype("double[pyarrow]")

self.assert_series_equal(result, expected)

def test_argmin_argmax(
self, data_for_sorting, data_missing_for_sorting, na_value, request
):
Expand Down Expand Up @@ -865,10 +855,6 @@ def test_combine_add(self, data_repeated, request):
else:
super().test_combine_add(data_repeated)

def test_basic_equals(self, data):
# https://github.com/pandas-dev/pandas/issues/34660
assert pd.Series(data).equals(pd.Series(data))


class TestBaseArithmeticOps(base.BaseArithmeticOpsTests):
divmod_exc = NotImplementedError
Expand Down Expand Up @@ -2552,33 +2538,17 @@ def test_dt_isocalendar():
)
def test_dt_day_month_name(method, exp, request):
# GH 52388
if is_platform_windows() and is_ci_environment():
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason=(
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
"on CI to path to the tzdata for pyarrow."
),
)
)
_require_timezone_database(request)

ser = pd.Series([datetime(2023, 1, 1), None], dtype=ArrowDtype(pa.timestamp("ms")))
result = getattr(ser.dt, method)()
expected = pd.Series([exp, None], dtype=ArrowDtype(pa.string()))
tm.assert_series_equal(result, expected)


def test_dt_strftime(request):
if is_platform_windows() and is_ci_environment():
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason=(
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
"on CI to path to the tzdata for pyarrow."
),
)
)
_require_timezone_database(request)

ser = pd.Series(
[datetime(year=2023, month=1, day=2, hour=3), None],
dtype=ArrowDtype(pa.timestamp("ns")),
Expand Down Expand Up @@ -2689,16 +2659,8 @@ def test_dt_tz_localize_none():

@pytest.mark.parametrize("unit", ["us", "ns"])
def test_dt_tz_localize(unit, request):
if is_platform_windows() and is_ci_environment():
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason=(
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
"on CI to path to the tzdata for pyarrow."
),
)
)
_require_timezone_database(request)

ser = pd.Series(
[datetime(year=2023, month=1, day=2, hour=3), None],
dtype=ArrowDtype(pa.timestamp(unit)),
Expand All @@ -2720,16 +2682,8 @@ def test_dt_tz_localize(unit, request):
],
)
def test_dt_tz_localize_nonexistent(nonexistent, exp_date, request):
if is_platform_windows() and is_ci_environment():
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason=(
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
"on CI to path to the tzdata for pyarrow."
),
)
)
_require_timezone_database(request)

ser = pd.Series(
[datetime(year=2023, month=3, day=12, hour=2, minute=30), None],
dtype=ArrowDtype(pa.timestamp("ns")),
Expand Down
4 changes: 1 addition & 3 deletions pandas/tests/extension/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,7 @@ class TestPrinting(BasePeriodTests, base.BasePrintingTests):


class TestParsing(BasePeriodTests, base.BaseParsingTests):
@pytest.mark.parametrize("engine", ["c", "python"])
def test_EA_types(self, engine, data):
super().test_EA_types(engine, data)
pass


class Test2DCompat(BasePeriodTests, base.NDArrayBacked2DTests):
Expand Down
3 changes: 0 additions & 3 deletions pandas/tests/extension/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,6 @@ def test_where_series(self, data, na_value):
expected = pd.Series(cls._from_sequence([a, b, b, b], dtype=data.dtype))
self.assert_series_equal(result, expected)

def test_combine_first(self, data, request):
super().test_combine_first(data)

def test_searchsorted(self, data_for_sorting, as_series):
with tm.assert_produces_warning(PerformanceWarning, check_stacklevel=False):
super().test_searchsorted(data_for_sorting, as_series)
Expand Down
31 changes: 1 addition & 30 deletions pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,22 +181,7 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):


class TestMethods(base.BaseMethodsTests):
def test_value_counts_with_normalize(self, data):
data = data[:10].unique()
values = np.array(data[~data.isna()])
ser = pd.Series(data, dtype=data.dtype)

result = ser.value_counts(normalize=True).sort_index()

expected = pd.Series(
[1 / len(values)] * len(values), index=result.index, name="proportion"
)
if getattr(data.dtype, "storage", "") == "pyarrow":
expected = expected.astype("double[pyarrow]")
else:
expected = expected.astype("Float64")

self.assert_series_equal(result, expected)
pass


class TestCasting(base.BaseCastingTests):
Expand Down Expand Up @@ -225,20 +210,6 @@ class TestPrinting(base.BasePrintingTests):


class TestGroupBy(base.BaseGroupbyTests):
@pytest.mark.parametrize("as_index", [True, False])
def test_groupby_extension_agg(self, as_index, data_for_grouping):
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping})
result = df.groupby("B", as_index=as_index).A.mean()
_, uniques = pd.factorize(data_for_grouping, sort=True)

if as_index:
index = pd.Index(uniques, name="B")
expected = pd.Series([3.0, 1.0, 4.0], index=index, name="A")
self.assert_series_equal(result, expected)
else:
expected = pd.DataFrame({"B": uniques, "A": [3.0, 1.0, 4.0]})
self.assert_frame_equal(result, expected)

@pytest.mark.filterwarnings("ignore:Falling back:pandas.errors.PerformanceWarning")
def test_groupby_extension_apply(self, data_for_grouping, groupby_apply_op):
super().test_groupby_extension_apply(data_for_grouping, groupby_apply_op)
Expand Down