Skip to content

Commit 0f06e4c

Browse files
Backport PR #53651 on branch 2.0.x (BUG: convert_dtype(dtype_backend=nullable_numpy) with ArrowDtype) (#53749)
* Backport PR #53651: BUG: convert_dtype(dtype_backend=nullable_numpy) with ArrowDtype * Move import --------- Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 0e28257 commit 0f06e4c

File tree

4 files changed

+25
-2
lines changed

4 files changed

+25
-2
lines changed

doc/source/whatsnew/v2.0.3.rst

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Fixed regressions
2121

2222
Bug fixes
2323
~~~~~~~~~
24+
- Bug in :func:`DataFrame.convert_dtype` and :func:`Series.convert_dtype` when trying to convert :class:`ArrowDtype` with ``dtype_backend="nullable_numpy"`` (:issue:`53648`)
2425
- Bug in :func:`RangeIndex.union` when using ``sort=True`` with another :class:`RangeIndex` (:issue:`53490`)
2526
- Bug in :func:`Series.reindex` when expanding a non-nanosecond datetime or timedelta :class:`Series` would not fill with ``NaT`` correctly (:issue:`53497`)
2627
- Bug in :func:`read_csv` when defining ``dtype`` with ``bool[pyarrow]`` for the ``"c"`` and ``"python"`` engines (:issue:`53390`)

pandas/core/dtypes/cast.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@
9696
notna,
9797
)
9898

99+
from pandas.io._util import _arrow_dtype_mapping
100+
99101
if TYPE_CHECKING:
100102
from pandas import Index
101103
from pandas.core.arrays import (
@@ -1046,6 +1048,8 @@ def convert_dtypes(
10461048
"""
10471049
inferred_dtype: str | DtypeObj
10481050

1051+
from pandas.core.arrays.arrow.dtype import ArrowDtype
1052+
10491053
if (
10501054
convert_string or convert_integer or convert_boolean or convert_floating
10511055
) and isinstance(input_array, np.ndarray):
@@ -1128,7 +1132,6 @@ def convert_dtypes(
11281132

11291133
if dtype_backend == "pyarrow":
11301134
from pandas.core.arrays.arrow.array import to_pyarrow_type
1131-
from pandas.core.arrays.arrow.dtype import ArrowDtype
11321135
from pandas.core.arrays.string_ import StringDtype
11331136

11341137
assert not isinstance(inferred_dtype, str)
@@ -1156,6 +1159,9 @@ def convert_dtypes(
11561159
pa_type = to_pyarrow_type(base_dtype)
11571160
if pa_type is not None:
11581161
inferred_dtype = ArrowDtype(pa_type)
1162+
elif dtype_backend == "numpy_nullable" and isinstance(inferred_dtype, ArrowDtype):
1163+
# GH 53648
1164+
inferred_dtype = _arrow_dtype_mapping()[inferred_dtype.pyarrow_dtype]
11591165

11601166
# error: Incompatible return value type (got "Union[str, Union[dtype[Any],
11611167
# ExtensionDtype]]", expected "Union[dtype[Any], ExtensionDtype]")

pandas/tests/frame/methods/test_convert_dtypes.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_pyarrow_engine_lines_false(self):
146146
with pytest.raises(ValueError, match=msg):
147147
df.convert_dtypes(dtype_backend="numpy")
148148

149-
def test_pyarrow_backend_no_convesion(self):
149+
def test_pyarrow_backend_no_conversion(self):
150150
# GH#52872
151151
pytest.importorskip("pyarrow")
152152
df = pd.DataFrame({"a": [1, 2], "b": 1.5, "c": True, "d": "x"})
@@ -159,3 +159,11 @@ def test_pyarrow_backend_no_convesion(self):
159159
dtype_backend="pyarrow",
160160
)
161161
tm.assert_frame_equal(result, expected)
162+
163+
def test_convert_dtypes_pyarrow_to_np_nullable(self):
164+
# GH 53648
165+
pytest.importorskip("pyarrow")
166+
ser = pd.DataFrame(range(2), dtype="int32[pyarrow]")
167+
result = ser.convert_dtypes(dtype_backend="numpy_nullable")
168+
expected = pd.DataFrame(range(2), dtype="Int32")
169+
tm.assert_frame_equal(result, expected)

pandas/tests/series/methods/test_convert_dtypes.py

+8
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,11 @@ def test_convert_dtype_object_with_na_float(self, infer_objects, dtype):
240240
result = ser.convert_dtypes(infer_objects=infer_objects)
241241
expected = pd.Series([1.5, pd.NA], dtype=dtype)
242242
tm.assert_series_equal(result, expected)
243+
244+
def test_convert_dtypes_pyarrow_to_np_nullable(self):
245+
# GH 53648
246+
pytest.importorskip("pyarrow")
247+
ser = pd.Series(range(2), dtype="int32[pyarrow]")
248+
result = ser.convert_dtypes(dtype_backend="numpy_nullable")
249+
expected = pd.Series(range(2), dtype="Int32")
250+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)