Skip to content

Commit dbf183b

Browse files
Disable SSL verification for CloudFetch links (databricks#414)
* Disable SSL verification for CloudFetch links Signed-off-by: Levko Kravets <[email protected]> * Use existing `_tls_no_verify` option in CloudFetch downloader Signed-off-by: Levko Kravets <[email protected]> * Update tests Signed-off-by: Levko Kravets <[email protected]> --------- Signed-off-by: Levko Kravets <[email protected]>
1 parent 134b21d commit dbf183b

File tree

9 files changed

+147
-29
lines changed

9 files changed

+147
-29
lines changed

src/databricks/sql/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ def read(self) -> Optional[OAuthToken]:
171171
# Which port to connect to
172172
# _skip_routing_headers:
173173
# Don't set routing headers if set to True (for use when connecting directly to server)
174+
# _tls_no_verify
175+
# Set to True (Boolean) to completely disable SSL verification.
174176
# _tls_verify_hostname
175177
# Set to False (Boolean) to disable SSL hostname verification, but check certificate.
176178
# _tls_trusted_ca_file

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22

3+
from ssl import SSLContext
34
from concurrent.futures import ThreadPoolExecutor, Future
45
from typing import List, Union
56

@@ -19,6 +20,7 @@ def __init__(
1920
links: List[TSparkArrowResultLink],
2021
max_download_threads: int,
2122
lz4_compressed: bool,
23+
ssl_context: SSLContext,
2224
):
2325
self._pending_links: List[TSparkArrowResultLink] = []
2426
for link in links:
@@ -36,6 +38,7 @@ def __init__(
3638
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)
3739

3840
self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
41+
self._ssl_context = ssl_context
3942

4043
def get_next_downloaded_file(
4144
self, next_row_offset: int
@@ -89,7 +92,11 @@ def _schedule_downloads(self):
8992
logger.debug(
9093
"- start: {}, row count: {}".format(link.startRowOffset, link.rowCount)
9194
)
92-
handler = ResultSetDownloadHandler(self._downloadable_result_settings, link)
95+
handler = ResultSetDownloadHandler(
96+
settings=self._downloadable_result_settings,
97+
link=link,
98+
ssl_context=self._ssl_context,
99+
)
93100
task = self._thread_pool.submit(handler.run)
94101
self._download_tasks.append(task)
95102

src/databricks/sql/cloudfetch/downloader.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import requests
55
from requests.adapters import HTTPAdapter, Retry
6+
from ssl import SSLContext, CERT_NONE
67
import lz4.frame
78
import time
89

@@ -65,9 +66,11 @@ def __init__(
6566
self,
6667
settings: DownloadableResultSettings,
6768
link: TSparkArrowResultLink,
69+
ssl_context: SSLContext,
6870
):
6971
self.settings = settings
7072
self.link = link
73+
self._ssl_context = ssl_context
7174

7275
def run(self) -> DownloadedFile:
7376
"""
@@ -92,10 +95,14 @@ def run(self) -> DownloadedFile:
9295
session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
9396
session.mount("https://", HTTPAdapter(max_retries=retryPolicy))
9497

98+
ssl_verify = self._ssl_context.verify_mode != CERT_NONE
99+
95100
try:
96101
# Get the file via HTTP request
97102
response = session.get(
98-
self.link.fileLink, timeout=self.settings.download_timeout
103+
self.link.fileLink,
104+
timeout=self.settings.download_timeout,
105+
verify=ssl_verify,
99106
)
100107
response.raise_for_status()
101108

src/databricks/sql/thrift_backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ def __init__(
184184
password=tls_client_cert_key_password,
185185
)
186186

187+
self._ssl_context = ssl_context
188+
187189
self._auth_provider = auth_provider
188190

189191
# Connector version 3 retry approach
@@ -223,7 +225,7 @@ def __init__(
223225
self._transport = databricks.sql.auth.thrift_http_client.THttpClient(
224226
auth_provider=self._auth_provider,
225227
uri_or_host=uri,
226-
ssl_context=ssl_context,
228+
ssl_context=self._ssl_context,
227229
**additional_transport_args, # type: ignore
228230
)
229231

@@ -774,6 +776,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
774776
max_download_threads=self.max_download_threads,
775777
lz4_compressed=lz4_compressed,
776778
description=description,
779+
ssl_context=self._ssl_context,
777780
)
778781
else:
779782
arrow_queue_opt = None
@@ -1005,6 +1008,7 @@ def fetch_results(
10051008
max_download_threads=self.max_download_threads,
10061009
lz4_compressed=lz4_compressed,
10071010
description=description,
1011+
ssl_context=self._ssl_context,
10081012
)
10091013

10101014
return queue, resp.hasMoreRows

src/databricks/sql/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from enum import Enum
1010
from typing import Any, Dict, List, Optional, Union
1111
import re
12+
from ssl import SSLContext
1213

1314
import lz4.frame
1415
import pyarrow
@@ -47,6 +48,7 @@ def build_queue(
4748
t_row_set: TRowSet,
4849
arrow_schema_bytes: bytes,
4950
max_download_threads: int,
51+
ssl_context: SSLContext,
5052
lz4_compressed: bool = True,
5153
description: Optional[List[List[Any]]] = None,
5254
) -> ResultSetQueue:
@@ -60,6 +62,7 @@ def build_queue(
6062
lz4_compressed (bool): Whether result data has been lz4 compressed.
6163
description (List[List[Any]]): Hive table schema description.
6264
max_download_threads (int): Maximum number of downloader thread pool threads.
65+
ssl_context (SSLContext): SSLContext object for CloudFetchQueue
6366
6467
Returns:
6568
ResultSetQueue
@@ -82,12 +85,13 @@ def build_queue(
8285
return ArrowQueue(converted_arrow_table, n_valid_rows)
8386
elif row_set_type == TSparkRowSetType.URL_BASED_SET:
8487
return CloudFetchQueue(
85-
arrow_schema_bytes,
88+
schema_bytes=arrow_schema_bytes,
8689
start_row_offset=t_row_set.startRowOffset,
8790
result_links=t_row_set.resultLinks,
8891
lz4_compressed=lz4_compressed,
8992
description=description,
9093
max_download_threads=max_download_threads,
94+
ssl_context=ssl_context,
9195
)
9296
else:
9397
raise AssertionError("Row set type is not valid")
@@ -133,6 +137,7 @@ def __init__(
133137
self,
134138
schema_bytes,
135139
max_download_threads: int,
140+
ssl_context: SSLContext,
136141
start_row_offset: int = 0,
137142
result_links: Optional[List[TSparkArrowResultLink]] = None,
138143
lz4_compressed: bool = True,
@@ -155,6 +160,7 @@ def __init__(
155160
self.result_links = result_links
156161
self.lz4_compressed = lz4_compressed
157162
self.description = description
163+
self._ssl_context = ssl_context
158164

159165
logger.debug(
160166
"Initialize CloudFetch loader, row set start offset: {}, file list:".format(
@@ -169,7 +175,10 @@ def __init__(
169175
)
170176
)
171177
self.download_manager = ResultFileDownloadManager(
172-
result_links or [], self.max_download_threads, self.lz4_compressed
178+
links=result_links or [],
179+
max_download_threads=self.max_download_threads,
180+
lz4_compressed=self.lz4_compressed,
181+
ssl_context=self._ssl_context,
173182
)
174183

175184
self.table = self._create_next_table()

tests/unit/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def test_cancel_command_calls_the_backend(self):
361361
mock_op_handle = Mock()
362362
cursor.active_op_handle = mock_op_handle
363363
cursor.cancel()
364-
self.assertTrue(mock_thrift_backend.cancel_command.called_with(mock_op_handle))
364+
mock_thrift_backend.cancel_command.assert_called_with(mock_op_handle)
365365

366366
@patch("databricks.sql.client.logger")
367367
def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command(

0 commit comments

Comments
 (0)