Skip to content

Commit 7eb2546

Browse files
authored
REF: de-duplicate extension tests (pandas-dev#54340)
* REF: unnecessary override in Period test, de-duplicate in arrow tess * REF: de-duplicate extension tests * typo fixup * De-duplicate * xfail for json * fix on pyarrow-less builds
1 parent bd5ce2a commit 7eb2546

File tree

6 files changed

+43
-105
lines changed

6 files changed

+43
-105
lines changed

pandas/tests/extension/base/methods.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,13 @@ def test_value_counts_with_normalize(self, data):
6565
else:
6666
expected = pd.Series(0.0, index=result.index, name="proportion")
6767
expected[result > 0] = 1 / len(values)
68-
if na_value_for_dtype(data.dtype) is pd.NA:
68+
69+
if getattr(data.dtype, "storage", "") == "pyarrow" or isinstance(
70+
data.dtype, pd.ArrowDtype
71+
):
72+
# TODO: avoid special-casing
73+
expected = expected.astype("double[pyarrow]")
74+
elif na_value_for_dtype(data.dtype) is pd.NA:
6975
# TODO(GH#44692): avoid special-casing
7076
expected = expected.astype("Float64")
7177

@@ -678,3 +684,7 @@ def test_equals(self, data, na_value, as_series, box):
678684
# other types
679685
assert data.equals(None) is False
680686
assert data[[0]].equals(data[0]) is False
687+
688+
def test_equals_same_data_different_object(self, data):
689+
# https://github.com/pandas-dev/pandas/issues/34660
690+
assert pd.Series(data).equals(pd.Series(data))

pandas/tests/extension/json/test_json.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,14 @@ def test_equals(self, data, na_value, as_series):
241241
def test_fillna_copy_frame(self, data_missing):
242242
super().test_fillna_copy_frame(data_missing)
243243

244+
def test_equals_same_data_different_object(
245+
self, data, using_copy_on_write, request
246+
):
247+
if using_copy_on_write:
248+
mark = pytest.mark.xfail(reason="Fails with CoW")
249+
request.node.add_marker(mark)
250+
super().test_equals_same_data_different_object(data)
251+
244252

245253
class TestCasting(BaseJSON, base.BaseCastingTests):
246254
@pytest.mark.xfail(reason="failing on np.array(self, dtype=str)")

pandas/tests/extension/test_arrow.py

Lines changed: 22 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,18 @@
6565
from pandas.core.arrays.arrow.extension_types import ArrowPeriodType
6666

6767

68+
def _require_timezone_database(request):
69+
if is_platform_windows() and is_ci_environment():
70+
mark = pytest.mark.xfail(
71+
raises=pa.ArrowInvalid,
72+
reason=(
73+
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
74+
"on CI to path to the tzdata for pyarrow."
75+
),
76+
)
77+
request.node.add_marker(mark)
78+
79+
6880
@pytest.fixture(params=tm.ALL_PYARROW_DTYPES, ids=str)
6981
def dtype(request):
7082
return ArrowDtype(pyarrow_dtype=request.param)
@@ -314,16 +326,8 @@ def test_from_sequence_of_strings_pa_array(self, data, request):
314326
)
315327
)
316328
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None:
317-
if is_platform_windows() and is_ci_environment():
318-
request.node.add_marker(
319-
pytest.mark.xfail(
320-
raises=pa.ArrowInvalid,
321-
reason=(
322-
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
323-
"on CI to path to the tzdata for pyarrow."
324-
),
325-
)
326-
)
329+
_require_timezone_database(request)
330+
327331
pa_array = data._pa_array.cast(pa.string())
328332
result = type(data)._from_sequence_of_strings(pa_array, dtype=data.dtype)
329333
tm.assert_extension_array_equal(result, data)
@@ -795,20 +799,6 @@ def test_value_counts_returns_pyarrow_int64(self, data):
795799
result = data.value_counts()
796800
assert result.dtype == ArrowDtype(pa.int64())
797801

798-
def test_value_counts_with_normalize(self, data, request):
799-
data = data[:10].unique()
800-
values = np.array(data[~data.isna()])
801-
ser = pd.Series(data, dtype=data.dtype)
802-
803-
result = ser.value_counts(normalize=True).sort_index()
804-
805-
expected = pd.Series(
806-
[1 / len(values)] * len(values), index=result.index, name="proportion"
807-
)
808-
expected = expected.astype("double[pyarrow]")
809-
810-
self.assert_series_equal(result, expected)
811-
812802
def test_argmin_argmax(
813803
self, data_for_sorting, data_missing_for_sorting, na_value, request
814804
):
@@ -865,10 +855,6 @@ def test_combine_add(self, data_repeated, request):
865855
else:
866856
super().test_combine_add(data_repeated)
867857

868-
def test_basic_equals(self, data):
869-
# https://github.com/pandas-dev/pandas/issues/34660
870-
assert pd.Series(data).equals(pd.Series(data))
871-
872858

873859
class TestBaseArithmeticOps(base.BaseArithmeticOpsTests):
874860
divmod_exc = NotImplementedError
@@ -2563,33 +2549,17 @@ def test_dt_isocalendar():
25632549
)
25642550
def test_dt_day_month_name(method, exp, request):
25652551
# GH 52388
2566-
if is_platform_windows() and is_ci_environment():
2567-
request.node.add_marker(
2568-
pytest.mark.xfail(
2569-
raises=pa.ArrowInvalid,
2570-
reason=(
2571-
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
2572-
"on CI to path to the tzdata for pyarrow."
2573-
),
2574-
)
2575-
)
2552+
_require_timezone_database(request)
2553+
25762554
ser = pd.Series([datetime(2023, 1, 1), None], dtype=ArrowDtype(pa.timestamp("ms")))
25772555
result = getattr(ser.dt, method)()
25782556
expected = pd.Series([exp, None], dtype=ArrowDtype(pa.string()))
25792557
tm.assert_series_equal(result, expected)
25802558

