Skip to content

Commit 8399b87

Browse files
committed
Add compression for to_csv() and to_json()
1 parent 5ba33da commit 8399b87

File tree

7 files changed

+298
-107
lines changed

7 files changed

+298
-107
lines changed

awswrangler/s3/_fs.py

Lines changed: 64 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import math
88
import socket
99
from contextlib import contextmanager
10+
from errno import ESPIPE
1011
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast
1112

1213
import boto3
@@ -178,9 +179,11 @@ def close(self) -> List[Dict[str, Union[str, int]]]:
178179
if self.closed is True:
179180
return []
180181
if self._exec is not None:
181-
for future in concurrent.futures.as_completed(self._futures):
182-
self._results.append(future.result())
183-
self._exec.shutdown(wait=True)
182+
try:
183+
for future in concurrent.futures.as_completed(self._futures):
184+
self._results.append(future.result())
185+
finally:
186+
self._exec.shutdown(wait=True)
184187
self.closed = True
185188
return self._sort_by_part_number(parts=self._results)
186189

@@ -198,7 +201,11 @@ def __init__(
198201
boto3_session: Optional[boto3.Session],
199202
newline: Optional[str],
200203
encoding: Optional[str],
204+
raw_buffer: bool,
201205
) -> None:
206+
if raw_buffer is True and "w" not in mode:
207+
raise exceptions.InvalidArgumentValue("raw_buffer=True is only acceptable on write mode.")
208+
self._raw_buffer: bool = raw_buffer
202209
self.closed: bool = False
203210
self._use_threads = use_threads
204211
self._newline: str = "\n" if newline is None else newline
@@ -242,7 +249,7 @@ def __init__(
242249
else:
243250
raise RuntimeError(f"Invalid mode: {self._mode}")
244251

245-
def __enter__(self) -> Union["_S3ObjectBase", io.TextIOWrapper]:
252+
def __enter__(self) -> Union["_S3ObjectBase"]:
246253
return self
247254

248255
def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:
@@ -256,6 +263,19 @@ def __del__(self) -> None:
256263
"""Delete object tear down."""
257264
self.close()
258265

266+
def __next__(self) -> bytes:
267+
"""Next line."""
268+
out: Union[bytes, None] = self.readline()
269+
if not out:
270+
raise StopIteration
271+
return out
272+
273+
next = __next__
274+
275+
def __iter__(self) -> "_S3ObjectBase":
276+
"""Iterate over lines."""
277+
return self
278+
259279
@staticmethod
260280
def _merge_range(ranges: List[Tuple[int, bytes]]) -> bytes:
261281
return b"".join(data for start, data in sorted(ranges, key=lambda r: r[0]))
@@ -372,7 +392,7 @@ def tell(self) -> int:
372392
def seek(self, loc: int, whence: int = 0) -> int:
373393
"""Set current file location."""
374394
if self.readable() is False:
375-
raise ValueError("Seek only available in read mode")
395+
raise OSError(ESPIPE, "Seek only available in read mode")
376396
if whence == 0:
377397
loc_tmp: int = loc
378398
elif whence == 1:
@@ -425,6 +445,9 @@ def flush(self, force: bool = False) -> None:
425445
function_name="upload_part", s3_additional_kwargs=self._s3_additional_kwargs
426446
),
427447
)
448+
self._buffer.seek(0)
449+
self._buffer.truncate(0)
450+
self._buffer.close()
428451
self._buffer = io.BytesIO()
429452
return None
430453

