Skip to content

Commit e09a880

Browse files
committed
Resolved merge conflicts
2 parents 87b1251 + 680b3b6 commit e09a880

File tree

10 files changed

+373
-84
lines changed

10 files changed

+373
-84
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Release History
22

3+
34
# 4.0.0
45

56
- Split the connector into two separate packages: `databricks-sql-connector` and `databricks-sqlalchemy`. The `databricks-sql-connector` package contains the core functionality of the connector, while the `databricks-sqlalchemy` package contains the SQLAlchemy dialect for the connector.
6-
- Pyarrow dependency is now optional in `databricks-sql-connector`. Users needing arrow are supposed to explicitly install pyarrow
7+
- Pyarrow dependency is now optional in `databricks-sql-connector`. Users needing arrow are supposed to explicitly install pyarrow
78

89
# 3.6.0 (2024-10-25)
910

docs/parameters.md

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ See `examples/parameters.py` in this repository for a runnable demo.
1717

1818
- A query executed with native parameters can contain at most 255 parameter markers
1919
- The maximum size of all parameterized values cannot exceed 1MB
20+
- For volume operations such as PUT, native parameters are not supported
2021

2122
## SQL Syntax
2223

poetry.lock

+96-47
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ lz4 = "^4.0.2"
1818
requests = "^2.18.1"
1919
oauthlib = "^3.1.0"
2020
numpy = [
21-
{ version = "^1.16.6", python = ">=3.8,<3.11" },
22-
{ version = "^1.23.4", python = ">=3.11" },
21+
{ version = ">=1.16.6", python = ">=3.8,<3.11" },
22+
{ version = ">=1.23.4", python = ">=3.11" },
2323
]
2424
openpyxl = "^3.0.10"
2525
urllib3 = ">=1.26"

src/databricks/sql/auth/retry.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import random
23
import time
34
import typing
45
from enum import Enum
@@ -285,25 +286,30 @@ def sleep_for_retry(self, response: BaseHTTPResponse) -> bool:
285286
"""
286287
retry_after = self.get_retry_after(response)
287288
if retry_after:
288-
backoff = self.get_backoff_time()
289-
proposed_wait = max(backoff, retry_after)
290-
self.check_proposed_wait(proposed_wait)
291-
time.sleep(proposed_wait)
292-
return True
289+
proposed_wait = retry_after
290+
else:
291+
proposed_wait = self.get_backoff_time()
293292

294-
return False
293+
proposed_wait = min(proposed_wait, self.delay_max)
294+
self.check_proposed_wait(proposed_wait)
295+
time.sleep(proposed_wait)
296+
return True
295297

296298
def get_backoff_time(self) -> float:
297-
"""Calls urllib3's built-in get_backoff_time.
299+
"""
300+
This method implements the exponential backoff algorithm to calculate the delay between retries.
298301
299302
Never returns a value larger than self.delay_max
300303
A MaxRetryDurationError will be raised if the calculated backoff would exceed self.max_attempts_duration
301304
302-
Note: within urllib3, a backoff is only calculated in cases where a Retry-After header is not present
303-
in the previous unsuccessful request and `self.respect_retry_after_header` is True (which is always true)
305+
:return:
304306
"""
305307

306-
proposed_backoff = super().get_backoff_time()
308+
current_attempt = self.stop_after_attempts_count - int(self.total or 0)
309+
proposed_backoff = (2**current_attempt) * self.delay_min
310+
if self.backoff_jitter != 0.0:
311+
proposed_backoff += random.random() * self.backoff_jitter
312+
307313
proposed_backoff = min(proposed_backoff, self.delay_max)
308314
self.check_proposed_wait(proposed_backoff)
309315

src/databricks/sql/client.py

+105
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence
23

34
import pandas
@@ -47,6 +48,7 @@
4748

4849
from databricks.sql.thrift_api.TCLIService.ttypes import (
4950
TSparkParameter,
51+
TOperationState,
5052
)
5153

5254

