Skip to content

Commit a2336d6

Browse files
NiallEgansusodapop
authored andcommitted
Add structured type support to PySQL
This PR adds support for complex types (timestamps, decimals, structs & arrays) being returned as native Arrow types in PySQL.
1 parent 8d32245 commit a2336d6

File tree

3 files changed

+50
-1
lines changed

3 files changed

+50
-1
lines changed

cmdexec/clients/python/src/databricks/sql/client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,15 @@ def __init__(self,
7272
# _disable_pandas
7373
# In case the deserialisation through pandas causes any issues, it can be disabled with
7474
# this flag.
75+
# _use_arrow_native_complex_types
76+
# DBR will return native Arrow types for structs, arrays and maps instead of Arrow strings
77+
# (True by default)
78+
# _use_arrow_native_decimals
79+
# Databricks runtime will return native Arrow types for decimals instead of Arrow strings
80+
# (True by default)
81+
# _use_arrow_native_timestamps
82+
# Databricks runtime will return native Arrow types for timestamps instead of Arrow strings
83+
# (True by default)
7584

7685
self.host = server_hostname
7786
self.port = kwargs.get("_port", 443)

cmdexec/clients/python/src/databricks/sql/thrift_backend.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def __init__(self, server_hostname: str, port, http_path: str, http_headers, **k
8787
raise ValueError("No valid connection settings.")
8888

8989
self._initialize_retry_args(kwargs)
90+
self._use_arrow_native_complex_types = kwargs.get("_use_arrow_native_complex_types", True)
91+
self._use_arrow_native_decimals = kwargs.get("_use_arrow_native_decimals", True)
92+
self._use_arrow_native_timestamps = kwargs.get("_use_arrow_native_timestamps", True)
9093

9194
# Configure tls context
9295
ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
@@ -602,6 +605,13 @@ def _check_direct_results_for_error(t_spark_direct_results):
602605
def execute_command(self, operation, session_handle, max_rows, max_bytes, cursor):
603606
assert (session_handle is not None)
604607

608+
spark_arrow_types = ttypes.TSparkArrowTypes(
609+
timestampAsArrow=self._use_arrow_native_timestamps,
610+
decimalAsArrow=self._use_arrow_native_decimals,
611+
complexTypesAsArrow=self._use_arrow_native_complex_types,
612+
# TODO: The current Arrow type used for intervals can not be deserialised in PyArrow
613+
# DBR should be changed to use month_day_nano_interval
614+
intervalTypesAsArrow=False)
605615
req = ttypes.TExecuteStatementReq(
606616
sessionHandle=session_handle,
607617
statement=operation,
@@ -613,7 +623,8 @@ def execute_command(self, operation, session_handle, max_rows, max_bytes, cursor
613623
confOverlay={
614624
# We want to receive proper Timestamp arrow types.
615625
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
616-
})
626+
},
627+
useArrowNativeTypes=spark_arrow_types)
617628
resp = self.make_request(self._client.ExecuteStatement, req)
618629
return self._handle_execute_response(resp, cursor)
619630

cmdexec/clients/python/tests/test_thrift_backend.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,6 +1251,35 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class):
12511251
self.assertIn("Setting initial namespace not supported by the DBR version",
12521252
str(cm.exception))
12531253

1254+
@patch("databricks.sql.thrift_backend.TCLIService.Client")
1255+
@patch("databricks.sql.thrift_backend.ThriftBackend._handle_execute_response")
1256+
def test_execute_command_sets_complex_type_fields_correctly(self, mock_handle_execute_response,
1257+
tcli_service_class):
1258+
tcli_service_instance = tcli_service_class.return_value
1259+
# Iterate through each possible combination of native types (True, False and unset)
1260+
for (complex, timestamp, decimals) in itertools.product(
1261+
[True, False, None], [True, False, None], [True, False, None]):
1262+
complex_arg_types = {}
1263+
if complex is not None:
1264+
complex_arg_types["_use_arrow_native_complex_types"] = complex
1265+
if timestamp is not None:
1266+
complex_arg_types["_use_arrow_native_timestamps"] = timestamp
1267+
if decimals is not None:
1268+
complex_arg_types["_use_arrow_native_decimals"] = decimals
1269+
1270+
thrift_backend = ThriftBackend("foobar", 443, "path", [], **complex_arg_types)
1271+
thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock())
1272+
1273+
t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[0][0]
1274+
# If the value is unset, the native type should default to True
1275+
self.assertEqual(t_execute_statement_req.useArrowNativeTypes.timestampAsArrow,
1276+
complex_arg_types.get("_use_arrow_native_timestamps", True))
1277+
self.assertEqual(t_execute_statement_req.useArrowNativeTypes.decimalAsArrow,
1278+
complex_arg_types.get("_use_arrow_native_decimals", True))
1279+
self.assertEqual(t_execute_statement_req.useArrowNativeTypes.complexTypesAsArrow,
1280+
complex_arg_types.get("_use_arrow_native_complex_types", True))
1281+
self.assertFalse(t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow)
1282+
12541283

12551284
if __name__ == '__main__':
12561285
unittest.main()

0 commit comments

Comments
 (0)