Skip to content

Commit f99123c

Browse files
[SC-110400] Enabling compression in Python SQL Connector (#49)
Signed-off-by: Mohit Singla <[email protected]> Co-authored-by: Moe Derakhshani <[email protected]>
1 parent 2e681b5 commit f99123c

File tree

9 files changed

+325
-57
lines changed

9 files changed

+325
-57
lines changed

poetry.lock

+207-23
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ python = "^3.7.1"
1313
thrift = "^0.13.0"
1414
pandas = "^1.3.0"
1515
pyarrow = "^9.0.0"
16+
lz4 = "^4.0.2"
1617
requests=">2.18.1"
1718
oauthlib=">=3.1.0"
1819

src/databricks/sql/client.py

+4
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def read(self) -> Optional[OAuthToken]:
152152
self.host = server_hostname
153153
self.port = kwargs.get("_port", 443)
154154
self.disable_pandas = kwargs.get("_disable_pandas", False)
155+
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
155156

156157
auth_provider = get_python_sql_connector_auth_provider(
157158
server_hostname, **kwargs
@@ -318,6 +319,7 @@ def execute(
318319
session_handle=self.connection._session_handle,
319320
max_rows=self.arraysize,
320321
max_bytes=self.buffer_size_bytes,
322+
lz4_compression=self.connection.lz4_compression,
321323
cursor=self,
322324
)
323325
self.active_result_set = ResultSet(
@@ -614,6 +616,7 @@ def __init__(
614616
self.has_been_closed_server_side = execute_response.has_been_closed_server_side
615617
self.has_more_rows = execute_response.has_more_rows
616618
self.buffer_size_bytes = result_buffer_size_bytes
619+
self.lz4_compressed = execute_response.lz4_compressed
617620
self.arraysize = arraysize
618621
self.thrift_backend = thrift_backend
619622
self.description = execute_response.description
@@ -642,6 +645,7 @@ def _fill_results_buffer(self):
642645
max_rows=self.arraysize,
643646
max_bytes=self.buffer_size_bytes,
644647
expected_row_start_offset=self._next_row_index,
648+
lz4_compressed=self.lz4_compressed,
645649
arrow_schema_bytes=self._arrow_schema_bytes,
646650
description=self.description,
647651
)

src/databricks/sql/thrift_backend.py

+28-13
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import math
55
import time
66
import threading
7+
import lz4.frame
78
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
89

910
import pyarrow
@@ -451,7 +452,7 @@ def open_session(self, session_configuration, catalog, schema):
451452
initial_namespace = None
452453

453454
open_session_req = ttypes.TOpenSessionReq(
454-
client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V5,
455+
client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6,
455456
client_protocol=None,
456457
initialNamespace=initial_namespace,
457458
canUseMultipleCatalogs=True,
@@ -507,7 +508,7 @@ def _poll_for_status(self, op_handle):
507508
)
508509
return self.make_request(self._client.GetOperationStatus, req)
509510

510-
def _create_arrow_table(self, t_row_set, schema_bytes, description):
511+
def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, description):
511512
if t_row_set.columns is not None:
512513
(
513514
arrow_table,
@@ -520,7 +521,7 @@ def _create_arrow_table(self, t_row_set, schema_bytes, description):
520521
arrow_table,
521522
num_rows,
522523
) = ThriftBackend._convert_arrow_based_set_to_arrow_table(
523-
t_row_set.arrowBatches, schema_bytes
524+
t_row_set.arrowBatches, lz4_compressed, schema_bytes
524525
)
525526
else:
526527
raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set))
@@ -545,13 +546,20 @@ def _convert_decimals_in_arrow_table(table, description):
545546
return table
546547

