Skip to content

Commit a31b478

Browse files
ieavesproost
authored andcommitted
ENH: pass through schema keyword in to_parquet for pyarrow (pandas-dev#30270)
1 parent bfd98d4 commit a31b478

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

doc/source/whatsnew/v1.0.0.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ Other enhancements
205205
(:meth:`~DataFrame.to_parquet` / :func:`read_parquet`) using the `'pyarrow'` engine
206206
now preserve those data types with pyarrow >= 1.0.0 (:issue:`20612`).
207207
- The ``partition_cols`` argument in :meth:`DataFrame.to_parquet` now accepts a string (:issue:`27117`)
208+
- :func:`to_parquet` now appropriately handles the ``schema`` argument for user defined schemas in the pyarrow engine. (:issue: `30270`)
209+
208210

209211
Build Changes
210212
^^^^^^^^^^^^^
@@ -801,7 +803,6 @@ I/O
801803
- Bug in :func:`read_json` where default encoding was not set to ``utf-8`` (:issue:`29565`)
802804
- Bug in :class:`PythonParser` where str and bytes were being mixed when dealing with the decimal field (:issue:`29650`)
803805
- :meth:`read_gbq` now accepts ``progress_bar_type`` to display progress bar while the data downloads. (:issue:`29857`)
804-
-
805806

806807
Plotting
807808
^^^^^^^^

pandas/io/parquet.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,10 @@ def write(
9191
self.validate_dataframe(df)
9292
path, _, _, _ = get_filepath_or_buffer(path, mode="wb")
9393

94-
from_pandas_kwargs: Dict[str, Any]
95-
if index is None:
96-
from_pandas_kwargs = {}
97-
else:
98-
from_pandas_kwargs = {"preserve_index": index}
94+
from_pandas_kwargs: Dict[str, Any] = {"schema": kwargs.pop("schema", None)}
95+
if index is not None:
96+
from_pandas_kwargs["preserve_index"] = index
97+
9998
table = self.api.Table.from_pandas(df, **from_pandas_kwargs)
10099
if partition_cols is not None:
101100
self.api.parquet.write_to_dataset(

pandas/tests/io/test_parquet.py

+8
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,14 @@ def test_empty_dataframe(self, pa):
517517
df = pd.DataFrame()
518518
check_round_trip(df, pa)
519519

520+
def test_write_with_schema(self, pa):
521+
import pyarrow
522+
523+
df = pd.DataFrame({"x": [0, 1]})
524+
schema = pyarrow.schema([pyarrow.field("x", type=pyarrow.bool_())])
525+
out_df = df.astype(bool)
526+
check_round_trip(df, pa, write_kwargs={"schema": schema}, expected=out_df)
527+
520528
@pytest.mark.skip(reason="broken test")
521529
@td.skip_if_no("pyarrow", min_version="0.15.0")
522530
def test_additional_extension_arrays(self, pa):

0 commit comments

Comments
 (0)