diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index e281e250d608e..0a2b5ed3b0789 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -42,6 +42,7 @@ Other enhancements - :meth:`Series.add_suffix`, :meth:`DataFrame.add_suffix`, :meth:`Series.add_prefix` and :meth:`DataFrame.add_prefix` support an ``axis`` argument. If ``axis`` is set, the default behaviour of which axis to consider can be overwritten (:issue:`47819`) - :func:`assert_frame_equal` now shows the first element where the DataFrames differ, analogously to ``pytest``'s output (:issue:`47910`) - Added new argument ``use_nullable_dtypes`` to :func:`read_csv` to enable automatic conversion to nullable dtypes (:issue:`36712`) +- Added new global configuration, ``io.nullable_backend`` to allow ``use_nullable_dtypes=True`` to return pyarrow-backed dtypes when set to ``"pyarrow"`` in :func:`read_parquet` (:issue:`48957`) - Added ``index`` parameter to :meth:`DataFrame.to_dict` (:issue:`46398`) - Added metadata propagation for binary operators on :class:`DataFrame` (:issue:`28283`) - :class:`.CategoricalConversionWarning`, :class:`.InvalidComparison`, :class:`.InvalidVersion`, :class:`.LossySetitemError`, and :class:`.NoBufferPresent` are now exposed in ``pandas.errors`` (:issue:`27656`) diff --git a/pandas/core/config_init.py b/pandas/core/config_init.py index 4bd6fb4b59caf..378f9dc80bb6d 100644 --- a/pandas/core/config_init.py +++ b/pandas/core/config_init.py @@ -730,6 +730,20 @@ def use_inf_as_na_cb(key) -> None: validator=is_one_of_factory(["auto", "sqlalchemy"]), ) +io_nullable_backend_doc = """ +: string + The nullable dtype implementation to return when ``use_nullable_dtypes=True``. + Available options: 'pandas', 'pyarrow', the default is 'pandas'. +""" + +with cf.config_prefix("io.nullable_backend"): + cf.register_option( + "io_nullable_backend", + "pandas", + io_nullable_backend_doc, + validator=is_one_of_factory(["pandas", "pyarrow"]), + ) + # -------- # Plotting # --------- diff --git a/pandas/io/parquet.py b/pandas/io/parquet.py index 6b7a10b7fad63..df02a6fbca295 100644 --- a/pandas/io/parquet.py +++ b/pandas/io/parquet.py @@ -22,6 +22,7 @@ from pandas import ( DataFrame, MultiIndex, + arrays, get_option, ) from pandas.core.shared_docs import _shared_docs @@ -221,25 +222,27 @@ def read( ) -> DataFrame: kwargs["use_pandas_metadata"] = True + nullable_backend = get_option("io.nullable_backend") to_pandas_kwargs = {} if use_nullable_dtypes: import pandas as pd - mapping = { - self.api.int8(): pd.Int8Dtype(), - self.api.int16(): pd.Int16Dtype(), - self.api.int32(): pd.Int32Dtype(), - self.api.int64(): pd.Int64Dtype(), - self.api.uint8(): pd.UInt8Dtype(), - self.api.uint16(): pd.UInt16Dtype(), - self.api.uint32(): pd.UInt32Dtype(), - self.api.uint64(): pd.UInt64Dtype(), - self.api.bool_(): pd.BooleanDtype(), - self.api.string(): pd.StringDtype(), - self.api.float32(): pd.Float32Dtype(), - self.api.float64(): pd.Float64Dtype(), - } - to_pandas_kwargs["types_mapper"] = mapping.get + if nullable_backend == "pandas": + mapping = { + self.api.int8(): pd.Int8Dtype(), + self.api.int16(): pd.Int16Dtype(), + self.api.int32(): pd.Int32Dtype(), + self.api.int64(): pd.Int64Dtype(), + self.api.uint8(): pd.UInt8Dtype(), + self.api.uint16(): pd.UInt16Dtype(), + self.api.uint32(): pd.UInt32Dtype(), + self.api.uint64(): pd.UInt64Dtype(), + self.api.bool_(): pd.BooleanDtype(), + self.api.string(): pd.StringDtype(), + self.api.float32(): pd.Float32Dtype(), + self.api.float64(): pd.Float64Dtype(), + } + to_pandas_kwargs["types_mapper"] = mapping.get manager = get_option("mode.data_manager") if manager == "array": to_pandas_kwargs["split_blocks"] = True # type: ignore[assignment] @@ -251,9 +254,20 @@ def read( mode="rb", ) try: - result = self.api.parquet.read_table( + pa_table = self.api.parquet.read_table( path_or_handle, columns=columns, **kwargs - ).to_pandas(**to_pandas_kwargs) + ) + if nullable_backend == "pandas": + result = pa_table.to_pandas(**to_pandas_kwargs) + elif nullable_backend == "pyarrow": + result = DataFrame( + { + col_name: arrays.ArrowExtensionArray(pa_col) + for col_name, pa_col in zip( + pa_table.column_names, pa_table.itercolumns() + ) + } + ) if manager == "array": result = result._as_manager("array", copy=False) return result @@ -494,6 +508,13 @@ def read_parquet( .. versionadded:: 1.2.0 + The nullable dtype implementation can be configured by setting the global + ``io.nullable_backend`` configuration option to ``"pandas"`` to use + numpy-backed nullable dtypes or ``"pyarrow"`` to use pyarrow-backed + nullable dtypes (using ``pd.ArrowDtype``). + + .. versionadded:: 2.0.0 + **kwargs Any additional kwargs are passed to the engine. diff --git a/pandas/tests/io/test_parquet.py b/pandas/tests/io/test_parquet.py index 45b19badf48f3..9c85ab4ba4a57 100644 --- a/pandas/tests/io/test_parquet.py +++ b/pandas/tests/io/test_parquet.py @@ -1014,6 +1014,43 @@ def test_read_parquet_manager(self, pa, using_array_manager): else: assert isinstance(result._mgr, pd.core.internals.BlockManager) + def test_read_use_nullable_types_pyarrow_config(self, pa, df_full): + import pyarrow + + df = df_full + + # additional supported types for pyarrow + dti = pd.date_range("20130101", periods=3, tz="Europe/Brussels") + dti = dti._with_freq(None) # freq doesn't round-trip + df["datetime_tz"] = dti + df["bool_with_none"] = [True, None, True] + + pa_table = pyarrow.Table.from_pandas(df) + expected = pd.DataFrame( + { + col_name: pd.arrays.ArrowExtensionArray(pa_column) + for col_name, pa_column in zip( + pa_table.column_names, pa_table.itercolumns() + ) + } + ) + # pyarrow infers datetimes as us instead of ns + expected["datetime"] = expected["datetime"].astype("timestamp[us][pyarrow]") + expected["datetime_with_nat"] = expected["datetime_with_nat"].astype( + "timestamp[us][pyarrow]" + ) + expected["datetime_tz"] = expected["datetime_tz"].astype( + pd.ArrowDtype(pyarrow.timestamp(unit="us", tz="Europe/Brussels")) + ) + + with pd.option_context("io.nullable_backend", "pyarrow"): + check_round_trip( + df, + engine=pa, + read_kwargs={"use_nullable_dtypes": True}, + expected=expected, + ) + class TestParquetFastParquet(Base): def test_basic(self, fp, df_full):