7
7
import math
8
8
import socket
9
9
from contextlib import contextmanager
10
+ from errno import ESPIPE
10
11
from typing import Any , BinaryIO , Dict , Iterator , List , Optional , Set , Tuple , Type , Union , cast
11
12
12
13
import boto3
@@ -178,9 +179,11 @@ def close(self) -> List[Dict[str, Union[str, int]]]:
178
179
if self .closed is True :
179
180
return []
180
181
if self ._exec is not None :
181
- for future in concurrent .futures .as_completed (self ._futures ):
182
- self ._results .append (future .result ())
183
- self ._exec .shutdown (wait = True )
182
+ try :
183
+ for future in concurrent .futures .as_completed (self ._futures ):
184
+ self ._results .append (future .result ())
185
+ finally :
186
+ self ._exec .shutdown (wait = True )
184
187
self .closed = True
185
188
return self ._sort_by_part_number (parts = self ._results )
186
189
@@ -198,7 +201,11 @@ def __init__(
198
201
boto3_session : Optional [boto3 .Session ],
199
202
newline : Optional [str ],
200
203
encoding : Optional [str ],
204
+ raw_buffer : bool ,
201
205
) -> None :
206
+ if raw_buffer is True and "w" not in mode :
207
+ raise exceptions .InvalidArgumentValue ("raw_buffer=True is only acceptable on write mode." )
208
+ self ._raw_buffer : bool = raw_buffer
202
209
self .closed : bool = False
203
210
self ._use_threads = use_threads
204
211
self ._newline : str = "\n " if newline is None else newline
@@ -242,7 +249,7 @@ def __init__(
242
249
else :
243
250
raise RuntimeError (f"Invalid mode: { self ._mode } " )
244
251
245
- def __enter__ (self ) -> Union ["_S3ObjectBase" , io . TextIOWrapper ]:
252
+ def __enter__ (self ) -> Union ["_S3ObjectBase" ]:
246
253
return self
247
254
248
255
def __exit__ (self , exc_type : Any , exc_value : Any , exc_traceback : Any ) -> None :
@@ -256,6 +263,19 @@ def __del__(self) -> None:
256
263
"""Delete object tear down."""
257
264
self .close ()
258
265
266
+ def __next__ (self ) -> bytes :
267
+ """Next line."""
268
+ out : Union [bytes , None ] = self .readline ()
269
+ if not out :
270
+ raise StopIteration
271
+ return out
272
+
273
+ next = __next__
274
+
275
+ def __iter__ (self ) -> "_S3ObjectBase" :
276
+ """Iterate over lines."""
277
+ return self
278
+
259
279
@staticmethod
260
280
def _merge_range (ranges : List [Tuple [int , bytes ]]) -> bytes :
261
281
return b"" .join (data for start , data in sorted (ranges , key = lambda r : r [0 ]))
@@ -372,7 +392,7 @@ def tell(self) -> int:
372
392
def seek (self , loc : int , whence : int = 0 ) -> int :
373
393
"""Set current file location."""
374
394
if self .readable () is False :
375
- raise ValueError ( "Seek only available in read mode" )
395
+ raise OSError ( ESPIPE , "Seek only available in read mode" )
376
396
if whence == 0 :
377
397
loc_tmp : int = loc
378
398
elif whence == 1 :
@@ -425,6 +445,9 @@ def flush(self, force: bool = False) -> None:
425
445
function_name = "upload_part" , s3_additional_kwargs = self ._s3_additional_kwargs
426
446
),
427
447
)
448
+ self ._buffer .seek (0 )
449
+ self ._buffer .truncate (0 )
450
+ self ._buffer .close ()
428
451
self ._buffer = io .BytesIO ()
429
452
return None
430
453
@@ -448,9 +471,9 @@ def close(self) -> None:
448
471
_logger .debug ("Closing: %s parts" , self ._parts_count )
449
472
if self ._parts_count > 0 :
450
473
self .flush (force = True )
451
- pasts : List [Dict [str , Union [str , int ]]] = self ._upload_proxy .close ()
452
- part_info : Dict [str , List [Dict [str , Any ]]] = {"Parts" : pasts }
453
- _logger .debug ("complete_multipart_upload" )
474
+ parts : List [Dict [str , Union [str , int ]]] = self ._upload_proxy .close ()
475
+ part_info : Dict [str , List [Dict [str , Any ]]] = {"Parts" : parts }
476
+ _logger .debug ("Running complete_multipart_upload... " )
454
477
_utils .try_it (
455
478
f = self ._client .complete_multipart_upload ,
456
479
ex = _S3_RETRYABLE_ERRORS ,
@@ -464,7 +487,8 @@ def close(self) -> None:
464
487
function_name = "complete_multipart_upload" , s3_additional_kwargs = self ._s3_additional_kwargs
465
488
),
466
489
)
467
- elif self ._buffer .tell () > 0 :
490
+ _logger .debug ("complete_multipart_upload done!" )
491
+ elif self ._buffer .tell () > 0 or self ._raw_buffer is True :
468
492
_logger .debug ("put_object" )
469
493
_utils .try_it (
470
494
f = self ._client .put_object ,
@@ -482,43 +506,21 @@ def close(self) -> None:
482
506
self ._buffer .seek (0 )
483
507
self ._buffer .truncate (0 )
484
508
self ._upload_proxy .close ()
509
+ self ._buffer .close ()
485
510
elif self .readable ():
486
511
self ._cache = b""
487
512
else :
488
513
raise RuntimeError (f"Invalid mode: { self ._mode } " )
489
514
self .closed = True
490
515
return None
491
516
517
+ def get_raw_buffer (self ) -> io .BytesIO :
518
+ """Return the Raw Buffer if it is possible."""
519
+ if self ._raw_buffer is False :
520
+ raise exceptions .InvalidArgumentValue ("Trying to get raw buffer with raw_buffer=False." )
521
+ return self ._buffer
492
522
493
- class _S3ObjectWriter (_S3ObjectBase ):
494
- def write (self , data : bytes ) -> int :
495
- """Write data to buffer and only upload on close() or if buffer is greater than or equal to _MIN_WRITE_BLOCK."""
496
- if self .writable () is False :
497
- raise RuntimeError ("File not in write mode." )
498
- if self .closed :
499
- raise RuntimeError ("I/O operation on closed file." )
500
- n : int = self ._buffer .write (data )
501
- self ._loc += n
502
- if self ._buffer .tell () >= _MIN_WRITE_BLOCK :
503
- self .flush ()
504
- return n
505
-
506
-
507
- class _S3ObjectReader (_S3ObjectBase ):
508
- def __next__ (self ) -> Union [bytes , str ]:
509
- """Next line."""
510
- out : Union [bytes , str , None ] = self .readline ()
511
- if not out :
512
- raise StopIteration
513
- return out
514
-
515
- next = __next__
516
-
517
- def __iter__ (self ) -> "_S3ObjectReader" :
518
- """Iterate over lines."""
519
- return self
520
-
521
- def read (self , length : int = - 1 ) -> Union [bytes , str ]:
523
+ def read (self , length : int = - 1 ) -> bytes :
522
524
"""Return cached data and fetch on demand chunks."""
523
525
if self .readable () is False :
524
526
raise ValueError ("File not in read mode." )
@@ -532,7 +534,7 @@ def read(self, length: int = -1) -> Union[bytes, str]:
532
534
self ._loc += len (out )
533
535
return out
534
536
535
- def readline (self , length : int = - 1 ) -> Union [ bytes , str ] :
537
+ def readline (self , length : int = - 1 ) -> bytes :
536
538
"""Read until the next line terminator."""
537
539
end : int = self ._loc + self ._s3_block_size
538
540
end = self ._size if end > self ._size else end
@@ -551,11 +553,25 @@ def readline(self, length: int = -1) -> Union[bytes, str]:
551
553
end = self ._size if end > self ._size else end
552
554
self ._fetch (self ._loc , end )
553
555
554
- def readlines (self ) -> List [Union [ bytes , str ] ]:
556
+ def readlines (self ) -> List [bytes ]:
555
557
"""Return all lines as list."""
556
558
return list (self )
557
559
558
560
561
+ class _S3ObjectWriter (_S3ObjectBase ):
562
+ def write (self , data : bytes ) -> int :
563
+ """Write data to buffer and only upload on close() or if buffer is greater than or equal to _MIN_WRITE_BLOCK."""
564
+ if self .writable () is False :
565
+ raise RuntimeError ("File not in write mode." )
566
+ if self .closed :
567
+ raise RuntimeError ("I/O operation on closed file." )
568
+ n : int = self ._buffer .write (data )
569
+ self ._loc += n
570
+ if self ._buffer .tell () >= _MIN_WRITE_BLOCK :
571
+ self .flush ()
572
+ return n
573
+
574
+
559
575
@contextmanager
560
576
@apply_configs
561
577
def open_s3_object (
@@ -567,11 +583,12 @@ def open_s3_object(
567
583
boto3_session : Optional [boto3 .Session ] = None ,
568
584
newline : Optional [str ] = "\n " ,
569
585
encoding : Optional [str ] = "utf-8" ,
570
- ) -> Iterator [Union [_S3ObjectReader , _S3ObjectWriter , io .TextIOWrapper ]]:
586
+ raw_buffer : bool = False ,
587
+ ) -> Iterator [Union [_S3ObjectBase , _S3ObjectWriter , io .TextIOWrapper , io .BytesIO ]]:
571
588
"""Return a _S3Object or TextIOWrapper based in the received mode."""
572
- s3obj : Optional [Union [_S3ObjectReader , _S3ObjectWriter ]] = None
589
+ s3obj : Optional [Union [_S3ObjectBase , _S3ObjectWriter ]] = None
573
590
text_s3obj : Optional [io .TextIOWrapper ] = None
574
- s3_class : Union [Type [_S3ObjectReader ], Type [_S3ObjectWriter ]] = _S3ObjectWriter if "w" in mode else _S3ObjectReader
591
+ s3_class : Union [Type [_S3ObjectBase ], Type [_S3ObjectWriter ]] = _S3ObjectWriter if "w" in mode else _S3ObjectBase
575
592
try :
576
593
s3obj = s3_class (
577
594
path = path ,
@@ -582,8 +599,11 @@ def open_s3_object(
582
599
boto3_session = boto3_session ,
583
600
encoding = encoding ,
584
601
newline = newline ,
602
+ raw_buffer = raw_buffer ,
585
603
)
586
- if "b" in mode : # binary
604
+ if raw_buffer is True : # Only useful for plain io.BytesIO write
605
+ yield s3obj .get_raw_buffer ()
606
+ elif "b" in mode : # binary
587
607
yield s3obj
588
608
else : # text
589
609
text_s3obj = io .TextIOWrapper (
0 commit comments