Skip to content

Commit 6f224b3

Browse files
committed
Resolve circular dependencies
Signed-off-by: Levko Kravets <[email protected]>
1 parent a7be4cc commit 6f224b3

File tree

7 files changed

+54
-56
lines changed

7 files changed

+54
-56
lines changed

src/databricks/sql/auth/thrift_http_client.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager
1515
from urllib3.util import make_headers
1616
from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy
17-
from databricks.sql.utils import SSLOptions
17+
from databricks.sql.types import SSLOptions
1818

1919
logger = logging.getLogger(__name__)
2020

src/databricks/sql/client.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
ParamEscaper,
2323
inject_parameters,
2424
transform_paramstyle,
25-
SSLOptions,
2625
)
2726
from databricks.sql.parameters.native import (
2827
DbsqlParameterBase,
@@ -36,7 +35,7 @@
3635
)
3736

3837

39-
from databricks.sql.types import Row
38+
from databricks.sql.types import Row, SSLOptions
4039
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
4140
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
4241

src/databricks/sql/cloudfetch/download_manager.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22

3-
from ssl import SSLContext
43
from concurrent.futures import ThreadPoolExecutor, Future
54
from typing import List, Union
65

@@ -9,7 +8,7 @@
98
DownloadableResultSettings,
109
DownloadedFile,
1110
)
12-
from databricks.sql.utils import SSLOptions
11+
from databricks.sql.types import SSLOptions
1312

1413
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
1514

src/databricks/sql/cloudfetch/downloader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
1010
from databricks.sql.exc import Error
11-
from databricks.sql.utils import SSLOptions
11+
from databricks.sql.types import SSLOptions
1212

1313
logger = logging.getLogger(__name__)
1414

src/databricks/sql/thrift_backend.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import time
66
import uuid
77
import threading
8-
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
98
from typing import List, Union
109

1110
import pyarrow
@@ -35,8 +34,8 @@
3534
convert_arrow_based_set_to_arrow_table,
3635
convert_decimals_in_arrow_table,
3736
convert_column_based_set_to_arrow_table,
38-
SSLOptions,
3937
)
38+
from databricks.sql.types import SSLOptions
4039

4140
logger = logging.getLogger(__name__)
4241

src/databricks/sql/types.py

+48
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,54 @@
1919
from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar
2020
import datetime
2121
import decimal
22+
from ssl import SSLContext, CERT_NONE, CERT_REQUIRED, create_default_context
23+
24+
25+
class SSLOptions:
26+
tls_verify: bool
27+
tls_verify_hostname: bool
28+
tls_trusted_ca_file: Optional[str]
29+
tls_client_cert_file: Optional[str]
30+
tls_client_cert_key_file: Optional[str]
31+
tls_client_cert_key_password: Optional[str]
32+
33+
def __init__(
34+
self,
35+
tls_verify: Optional[bool] = True,
36+
tls_verify_hostname: Optional[bool] = True,
37+
tls_trusted_ca_file: Optional[str] = None,
38+
tls_client_cert_file: Optional[str] = None,
39+
tls_client_cert_key_file: Optional[str] = None,
40+
tls_client_cert_key_password: Optional[str] = None,
41+
):
42+
self.tls_verify = tls_verify
43+
self.tls_verify_hostname = tls_verify_hostname
44+
self.tls_trusted_ca_file = tls_trusted_ca_file
45+
self.tls_client_cert_file = tls_client_cert_file
46+
self.tls_client_cert_key_file = tls_client_cert_key_file
47+
self.tls_client_cert_key_password = tls_client_cert_key_password
48+
49+
def create_ssl_context(self) -> SSLContext:
50+
ssl_context = create_default_context(cafile=self.tls_trusted_ca_file)
51+
52+
if self.tls_verify is False:
53+
ssl_context.check_hostname = False
54+
ssl_context.verify_mode = CERT_NONE
55+
elif self.tls_verify_hostname is False:
56+
ssl_context.check_hostname = False
57+
ssl_context.verify_mode = CERT_REQUIRED
58+
else:
59+
ssl_context.check_hostname = True
60+
ssl_context.verify_mode = CERT_REQUIRED
61+
62+
if self.tls_client_cert_file:
63+
ssl_context.load_cert_chain(
64+
certfile=self.tls_client_cert_file,
65+
keyfile=self.tls_client_cert_key_file,
66+
password=self.tls_client_cert_key_password,
67+
)
68+
69+
return ssl_context
2270

2371

2472
class Row(tuple):

src/databricks/sql/utils.py

+1-48
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from enum import Enum
1010
from typing import Any, Dict, List, Optional, Union
1111
import re
12-
from ssl import SSLContext, CERT_NONE, CERT_REQUIRED, create_default_context
1312

1413
import lz4.frame
1514
import pyarrow
@@ -21,6 +20,7 @@
2120
TSparkArrowResultLink,
2221
TSparkRowSetType,
2322
)
23+
from databricks.sql.types import SSLOptions
2424

2525
from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter
2626

@@ -31,53 +31,6 @@
3131
logger = logging.getLogger(__name__)
3232

3333

34-
class SSLOptions:
35-
tls_verify: bool
36-
tls_verify_hostname: bool
37-
tls_trusted_ca_file: Optional[str]
38-
tls_client_cert_file: Optional[str]
39-
tls_client_cert_key_file: Optional[str]
40-
tls_client_cert_key_password: Optional[str]
41-
42-
def __init__(
43-
self,
44-
tls_verify: Optional[bool] = True,
45-
tls_verify_hostname: Optional[bool] = True,
46-
tls_trusted_ca_file: Optional[str] = None,
47-
tls_client_cert_file: Optional[str] = None,
48-
tls_client_cert_key_file: Optional[str] = None,
49-
tls_client_cert_key_password: Optional[str] = None,
50-
):
51-
self.tls_verify = tls_verify
52-
self.tls_verify_hostname = tls_verify_hostname
53-
self.tls_trusted_ca_file = tls_trusted_ca_file
54-
self.tls_client_cert_file = tls_client_cert_file
55-
self.tls_client_cert_key_file = tls_client_cert_key_file
56-
self.tls_client_cert_key_password = tls_client_cert_key_password
57-
58-
def create_ssl_context(self) -> SSLContext:
59-
ssl_context = create_default_context(cafile=self.tls_trusted_ca_file)
60-
61-
if self.tls_verify is False:
62-
ssl_context.check_hostname = False
63-
ssl_context.verify_mode = CERT_NONE
64-
elif self.tls_verify_hostname is False:
65-
ssl_context.check_hostname = False
66-
ssl_context.verify_mode = CERT_REQUIRED
67-
else:
68-
ssl_context.check_hostname = True
69-
ssl_context.verify_mode = CERT_REQUIRED
70-
71-
if self.tls_client_cert_file:
72-
ssl_context.load_cert_chain(
73-
certfile=self.tls_client_cert_file,
74-
keyfile=self.tls_client_cert_key_file,
75-
password=self.tls_client_cert_key_password,
76-
)
77-
78-
return ssl_context
79-
80-
8134
class ResultSetQueue(ABC):
8235
@abstractmethod
8336
def next_n_rows(self, num_rows: int) -> pyarrow.Table:

0 commit comments

Comments
 (0)