547548
@staticmethod
548-
def _convert_arrow_based_set_to_arrow_table(arrow_batches, schema_bytes):
549+
def _convert_arrow_based_set_to_arrow_table(
550+
arrow_batches, lz4_compressed, schema_bytes
551+
):
549552
ba = bytearray()
550553
ba += schema_bytes
551554
n_rows = 0
552-
for arrow_batch in arrow_batches:
553-
n_rows += arrow_batch.rowCount
554-
ba += arrow_batch.batch
555+
if lz4_compressed:
556+
for arrow_batch in arrow_batches:
557+
n_rows += arrow_batch.rowCount
558+
ba += lz4.frame.decompress(arrow_batch.batch)
559+
else:
560+
for arrow_batch in arrow_batches:
561+
n_rows += arrow_batch.rowCount
562+
ba += arrow_batch.batch
555563
arrow_table = pyarrow.ipc.open_stream(ba).read_all()
556564
return arrow_table, n_rows
557565

@@ -708,7 +716,6 @@ def _results_message_to_execute_response(self, resp, operation_state):
708716
]
709717
)
710718
)
711-
712719
direct_results = resp.directResults
713720
has_been_closed_server_side = direct_results and direct_results.closeOperation
714721
has_more_rows = (
@@ -725,12 +732,16 @@ def _results_message_to_execute_response(self, resp, operation_state):
725732
.serialize()
726733
.to_pybytes()
727734
)
728-
735+
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
729736
if direct_results and direct_results.resultSet:
730737
assert direct_results.resultSet.results.startRowOffset == 0
731738
assert direct_results.resultSetMetadata
739+
732740
arrow_results, n_rows = self._create_arrow_table(
733-
direct_results.resultSet.results, schema_bytes, description
741+
direct_results.resultSet.results,
742+
lz4_compressed,
743+
schema_bytes,
744+
description,
734745
)
735746
arrow_queue_opt = ArrowQueue(arrow_results, n_rows, 0)
736747
else:
@@ -740,6 +751,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
740751
status=operation_state,
741752
has_been_closed_server_side=has_been_closed_server_side,
742753
has_more_rows=has_more_rows,
754+
lz4_compressed=lz4_compressed,
743755
command_handle=resp.operationHandle,
744756
description=description,
745757
arrow_schema_bytes=schema_bytes,
@@ -783,7 +795,9 @@ def _check_direct_results_for_error(t_spark_direct_results):
783795
t_spark_direct_results.closeOperation
784796
)
785797

786-
def execute_command(self, operation, session_handle, max_rows, max_bytes, cursor):
798+
def execute_command(
799+
self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor
800+
):
787801
assert session_handle is not None
788802

