Skip to content

Commit feda987

Browse files
pitrouwesm
authored andcommitted
ARROW-9333: [Python] Expose more IPC options
Also make some optional arguments keyword-only. Closes #7730 from pitrou/ARROW-9333-py-ipc-options Authored-by: Antoine Pitrou <[email protected]> Signed-off-by: Wes McKinney <[email protected]>
1 parent 9d2079c commit feda987

File tree

9 files changed

+174
-57
lines changed

9 files changed

+174
-57
lines changed

cpp/src/arrow/ipc/options.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,9 @@ struct ARROW_EXPORT IpcWriteOptions {
5656
/// \brief The memory pool to use for allocations made during IPC writing
5757
MemoryPool* memory_pool = default_memory_pool();
5858

59-
/// \brief EXPERIMENTAL: Codec to use for compressing and decompressing
60-
/// record batch body buffers. This is not part of the Arrow IPC protocol and
61-
/// only for internal use (e.g. Feather files). May only be LZ4_FRAME and
62-
/// ZSTD
59+
/// \brief Compression codec to use for record batch body buffers
60+
///
61+
/// May only be UNCOMPRESSED, LZ4_FRAME and ZSTD.
6362
Compression::type compression = Compression::UNCOMPRESSED;
6463
int compression_level = Compression::kUseDefaultCompressionLevel;
6564

python/pyarrow/_flight.pyx

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,8 @@ def _munge_grpc_python_error(message):
9797

9898

9999
cdef IpcWriteOptions _get_options(options):
100-
cdef IpcWriteOptions write_options = \
101-
<IpcWriteOptions> _get_legacy_format_default(
102-
use_legacy_format=None, options=options)
103-
return write_options
100+
return <IpcWriteOptions> _get_legacy_format_default(
101+
use_legacy_format=None, options=options)
104102

105103

106104
cdef class FlightCallOptions:

python/pyarrow/includes/libarrow.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,6 +1329,8 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:
13291329
c_bool write_legacy_ipc_format
13301330
CMemoryPool* memory_pool
13311331
CMetadataVersion metadata_version
1332+
CCompressionType compression
1333+
c_bool use_threads
13321334

13331335
@staticmethod
13341336
CIpcWriteOptions Defaults()

python/pyarrow/io.pxi

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,24 +1539,43 @@ def _detect_compression(path):
15391539

15401540
cdef CCompressionType _ensure_compression(str name) except *:
15411541
uppercase = name.upper()
1542-
if uppercase == 'GZIP':
1543-
return CCompressionType_GZIP
1544-
elif uppercase == 'BZ2':
1542+
if uppercase == 'BZ2':
15451543
return CCompressionType_BZ2
1544+
elif uppercase == 'GZIP':
1545+
return CCompressionType_GZIP
15461546
elif uppercase == 'BROTLI':
15471547
return CCompressionType_BROTLI
15481548
elif uppercase == 'LZ4' or uppercase == 'LZ4_FRAME':
15491549
return CCompressionType_LZ4_FRAME
15501550
elif uppercase == 'LZ4_RAW':
15511551
return CCompressionType_LZ4
1552-
elif uppercase == 'ZSTD':
1553-
return CCompressionType_ZSTD
15541552
elif uppercase == 'SNAPPY':
15551553
return CCompressionType_SNAPPY
1554+
elif uppercase == 'ZSTD':
1555+
return CCompressionType_ZSTD
15561556
else:
15571557
raise ValueError('Invalid value for compression: {!r}'.format(name))
15581558

15591559

1560+
cdef str _compression_name(CCompressionType ctype):
1561+
if ctype == CCompressionType_GZIP:
1562+
return 'gzip'
1563+
elif ctype == CCompressionType_BROTLI:
1564+
return 'brotli'
1565+
elif ctype == CCompressionType_BZ2:
1566+
return 'bz2'
1567+
elif ctype == CCompressionType_LZ4_FRAME:
1568+
return 'lz4'
1569+
elif ctype == CCompressionType_LZ4:
1570+
return 'lz4_raw'
1571+
elif ctype == CCompressionType_SNAPPY:
1572+
return 'snappy'
1573+
elif ctype == CCompressionType_ZSTD:
1574+
return 'zstd'
1575+
else:
1576+
raise RuntimeError('Unexpected CCompressionType value')
1577+
1578+
15601579
cdef class Codec:
15611580
"""
15621581
Compression codec.

python/pyarrow/ipc.pxi

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,29 +50,70 @@ cdef class IpcWriteOptions:
5050
5151
Parameters
5252
----------
53+
metadata_version : MetadataVersion, default MetadataVersion.V5
54+
The metadata version to write. V5 is the current and latest,
55+
V4 is the pre-1.0 metadata version (with incompatible Union layout).
5356
use_legacy_format : bool, default False
5457
Whether to use the pre-Arrow 0.15 IPC format.
55-
metadata_version : MetadataVersion, default MetadataVersion.V5
56-
The metadata version to write.
58+
compression: str or None
59+
If not None, compression codec to use for record batch buffers.
60+
May only be "lz4", "zstd" or None.
61+
use_threads: bool
62+
Whether to use the global CPU thread pool to parallelize any
63+
computational tasks like compression.
5764
"""
65+
__slots__ = ()
5866

5967
# cdef block is in lib.pxd
6068

61-
def __init__(self, use_legacy_format=False,
62-
metadata_version=MetadataVersion.V5):
69+
def __init__(self, *, metadata_version=MetadataVersion.V5,
70+
use_legacy_format=False, compression=None,
71+
bint use_threads=True):
6372
self.c_options = CIpcWriteOptions.Defaults()
64-
self.c_options.write_legacy_ipc_format = use_legacy_format
65-
self.c_options.metadata_version = \
66-
_unwrap_metadata_version(metadata_version)
73+
self.use_legacy_format = use_legacy_format
74+
self.metadata_version = metadata_version
75+
if compression is not None:
76+
self.compression = compression
77+
self.use_threads = use_threads
6778

6879
@property
6980
def use_legacy_format(self):
7081
return self.c_options.write_legacy_ipc_format
7182

83+
@use_legacy_format.setter
84+
def use_legacy_format(self, bint value):
85+
self.c_options.write_legacy_ipc_format = value
86+
7287
@property
7388
def metadata_version(self):
7489
return _wrap_metadata_version(self.c_options.metadata_version)
7590

91+
@metadata_version.setter
92+
def metadata_version(self, value):
93+
self.c_options.metadata_version = _unwrap_metadata_version(value)
94+
95+
@property
96+
def compression(self):
97+
if self.c_options.compression == CCompressionType_UNCOMPRESSED:
98+
return None
99+
else:
100+
return _compression_name(self.c_options.compression)
101+
102+
@compression.setter
103+
def compression(self, value):
104+
if value is None:
105+
self.c_options.compression = CCompressionType_UNCOMPRESSED
106+
else:
107+
self.c_options.compression = _ensure_compression(value)
108+
109+
@property
110+
def use_threads(self):
111+
return self.c_options.use_threads
112+
113+
@use_threads.setter
114+
def use_threads(self, bint value):
115+
self.c_options.use_threads = value
116+
76117

77118
cdef class Message:
78119
"""

python/pyarrow/ipc.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class RecordBatchStreamWriter(lib._RecordBatchStreamWriter):
9292
9393
{}""".format(_ipc_writer_class_doc)
9494

95-
def __init__(self, sink, schema, use_legacy_format=None, options=None):
95+
def __init__(self, sink, schema, *, use_legacy_format=None, options=None):
9696
options = _get_legacy_format_default(use_legacy_format, options)
9797
self._open(sink, schema, options=options)
9898

@@ -120,7 +120,7 @@ class RecordBatchFileWriter(lib._RecordBatchFileWriter):
120120
121121
{}""".format(_ipc_writer_class_doc)
122122

123-
def __init__(self, sink, schema, use_legacy_format=None, options=None):
123+
def __init__(self, sink, schema, *, use_legacy_format=None, options=None):
124124
options = _get_legacy_format_default(use_legacy_format, options)
125125
self._open(sink, schema, options=options)
126126

@@ -130,6 +130,9 @@ def _get_legacy_format_default(use_legacy_format, options):
130130
raise ValueError(
131131
"Can provide at most one of options and use_legacy_format")
132132
elif options:
133+
if not isinstance(options, IpcWriteOptions):
134+
raise TypeError("expected IpcWriteOptions, got {}"
135+
.format(type(options)))
133136
return options
134137

135138
metadata_version = MetadataVersion.V5
@@ -142,7 +145,7 @@ def _get_legacy_format_default(use_legacy_format, options):
142145
metadata_version=metadata_version)
143146

144147

145-
def new_stream(sink, schema, use_legacy_format=None, options=None):
148+
def new_stream(sink, schema, *, use_legacy_format=None, options=None):
146149
return RecordBatchStreamWriter(sink, schema,
147150
use_legacy_format=use_legacy_format,
148151
options=options)
@@ -170,7 +173,7 @@ def open_stream(source):
170173
return RecordBatchStreamReader(source)
171174

172175

173-
def new_file(sink, schema, use_legacy_format=None, options=None):
176+
def new_file(sink, schema, *, use_legacy_format=None, options=None):
174177
return RecordBatchFileWriter(sink, schema,
175178
use_legacy_format=use_legacy_format,
176179
options=options)
@@ -201,7 +204,7 @@ def open_file(source, footer_offset=None):
201204
return RecordBatchFileReader(source, footer_offset=footer_offset)
202205

203206

204-
def serialize_pandas(df, nthreads=None, preserve_index=None):
207+
def serialize_pandas(df, *, nthreads=None, preserve_index=None):
205208
"""
206209
Serialize a pandas DataFrame into a buffer protocol compatible object.
207210
@@ -229,7 +232,7 @@ def serialize_pandas(df, nthreads=None, preserve_index=None):
229232
return sink.getvalue()
230233

231234

232-
def deserialize_pandas(buf, use_threads=True):
235+
def deserialize_pandas(buf, *, use_threads=True):
233236
"""Deserialize a buffer protocol compatible object into a pandas DataFrame.
234237
235238
Parameters

python/pyarrow/tests/test_flight.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,12 @@ def test_flight_do_get_ints():
706706
data = client.do_get(flight.Ticket(b'ints')).read_all()
707707
assert data.equals(table)
708708

709+
with pytest.raises(flight.FlightServerError,
710+
match="expected IpcWriteOptions, got <class 'int'>"):
711+
with ConstantFlightServer(options=42) as server:
712+
client = flight.connect(('localhost', server.port))
713+
data = client.do_get(flight.Ticket(b'ints')).read_all()
714+
709715

710716
@pytest.mark.pandas
711717
def test_do_get_ints_pandas():

python/pyarrow/tests/test_ipc.py

Lines changed: 64 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numpy as np
2525

2626
import pyarrow as pa
27+
from pyarrow.tests.util import changed_environ
2728

2829

2930
try:
@@ -315,7 +316,45 @@ def test_stream_simple_roundtrip(stream_fixture, use_legacy_ipc_format):
315316
reader.read_next_batch()
316317

317318

318-
def test_options_legacy_exclusive(stream_fixture):
319+
def test_write_options():
320+
options = pa.ipc.IpcWriteOptions()
321+
assert options.use_legacy_format is False
322+
assert options.metadata_version == pa.ipc.MetadataVersion.V5
323+
324+
options.use_legacy_format = True
325+
assert options.use_legacy_format is True
326+
327+
options.metadata_version = pa.ipc.MetadataVersion.V4
328+
assert options.metadata_version == pa.ipc.MetadataVersion.V4
329+
for value in ('V5', 42):
330+
with pytest.raises((TypeError, ValueError)):
331+
options.metadata_version = value
332+
333+
assert options.compression is None
334+
for value in ['lz4', 'zstd']:
335+
options.compression = value
336+
assert options.compression == value
337+
options.compression = value.upper()
338+
assert options.compression == value
339+
options.compression = None
340+
assert options.compression is None
341+
342+
assert options.use_threads is True
343+
options.use_threads = False
344+
assert options.use_threads is False
345+
346+
options = pa.ipc.IpcWriteOptions(
347+
metadata_version=pa.ipc.MetadataVersion.V4,
348+
use_legacy_format=True,
349+
compression='lz4',
350+
use_threads=False)
351+
assert options.metadata_version == pa.ipc.MetadataVersion.V4
352+
assert options.use_legacy_format is True
353+
assert options.compression == 'lz4'
354+
assert options.use_threads is False
355+
356+
357+
def test_write_options_legacy_exclusive(stream_fixture):
319358
with pytest.raises(
320359
ValueError,
321360
match="provide at most one of options and use_legacy_format"):
@@ -365,36 +404,30 @@ def test_envvar_set_legacy_ipc_format():
365404
assert not writer._use_legacy_format
366405
assert writer._metadata_version == pa.ipc.MetadataVersion.V5
367406

368-
import os
369-
370-
os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = '1'
371-
writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
372-
assert writer._use_legacy_format
373-
assert writer._metadata_version == pa.ipc.MetadataVersion.V5
374-
writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
375-
assert writer._use_legacy_format
376-
assert writer._metadata_version == pa.ipc.MetadataVersion.V5
377-
del os.environ['ARROW_PRE_0_15_IPC_FORMAT']
378-
379-
os.environ['ARROW_PRE_1_0_METADATA_VERSION'] = '1'
380-
writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
381-
assert not writer._use_legacy_format
382-
assert writer._metadata_version == pa.ipc.MetadataVersion.V4
383-
writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
384-
assert not writer._use_legacy_format
385-
assert writer._metadata_version == pa.ipc.MetadataVersion.V4
386-
del os.environ['ARROW_PRE_1_0_METADATA_VERSION']
387-
388-
os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = '1'
389-
os.environ['ARROW_PRE_1_0_METADATA_VERSION'] = '1'
390-
writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
391-
assert writer._use_legacy_format
392-
assert writer._metadata_version == pa.ipc.MetadataVersion.V4
393-
writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
394-
assert writer._use_legacy_format
395-
assert writer._metadata_version == pa.ipc.MetadataVersion.V4
396-
del os.environ['ARROW_PRE_0_15_IPC_FORMAT']
397-
del os.environ['ARROW_PRE_1_0_METADATA_VERSION']
407+
with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'):
408+
writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
409+
assert writer._use_legacy_format
410+
assert writer._metadata_version == pa.ipc.MetadataVersion.V5
411+
writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
412+
assert writer._use_legacy_format
413+
assert writer._metadata_version == pa.ipc.MetadataVersion.V5
414+
415+
with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'):
416+
writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
417+
assert not writer._use_legacy_format
418+
assert writer._metadata_version == pa.ipc.MetadataVersion.V4
419+
writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
420+
assert not writer._use_legacy_format
421+
assert writer._metadata_version == pa.ipc.MetadataVersion.V4
422+
423+
with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'):
424+
with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'):
425+
writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
426+
assert writer._use_legacy_format
427+
assert writer._metadata_version == pa.ipc.MetadataVersion.V4
428+
writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
429+
assert writer._use_legacy_format
430+
assert writer._metadata_version == pa.ipc.MetadataVersion.V4
398431

399432

400433
def test_stream_read_all(stream_fixture):

python/pyarrow/tests/util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,19 @@ def invoke_script(script_name, *args):
194194
cmd.extend(args)
195195

196196
subprocess.check_call(cmd, env=subprocess_env)
197+
198+
199+
@contextlib.contextmanager
200+
def changed_environ(name, value):
201+
"""
202+
Temporarily set environment variable *name* to *value*.
203+
"""
204+
orig_value = os.environ.get(name)
205+
os.environ[name] = value
206+
try:
207+
yield
208+
finally:
209+
if orig_value is None:
210+
del os.environ[name]
211+
else:
212+
os.environ[name] = orig_value

0 commit comments

Comments
 (0)