Skip to content

Commit 700ef33

Browse files
authored
ENH/TST: Add BaseMethodsTests tests for ArrowExtensionArray (#47552)
* ENH/TST: Add BaseMethodsTests tests for ArrowExtensionArray * Passing test now * add xfails for arraymanager * Fix typo * Trigger CI * Add xfails for min version and datamanger * Adjust more tests
1 parent 4d17588 commit 700ef33

File tree

4 files changed

+317
-10
lines changed

4 files changed

+317
-10
lines changed

pandas/core/base.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -983,10 +983,12 @@ def unique(self):
983983

984984
if not isinstance(values, np.ndarray):
985985
result: ArrayLike = values.unique()
986-
if self.dtype.kind in ["m", "M"] and isinstance(self, ABCSeries):
987-
# GH#31182 Series._values returns EA, unpack for backward-compat
988-
if getattr(self.dtype, "tz", None) is None:
989-
result = np.asarray(result)
986+
if (
987+
isinstance(self.dtype, np.dtype) and self.dtype.kind in ["m", "M"]
988+
) and isinstance(self, ABCSeries):
989+
# GH#31182 Series._values returns EA
990+
# unpack numpy datetime for backward-compat
991+
result = np.asarray(result)
990992
else:
991993
result = unique1d(values)
992994

pandas/tests/extension/base/methods.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
from pandas.core.dtypes.common import is_bool_dtype
8+
from pandas.core.dtypes.missing import na_value_for_dtype
89

910
import pandas as pd
1011
import pandas._testing as tm
@@ -49,8 +50,7 @@ def test_value_counts_with_normalize(self, data):
4950
else:
5051
expected = pd.Series(0.0, index=result.index)
5152
expected[result > 0] = 1 / len(values)
52-
53-
if isinstance(data.dtype, pd.core.dtypes.dtypes.BaseMaskedDtype):
53+
if na_value_for_dtype(data.dtype) is pd.NA:
5454
# TODO(GH#44692): avoid special-casing
5555
expected = expected.astype("Float64")
5656

pandas/tests/extension/test_arrow.py

+308-1
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,32 @@ def data_for_grouping(dtype):
153153
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)
154154

155155

156+
@pytest.fixture
157+
def data_for_sorting(data_for_grouping):
158+
"""
159+
Length-3 array with a known sort order.
160+
161+
This should be three items [B, C, A] with
162+
A < B < C
163+
"""
164+
return type(data_for_grouping)._from_sequence(
165+
[data_for_grouping[0], data_for_grouping[7], data_for_grouping[4]]
166+
)
167+
168+
169+
@pytest.fixture
170+
def data_missing_for_sorting(data_for_grouping):
171+
"""
172+
Length-3 array with a known sort order.
173+
174+
This should be three items [B, NA, A] with
175+
A < B and NA missing.
176+
"""
177+
return type(data_for_grouping)._from_sequence(
178+
[data_for_grouping[0], data_for_grouping[2], data_for_grouping[4]]
179+
)
180+
181+
156182
@pytest.fixture
157183
def na_value():
158184
"""The scalar missing value for this type. Default 'None'"""
@@ -654,7 +680,7 @@ def test_setitem_loc_scalar_single(self, data, using_array_manager, request):
654680
if pa_version_under2p0 and tz not in (None, "UTC"):
655681
request.node.add_marker(
656682
pytest.mark.xfail(
657-
reason=(f"Not supported by pyarrow < 2.0 with timestamp type {tz}")
683+
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
658684
)
659685
)
660686
elif using_array_manager and pa.types.is_duration(data.dtype.pyarrow_dtype):
@@ -988,6 +1014,287 @@ def test_EA_types(self, engine, data, request):
9881014
super().test_EA_types(engine, data)
9891015

9901016