789803
spark_arrow_types = ttypes.TSparkArrowTypes(
@@ -802,7 +816,7 @@ def execute_command(self, operation, session_handle, max_rows, max_bytes, cursor
802816
maxRows=max_rows, maxBytes=max_bytes
803817
),
804818
canReadArrowResult=True,
805-
canDecompressLZ4Result=False,
819+
canDecompressLZ4Result=lz4_compression,
806820
canDownloadResult=False,
807821
confOverlay={
808822
# We want to receive proper Timestamp arrow types.
@@ -916,6 +930,7 @@ def fetch_results(
916930
max_rows,
917931
max_bytes,
918932
expected_row_start_offset,
933+
lz4_compressed,
919934
arrow_schema_bytes,
920935
description,
921936
):
@@ -941,7 +956,7 @@ def fetch_results(
941956
)
942957
)
943958
arrow_results, n_rows = self._create_arrow_table(
944-
resp.results, arrow_schema_bytes, description
959+
resp.results, lz4_compressed, arrow_schema_bytes, description
945960
)
946961
arrow_queue = ArrowQueue(arrow_results, n_rows)
947962

src/databricks/sql/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def remaining_rows(self) -> pyarrow.Table:
4040

4141
ExecuteResponse = namedtuple(
4242
"ExecuteResponse",
43-
"status has_been_closed_server_side has_more_rows description "
43+
"status has_been_closed_server_side has_more_rows description lz4_compressed "
4444
"command_handle arrow_queue arrow_schema_bytes",
4545
)
4646

tests/e2e/common/large_queries_mixin.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,15 @@ def test_query_with_large_wide_result_set(self):
4949
# This is used by PyHive tests to determine the buffer size
5050
self.arraysize = 1000
5151
with self.cursor() as cursor:
52-
uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)])
53-
cursor.execute("SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows))
54-
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
55-
self.assertEqual(row[0], row_id) # Verify no rows are dropped in the middle.
56-
self.assertEqual(len(row[1]), 36)
52+
for lz4_compression in [False, True]:
53+
cursor.connection.lz4_compression=lz4_compression
54+
uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)])
55+
cursor.execute("SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows))
56+
self.assertEqual(lz4_compression, cursor.active_result_set.lz4_compressed)
57+
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
58+
self.assertEqual(row[0], row_id) # Verify no rows are dropped in the middle.
59+
self.assertEqual(len(row[1]), 36)
60+
5761

5862
def test_query_with_large_narrow_result_set(self):
5963
resultSize = 300 * 1000 * 1000 # 300 MB
@@ -85,10 +89,10 @@ def test_long_running_query(self):
8589
start = time.time()
8690

8791
cursor.execute("""SELECT count(*)
88-
FROM RANGE({scale}) x
89-
JOIN RANGE({scale0}) y
90-
ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%"
91-
""".format(scale=scale_factor * scale0, scale0=scale0))
92+
FROM RANGE({scale}) x
93+
JOIN RANGE({scale0}) y
94+
ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%"
95+
""".format(scale=scale_factor * scale0, scale0=scale0))
9296

9397
n, = cursor.fetchone()
9498
self.assertEqual(n, 0)

tests/e2e/driver_tests.py

+14
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,20 @@ def test_timezone_with_timestamp(self):
510510
self.assertEqual(arrow_result_table.field(0).type, ts_type)
511511
self.assertEqual(arrow_result_value, expected.timestamp() * 1000000)
512512

513+
@skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support')
514+
def test_can_flip_compression(self):
515+
with self.cursor() as cursor:
516+
cursor.execute("SELECT array(1,2,3,4)")
517+
cursor.fetchall()
518+
lz4_compressed = cursor.active_result_set.lz4_compressed
519+
#The endpoint should support compression
520+
self.assertEqual(lz4_compressed, True)
521+
cursor.connection.lz4_compression=False
522+
cursor.execute("SELECT array(1,2,3,4)")
523+
cursor.fetchall()
524+
lz4_compressed = cursor.active_result_set.lz4_compressed
525+
self.assertEqual(lz4_compressed, False)
526+
513527
def _should_have_native_complex_types(self):
514528
return pysql_has_version(">=", 2) and is_thrift_v5_plus(self.arguments)
515529

tests/unit/test_fetches.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def make_dummy_result_set_from_initial_results(initial_results):
3838
has_been_closed_server_side=True,
3939
has_more_rows=False,
4040
description=Mock(),
41+
lz4_compressed=Mock(),
4142
command_handle=None,
4243
arrow_queue=arrow_queue,
4344
arrow_schema_bytes=schema.serialize().to_pybytes()))
@@ -50,7 +51,7 @@ def make_dummy_result_set_from_initial_results(initial_results):
5051
def make_dummy_result_set_from_batch_list(batch_list):
5152
batch_index = 0
5253

53-
def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset,
54+
def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset, lz4_compressed,
5455
arrow_schema_bytes, description):
5556
nonlocal batch_index
5657
results = FetchTests.make_arrow_queue(batch_list[batch_index])
@@ -71,6 +72,7 @@ def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset,
7172
has_more_rows=True,
7273
description=[(f'col{col_id}', 'integer', None, None, None, None, None)
7374
for col_id in range(num_cols)],
75+
lz4_compressed=Mock(),
7476
command_handle=None,
7577
arrow_queue=None,
7678
arrow_schema_bytes=None))

0 commit comments

Comments
 (0)