Skip to content

Commit 65af4ef

Browse files
authored
BUG: fix test_arrow.py tests (#48489)
* BUG: fix test_arrow.py tests * fix ArrayManager test * remove unnecessary path
1 parent aea824f commit 65af4ef

File tree

6 files changed

+29
-148
lines changed

6 files changed

+29
-148
lines changed

pandas/core/dtypes/common.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1188,10 +1188,10 @@ def needs_i8_conversion(arr_or_dtype) -> bool:
11881188
"""
11891189
if arr_or_dtype is None:
11901190
return False
1191-
if isinstance(arr_or_dtype, (np.dtype, ExtensionDtype)):
1192-
# fastpath
1193-
dtype = arr_or_dtype
1194-
return dtype.kind in ["m", "M"] or dtype.type is Period
1191+
if isinstance(arr_or_dtype, np.dtype):
1192+
return arr_or_dtype.kind in ["m", "M"]
1193+
elif isinstance(arr_or_dtype, ExtensionDtype):
1194+
return isinstance(arr_or_dtype, (PeriodDtype, DatetimeTZDtype))
11951195

11961196
try:
11971197
dtype = get_dtype(arr_or_dtype)

pandas/core/dtypes/concat.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727
is_dtype_equal,
2828
is_sparse,
2929
)
30-
from pandas.core.dtypes.dtypes import ExtensionDtype
30+
from pandas.core.dtypes.dtypes import (
31+
DatetimeTZDtype,
32+
ExtensionDtype,
33+
)
3134
from pandas.core.dtypes.generic import (
3235
ABCCategoricalIndex,
3336
ABCExtensionArray,
@@ -103,10 +106,12 @@ def is_nonempty(x) -> bool:
103106
# ea_compat_axis see GH#39574
104107
to_concat = non_empties
105108

109+
dtypes = {obj.dtype for obj in to_concat}
106110
kinds = {obj.dtype.kind for obj in to_concat}
107-
contains_datetime = any(kind in ["m", "M"] for kind in kinds) or any(
108-
isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat
109-
)
111+
contains_datetime = any(
112+
isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in ["m", "M"]
113+
for dtype in dtypes
114+
) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)
110115

111116
all_empty = not len(non_empties)
112117
single_dtype = len({x.dtype for x in to_concat}) == 1

pandas/core/internals/concat.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
)
3030
from pandas.core.dtypes.common import (
3131
is_1d_only_ea_dtype,
32-
is_datetime64tz_dtype,
3332
is_dtype_equal,
3433
is_scalar,
3534
needs_i8_conversion,
@@ -38,7 +37,10 @@
3837
cast_to_common_type,
3938
concat_compat,
4039
)
41-
from pandas.core.dtypes.dtypes import ExtensionDtype
40+
from pandas.core.dtypes.dtypes import (
41+
DatetimeTZDtype,
42+
ExtensionDtype,
43+
)
4244
from pandas.core.dtypes.missing import (
4345
is_valid_na_for_dtype,
4446
isna,
@@ -147,16 +149,6 @@ def concat_arrays(to_concat: list) -> ArrayLike:
147149
else:
148150
target_dtype = find_common_type([arr.dtype for arr in to_concat_no_proxy])
149151

150-
if target_dtype.kind in ["m", "M"]:
151-
# for datetimelike use DatetimeArray/TimedeltaArray concatenation
152-
# don't use arr.astype(target_dtype, copy=False), because that doesn't
153-
# work for DatetimeArray/TimedeltaArray (returns ndarray)
154-
to_concat = [
155-
arr.to_array(target_dtype) if isinstance(arr, NullArrayProxy) else arr
156-
for arr in to_concat
157-
]
158-
return type(to_concat_no_proxy[0])._concat_same_type(to_concat, axis=0)
159-
160152
to_concat = [
161153
arr.to_array(target_dtype)
162154
if isinstance(arr, NullArrayProxy)
@@ -471,7 +463,8 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
471463
if len(values) and values[0] is None:
472464
fill_value = None
473465

474-
if is_datetime64tz_dtype(empty_dtype):
466+
if isinstance(empty_dtype, DatetimeTZDtype):
467+
# NB: exclude e.g. pyarrow[dt64tz] dtypes
475468
i8values = np.full(self.shape, fill_value.value)
476469
return DatetimeArray(i8values, dtype=empty_dtype)
477470

pandas/core/reshape/merge.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
is_bool,
5151
is_bool_dtype,
5252
is_categorical_dtype,
53-
is_datetime64tz_dtype,
5453
is_dtype_equal,
5554
is_extension_array_dtype,
5655
is_float_dtype,
@@ -62,6 +61,7 @@
6261
is_object_dtype,
6362
needs_i8_conversion,
6463
)
64+
from pandas.core.dtypes.dtypes import DatetimeTZDtype
6565
from pandas.core.dtypes.generic import (
6666
ABCDataFrame,
6767
ABCSeries,
@@ -1352,12 +1352,12 @@ def _maybe_coerce_merge_keys(self) -> None:
13521352
raise ValueError(msg)
13531353
elif not needs_i8_conversion(lk.dtype) and needs_i8_conversion(rk.dtype):
13541354
raise ValueError(msg)
1355-
elif is_datetime64tz_dtype(lk.dtype) and not is_datetime64tz_dtype(
1356-
rk.dtype
1355+
elif isinstance(lk.dtype, DatetimeTZDtype) and not isinstance(
1356+
rk.dtype, DatetimeTZDtype
13571357
):
13581358
raise ValueError(msg)
1359-
elif not is_datetime64tz_dtype(lk.dtype) and is_datetime64tz_dtype(
1360-
rk.dtype
1359+
elif not isinstance(lk.dtype, DatetimeTZDtype) and isinstance(
1360+
rk.dtype, DatetimeTZDtype
13611361
):
13621362
raise ValueError(msg)
13631363

@@ -2283,9 +2283,10 @@ def _factorize_keys(
22832283
rk = extract_array(rk, extract_numpy=True, extract_range=True)
22842284
# TODO: if either is a RangeIndex, we can likely factorize more efficiently?
22852285

2286-
if is_datetime64tz_dtype(lk.dtype) and is_datetime64tz_dtype(rk.dtype):
2286+
if isinstance(lk.dtype, DatetimeTZDtype) and isinstance(rk.dtype, DatetimeTZDtype):
22872287
# Extract the ndarray (UTC-localized) values
22882288
# Note: we dont need the dtypes to match, as these can still be compared
2289+
# TODO(non-nano): need to make sure resolutions match
22892290
lk = cast("DatetimeArray", lk)._ndarray
22902291
rk = cast("DatetimeArray", rk)._ndarray
22912292

pandas/io/formats/format.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
is_categorical_dtype,
6969
is_complex_dtype,
7070
is_datetime64_dtype,
71-
is_datetime64tz_dtype,
7271
is_extension_array_dtype,
7372
is_float,
7473
is_float_dtype,
@@ -79,6 +78,7 @@
7978
is_scalar,
8079
is_timedelta64_dtype,
8180
)
81+
from pandas.core.dtypes.dtypes import DatetimeTZDtype
8282
from pandas.core.dtypes.missing import (
8383
isna,
8484
notna,
@@ -1290,7 +1290,7 @@ def format_array(
12901290
fmt_klass: type[GenericArrayFormatter]
12911291
if is_datetime64_dtype(values.dtype):
12921292
fmt_klass = Datetime64Formatter
1293-
elif is_datetime64tz_dtype(values.dtype):
1293+
elif isinstance(values.dtype, DatetimeTZDtype):
12941294
fmt_klass = Datetime64TZFormatter
12951295
elif is_timedelta64_dtype(values.dtype):
12961296
fmt_klass = Timedelta64Formatter

pandas/tests/extension/test_arrow.py

+1-119
Original file line numberDiff line numberDiff line change
@@ -542,25 +542,13 @@ def test_groupby_extension_apply(
542542
self, data_for_grouping, groupby_apply_op, request
543543
):
544544
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
545-
# TODO: Is there a better way to get the "object" ID for groupby_apply_op?
546-
is_object = "object" in request.node.nodeid
547545
if pa.types.is_duration(pa_dtype):
548546
request.node.add_marker(
549547
pytest.mark.xfail(
550548
raises=pa.ArrowNotImplementedError,
551549
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
552550
)
553551
)
554-
elif pa.types.is_date(pa_dtype) or (
555-
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
556-
):
557-
if is_object:
558-
request.node.add_marker(
559-
pytest.mark.xfail(
560-
raises=TypeError,
561-
reason="GH 47514: _concat_datetime expects axis arg.",
562-
)
563-
)
564552
with tm.maybe_produces_warning(
565553
PerformanceWarning, pa_version_under7p0, check_stacklevel=False
566554
):
@@ -691,70 +679,10 @@ def test_dropna_array(self, data_missing):
691679

692680

693681
class TestBasePrinting(base.BasePrintingTests):
694-
def test_series_repr(self, data, request):
695-
pa_dtype = data.dtype.pyarrow_dtype
696-
if (
697-
pa.types.is_date(pa_dtype)
698-
or pa.types.is_duration(pa_dtype)
699-
or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
700-
):
701-
request.node.add_marker(
702-
pytest.mark.xfail(
703-
raises=TypeError,
704-
reason="GH 47514: _concat_datetime expects axis arg.",
705-
)
706-
)
707-
super().test_series_repr(data)
708-
709-
def test_dataframe_repr(self, data, request):
710-
pa_dtype = data.dtype.pyarrow_dtype
711-
if (
712-
pa.types.is_date(pa_dtype)
713-
or pa.types.is_duration(pa_dtype)
714-
or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
715-
):
716-
request.node.add_marker(
717-
pytest.mark.xfail(
718-
raises=TypeError,
719-
reason="GH 47514: _concat_datetime expects axis arg.",
720-
)
721-
)
722-
super().test_dataframe_repr(data)
682+
pass
723683

724684

725685
class TestBaseReshaping(base.BaseReshapingTests):
726-
@pytest.mark.parametrize("in_frame", [True, False])
727-
def test_concat(self, data, in_frame, request):
728-
pa_dtype = data.dtype.pyarrow_dtype
729-
if (
730-
pa.types.is_date(pa_dtype)
731-
or pa.types.is_duration(pa_dtype)
732-
or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
733-
):
734-
request.node.add_marker(
735-
pytest.mark.xfail(
736-
raises=TypeError,
737-
reason="GH 47514: _concat_datetime expects axis arg.",
738-
)
739-
)
740-
super().test_concat(data, in_frame)
741-
742-
@pytest.mark.parametrize("in_frame", [True, False])
743-
def test_concat_all_na_block(self, data_missing, in_frame, request):
744-
pa_dtype = data_missing.dtype.pyarrow_dtype
745-
if (
746-
pa.types.is_date(pa_dtype)
747-
or pa.types.is_duration(pa_dtype)
748-
or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
749-
):
750-
request.node.add_marker(
751-
pytest.mark.xfail(
752-
raises=TypeError,
753-
reason="GH 47514: _concat_datetime expects axis arg.",
754-
)
755-
)
756-
super().test_concat_all_na_block(data_missing, in_frame)
757-
758686
def test_concat_columns(self, data, na_value, request):
759687
tz = getattr(data.dtype.pyarrow_dtype, "tz", None)
760688
if pa_version_under2p0 and tz not in (None, "UTC"):
@@ -775,26 +703,6 @@ def test_concat_extension_arrays_copy_false(self, data, na_value, request):
775703
)
776704
super().test_concat_extension_arrays_copy_false(data, na_value)
777705

778-
def test_concat_with_reindex(self, data, request, using_array_manager):
779-
pa_dtype = data.dtype.pyarrow_dtype
780-
if pa.types.is_duration(pa_dtype):
781-
request.node.add_marker(
782-
pytest.mark.xfail(
783-
raises=TypeError,
784-
reason="GH 47514: _concat_datetime expects axis arg.",
785-
)
786-
)
787-
elif pa.types.is_date(pa_dtype) or (
788-
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
789-
):
790-
request.node.add_marker(
791-
pytest.mark.xfail(
792-
raises=AttributeError if not using_array_manager else TypeError,
793-
reason="GH 34986",
794-
)
795-
)
796-
super().test_concat_with_reindex(data)
797-
798706
def test_align(self, data, na_value, request):
799707
tz = getattr(data.dtype.pyarrow_dtype, "tz", None)
800708
if pa_version_under2p0 and tz not in (None, "UTC"):
@@ -835,32 +743,6 @@ def test_merge(self, data, na_value, request):
835743
)
836744
super().test_merge(data, na_value)
837745

838-
def test_merge_on_extension_array(self, data, request):
839-
pa_dtype = data.dtype.pyarrow_dtype
840-
if pa.types.is_date(pa_dtype) or (
841-
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
842-
):
843-
request.node.add_marker(
844-
pytest.mark.xfail(
845-
raises=AttributeError,
846-
reason="GH 34986",
847-
)
848-
)
849-
super().test_merge_on_extension_array(data)
850-
851-
def test_merge_on_extension_array_duplicates(self, data, request):
852-
pa_dtype = data.dtype.pyarrow_dtype
853-
if pa.types.is_date(pa_dtype) or (
854-
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
855-
):
856-
request.node.add_marker(
857-
pytest.mark.xfail(
858-
raises=AttributeError,
859-
reason="GH 34986",
860-
)
861-
)
862-
super().test_merge_on_extension_array_duplicates(data)
863-
864746
def test_ravel(self, data, request):
865747
tz = getattr(data.dtype.pyarrow_dtype, "tz", None)
866748
if pa_version_under2p0 and tz not in (None, "UTC"):

0 commit comments

Comments
 (0)