1017+
class TestBaseMethods(base.BaseMethodsTests):
1018+
@pytest.mark.parametrize("dropna", [True, False])
1019+
def test_value_counts(self, all_data, dropna, request):
1020+
pa_dtype = all_data.dtype.pyarrow_dtype
1021+
if pa.types.is_date(pa_dtype) or (
1022+
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
1023+
):
1024+
request.node.add_marker(
1025+
pytest.mark.xfail(
1026+
raises=AttributeError,
1027+
reason="GH 34986",
1028+
)
1029+
)
1030+
elif pa.types.is_duration(pa_dtype):
1031+
request.node.add_marker(
1032+
pytest.mark.xfail(
1033+
raises=pa.ArrowNotImplementedError,
1034+
reason=f"value_count has no kernel for {pa_dtype}",
1035+
)
1036+
)
1037+
super().test_value_counts(all_data, dropna)
1038+
1039+
def test_value_counts_with_normalize(self, data, request):
1040+
pa_dtype = data.dtype.pyarrow_dtype
1041+
if pa.types.is_date(pa_dtype) or (
1042+
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
1043+
):
1044+
request.node.add_marker(
1045+
pytest.mark.xfail(
1046+
raises=AttributeError,
1047+
reason="GH 34986",
1048+
)
1049+
)
1050+
elif pa.types.is_duration(pa_dtype):
1051+
request.node.add_marker(
1052+
pytest.mark.xfail(
1053+
raises=pa.ArrowNotImplementedError,
1054+
reason=f"value_count has no pyarrow kernel for {pa_dtype}",
1055+
)
1056+
)
1057+
super().test_value_counts_with_normalize(data)
1058+
1059+
def test_argmin_argmax(
1060+
self, data_for_sorting, data_missing_for_sorting, na_value, request
1061+
):
1062+
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
1063+
if pa.types.is_boolean(pa_dtype):
1064+
request.node.add_marker(
1065+
pytest.mark.xfail(
1066+
reason=f"{pa_dtype} only has 2 unique possible values",
1067+
)
1068+
)
1069+
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)
1070+
1071+
@pytest.mark.parametrize("ascending", [True, False])
1072+
def test_sort_values(self, data_for_sorting, ascending, sort_by_key, request):
1073+
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
1074+
if pa.types.is_duration(pa_dtype) and not ascending and not pa_version_under2p0:
1075+
request.node.add_marker(
1076+
pytest.mark.xfail(
1077+
raises=pa.ArrowNotImplementedError,
1078+
reason=(
1079+
f"unique has no pyarrow kernel "
1080+
f"for {pa_dtype} when ascending={ascending}"
1081+
),
1082+
)
1083+
)
1084+
super().test_sort_values(data_for_sorting, ascending, sort_by_key)
1085+
1086+
@pytest.mark.parametrize("ascending", [True, False])
1087+
def test_sort_values_frame(self, data_for_sorting, ascending, request):
1088+
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
1089+
if pa.types.is_duration(pa_dtype):
1090+
request.node.add_marker(
1091+
pytest.mark.xfail(
1092+
raises=pa.ArrowNotImplementedError,
1093+
reason=(
1094+
f"dictionary_encode has no pyarrow kernel "
1095+
f"for {pa_dtype} when ascending={ascending}"
1096+
),
1097+
)
1098+
)
1099+
super().test_sort_values_frame(data_for_sorting, ascending)
1100+
1101+
@pytest.mark.parametrize("box", [pd.Series, lambda x: x])
1102+
@pytest.mark.parametrize("method", [lambda x: x.unique(), pd.unique])
1103+
def test_unique(self, data, box, method, request):
1104+
pa_dtype = data.dtype.pyarrow_dtype
1105+
if pa.types.is_duration(pa_dtype) and not pa_version_under2p0:
1106+
request.node.add_marker(
1107+
pytest.mark.xfail(
1108+
raises=pa.ArrowNotImplementedError,
1109+
reason=f"unique has no pyarrow kernel for {pa_dtype}.",
1110+
)
1111+
)
1112+
super().test_unique(data, box, method)
1113+
1114+
@pytest.mark.parametrize("na_sentinel", [-1, -2])
1115+
def test_factorize(self, data_for_grouping, na_sentinel, request):
1116+
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
1117+
if pa.types.is_duration(pa_dtype):
1118+
request.node.add_marker(
1119+
pytest.mark.xfail(
1120+
raises=pa.ArrowNotImplementedError,
1121+
reason=f"dictionary_encode has no pyarrow kernel for {pa_dtype}",
1122+
)
1123+
)
1124+
elif pa.types.is_boolean(pa_dtype):
1125+
request.node.add_marker(
1126+
pytest.mark.xfail(
1127+
reason=f"{pa_dtype} only has 2 unique possible values",
1128+
)
1129+
)
1130+
super().test_factorize(data_for_grouping, na_sentinel)
1131+
1132+
@pytest.mark.parametrize("na_sentinel", [-1, -2])
1133+
def test_factorize_equivalence(self, data_for_grouping, na_sentinel, request):
1134+
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
1135+
if pa.types.is_duration(pa_dtype):
1136+
request.node.add_marker(
1137+
pytest.mark.xfail(
1138+
raises=pa.ArrowNotImplementedError,
1139+
reason=f"dictionary_encode has no pyarrow kernel for {pa_dtype}",
1140+
)
1141+
)
1142+
super().test_factorize_equivalence(data_for_grouping, na_sentinel)
1143+
1144+
def test_factorize_empty(self, data, request):
1145+
pa_dtype = data.dtype.pyarrow_dtype
1146+
if pa.types.is_duration(pa_dtype):
1147+
request.node.add_marker(
1148+
pytest.mark.xfail(
1149+
raises=pa.ArrowNotImplementedError,
1150+
reason=f"dictionary_encode has no pyarrow kernel for {pa_dtype}",
1151+
)
1152+
)
1153+
super().test_factorize_empty(data)
1154+
1155+
def test_fillna_copy_frame(self, data_missing, request, using_array_manager):
1156+
pa_dtype = data_missing.dtype.pyarrow_dtype
1157+
if using_array_manager and pa.types.is_duration(pa_dtype):
1158+
request.node.add_marker(
1159+
pytest.mark.xfail(
1160+
reason=f"Checking ndim when using arraymanager with {pa_dtype}"
1161+
)
1162+
)
1163+
super().test_fillna_copy_frame(data_missing)
1164+
1165+
def test_fillna_copy_series(self, data_missing, request, using_array_manager):
1166+
pa_dtype = data_missing.dtype.pyarrow_dtype
1167+
if using_array_manager and pa.types.is_duration(pa_dtype):
1168+
request.node.add_marker(
1169+
pytest.mark.xfail(
1170+
reason=f"Checking ndim when using arraymanager with {pa_dtype}"
1171+
)
1172+
)
1173+
super().test_fillna_copy_series(data_missing)
1174+
1175+
def test_shift_fill_value(self, data, request):
1176+
pa_dtype = data.dtype.pyarrow_dtype
1177+
tz = getattr(pa_dtype, "tz", None)
1178+
if pa_version_under2p0 and tz not in (None, "UTC"):
1179+
request.node.add_marker(
1180+
pytest.mark.xfail(
1181+
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
1182+
)
1183+
)
1184+
super().test_shift_fill_value(data)
1185+
1186+
@pytest.mark.parametrize("repeats", [0, 1, 2, [1, 2, 3]])
1187+
def test_repeat(self, data, repeats, as_series, use_numpy, request):
1188+
pa_dtype = data.dtype.pyarrow_dtype
1189+
tz = getattr(pa_dtype, "tz", None)
1190+
if pa_version_under2p0 and tz not in (None, "UTC") and repeats != 0:
1191+
request.node.add_marker(
1192+
pytest.mark.xfail(
1193+
reason=(
1194+
f"Not supported by pyarrow < 2.0 with "
1195+
f"timestamp type {tz} when repeats={repeats}"
1196+
)
1197+
)
1198+
)
1199+
super().test_repeat(data, repeats, as_series, use_numpy)
1200+
1201+
def test_insert(self, data, request):
1202+
pa_dtype = data.dtype.pyarrow_dtype
1203+
tz = getattr(pa_dtype, "tz", None)
1204+
if pa_version_under2p0 and tz not in (None, "UTC"):
1205+
request.node.add_marker(
1206+
pytest.mark.xfail(
1207+
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
1208+
)
1209+
)
1210+
super().test_insert(data)
1211+
1212+
def test_combine_first(self, data, request, using_array_manager):
1213+
pa_dtype = data.dtype.pyarrow_dtype
1214+
tz = getattr(pa_dtype, "tz", None)
1215+
if using_array_manager and pa.types.is_duration(pa_dtype):
1216+
request.node.add_marker(
1217+
pytest.mark.xfail(
1218+
reason=f"Checking ndim when using arraymanager with {pa_dtype}"
1219+
)
1220+
)
1221+
elif pa_version_under2p0 and tz not in (None, "UTC"):
1222+
request.node.add_marker(
1223+
pytest.mark.xfail(
1224+
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
1225+
)
1226+
)
1227+
super().test_combine_first(data)
1228+
1229+
@pytest.mark.parametrize("frame", [True, False])
1230+
@pytest.mark.parametrize(
1231+
"periods, indices",
1232+
[(-2, [2, 3, 4, -1, -1]), (0, [0, 1, 2, 3, 4]), (2, [-1, -1, 0, 1, 2])],
1233+
)
1234+
def test_container_shift(
1235+
self, data, frame, periods, indices, request, using_array_manager
1236+
):
1237+
pa_dtype = data.dtype.pyarrow_dtype
1238+
if (
1239+
using_array_manager
1240+
and pa.types.is_duration(pa_dtype)
1241+
and periods in (-2, 2)
1242+
):
1243+
request.node.add_marker(
1244+
pytest.mark.xfail(
1245+
reason=(
1246+
f"Checking ndim when using arraymanager with "
1247+
f"{pa_dtype} and periods={periods}"
1248+
)
1249+
)
1250+
)
1251+
super().test_container_shift(data, frame, periods, indices)
1252+
1253+
@pytest.mark.xfail(
1254+
reason="result dtype pyarrow[bool] better than expected dtype object"
1255+
)
1256+
def test_combine_le(self, data_repeated):
1257+
super().test_combine_le(data_repeated)
1258+
1259+
def test_combine_add(self, data_repeated, request):
1260+
pa_dtype = next(data_repeated(1)).dtype.pyarrow_dtype
1261+
if pa.types.is_temporal(pa_dtype):
1262+
request.node.add_marker(
1263+
pytest.mark.xfail(
1264+
raises=TypeError,
1265+
reason=f"{pa_dtype} cannot be added to {pa_dtype}",
1266+
)
1267+
)
1268+
super().test_combine_add(data_repeated)
1269+
1270+
def test_searchsorted(self, data_for_sorting, as_series, request):
1271+
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
1272+
if pa.types.is_boolean(pa_dtype):
1273+
request.node.add_marker(
1274+
pytest.mark.xfail(
1275+
reason=f"{pa_dtype} only has 2 unique possible values",
1276+
)
1277+
)
1278+
super().test_searchsorted(data_for_sorting, as_series)
1279+
1280+
def test_where_series(self, data, na_value, as_frame, request, using_array_manager):
1281+
pa_dtype = data.dtype.pyarrow_dtype
1282+
if using_array_manager and pa.types.is_duration(pa_dtype):
1283+
request.node.add_marker(
1284+
pytest.mark.xfail(
1285+
reason=f"Checking ndim when using arraymanager with {pa_dtype}"
1286+
)
1287+
)
1288+
elif pa.types.is_temporal(pa_dtype):
1289+
request.node.add_marker(
1290+
pytest.mark.xfail(
1291+
raises=pa.ArrowNotImplementedError,
1292+
reason=f"Unsupported cast from double to {pa_dtype}",
1293+
)
1294+
)
1295+
super().test_where_series(data, na_value, as_frame)
1296+
1297+
9911298
def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
9921299
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
9931300
ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")

pandas/tests/extension/test_string.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,7 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):
167167

168168

169169
class TestMethods(base.BaseMethodsTests):
170-
@pytest.mark.xfail(reason="returns nullable: GH 44692")
171-
def test_value_counts_with_normalize(self, data):
172-
super().test_value_counts_with_normalize(data)
170+
pass
173171

174172

175173
class TestCasting(base.BaseCastingTests):

0 commit comments

Comments
 (0)