25812559

25822560
def test_dt_strftime(request):
2583-
if is_platform_windows() and is_ci_environment():
2584-
request.node.add_marker(
2585-
pytest.mark.xfail(
2586-
raises=pa.ArrowInvalid,
2587-
reason=(
2588-
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
2589-
"on CI to path to the tzdata for pyarrow."
2590-
),
2591-
)
2592-
)
2561+
_require_timezone_database(request)
2562+
25932563
ser = pd.Series(
25942564
[datetime(year=2023, month=1, day=2, hour=3), None],
25952565
dtype=ArrowDtype(pa.timestamp("ns")),
@@ -2700,16 +2670,8 @@ def test_dt_tz_localize_none():
27002670

27012671
@pytest.mark.parametrize("unit", ["us", "ns"])
27022672
def test_dt_tz_localize(unit, request):
2703-
if is_platform_windows() and is_ci_environment():
2704-
request.node.add_marker(
2705-
pytest.mark.xfail(
2706-
raises=pa.ArrowInvalid,
2707-
reason=(
2708-
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
2709-
"on CI to path to the tzdata for pyarrow."
2710-
),
2711-
)
2712-
)
2673+
_require_timezone_database(request)
2674+
27132675
ser = pd.Series(
27142676
[datetime(year=2023, month=1, day=2, hour=3), None],
27152677
dtype=ArrowDtype(pa.timestamp(unit)),
@@ -2731,16 +2693,8 @@ def test_dt_tz_localize(unit, request):
27312693
],
27322694
)
27332695
def test_dt_tz_localize_nonexistent(nonexistent, exp_date, request):
2734-
if is_platform_windows() and is_ci_environment():
2735-
request.node.add_marker(
2736-
pytest.mark.xfail(
2737-
raises=pa.ArrowInvalid,
2738-
reason=(
2739-
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
2740-
"on CI to path to the tzdata for pyarrow."
2741-
),
2742-
)
2743-
)
2696+
_require_timezone_database(request)
2697+
27442698
ser = pd.Series(
27452699
[datetime(year=2023, month=3, day=12, hour=2, minute=30), None],
27462700
dtype=ArrowDtype(pa.timestamp("ns")),

pandas/tests/extension/test_period.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,7 @@ class TestPrinting(BasePeriodTests, base.BasePrintingTests):
198198

199199

200200
class TestParsing(BasePeriodTests, base.BaseParsingTests):
201-
@pytest.mark.parametrize("engine", ["c", "python"])
202-
def test_EA_types(self, engine, data):
203-
super().test_EA_types(engine, data)
201+
pass
204202

205203

206204
class Test2DCompat(BasePeriodTests, base.NDArrayBacked2DTests):

pandas/tests/extension/test_sparse.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,6 @@ def test_where_series(self, data, na_value):
319319
expected = pd.Series(cls._from_sequence([a, b, b, b], dtype=data.dtype))
320320
self.assert_series_equal(result, expected)
321321

322-
def test_combine_first(self, data, request):
323-
super().test_combine_first(data)
324-
325322
def test_searchsorted(self, data_for_sorting, as_series):
326323
with tm.assert_produces_warning(PerformanceWarning, check_stacklevel=False):
327324
super().test_searchsorted(data_for_sorting, as_series)

pandas/tests/extension/test_string.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -182,22 +182,7 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):
182182

183183

184184
class TestMethods(base.BaseMethodsTests):
185-
def test_value_counts_with_normalize(self, data):
186-
data = data[:10].unique()
187-
values = np.array(data[~data.isna()])
188-
ser = pd.Series(data, dtype=data.dtype)
189-
190-
result = ser.value_counts(normalize=True).sort_index()
191-
192-
expected = pd.Series(
193-
[1 / len(values)] * len(values), index=result.index, name="proportion"
194-
)
195-
if getattr(data.dtype, "storage", "") == "pyarrow":
196-
expected = expected.astype("double[pyarrow]")
197-
else:
198-
expected = expected.astype("Float64")
199-
200-
self.assert_series_equal(result, expected)
185+
pass
201186

202187

203188
class TestCasting(base.BaseCastingTests):
@@ -226,20 +211,6 @@ class TestPrinting(base.BasePrintingTests):
226211

227212

228213
class TestGroupBy(base.BaseGroupbyTests):
229-
@pytest.mark.parametrize("as_index", [True, False])
230-
def test_groupby_extension_agg(self, as_index, data_for_grouping):
231-
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping})
232-
result = df.groupby("B", as_index=as_index).A.mean()
233-
_, uniques = pd.factorize(data_for_grouping, sort=True)
234-
235-
if as_index:
236-
index = pd.Index(uniques, name="B")
237-
expected = pd.Series([3.0, 1.0, 4.0], index=index, name="A")
238-
self.assert_series_equal(result, expected)
239-
else:
240-
expected = pd.DataFrame({"B": uniques, "A": [3.0, 1.0, 4.0]})
241-
self.assert_frame_equal(result, expected)
242-
243214
@pytest.mark.filterwarnings("ignore:Falling back:pandas.errors.PerformanceWarning")
244215
def test_groupby_extension_apply(self, data_for_grouping, groupby_apply_op):
245216
super().test_groupby_extension_apply(data_for_grouping, groupby_apply_op)

0 commit comments

Comments
 (0)