Skip to content

Commit 4763b91

Browse files
authored
BUG: Return numpy types from ArrowExtensionArray.to_numpy for temporal types when possible (#56459)
1 parent 94a6d6a commit 4763b91

File tree

7 files changed

+99
-38
lines changed

7 files changed

+99
-38
lines changed

doc/source/whatsnew/v2.2.0.rst

+8-4
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,11 @@ documentation.
107107

108108
.. _whatsnew_220.enhancements.to_numpy_ea:
109109

110-
ExtensionArray.to_numpy converts to suitable NumPy dtype
111-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
110+
``to_numpy`` for NumPy nullable and Arrow types converts to suitable NumPy dtype
111+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
112112

113-
:meth:`ExtensionArray.to_numpy` will now convert to a suitable NumPy dtype instead
114-
of ``object`` dtype for nullable extension dtypes.
113+
``to_numpy`` for NumPy nullable and Arrow types will now convert to a
114+
suitable NumPy dtype instead of ``object`` dtype for nullable extension dtypes.
115115

116116
*Old behavior:*
117117

@@ -128,13 +128,17 @@ of ``object`` dtype for nullable extension dtypes.
128128
ser = pd.Series([1, 2, 3], dtype="Int64")
129129
ser.to_numpy()
130130
131+
ser = pd.Series([1, 2, 3], dtype="timestamp[ns][pyarrow]")
132+
ser.to_numpy()
133+
131134
The default NumPy dtype (without any arguments) is determined as follows:
132135

133136
- float dtypes are cast to NumPy floats
134137
- integer dtypes without missing values are cast to NumPy integer dtypes
135138
- integer dtypes with missing values are cast to NumPy float dtypes and ``NaN`` is used as missing value indicator
136139
- boolean dtypes without missing values are cast to NumPy bool dtype
137140
- boolean dtypes with missing values keep object dtype
141+
- datetime and timedelta types are cast to Numpy datetime64 and timedelta64 types respectively and ``NaT`` is used as missing value indicator
138142

139143
.. _whatsnew_220.enhancements.struct_accessor:
140144

pandas/core/arrays/arrow/array.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from pandas.core import (
4949
algorithms as algos,
5050
missing,
51+
ops,
5152
roperator,
5253
)
5354
from pandas.core.arraylike import OpsMixin
@@ -676,7 +677,11 @@ def _cmp_method(self, other, op):
676677
mask = isna(self) | isna(other)
677678
valid = ~mask
678679
result = np.zeros(len(self), dtype="bool")
679-
result[valid] = op(np.array(self)[valid], other)
680+
np_array = np.array(self)
681+
try:
682+
result[valid] = op(np_array[valid], other)
683+
except TypeError:
684+
result = ops.invalid_comparison(np_array, other, op)
680685
result = pa.array(result, type=pa.bool_())
681686
result = pc.if_else(valid, result, None)
682687
else:
@@ -1151,7 +1156,16 @@ def searchsorted(
11511156
if isinstance(value, ExtensionArray):
11521157
value = value.astype(object)
11531158
# Base class searchsorted would cast to object, which is *much* slower.
1154-
return self.to_numpy().searchsorted(value, side=side, sorter=sorter)
1159+
dtype = None
1160+
if isinstance(self.dtype, ArrowDtype):
1161+
pa_dtype = self.dtype.pyarrow_dtype
1162+
if (
1163+
pa.types.is_timestamp(pa_dtype) or pa.types.is_duration(pa_dtype)
1164+
) and pa_dtype.unit == "ns":
1165+
# np.array[datetime/timedelta].searchsorted(datetime/timedelta)
1166+
# erroneously fails when numpy type resolution is nanoseconds
1167+
dtype = object
1168+
return self.to_numpy(dtype=dtype).searchsorted(value, side=side, sorter=sorter)
11551169

11561170
def take(
11571171
self,
@@ -1302,10 +1316,15 @@ def to_numpy(
13021316

13031317
if pa.types.is_timestamp(pa_type) or pa.types.is_duration(pa_type):
13041318
result = data._maybe_convert_datelike_array()
1305-
if dtype is None or dtype.kind == "O":
1306-
result = result.to_numpy(dtype=object, na_value=na_value)
1319+
if (pa.types.is_timestamp(pa_type) and pa_type.tz is not None) or (
1320+
dtype is not None and dtype.kind == "O"
1321+
):
1322+
dtype = object
13071323
else:
1308-
result = result.to_numpy(dtype=dtype)
1324+
# GH 55997
1325+
dtype = None
1326+
na_value = pa_type.to_pandas_dtype().type("nat", pa_type.unit)
1327+
result = result.to_numpy(dtype=dtype, na_value=na_value)
13091328
elif pa.types.is_time(pa_type) or pa.types.is_date(pa_type):
13101329
# convert to list of python datetime.time objects before
13111330
# wrapping in ndarray

pandas/core/arrays/sparse/accessor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -273,9 +273,9 @@ def from_spmatrix(cls, data, index=None, columns=None) -> DataFrame:
273273
>>> mat = scipy.sparse.eye(3, dtype=float)
274274
>>> pd.DataFrame.sparse.from_spmatrix(mat)
275275
0 1 2
276-
0 1.0 0.0 0.0
277-
1 0.0 1.0 0.0
278-
2 0.0 0.0 1.0
276+
0 1.0 0 0
277+
1 0 1.0 0
278+
2 0 0 1.0
279279
"""
280280
from pandas._libs.sparse import IntIndex
281281

pandas/io/formats/format.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
)
6262

6363
from pandas.core.arrays import (
64-
BaseMaskedArray,
6564
Categorical,
6665
DatetimeArray,
6766
ExtensionArray,
@@ -1528,10 +1527,8 @@ def _format_strings(self) -> list[str]:
15281527
if isinstance(values, Categorical):
15291528
# Categorical is special for now, so that we can preserve tzinfo
15301529
array = values._internal_get_values()
1531-
elif isinstance(values, BaseMaskedArray):
1532-
array = values.to_numpy(dtype=object)
15331530
else:
1534-
array = np.asarray(values)
1531+
array = np.asarray(values, dtype=object)
15351532

15361533
fmt_values = format_array(
15371534
array,

pandas/tests/arrays/string_/test_string.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype):
263263
other = 42
264264

265265
if op_name not in ["__eq__", "__ne__"]:
266-
with pytest.raises(TypeError, match="not supported between"):
266+
with pytest.raises(TypeError, match="Invalid comparison|not supported between"):
267267
getattr(a, op_name)(other)
268268

269269
return

pandas/tests/extension/test_arrow.py

+61-20
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
pa_version_under13p0,
4242
pa_version_under14p0,
4343
)
44+
import pandas.util._test_decorators as td
4445

4546
from pandas.core.dtypes.dtypes import (
4647
ArrowDtype,
@@ -266,6 +267,19 @@ def data_for_twos(data):
266267

267268

268269
class TestArrowArray(base.ExtensionTests):
270+
def test_compare_scalar(self, data, comparison_op):
271+
ser = pd.Series(data)
272+
self._compare_other(ser, data, comparison_op, data[0])
273+
274+
@pytest.mark.parametrize("na_action", [None, "ignore"])
275+
def test_map(self, data_missing, na_action):
276+
if data_missing.dtype.kind in "mM":
277+
result = data_missing.map(lambda x: x, na_action=na_action)
278+
expected = data_missing.to_numpy(dtype=object)
279+
tm.assert_numpy_array_equal(result, expected)
280+
else:
281+
super().test_map(data_missing, na_action)
282+
269283
def test_astype_str(self, data, request):
270284
pa_dtype = data.dtype.pyarrow_dtype
271285
if pa.types.is_binary(pa_dtype):
@@ -274,8 +288,35 @@ def test_astype_str(self, data, request):
274288
reason=f"For {pa_dtype} .astype(str) decodes.",
275289
)
276290
)
291+
elif (
292+
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
293+
) or pa.types.is_duration(pa_dtype):
294+
request.applymarker(
295+
pytest.mark.xfail(
296+
reason="pd.Timestamp/pd.Timedelta repr different from numpy repr",
297+
)
298+
)
277299
super().test_astype_str(data)
278300

301+
@pytest.mark.parametrize(
302+
"nullable_string_dtype",
303+
[
304+
"string[python]",
305+
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
306+
],
307+
)
308+
def test_astype_string(self, data, nullable_string_dtype, request):
309+
pa_dtype = data.dtype.pyarrow_dtype
310+
if (
311+
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
312+
) or pa.types.is_duration(pa_dtype):
313+
request.applymarker(
314+
pytest.mark.xfail(
315+
reason="pd.Timestamp/pd.Timedelta repr different from numpy repr",
316+
)
317+
)
318+
super().test_astype_string(data, nullable_string_dtype)
319+
279320
def test_from_dtype(self, data, request):
280321
pa_dtype = data.dtype.pyarrow_dtype
281322
if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype):
@@ -1511,11 +1552,9 @@ def test_to_numpy_with_defaults(data):
15111552
result = data.to_numpy()
15121553

15131554
pa_type = data._pa_array.type
1514-
if (
1515-
pa.types.is_duration(pa_type)
1516-
or pa.types.is_timestamp(pa_type)
1517-
or pa.types.is_date(pa_type)
1518-
):
1555+
if pa.types.is_duration(pa_type) or pa.types.is_timestamp(pa_type):
1556+
pytest.skip("Tested in test_to_numpy_temporal")
1557+
elif pa.types.is_date(pa_type):
15191558
expected = np.array(list(data))
15201559
else:
15211560
expected = np.array(data._pa_array)
@@ -2937,26 +2976,28 @@ def test_groupby_series_size_returns_pa_int(data):
29372976

29382977

29392978
@pytest.mark.parametrize(
2940-
"pa_type", tm.DATETIME_PYARROW_DTYPES + tm.TIMEDELTA_PYARROW_DTYPES
2979+
"pa_type", tm.DATETIME_PYARROW_DTYPES + tm.TIMEDELTA_PYARROW_DTYPES, ids=repr
29412980
)
2942-
def test_to_numpy_temporal(pa_type):
2981+
@pytest.mark.parametrize("dtype", [None, object])
2982+
def test_to_numpy_temporal(pa_type, dtype):
29432983
# GH 53326
2984+
# GH 55997: Return datetime64/timedelta64 types with NaT if possible
29442985
arr = ArrowExtensionArray(pa.array([1, None], type=pa_type))
2945-
result = arr.to_numpy()
2986+
result = arr.to_numpy(dtype=dtype)
29462987
if pa.types.is_duration(pa_type):
2947-
expected = [
2948-
pd.Timedelta(1, unit=pa_type.unit).as_unit(pa_type.unit),
2949-
pd.NA,
2950-
]
2951-
assert isinstance(result[0], pd.Timedelta)
2988+
value = pd.Timedelta(1, unit=pa_type.unit).as_unit(pa_type.unit)
29522989
else:
2953-
expected = [
2954-
pd.Timestamp(1, unit=pa_type.unit, tz=pa_type.tz).as_unit(pa_type.unit),
2955-
pd.NA,
2956-
]
2957-
assert isinstance(result[0], pd.Timestamp)
2958-
expected = np.array(expected, dtype=object)
2959-
assert result[0].unit == expected[0].unit
2990+
value = pd.Timestamp(1, unit=pa_type.unit, tz=pa_type.tz).as_unit(pa_type.unit)
2991+
2992+
if dtype == object or (pa.types.is_timestamp(pa_type) and pa_type.tz is not None):
2993+
na = pd.NA
2994+
expected = np.array([value, na], dtype=object)
2995+
assert result[0].unit == value.unit
2996+
else:
2997+
na = pa_type.to_pandas_dtype().type("nat", pa_type.unit)
2998+
value = value.to_numpy()
2999+
expected = np.array([value, na])
3000+
assert np.datetime_data(result[0])[0] == pa_type.unit
29603001
tm.assert_numpy_array_equal(result, expected)
29613002

29623003

pandas/tests/io/formats/test_format.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1921,7 +1921,7 @@ def dtype(self):
19211921
series = Series(ExtTypeStub(), copy=False)
19221922
res = repr(series) # This line crashed before #33770 was fixed.
19231923
expected = "\n".join(
1924-
["0 [False True]", "1 [ True False]", "dtype: DtypeStub"]
1924+
["0 [False True]", "1 [True False]", "dtype: DtypeStub"]
19251925
)
19261926
assert res == expected
19271927

0 commit comments

Comments
 (0)