Skip to content

Commit 1f8cf73

Browse files
[PECO-1857] Use SSL options with HTTPS connection pool (#425)
* [PECO-1857] Use SSL options with HTTPS connection pool Signed-off-by: Levko Kravets <[email protected]> * Some cleanup Signed-off-by: Levko Kravets <[email protected]> * Resolve circular dependencies Signed-off-by: Levko Kravets <[email protected]> * Update existing tests Signed-off-by: Levko Kravets <[email protected]> * Fix MyPy issues Signed-off-by: Levko Kravets <[email protected]> * Fix `_tls_no_verify` handling Signed-off-by: Levko Kravets <[email protected]> * Add tests Signed-off-by: Levko Kravets <[email protected]> --------- Signed-off-by: Levko Kravets <[email protected]>
1 parent 2d2b3c1 commit 1f8cf73

11 files changed

+267
-159
lines changed

src/databricks/sql/auth/thrift_http_client.py

+25-16
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import base64
22
import logging
33
import urllib.parse
4-
from typing import Dict, Union
4+
from typing import Dict, Union, Optional
55

66
import six
77
import thrift
88

9-
logger = logging.getLogger(__name__)
10-
119
import ssl
1210
import warnings
1311
from http.client import HTTPResponse
@@ -16,6 +14,9 @@
1614
from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager
1715
from urllib3.util import make_headers
1816
from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy
17+
from databricks.sql.types import SSLOptions
18+
19+
logger = logging.getLogger(__name__)
1920

2021

2122
class THttpClient(thrift.transport.THttpClient.THttpClient):
@@ -25,13 +26,12 @@ def __init__(
2526
uri_or_host,
2627
port=None,
2728
path=None,
28-
cafile=None,
29-
cert_file=None,
30-
key_file=None,
31-
ssl_context=None,
29+
ssl_options: Optional[SSLOptions] = None,
3230
max_connections: int = 1,
3331
retry_policy: Union[DatabricksRetryPolicy, int] = 0,
3432
):
33+
self._ssl_options = ssl_options
34+
3535
if port is not None:
3636
warnings.warn(
3737
"Please use the THttpClient('http{s}://host:port/path') constructor",
@@ -48,13 +48,11 @@ def __init__(
4848
self.scheme = parsed.scheme
4949
assert self.scheme in ("http", "https")
5050
if self.scheme == "https":
51-
self.certfile = cert_file
52-
self.keyfile = key_file
53-
self.context = (
54-
ssl.create_default_context(cafile=cafile)
55-
if (cafile and not ssl_context)
56-
else ssl_context
57-
)
51+
if self._ssl_options is not None:
52+
# TODO: Not sure if those options are used anywhere - need to double-check
53+
self.certfile = self._ssl_options.tls_client_cert_file
54+
self.keyfile = self._ssl_options.tls_client_cert_key_file
55+
self.context = self._ssl_options.create_ssl_context()
5856
self.port = parsed.port
5957
self.host = parsed.hostname
6058
self.path = parsed.path
@@ -109,12 +107,23 @@ def startRetryTimer(self):
109107
def open(self):
110108

111109
# self.__pool replaces the self.__http used by the original THttpClient
110+
_pool_kwargs = {"maxsize": self.max_connections}
111+
112112
if self.scheme == "http":
113113
pool_class = HTTPConnectionPool
114114
elif self.scheme == "https":
115115
pool_class = HTTPSConnectionPool
116-
117-
_pool_kwargs = {"maxsize": self.max_connections}
116+
_pool_kwargs.update(
117+
{
118+
"cert_reqs": ssl.CERT_REQUIRED
119+
if self._ssl_options.tls_verify
120+
else ssl.CERT_NONE,
121+
"ca_certs": self._ssl_options.tls_trusted_ca_file,
122+
"cert_file": self._ssl_options.tls_client_cert_file,
123+
"key_file": self._ssl_options.tls_client_cert_key_file,
124+
"key_password": self._ssl_options.tls_client_cert_key_password,
125+
}
126+
)
118127

119128
if self.using_proxy():
120129
proxy_manager = ProxyManager(

src/databricks/sql/client.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636

3737

38-
from databricks.sql.types import Row
38+
from databricks.sql.types import Row, SSLOptions
3939
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
4040
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
4141

@@ -178,8 +178,9 @@ def read(self) -> Optional[OAuthToken]:
178178
# _tls_trusted_ca_file
179179
# Set to the path of the file containing trusted CA certificates for server certificate
180180
# verification. If not provide, uses system truststore.
181-
# _tls_client_cert_file, _tls_client_cert_key_file
181+
# _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
182182
# Set client SSL certificate.
183+
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
183184
# _retry_stop_after_attempts_count
184185
# The maximum number of attempts during a request retry sequence (defaults to 24)
185186
# _socket_timeout
@@ -220,12 +221,25 @@ def read(self) -> Optional[OAuthToken]:
220221

221222
base_headers = [("User-Agent", useragent_header)]
222223

224+
self._ssl_options = SSLOptions(
225+
# Double negation is generally a bad thing, but we have to keep backward compatibility
226+
tls_verify=not kwargs.get(
227+
"_tls_no_verify", False
228+
), # by default - verify cert and host
229+
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
230+
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
231+
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
232+
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
233+
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
234+
)
235+
223236
self.thrift_backend = ThriftBackend(
224237
self.host,
225238
self.port,
226239
http_path,
227240
(http_headers or []) + base_headers,
228241
auth_provider,
242+
ssl_options=self._ssl_options,
229243
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
230244
**kwargs,
231245
)

src/databricks/sql/cloudfetch/download_manager.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22

3-
from ssl import SSLContext
43
from concurrent.futures import ThreadPoolExecutor, Future
54
from typing import List, Union
65

@@ -9,6 +8,8 @@
98
DownloadableResultSettings,
109
DownloadedFile,
1110
)
11+
from databricks.sql.types import SSLOptions
12+
1213
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
1314

1415
logger = logging.getLogger(__name__)
@@ -20,7 +21,7 @@ def __init__(
2021
links: List[TSparkArrowResultLink],
2122
max_download_threads: int,
2223
lz4_compressed: bool,
23-
ssl_context: SSLContext,
24+
ssl_options: SSLOptions,
2425
):
2526
self._pending_links: List[TSparkArrowResultLink] = []
2627
for link in links:
@@ -38,7 +39,7 @@ def __init__(
3839
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)
3940

4041
self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
41-
self._ssl_context = ssl_context
42+
self._ssl_options = ssl_options
4243

4344
def get_next_downloaded_file(
4445
self, next_row_offset: int
@@ -95,7 +96,7 @@ def _schedule_downloads(self):
9596
handler = ResultSetDownloadHandler(
9697
settings=self._downloadable_result_settings,
9798
link=link,
98-
ssl_context=self._ssl_context,
99+
ssl_options=self._ssl_options,
99100
)
100101
task = self._thread_pool.submit(handler.run)
101102
self._download_tasks.append(task)

src/databricks/sql/cloudfetch/downloader.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33

44
import requests
55
from requests.adapters import HTTPAdapter, Retry
6-
from ssl import SSLContext, CERT_NONE
76
import lz4.frame
87
import time
98

109
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
11-
1210
from databricks.sql.exc import Error
11+
from databricks.sql.types import SSLOptions
1312

1413
logger = logging.getLogger(__name__)
1514

@@ -66,11 +65,11 @@ def __init__(
6665
self,
6766
settings: DownloadableResultSettings,
6867
link: TSparkArrowResultLink,
69-
ssl_context: SSLContext,
68+
ssl_options: SSLOptions,
7069
):
7170
self.settings = settings
7271
self.link = link
73-
self._ssl_context = ssl_context
72+
self._ssl_options = ssl_options
7473

7574
def run(self) -> DownloadedFile:
7675
"""
@@ -95,14 +94,13 @@ def run(self) -> DownloadedFile:
9594
session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
9695
session.mount("https://", HTTPAdapter(max_retries=retryPolicy))
9796

98-
ssl_verify = self._ssl_context.verify_mode != CERT_NONE
99-
10097
try:
10198
# Get the file via HTTP request
10299
response = session.get(
103100
self.link.fileLink,
104101
timeout=self.settings.download_timeout,
105-
verify=ssl_verify,
102+
verify=self._ssl_options.tls_verify,
103+
# TODO: Pass cert from `self._ssl_options`
106104
)
107105
response.raise_for_status()
108106

src/databricks/sql/thrift_backend.py

+6-37
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import time
66
import uuid
77
import threading
8-
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
98
from typing import List, Union
109

1110
import pyarrow
@@ -36,6 +35,7 @@
3635
convert_decimals_in_arrow_table,
3736
convert_column_based_set_to_arrow_table,
3837
)
38+
from databricks.sql.types import SSLOptions
3939

4040
logger = logging.getLogger(__name__)
4141

@@ -85,6 +85,7 @@ def __init__(
8585
http_path: str,
8686
http_headers,
8787
auth_provider: AuthProvider,
88+
ssl_options: SSLOptions,
8889
staging_allowed_local_path: Union[None, str, List[str]] = None,
8990
**kwargs,
9091
):
@@ -93,16 +94,6 @@ def __init__(
9394
# Tag to add to User-Agent header. For use by partners.
9495
# _username, _password
9596
# Username and password Basic authentication (no official support)
96-
# _tls_no_verify
97-
# Set to True (Boolean) to completely disable SSL verification.
98-
# _tls_verify_hostname
99-
# Set to False (Boolean) to disable SSL hostname verification, but check certificate.
100-
# _tls_trusted_ca_file
101-
# Set to the path of the file containing trusted CA certificates for server certificate
102-
# verification. If not provide, uses system truststore.
103-
# _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
104-
# Set client SSL certificate.
105-
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
10697
# _connection_uri
10798
# Overrides server_hostname and http_path.
10899
# RETRY/ATTEMPT POLICY
@@ -162,29 +153,7 @@ def __init__(
162153
# Cloud fetch
163154
self.max_download_threads = kwargs.get("max_download_threads", 10)
164155

165-
# Configure tls context
166-
ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
167-
if kwargs.get("_tls_no_verify") is True:
168-
ssl_context.check_hostname = False
169-
ssl_context.verify_mode = CERT_NONE
170-
elif kwargs.get("_tls_verify_hostname") is False:
171-
ssl_context.check_hostname = False
172-
ssl_context.verify_mode = CERT_REQUIRED
173-
else:
174-
ssl_context.check_hostname = True
175-
ssl_context.verify_mode = CERT_REQUIRED
176-
177-
tls_client_cert_file = kwargs.get("_tls_client_cert_file")
178-
tls_client_cert_key_file = kwargs.get("_tls_client_cert_key_file")
179-
tls_client_cert_key_password = kwargs.get("_tls_client_cert_key_password")
180-
if tls_client_cert_file:
181-
ssl_context.load_cert_chain(
182-
certfile=tls_client_cert_file,
183-
keyfile=tls_client_cert_key_file,
184-
password=tls_client_cert_key_password,
185-
)
186-
187-
self._ssl_context = ssl_context
156+
self._ssl_options = ssl_options
188157

189158
self._auth_provider = auth_provider
190159

@@ -225,7 +194,7 @@ def __init__(
225194
self._transport = databricks.sql.auth.thrift_http_client.THttpClient(
226195
auth_provider=self._auth_provider,
227196
uri_or_host=uri,
228-
ssl_context=self._ssl_context,
197+
ssl_options=self._ssl_options,
229198
**additional_transport_args, # type: ignore
230199
)
231200

@@ -776,7 +745,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
776745
max_download_threads=self.max_download_threads,
777746
lz4_compressed=lz4_compressed,
778747
description=description,
779-
ssl_context=self._ssl_context,
748+
ssl_options=self._ssl_options,
780749
)
781750
else:
782751
arrow_queue_opt = None
@@ -1008,7 +977,7 @@ def fetch_results(
1008977
max_download_threads=self.max_download_threads,
1009978
lz4_compressed=lz4_compressed,
1010979
description=description,
1011-
ssl_context=self._ssl_context,
980+
ssl_options=self._ssl_options,
1012981
)
1013982

1014983
return queue, resp.hasMoreRows

src/databricks/sql/types.py

+48
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,54 @@
1919
from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar
2020
import datetime
2121
import decimal
22+
from ssl import SSLContext, CERT_NONE, CERT_REQUIRED, create_default_context
23+
24+
25+
class SSLOptions:
26+
tls_verify: bool
27+
tls_verify_hostname: bool
28+
tls_trusted_ca_file: Optional[str]
29+
tls_client_cert_file: Optional[str]
30+
tls_client_cert_key_file: Optional[str]
31+
tls_client_cert_key_password: Optional[str]
32+
33+
def __init__(
34+
self,
35+
tls_verify: bool = True,
36+
tls_verify_hostname: bool = True,
37+
tls_trusted_ca_file: Optional[str] = None,
38+
tls_client_cert_file: Optional[str] = None,
39+
tls_client_cert_key_file: Optional[str] = None,
40+
tls_client_cert_key_password: Optional[str] = None,
41+
):
42+
self.tls_verify = tls_verify
43+
self.tls_verify_hostname = tls_verify_hostname
44+
self.tls_trusted_ca_file = tls_trusted_ca_file
45+
self.tls_client_cert_file = tls_client_cert_file
46+
self.tls_client_cert_key_file = tls_client_cert_key_file
47+
self.tls_client_cert_key_password = tls_client_cert_key_password
48+
49+
def create_ssl_context(self) -> SSLContext:
50+
ssl_context = create_default_context(cafile=self.tls_trusted_ca_file)
51+
52+
if self.tls_verify is False:
53+
ssl_context.check_hostname = False
54+
ssl_context.verify_mode = CERT_NONE
55+
elif self.tls_verify_hostname is False:
56+
ssl_context.check_hostname = False
57+
ssl_context.verify_mode = CERT_REQUIRED
58+
else:
59+
ssl_context.check_hostname = True
60+
ssl_context.verify_mode = CERT_REQUIRED
61+
62+
if self.tls_client_cert_file:
63+
ssl_context.load_cert_chain(
64+
certfile=self.tls_client_cert_file,
65+
keyfile=self.tls_client_cert_key_file,
66+
password=self.tls_client_cert_key_password,
67+
)
68+
69+
return ssl_context
2270

2371

2472
class Row(tuple):

0 commit comments

Comments
 (0)