Skip to content

Commit a55cf9d

Browse files
committed
Retry GetOperationStatus if an OSError was raised during execution
Add retry_delay_default to use in this case. Signed-off-by: Jesse Whitehouse <[email protected]>
1 parent a0d340e commit a55cf9d

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

src/databricks/sql/thrift_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616
from databricks.sql.thrift_api.TCLIService import TCLIService, ttypes
1717
from databricks.sql import *
18+
from databricks.sql.thrift_api.TCLIService.TCLIService import (
19+
Client as TCLIServiceClient,
20+
)
1821
from databricks.sql.utils import (
1922
ArrowQueue,
2023
ExecuteResponse,
@@ -39,6 +42,7 @@
3942
"_retry_delay_max": (float, 60, 5, 3600),
4043
"_retry_stop_after_attempts_count": (int, 30, 1, 60),
4144
"_retry_stop_after_attempts_duration": (float, 900, 1, 86400),
45+
"_retry_delay_default": (float, 5, 1, 60)
4246
}
4347

4448

@@ -71,6 +75,8 @@ def __init__(
7175
# _retry_delay_min (default: 1)
7276
# _retry_delay_max (default: 60)
7377
# {min,max} pre-retry delay bounds
78+
# _retry_delay_default (default: 5)
79+
# Only used when GetOperationStatus fails due to a TCP/OS Error.
7480
# _retry_stop_after_attempts_count (default: 30)
7581
# total max attempts during retry sequence
7682
# _retry_stop_after_attempts_duration (default: 900)
@@ -291,6 +297,13 @@ def attempt_request(attempt):
291297
response = method(request)
292298
logger.debug("Received response: {}".format(response))
293299
return response
300+
except OSError as err:
301+
error = err
302+
error_message = str(err)
303+
304+
gos_name = TCLIServiceClient.GetOperationStatus.__name__
305+
if method.__name__ == gos_name:
306+
retry_delay = bound_retry_delay(attempt, self._retry_delay_default)
294307
except Exception as err:
295308
error = err
296309
retry_delay = extract_retry_delay(attempt)

tests/unit/test_thrift_backend.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def retry_policy_factory():
1919
"_retry_delay_max": (float, 60, None, None),
2020
"_retry_stop_after_attempts_count": (int, 30, None, None),
2121
"_retry_stop_after_attempts_duration": (float, 900, None, None),
22+
"_retry_delay_default": (float, 5, 1, 60)
2223
}
2324

2425

@@ -968,6 +969,44 @@ def test_handle_execute_response_sets_active_op_handle(self):
968969

969970
self.assertEqual(mock_resp.operationHandle, mock_cursor.active_op_handle)
970971

972+
@patch("thrift.transport.THttpClient.THttpClient")
973+
@patch("databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus")
974+
@patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory)
975+
def test_make_request_will_retry_GetOperationStatus(
976+
self, mock_retry_policy, mock_GetOperationStatus, t_transport_class):
977+
978+
import thrift
979+
from databricks.sql.thrift_api.TCLIService.TCLIService import Client
980+
from databricks.sql.exc import RequestError
981+
from databricks.sql.utils import NoRetryReason
982+
983+
mock_GetOperationStatus.__name__ = "GetOperationStatus"
984+
mock_GetOperationStatus.side_effect = TimeoutError(110)
985+
986+
protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(t_transport_class)
987+
client = Client(protocol)
988+
989+
req = ttypes.TGetOperationStatusReq(
990+
operationHandle=self.operation_handle,
991+
getProgressUpdate=False,
992+
)
993+
994+
EXPECTED_RETRIES = 2
995+
996+
thrift_backend = ThriftBackend(
997+
"foobar",
998+
443,
999+
"path", [],
1000+
_retry_stop_after_attempts_count=EXPECTED_RETRIES,
1001+
_retry_delay_default=0.1)
1002+
1003+
with self.assertRaises(RequestError) as cm:
1004+
thrift_backend.make_request(client.GetOperationStatus, req)
1005+
1006+
self.assertEqual(NoRetryReason.OUT_OF_ATTEMPTS.value, cm.exception.context["no-retry-reason"])
1007+
self.assertEqual(f'{EXPECTED_RETRIES}/{EXPECTED_RETRIES}', cm.exception.context["attempt"])
1008+
1009+
9711010
@patch("thrift.transport.THttpClient.THttpClient")
9721011
def test_make_request_wont_retry_if_headers_not_present(self, t_transport_class):
9731012
t_transport_instance = t_transport_class.return_value

0 commit comments

Comments
 (0)