Skip to content

Commit ecc8ab4

Browse files
Backport PR #48489 on branch 1.5.x (BUG: fix test_arrow.py tests) (#48532)
Backport PR #48489: BUG: fix test_arrow.py tests Co-authored-by: jbrockmendel <[email protected]>
1 parent 5817209 commit ecc8ab4

File tree

6 files changed

+29
-148
lines changed

6 files changed

+29
-148
lines changed

pandas/core/dtypes/common.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1188,10 +1188,10 @@ def needs_i8_conversion(arr_or_dtype) -> bool:
11881188
"""
11891189
if arr_or_dtype is None:
11901190
return False
1191-
if isinstance(arr_or_dtype, (np.dtype, ExtensionDtype)):
1192-
# fastpath
1193-
dtype = arr_or_dtype
1194-
return dtype.kind in ["m", "M"] or dtype.type is Period
1191+
if isinstance(arr_or_dtype, np.dtype):
1192+
return arr_or_dtype.kind in ["m", "M"]
1193+
elif isinstance(arr_or_dtype, ExtensionDtype):
1194+
return isinstance(arr_or_dtype, (PeriodDtype, DatetimeTZDtype))
11951195

11961196
try:
11971197
dtype = get_dtype(arr_or_dtype)

pandas/core/dtypes/concat.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727
is_dtype_equal,
2828
is_sparse,
2929
)
30-
from pandas.core.dtypes.dtypes import ExtensionDtype
30+
from pandas.core.dtypes.dtypes import (
31+
DatetimeTZDtype,
32+
ExtensionDtype,
33+
)
3134
from pandas.core.dtypes.generic import (
3235
ABCCategoricalIndex,
3336
ABCExtensionArray,
@@ -103,10 +106,12 @@ def is_nonempty(x) -> bool:
103106
# ea_compat_axis see GH#39574
104107
to_concat = non_empties
105108

109+
dtypes = {obj.dtype for obj in to_concat}
106110
kinds = {obj.dtype.kind for obj in to_concat}
107-
contains_datetime = any(kind in ["m", "M"] for kind in kinds) or any(
108-
isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat
109-
)
111+
contains_datetime = any(
112+
isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in ["m", "M"]
113+
for dtype in dtypes
114+
) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)
110115

111116
all_empty = not len(non_empties)
112117
single_dtype = len({x.dtype for x in to_concat}) == 1

pandas/core/internals/concat.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
)
3030
from pandas.core.dtypes.common import (
3131
is_1d_only_ea_dtype,
32-
is_datetime64tz_dtype,
3332
is_dtype_equal,
3433
is_scalar,
3534
needs_i8_conversion,
@@ -38,7 +37,10 @@
3837
cast_to_common_type,
3938
concat_compat,
4039
)
41-
from pandas.core.dtypes.dtypes import ExtensionDtype
40+
from pandas.core.dtypes.dtypes import (
41+
DatetimeTZDtype,
42+
ExtensionDtype,
43+
)
4244
from pandas.core.dtypes.missing import (
4345
is_valid_na_for_dtype,
4446
isna,
@@ -147,16 +149,6 @@ def concat_arrays(to_concat: list) -> ArrayLike:
147149
else:
148150
target_dtype = find_common_type([arr.dtype for arr in to_concat_no_proxy])
149151

150-
if target_dtype.kind in ["m", "M"]:
151-
# for datetimelike use DatetimeArray/TimedeltaArray concatenation
152-
# don't use arr.astype(target_dtype, copy=False), because that doesn't
153-
# work for DatetimeArray/TimedeltaArray (returns ndarray)
154-
to_concat = [
155-
arr.to_array(target_dtype) if isinstance(arr, NullArrayProxy) else arr
156-
for arr in to_concat
157-
]
158-
return type(to_concat_no_proxy[0])._concat_same_type(to_concat, axis=0)
159-
160152
to_concat = [
161153
arr.to_array(target_dtype)
162154
if isinstance(arr, NullArrayProxy)
@@ -471,7 +463,8 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
471463
if len(values) and values[0] is None:
472464
fill_value = None
473465

474-
if is_datetime64tz_dtype(empty_dtype):
466+
if isinstance(empty_dtype, DatetimeTZDtype):
467+
# NB: exclude e.g. pyarrow[dt64tz] dtypes
475468
i8values = np.full(self.shape, fill_value.value)
476469
return DatetimeArray(i8values, dtype=empty_dtype)
477470

pandas/core/reshape/merge.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
is_bool,
5151
is_bool_dtype,
5252
is_categorical_dtype,
53-
is_datetime64tz_dtype,
5453
is_dtype_equal,
5554
is_extension_array_dtype,
5655
is_float_dtype,
@@ -62,6 +61,7 @@
6261
is_object_dtype,
6362
needs_i8_conversion,
6463
)
64+
from pandas.core.dtypes.dtypes import DatetimeTZDtype
6565
from pandas.core.dtypes.generic import (
6666
ABCDataFrame,
6767
ABCSeries,
@@ -1349,12 +1349,12 @@ def _maybe_coerce_merge_keys(self) -> None:
13491349
raise ValueError(msg)
13501350
elif not needs_i8_conversion(lk.dtype) and needs_i8_conversion(rk.dtype):
13511351
raise ValueError(msg)
1352-
elif is_datetime64tz_dtype(lk.dtype) and not is_datetime64tz_dtype(
1353-
rk.dtype
1352+
elif isinstance(lk.dtype, DatetimeTZDtype) and not isinstance(
1353+
rk.dtype, DatetimeTZDtype
13541354
):
13551355
raise ValueError(msg)
1356-
elif not is_datetime64tz_dtype(lk.dtype) and is_datetime64tz_dtype(
1357-
rk.dtype
1356+
elif not isinstance(lk.dtype, DatetimeTZDtype) and isinstance(
1357+
rk.dtype, DatetimeTZDtype
13581358
):
13591359
raise ValueError(msg)
13601360

@@ -2280,9 +2280,10 @@ def _factorize_keys(
22802280
rk = extract_array(rk, extract_numpy=True, extract_range=True)
22812281
# TODO: if either is a RangeIndex, we can likely factorize more efficiently?
22822282

2283-
if is_datetime64tz_dtype(lk.dtype) and is_datetime64tz_dtype(rk.dtype):
2283+
if isinstance(lk.dtype, DatetimeTZDtype) and isinstance(rk.dtype, DatetimeTZDtype):
22842284
# Extract the ndarray (UTC-localized) values
22852285
# Note: we dont need the dtypes to match, as these can still be compared
2286+
# TODO(non-nano): need to make sure resolutions match
22862287
lk = cast("DatetimeArray", lk)._ndarray
22872288
rk = cast("DatetimeArray", rk)._ndarray
22882289

pandas/io/formats/format.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
is_categorical_dtype,
6969
is_complex_dtype,
7070
is_datetime64_dtype,
71-
is_datetime64tz_dtype,
7271
is_extension_array_dtype,
7372
is_float,
7473
is_float_dtype,
@@ -79,6 +78,7 @@
7978
is_scalar,
8079
is_timedelta64_dtype,
8180
)
81+
from pandas.core.dtypes.dtypes import DatetimeTZDtype
8282
from pandas.core.dtypes.missing import (
8383
isna,
8484
notna,
@@ -1290,7 +1290,7 @@ def format_array(
12901290
fmt_klass: type[GenericArrayFormatter]
12911291
if is_datetime64_dtype(values.dtype):
12921292
fmt_klass = Datetime64Formatter
1293-
elif is_datetime64tz_dtype(values.dtype):
1293+
elif isinstance(values.dtype, DatetimeTZDtype):
12941294
fmt_klass = Datetime64TZFormatter
12951295
elif is_timedelta64_dtype(values.dtype):
12961296
fmt_klass = Timedelta64Formatter

pandas/tests/extension/test_arrow.py

+1-119
Original file line numberDiff line numberDiff line change
@@ -539,25 +539,13 @@ def test_groupby_extension_apply(
539539
self, data_for_grouping, groupby_apply_op, request
540540
):
541541
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
542-
# TODO: Is there a better way to get the "object" ID for groupby_apply_op?
543-
is_object = "object" in request.node.nodeid
544542
if pa.types.is_duration(pa_dtype):
545543
request.node.add_marker(
546544
pytest.mark.xfail(
547545
raises=pa.ArrowNotImplementedError,
548546
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
549547
)
550548
)
551-
elif pa.types.is_date(pa_dtype) or (
552-
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
553-
):
554-
if is_object:
555-
request.node.add_marker(
556-
pytest.mark.xfail(
557-
raises=TypeError,
558-
reason="GH 47514: _concat_datetime expects axis arg.",
559-
)
560-
)
561549
with tm.maybe_produces_warning(
562550
PerformanceWarning, pa_version_under7p0, check_stacklevel=False
563551
):
@@ -688,70 +676,10 @@ def test_dropna_array(self, data_missing):
688676

689677

690678
class TestBasePrinting(base.BasePrintingTests):
691-
def test_series_repr(self, data, request):
692-
pa_dtype = data.dtype.pyarrow_dtype
693-
if (
694-
pa.types.is_date(pa_dtype)
695-
or pa.types.is_duration(pa_dtype)
696-
or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
697-
):
698-
request.node.add_marker(
699-
pytest.mark.xfail(
700-
raises=TypeError,
701-
reason="GH 47514: _concat_datetime expects axis arg.",
702-
)
703-
)
704-
super().test_series_repr(data)
705-
706-
def test_dataframe_repr(self, data, request):
707-
pa_dtype = data.dtype.pyarrow_dtype
708-
if (
709-
pa.types.is_date(pa_dtype)
710-
or pa.types.is_duration(pa_dtype)
711-
or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
712-
):
713-
request.node.add_marker(
714-
pytest.mark.xfail(
715-
raises=TypeError,
716-
reason="GH 47514: _concat_datetime expects axis arg.",
717-
)
718-
)
719-
super().test_dataframe_repr(data)
679+
pass
720680

721681

722682
class TestBaseReshaping(base.BaseReshapingTests):
723-
@pytest.mark.parametrize("in_frame", [True, False])
724-
def test_concat(self, data, in_frame, request):
725-
pa_dtype = data.dtype.pyarrow_dtype
726-
if (
727-
pa.types.is_date(pa_dtype)
728-
or pa.types.is_duration(pa_dtype)
729-
or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
730-
):
731-
request.node.add_marker(
732-
pytest.mark.xfail(
733-
raises=TypeError,
734-
reason="GH 47514: _concat_datetime expects axis arg.",
735-
)
736-
)
737-
super().test_concat(data, in_frame)
738-
739-
@pytest.mark.parametrize("in_frame", [True, False])
740-
def test_concat_all_na_block(self, data_missing, in_frame, request):
741-
pa_dtype = data_missing.dtype.pyarrow_dtype
742-
if (
743-
pa.types.is_date(pa_dtype)
744-
or pa.types.is_duration(pa_dtype)
745-
or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
746-
):
747-
request.node.add_marker(
748-
pytest.mark.xfail(
749-
raises=TypeError,
750-
reason="GH 47514: _concat_datetime expects axis arg.",
751-
)
752-
)
753-
super().test_concat_all_na_block(data_missing, in_frame)
754-
755683
def test_concat_columns(self, data, na_value, request):
756684
tz = getattr(data.dtype.pyarrow_dtype, "tz", None)
757685
if pa_version_under2p0 and tz not in (None, "UTC"):
@@ -772,26 +700,6 @@ def test_concat_extension_arrays_copy_false(self, data, na_value, request):
772700
)
773701
super().test_concat_extension_arrays_copy_false(data, na_value)
774702

775-
def test_concat_with_reindex(self, data, request, using_array_manager):
776-
pa_dtype = data.dtype.pyarrow_dtype
777-
if pa.types.is_duration(pa_dtype):
778-
request.node.add_marker(
779-
pytest.mark.xfail(
780-
raises=TypeError,
781-
reason="GH 47514: _concat_datetime expects axis arg.",
782-
)
783-
)
784-
elif pa.types.is_date(pa_dtype) or (
785-
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
786-
):
787-
request.node.add_marker(
788-
pytest.mark.xfail(
789-
raises=AttributeError if not using_array_manager else TypeError,
790-
reason="GH 34986",
791-
)
792-
)
793-
super().test_concat_with_reindex(data)
794-
795703
def test_align(self, data, na_value, request):
796704
tz = getattr(data.dtype.pyarrow_dtype, "tz", None)
797705
if pa_version_under2p0 and tz not in (None, "UTC"):
@@ -832,32 +740,6 @@ def test_merge(self, data, na_value, request):
832740
)
833741
super().test_merge(data, na_value)
834742

835-
def test_merge_on_extension_array(self, data, request):
836-
pa_dtype = data.dtype.pyarrow_dtype
837-
if pa.types.is_date(pa_dtype) or (
838-
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
839-
):
840-
request.node.add_marker(
841-
pytest.mark.xfail(
842-
raises=AttributeError,
843-
reason="GH 34986",
844-
)
845-
)
846-
super().test_merge_on_extension_array(data)
847-
848-
def test_merge_on_extension_array_duplicates(self, data, request):
849-
pa_dtype = data.dtype.pyarrow_dtype
850-
if pa.types.is_date(pa_dtype) or (
851-
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
852-
):
853-
request.node.add_marker(
854-
pytest.mark.xfail(
855-
raises=AttributeError,
856-
reason="GH 34986",
857-
)
858-
)
859-
super().test_merge_on_extension_array_duplicates(data)
860-
861743
def test_ravel(self, data, request):
862744
tz = getattr(data.dtype.pyarrow_dtype, "tz", None)
863745
if pa_version_under2p0 and tz not in (None, "UTC"):

0 commit comments

Comments
 (0)