@@ -437,6 +439,8 @@ def __init__(
437439
self.escaper = ParamEscaper()
438440
self.lastrowid = None
439441

442+
self.ASYNC_DEFAULT_POLLING_INTERVAL = 2
443+
440444
# The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently.
441445
def __enter__(self) -> "Cursor":
442446
return self
@@ -803,6 +807,7 @@ def execute(
803807
cursor=self,
804808
use_cloud_fetch=self.connection.use_cloud_fetch,
805809
parameters=prepared_params,
810+
async_op=False,
806811
)
807812
self.active_result_set = ResultSet(
808813
self.connection,
@@ -819,6 +824,106 @@ def execute(
819824

820825
return self
821826

827+
def execute_async(
828+
self,
829+
operation: str,
830+
parameters: Optional[TParameterCollection] = None,
831+
) -> "Cursor":
832+
"""
833+
834+
Execute a query and do not wait for it to complete and just move ahead
835+
836+
:param operation:
837+
:param parameters:
838+
:return:
839+
"""
840+
param_approach = self._determine_parameter_approach(parameters)
841+
if param_approach == ParameterApproach.NONE:
842+
prepared_params = NO_NATIVE_PARAMS
843+
prepared_operation = operation
844+
845+
elif param_approach == ParameterApproach.INLINE:
846+
prepared_operation, prepared_params = self._prepare_inline_parameters(
847+
operation, parameters
848+
)
849+
elif param_approach == ParameterApproach.NATIVE:
850+
normalized_parameters = self._normalize_tparametercollection(parameters)
851+
param_structure = self._determine_parameter_structure(normalized_parameters)
852+
transformed_operation = transform_paramstyle(
853+
operation, normalized_parameters, param_structure
854+
)
855+
prepared_operation, prepared_params = self._prepare_native_parameters(
856+
transformed_operation, normalized_parameters, param_structure
857+
)
858+
859+
self._check_not_closed()
860+
self._close_and_clear_active_result_set()
861+
self.thrift_backend.execute_command(
862+
operation=prepared_operation,
863+
session_handle=self.connection._session_handle,
864+
max_rows=self.arraysize,
865+
max_bytes=self.buffer_size_bytes,
866+
lz4_compression=self.connection.lz4_compression,
867+
cursor=self,
868+
use_cloud_fetch=self.connection.use_cloud_fetch,
869+
parameters=prepared_params,
870+
async_op=True,
871+
)
872+
873+
return self
874+
875+
def get_query_state(self) -> "TOperationState":
876+
"""
877+
Get the state of the async executing query or basically poll the status of the query
878+
879+
:return:
880+
"""
881+
self._check_not_closed()
882+
return self.thrift_backend.get_query_state(self.active_op_handle)
883+
884+
def get_async_execution_result(self):
885+
"""
886+
887+
Checks for the status of the async executing query and fetches the result if the query is finished
888+
Otherwise it will keep polling the status of the query till there is a Not pending state
889+
:return:
890+
"""
891+
self._check_not_closed()
892+
893+
def is_executing(operation_state) -> "bool":
894+
return not operation_state or operation_state in [
895+
ttypes.TOperationState.RUNNING_STATE,
896+
ttypes.TOperationState.PENDING_STATE,
897+
]
898+
899+
while is_executing(self.get_query_state()):
900+
# Poll after some default time
901+
time.sleep(self.ASYNC_DEFAULT_POLLING_INTERVAL)
902+
903+
operation_state = self.get_query_state()
904+
if operation_state == ttypes.TOperationState.FINISHED_STATE:
905+
execute_response = self.thrift_backend.get_execution_result(
906+
self.active_op_handle, self
907+
)
908+
self.active_result_set = ResultSet(
909+
self.connection,
910+
execute_response,
911+
self.thrift_backend,
912+
self.buffer_size_bytes,
913+
self.arraysize,
914+
)
915+
916+
if execute_response.is_staging_operation:
917+
self._handle_staging_operation(
918+
staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path
919+
)
920+
921+
return self
922+
else:
923+
raise Error(
924+
f"get_execution_result failed with Operation status {operation_state}"
925+
)
926+
822927
def executemany(self, operation, seq_of_parameters):
823928
"""
824929
Execute the operation once for every set of passed in parameters.

src/databricks/sql/thrift_backend.py

+77-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import threading
88
from typing import List, Union
99

10+
from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState
11+
1012
try:
1113
import pyarrow
1214
except ImportError:
@@ -64,8 +66,8 @@
6466
# - 900s attempts-duration lines up w ODBC/JDBC drivers (for cluster startup > 10 mins)
6567
_retry_policy = { # (type, default, min, max)
6668
"_retry_delay_min": (float, 1, 0.1, 60),
67-
"_retry_delay_max": (float, 60, 5, 3600),
68-
"_retry_stop_after_attempts_count": (int, 30, 1, 60),
69+
"_retry_delay_max": (float, 30, 5, 3600),
70+
"_retry_stop_after_attempts_count": (int, 5, 1, 60),
6971
"_retry_stop_after_attempts_duration": (float, 900, 1, 86400),
7072
"_retry_delay_default": (float, 5, 1, 60),
7173
}
@@ -769,6 +771,63 @@ def _results_message_to_execute_response(self, resp, operation_state):
769771
arrow_schema_bytes=schema_bytes,
770772
)
771773

774+
def get_execution_result(self, op_handle, cursor):
775+
776+
assert op_handle is not None
777+
778+
req = ttypes.TFetchResultsReq(
779+
operationHandle=ttypes.TOperationHandle(
780+
op_handle.operationId,
781+
op_handle.operationType,
782+
False,
783+
op_handle.modifiedRowCount,
784+
),
785+
maxRows=cursor.arraysize,
786+
maxBytes=cursor.buffer_size_bytes,
787+
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
788+
includeResultSetMetadata=True,
789+
)
790+
791+
resp = self.make_request(self._client.FetchResults, req)
792+
793+
t_result_set_metadata_resp = resp.resultSetMetadata
794+
795+
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
796+
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
797+
has_more_rows = resp.hasMoreRows
798+
description = self._hive_schema_to_description(
799+
t_result_set_metadata_resp.schema
800+
)
801+
802+
schema_bytes = (
803+
t_result_set_metadata_resp.arrowSchema
804+
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
805+
.serialize()
806+
.to_pybytes()
807+
)
808+
809+
queue = ResultSetQueueFactory.build_queue(
810+
row_set_type=resp.resultSetMetadata.resultFormat,
811+
t_row_set=resp.results,
812+
arrow_schema_bytes=schema_bytes,
813+
max_download_threads=self.max_download_threads,
814+
lz4_compressed=lz4_compressed,
815+
description=description,
816+
ssl_options=self._ssl_options,
817+
)
818+
819+
return ExecuteResponse(
820+
arrow_queue=queue,
821+
status=resp.status,
822+
has_been_closed_server_side=False,
823+
has_more_rows=has_more_rows,
824+
lz4_compressed=lz4_compressed,
825+
is_staging_operation=is_staging_operation,
826+
command_handle=op_handle,
827+
description=description,
828+
arrow_schema_bytes=schema_bytes,
829+
)
830+
772831
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
773832
if initial_operation_status_resp:
774833
self._check_command_not_in_error_or_closed_state(
@@ -787,6 +846,12 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
787846
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
788847
return operation_state
789848

849+
def get_query_state(self, op_handle) -> "TOperationState":
850+
poll_resp = self._poll_for_status(op_handle)
851+
operation_state = poll_resp.operationState
852+
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
853+
return operation_state
854+
790855
@staticmethod
791856
def _check_direct_results_for_error(t_spark_direct_results):
792857
if t_spark_direct_results:
@@ -817,6 +882,7 @@ def execute_command(
817882
cursor,
818883
use_cloud_fetch=True,
819884
parameters=[],
885+
async_op=False,
820886
):
821887
assert session_handle is not None
822888

@@ -846,7 +912,11 @@ def execute_command(
846912
parameters=parameters,
847913
)
848914
resp = self.make_request(self._client.ExecuteStatement, req)
849-
return self._handle_execute_response(resp, cursor)
915+
916+
if async_op:
917+
self._handle_execute_response_async(resp, cursor)
918+
else:
919+
return self._handle_execute_response(resp, cursor)
850920

851921
def get_catalogs(self, session_handle, max_rows, max_bytes, cursor):
852922
assert session_handle is not None
@@ -945,6 +1015,10 @@ def _handle_execute_response(self, resp, cursor):
9451015

9461016
return self._results_message_to_execute_response(resp, final_operation_state)
9471017

1018+
def _handle_execute_response_async(self, resp, cursor):
1019+
cursor.active_op_handle = resp.operationHandle
1020+
self._check_direct_results_for_error(resp.directResults)
1021+
9481022
def fetch_results(
9491023
self,
9501024
op_handle,

tests/e2e/common/retry_test_mixins.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def test_retry_max_count_not_exceeded(self):
174174
def test_retry_exponential_backoff(self):
175175
"""GIVEN the retry policy is configured for reasonable exponential backoff
176176
WHEN the server sends nothing but 429 responses with retry-afters
177-
THEN the connector will use those retry-afters as a floor
177+
THEN the connector will use those retry-afters values as delay
178178
"""
179179
retry_policy = self._retry_policy.copy()
180180
retry_policy["_retry_delay_min"] = 1
@@ -191,10 +191,10 @@ def test_retry_exponential_backoff(self):
191191
assert isinstance(cm.value.args[1], MaxRetryDurationError)
192192

193193
# With setting delay_min to 1, the expected retry delays should be:
194-
# 3, 3, 4
195-
# The first 2 retries are allowed, the 3rd retry puts the total duration over the limit
194+
# 3, 3, 3, 3
195+
# The first 3 retries are allowed, the 4th retry puts the total duration over the limit
196196
# of 10 seconds
197-
assert mock_obj.return_value.getresponse.call_count == 3
197+
assert mock_obj.return_value.getresponse.call_count == 4
198198
assert duration > 6
199199

200200
# Should be less than 7, but this is a safe margin for CI/CD slowness

0 commit comments

Comments
 (0)