4
4
import math
5
5
import time
6
6
import threading
7
+ import lz4 .frame
7
8
from ssl import CERT_NONE , CERT_REQUIRED , create_default_context
8
9
9
10
import pyarrow
@@ -451,7 +452,7 @@ def open_session(self, session_configuration, catalog, schema):
451
452
initial_namespace = None
452
453
453
454
open_session_req = ttypes .TOpenSessionReq (
454
- client_protocol_i64 = ttypes .TProtocolVersion .SPARK_CLI_SERVICE_PROTOCOL_V5 ,
455
+ client_protocol_i64 = ttypes .TProtocolVersion .SPARK_CLI_SERVICE_PROTOCOL_V6 ,
455
456
client_protocol = None ,
456
457
initialNamespace = initial_namespace ,
457
458
canUseMultipleCatalogs = True ,
@@ -507,7 +508,7 @@ def _poll_for_status(self, op_handle):
507
508
)
508
509
return self .make_request (self ._client .GetOperationStatus , req )
509
510
510
- def _create_arrow_table (self , t_row_set , schema_bytes , description ):
511
+ def _create_arrow_table (self , t_row_set , lz4_compressed , schema_bytes , description ):
511
512
if t_row_set .columns is not None :
512
513
(
513
514
arrow_table ,
@@ -520,7 +521,7 @@ def _create_arrow_table(self, t_row_set, schema_bytes, description):
520
521
arrow_table ,
521
522
num_rows ,
522
523
) = ThriftBackend ._convert_arrow_based_set_to_arrow_table (
523
- t_row_set .arrowBatches , schema_bytes
524
+ t_row_set .arrowBatches , lz4_compressed , schema_bytes
524
525
)
525
526
else :
526
527
raise OperationalError ("Unsupported TRowSet instance {}" .format (t_row_set ))
@@ -545,13 +546,20 @@ def _convert_decimals_in_arrow_table(table, description):
545
546
return table
546
547
547
548
@staticmethod
548
- def _convert_arrow_based_set_to_arrow_table (arrow_batches , schema_bytes ):
549
+ def _convert_arrow_based_set_to_arrow_table (
550
+ arrow_batches , lz4_compressed , schema_bytes
551
+ ):
549
552
ba = bytearray ()
550
553
ba += schema_bytes
551
554
n_rows = 0
552
- for arrow_batch in arrow_batches :
553
- n_rows += arrow_batch .rowCount
554
- ba += arrow_batch .batch
555
+ if lz4_compressed :
556
+ for arrow_batch in arrow_batches :
557
+ n_rows += arrow_batch .rowCount
558
+ ba += lz4 .frame .decompress (arrow_batch .batch )
559
+ else :
560
+ for arrow_batch in arrow_batches :
561
+ n_rows += arrow_batch .rowCount
562
+ ba += arrow_batch .batch
555
563
arrow_table = pyarrow .ipc .open_stream (ba ).read_all ()
556
564
return arrow_table , n_rows
557
565
@@ -708,7 +716,6 @@ def _results_message_to_execute_response(self, resp, operation_state):
708
716
]
709
717
)
710
718
)
711
-
712
719
direct_results = resp .directResults
713
720
has_been_closed_server_side = direct_results and direct_results .closeOperation
714
721
has_more_rows = (
@@ -725,12 +732,16 @@ def _results_message_to_execute_response(self, resp, operation_state):
725
732
.serialize ()
726
733
.to_pybytes ()
727
734
)
728
-
735
+ lz4_compressed = t_result_set_metadata_resp . lz4Compressed
729
736
if direct_results and direct_results .resultSet :
730
737
assert direct_results .resultSet .results .startRowOffset == 0
731
738
assert direct_results .resultSetMetadata
739
+
732
740
arrow_results , n_rows = self ._create_arrow_table (
733
- direct_results .resultSet .results , schema_bytes , description
741
+ direct_results .resultSet .results ,
742
+ lz4_compressed ,
743
+ schema_bytes ,
744
+ description ,
734
745
)
735
746
arrow_queue_opt = ArrowQueue (arrow_results , n_rows , 0 )
736
747
else :
@@ -740,6 +751,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
740
751
status = operation_state ,
741
752
has_been_closed_server_side = has_been_closed_server_side ,
742
753
has_more_rows = has_more_rows ,
754
+ lz4_compressed = lz4_compressed ,
743
755
command_handle = resp .operationHandle ,
744
756
description = description ,
745
757
arrow_schema_bytes = schema_bytes ,
@@ -783,7 +795,9 @@ def _check_direct_results_for_error(t_spark_direct_results):
783
795
t_spark_direct_results .closeOperation
784
796
)
785
797
786
- def execute_command (self , operation , session_handle , max_rows , max_bytes , cursor ):
798
+ def execute_command (
799
+ self , operation , session_handle , max_rows , max_bytes , lz4_compression , cursor
800
+ ):
787
801
assert session_handle is not None
788
802
789
803
spark_arrow_types = ttypes .TSparkArrowTypes (
@@ -802,7 +816,7 @@ def execute_command(self, operation, session_handle, max_rows, max_bytes, cursor
802
816
maxRows = max_rows , maxBytes = max_bytes
803
817
),
804
818
canReadArrowResult = True ,
805
- canDecompressLZ4Result = False ,
819
+ canDecompressLZ4Result = lz4_compression ,
806
820
canDownloadResult = False ,
807
821
confOverlay = {
808
822
# We want to receive proper Timestamp arrow types.
@@ -916,6 +930,7 @@ def fetch_results(
916
930
max_rows ,
917
931
max_bytes ,
918
932
expected_row_start_offset ,
933
+ lz4_compressed ,
919
934
arrow_schema_bytes ,
920
935
description ,
921
936
):
@@ -941,7 +956,7 @@ def fetch_results(
941
956
)
942
957
)
943
958
arrow_results , n_rows = self ._create_arrow_table (
944
- resp .results , arrow_schema_bytes , description
959
+ resp .results , lz4_compressed , arrow_schema_bytes , description
945
960
)
946
961
arrow_queue = ArrowQueue (arrow_results , n_rows )
947
962
0 commit comments