Skip to content

Commit 9ef50e8

Browse files
authored
Default socket timeout to 15 min (#137)
Signed-off-by: Matthew Kim <[email protected]>
1 parent 5a3f83e commit 9ef50e8

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

src/databricks/sql/thrift_backend.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
DATABRICKS_REASON_HEADER = "x-databricks-reason-phrase"
3838

3939
TIMESTAMP_AS_STRING_CONFIG = "spark.thriftserver.arrowBasedRowSet.timestampAsString"
40+
DEFAULT_SOCKET_TIMEOUT = float(900)
4041

4142
# see Connection.__init__ for parameter descriptions.
4243
# - Min/Max avoids unsustainable configs (sane values are far more constrained)
@@ -99,8 +100,8 @@ def __init__(
99100
# _retry_stop_after_attempts_count
100101
# The maximum number of times we should retry retryable requests (defaults to 24)
101102
# _socket_timeout
102-
# The timeout in seconds for socket send, recv and connect operations. Defaults to None for
103-
# no timeout. Should be a positive float or integer.
103+
# The timeout in seconds for socket send, recv and connect operations. Should be a positive float or integer.
104+
# (defaults to 900)
104105

105106
port = port or 443
106107
if kwargs.get("_connection_uri"):
@@ -152,8 +153,8 @@ def __init__(
152153
ssl_context=ssl_context,
153154
)
154155

155-
timeout = kwargs.get("_socket_timeout")
156-
# setTimeout defaults to None (i.e. no timeout), and is expected in ms
156+
timeout = kwargs.get("_socket_timeout", DEFAULT_SOCKET_TIMEOUT)
157+
# setTimeout defaults to 15 minutes and is expected in ms
157158
self._transport.setTimeout(timeout and (float(timeout) * 1000.0))
158159

159160
self._transport.setCustomHeaders(dict(http_headers))

tests/e2e/driver_tests.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pytest
1919

2020
import databricks.sql as sql
21-
from databricks.sql import STRING, BINARY, NUMBER, DATETIME, DATE, DatabaseError, Error, OperationalError
21+
from databricks.sql import STRING, BINARY, NUMBER, DATETIME, DATE, DatabaseError, Error, OperationalError, RequestError
2222
from tests.e2e.common.predicates import pysql_has_version, pysql_supports_arrow, compare_dbr_versions, is_thrift_v5_plus
2323
from tests.e2e.common.core_tests import CoreTestMixin, SmokeTestMixin
2424
from tests.e2e.common.large_queries_mixin import LargeQueriesMixin
@@ -460,14 +460,25 @@ def test_temp_view_fetch(self):
460460
@skipIf(pysql_has_version('<', '2'), 'requires pysql v2')
461461
@skipIf(True, "Unclear the purpose of this test since urllib3 does not complain when timeout == 0")
462462
def test_socket_timeout(self):
463-
# We we expect to see a BlockingIO error when the socket is opened
463+
# We expect to see a BlockingIO error when the socket is opened
464464
# in non-blocking mode, since no poll is done before the read
465465
with self.assertRaises(OperationalError) as cm:
466466
with self.cursor({"_socket_timeout": 0}):
467467
pass
468468

469469
self.assertIsInstance(cm.exception.args[1], io.BlockingIOError)
470470

471+
@skipIf(pysql_has_version('<', '2'), 'requires pysql v2')
472+
def test_socket_timeout_user_defined(self):
473+
# We expect to see a TimeoutError when the socket timeout is only
474+
# 1 sec for a query that takes longer than that to process
475+
with self.assertRaises(RequestError) as cm:
476+
with self.cursor({"_socket_timeout": 1}) as cursor:
477+
query = "select * from range(10000000)"
478+
cursor.execute(query)
479+
480+
self.assertIsInstance(cm.exception.args[1], TimeoutError)
481+
471482
def test_ssp_passthrough(self):
472483
for enable_ansi in (True, False):
473484
with self.cursor({"session_configuration": {"ansi_mode": enable_ansi}}) as cursor:

tests/unit/test_thrift_backend.py

+2
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ def test_socket_timeout_is_propagated(self, t_http_client_class):
217217
self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000)
218218
ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=0)
219219
self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 0)
220+
ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider())
221+
self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000)
220222
ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=None)
221223
self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], None)
222224

0 commit comments

Comments
 (0)