Skip to content

Commit 3e0c1da

Browse files
authored
CLN: Use type_mapper instead of manual conversion (#51766)
1 parent 9312dde commit 3e0c1da

File tree

6 files changed

+33
-56
lines changed

6 files changed

+33
-56
lines changed

pandas/io/feather_format.py

+3-12
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
from pandas.compat._optional import import_optional_dependency
1414
from pandas.util._decorators import doc
1515

16-
from pandas import (
17-
arrays,
18-
get_option,
19-
)
16+
import pandas as pd
17+
from pandas import get_option
2018
from pandas.core.api import DataFrame
2119
from pandas.core.shared_docs import _shared_docs
2220

@@ -140,11 +138,4 @@ def read_feather(
140138
return pa_table.to_pandas(types_mapper=_arrow_dtype_mapping().get)
141139

142140
elif dtype_backend == "pyarrow":
143-
return DataFrame(
144-
{
145-
col_name: arrays.ArrowExtensionArray(pa_col)
146-
for col_name, pa_col in zip(
147-
pa_table.column_names, pa_table.itercolumns()
148-
)
149-
}
150-
)
141+
return pa_table.to_pandas(types_mapper=pd.ArrowDtype)

pandas/io/json/_json.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from pandas.core.dtypes.generic import ABCIndex
4343

4444
from pandas import (
45+
ArrowDtype,
4546
DataFrame,
4647
MultiIndex,
4748
Series,
@@ -963,16 +964,8 @@ def read(self) -> DataFrame | Series:
963964
pa_table = pyarrow_json.read_json(self.data)
964965
if self.use_nullable_dtypes:
965966
if get_option("mode.dtype_backend") == "pyarrow":
966-
from pandas.arrays import ArrowExtensionArray
967-
968-
return DataFrame(
969-
{
970-
col_name: ArrowExtensionArray(pa_col)
971-
for col_name, pa_col in zip(
972-
pa_table.column_names, pa_table.itercolumns()
973-
)
974-
}
975-
)
967+
return pa_table.to_pandas(types_mapper=ArrowDtype)
968+
976969
elif get_option("mode.dtype_backend") == "pandas":
977970
from pandas.io._util import _arrow_dtype_mapping
978971

pandas/io/orc.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
is_unsigned_integer_dtype,
2626
)
2727

28-
from pandas.core.arrays import ArrowExtensionArray
29-
from pandas.core.frame import DataFrame
28+
import pandas as pd
3029

3130
from pandas.io.common import (
3231
get_handle,
@@ -40,6 +39,8 @@
4039
WriteBuffer,
4140
)
4241

42+
from pandas.core.frame import DataFrame
43+
4344

4445
def read_orc(
4546
path: FilePath | ReadBuffer[bytes],
@@ -127,14 +128,7 @@ def read_orc(
127128
if use_nullable_dtypes:
128129
dtype_backend = get_option("mode.dtype_backend")
129130
if dtype_backend == "pyarrow":
130-
df = DataFrame(
131-
{
132-
col_name: ArrowExtensionArray(pa_col)
133-
for col_name, pa_col in zip(
134-
pa_table.column_names, pa_table.itercolumns()
135-
)
136-
}
137-
)
131+
df = pa_table.to_pandas(types_mapper=pd.ArrowDtype)
138132
else:
139133
from pandas.io._util import _arrow_dtype_mapping
140134

pandas/io/parquet.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from pandas.errors import AbstractMethodError
1818
from pandas.util._decorators import doc
1919

20+
import pandas as pd
2021
from pandas import (
2122
DataFrame,
2223
MultiIndex,
23-
arrays,
2424
get_option,
2525
)
2626
from pandas.core.shared_docs import _shared_docs
@@ -252,14 +252,11 @@ def read(
252252
if dtype_backend == "pandas":
253253
result = pa_table.to_pandas(**to_pandas_kwargs)
254254
elif dtype_backend == "pyarrow":
255-
result = DataFrame(
256-
{
257-
col_name: arrays.ArrowExtensionArray(pa_col)
258-
for col_name, pa_col in zip(
259-
pa_table.column_names, pa_table.itercolumns()
260-
)
261-
}
262-
)
255+
# Incompatible types in assignment (expression has type
256+
# "Type[ArrowDtype]", target has type overloaded function
257+
to_pandas_kwargs["types_mapper"] = pd.ArrowDtype # type: ignore[assignment] # noqa
258+
result = pa_table.to_pandas(**to_pandas_kwargs)
259+
263260
if manager == "array":
264261
result = result._as_manager("array", copy=False)
265262
return result

pandas/io/parsers/arrow_parser_wrapper.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
from pandas.core.dtypes.inference import is_integer
88

9+
import pandas as pd
910
from pandas import (
1011
DataFrame,
11-
arrays,
1212
get_option,
1313
)
1414

@@ -156,12 +156,7 @@ def read(self) -> DataFrame:
156156
self.kwds["use_nullable_dtypes"]
157157
and get_option("mode.dtype_backend") == "pyarrow"
158158
):
159-
frame = DataFrame(
160-
{
161-
col_name: arrays.ArrowExtensionArray(pa_col)
162-
for col_name, pa_col in zip(table.column_names, table.itercolumns())
163-
}
164-
)
159+
frame = table.to_pandas(types_mapper=pd.ArrowDtype)
165160
else:
166161
frame = table.to_pandas()
167162
return self._finalize_pandas_output(frame)

pandas/tests/io/test_parquet.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -1034,14 +1034,7 @@ def test_read_use_nullable_types_pyarrow_config(self, pa, df_full):
10341034
df["bool_with_none"] = [True, None, True]
10351035

10361036
pa_table = pyarrow.Table.from_pandas(df)
1037-
expected = pd.DataFrame(
1038-
{
1039-
col_name: pd.arrays.ArrowExtensionArray(pa_column)
1040-
for col_name, pa_column in zip(
1041-
pa_table.column_names, pa_table.itercolumns()
1042-
)
1043-
}
1044-
)
1037+
expected = pa_table.to_pandas(types_mapper=pd.ArrowDtype)
10451038
# pyarrow infers datetimes as us instead of ns
10461039
expected["datetime"] = expected["datetime"].astype("timestamp[us][pyarrow]")
10471040
expected["datetime_with_nat"] = expected["datetime_with_nat"].astype(
@@ -1059,6 +1052,20 @@ def test_read_use_nullable_types_pyarrow_config(self, pa, df_full):
10591052
expected=expected,
10601053
)
10611054

1055+
def test_read_use_nullable_types_pyarrow_config_index(self, pa):
1056+
df = pd.DataFrame(
1057+
{"a": [1, 2]}, index=pd.Index([3, 4], name="test"), dtype="int64[pyarrow]"
1058+
)
1059+
expected = df.copy()
1060+
1061+
with pd.option_context("mode.dtype_backend", "pyarrow"):
1062+
check_round_trip(
1063+
df,
1064+
engine=pa,
1065+
read_kwargs={"use_nullable_dtypes": True},
1066+
expected=expected,
1067+
)
1068+
10621069

10631070
class TestParquetFastParquet(Base):
10641071
def test_basic(self, fp, df_full):

0 commit comments

Comments
 (0)