@@ -448,9 +471,9 @@ def close(self) -> None:
448471
_logger.debug("Closing: %s parts", self._parts_count)
449472
if self._parts_count > 0:
450473
self.flush(force=True)
451-
pasts: List[Dict[str, Union[str, int]]] = self._upload_proxy.close()
452-
part_info: Dict[str, List[Dict[str, Any]]] = {"Parts": pasts}
453-
_logger.debug("complete_multipart_upload")
474+
parts: List[Dict[str, Union[str, int]]] = self._upload_proxy.close()
475+
part_info: Dict[str, List[Dict[str, Any]]] = {"Parts": parts}
476+
_logger.debug("Running complete_multipart_upload...")
454477
_utils.try_it(
455478
f=self._client.complete_multipart_upload,
456479
ex=_S3_RETRYABLE_ERRORS,
@@ -464,7 +487,8 @@ def close(self) -> None:
464487
function_name="complete_multipart_upload", s3_additional_kwargs=self._s3_additional_kwargs
465488
),
466489
)
467-
elif self._buffer.tell() > 0:
490+
_logger.debug("complete_multipart_upload done!")
491+
elif self._buffer.tell() > 0 or self._raw_buffer is True:
468492
_logger.debug("put_object")
469493
_utils.try_it(
470494
f=self._client.put_object,
@@ -482,43 +506,21 @@ def close(self) -> None:
482506
self._buffer.seek(0)
483507
self._buffer.truncate(0)
484508
self._upload_proxy.close()
509+
self._buffer.close()
485510
elif self.readable():
486511
self._cache = b""
487512
else:
488513
raise RuntimeError(f"Invalid mode: {self._mode}")
489514
self.closed = True
490515
return None
491516

517+
def get_raw_buffer(self) -> io.BytesIO:
518+
"""Return the Raw Buffer if it is possible."""
519+
if self._raw_buffer is False:
520+
raise exceptions.InvalidArgumentValue("Trying to get raw buffer with raw_buffer=False.")
521+
return self._buffer
492522

493-
class _S3ObjectWriter(_S3ObjectBase):
494-
def write(self, data: bytes) -> int:
495-
"""Write data to buffer and only upload on close() or if buffer is greater than or equal to _MIN_WRITE_BLOCK."""
496-
if self.writable() is False:
497-
raise RuntimeError("File not in write mode.")
498-
if self.closed:
499-
raise RuntimeError("I/O operation on closed file.")
500-
n: int = self._buffer.write(data)
501-
self._loc += n
502-
if self._buffer.tell() >= _MIN_WRITE_BLOCK:
503-
self.flush()
504-
return n
505-
506-
507-
class _S3ObjectReader(_S3ObjectBase):
508-
def __next__(self) -> Union[bytes, str]:
509-
"""Next line."""
510-
out: Union[bytes, str, None] = self.readline()
511-
if not out:
512-
raise StopIteration
513-
return out
514-
515-
next = __next__
516-
517-
def __iter__(self) -> "_S3ObjectReader":
518-
"""Iterate over lines."""
519-
return self
520-
521-
def read(self, length: int = -1) -> Union[bytes, str]:
523+
def read(self, length: int = -1) -> bytes:
522524
"""Return cached data and fetch on demand chunks."""
523525
if self.readable() is False:
524526
raise ValueError("File not in read mode.")
@@ -532,7 +534,7 @@ def read(self, length: int = -1) -> Union[bytes, str]:
532534
self._loc += len(out)
533535
return out
534536

535-
def readline(self, length: int = -1) -> Union[bytes, str]:
537+
def readline(self, length: int = -1) -> bytes:
536538
"""Read until the next line terminator."""
537539
end: int = self._loc + self._s3_block_size
538540
end = self._size if end > self._size else end
@@ -551,11 +553,25 @@ def readline(self, length: int = -1) -> Union[bytes, str]:
551553
end = self._size if end > self._size else end
552554
self._fetch(self._loc, end)
553555

554-
def readlines(self) -> List[Union[bytes, str]]:
556+
def readlines(self) -> List[bytes]:
555557
"""Return all lines as list."""
556558
return list(self)
557559

558560

561+
class _S3ObjectWriter(_S3ObjectBase):
562+
def write(self, data: bytes) -> int:
563+
"""Write data to buffer and only upload on close() or if buffer is greater than or equal to _MIN_WRITE_BLOCK."""
564+
if self.writable() is False:
565+
raise RuntimeError("File not in write mode.")
566+
if self.closed:
567+
raise RuntimeError("I/O operation on closed file.")
568+
n: int = self._buffer.write(data)
569+
self._loc += n
570+
if self._buffer.tell() >= _MIN_WRITE_BLOCK:
571+
self.flush()
572+
return n
573+
574+
559575
@contextmanager
560576
@apply_configs
561577
def open_s3_object(
@@ -567,11 +583,12 @@ def open_s3_object(
567583
boto3_session: Optional[boto3.Session] = None,
568584
newline: Optional[str] = "\n",
569585
encoding: Optional[str] = "utf-8",
570-
) -> Iterator[Union[_S3ObjectReader, _S3ObjectWriter, io.TextIOWrapper]]:
586+
raw_buffer: bool = False,
587+
) -> Iterator[Union[_S3ObjectBase, _S3ObjectWriter, io.TextIOWrapper, io.BytesIO]]:
571588
"""Return a _S3Object or TextIOWrapper based in the received mode."""
572-
s3obj: Optional[Union[_S3ObjectReader, _S3ObjectWriter]] = None
589+
s3obj: Optional[Union[_S3ObjectBase, _S3ObjectWriter]] = None
573590
text_s3obj: Optional[io.TextIOWrapper] = None
574-
s3_class: Union[Type[_S3ObjectReader], Type[_S3ObjectWriter]] = _S3ObjectWriter if "w" in mode else _S3ObjectReader
591+
s3_class: Union[Type[_S3ObjectBase], Type[_S3ObjectWriter]] = _S3ObjectWriter if "w" in mode else _S3ObjectBase
575592
try:
576593
s3obj = s3_class(
577594
path=path,
@@ -582,8 +599,11 @@ def open_s3_object(
582599
boto3_session=boto3_session,
583600
encoding=encoding,
584601
newline=newline,
602+
raw_buffer=raw_buffer,
585603
)
586-
if "b" in mode: # binary
604+
if raw_buffer is True: # Only useful for plain io.BytesIO write
605+
yield s3obj.get_raw_buffer()
606+
elif "b" in mode: # binary
587607
yield s3obj
588608
else: # text
589609
text_s3obj = io.TextIOWrapper(

awswrangler/s3/_write.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@
99

1010
_logger: logging.Logger = logging.getLogger(__name__)
1111

12-
_COMPRESSION_2_EXT: Dict[Optional[str], str] = {None: "", "gzip": ".gz", "snappy": ".snappy"}
12+
_COMPRESSION_2_EXT: Dict[Optional[str], str] = {
13+
None: "",
14+
"gzip": ".gz",
15+
"snappy": ".snappy",
16+
"bz2": ".bz2",
17+
"xz": ".xz",
18+
"zip": ".zip",
19+
}
1320

1421

1522
def _extract_dtypes_from_table_input(table_input: Dict[str, Any]) -> Dict[str, str]:

0 commit comments

Comments
 (0)