|
9 | 9 | from enum import Enum
|
10 | 10 | from typing import Any, Dict, List, Optional, Union
|
11 | 11 | import re
|
12 |
| -from ssl import SSLContext, CERT_NONE, CERT_REQUIRED, create_default_context |
13 | 12 |
|
14 | 13 | import lz4.frame
|
15 | 14 | import pyarrow
|
|
21 | 20 | TSparkArrowResultLink,
|
22 | 21 | TSparkRowSetType,
|
23 | 22 | )
|
| 23 | +from databricks.sql.types import SSLOptions |
24 | 24 |
|
25 | 25 | from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter
|
26 | 26 |
|
|
31 | 31 | logger = logging.getLogger(__name__)
|
32 | 32 |
|
33 | 33 |
|
34 |
| -class SSLOptions: |
35 |
| - tls_verify: bool |
36 |
| - tls_verify_hostname: bool |
37 |
| - tls_trusted_ca_file: Optional[str] |
38 |
| - tls_client_cert_file: Optional[str] |
39 |
| - tls_client_cert_key_file: Optional[str] |
40 |
| - tls_client_cert_key_password: Optional[str] |
41 |
| - |
42 |
| - def __init__( |
43 |
| - self, |
44 |
| - tls_verify: Optional[bool] = True, |
45 |
| - tls_verify_hostname: Optional[bool] = True, |
46 |
| - tls_trusted_ca_file: Optional[str] = None, |
47 |
| - tls_client_cert_file: Optional[str] = None, |
48 |
| - tls_client_cert_key_file: Optional[str] = None, |
49 |
| - tls_client_cert_key_password: Optional[str] = None, |
50 |
| - ): |
51 |
| - self.tls_verify = tls_verify |
52 |
| - self.tls_verify_hostname = tls_verify_hostname |
53 |
| - self.tls_trusted_ca_file = tls_trusted_ca_file |
54 |
| - self.tls_client_cert_file = tls_client_cert_file |
55 |
| - self.tls_client_cert_key_file = tls_client_cert_key_file |
56 |
| - self.tls_client_cert_key_password = tls_client_cert_key_password |
57 |
| - |
58 |
| - def create_ssl_context(self) -> SSLContext: |
59 |
| - ssl_context = create_default_context(cafile=self.tls_trusted_ca_file) |
60 |
| - |
61 |
| - if self.tls_verify is False: |
62 |
| - ssl_context.check_hostname = False |
63 |
| - ssl_context.verify_mode = CERT_NONE |
64 |
| - elif self.tls_verify_hostname is False: |
65 |
| - ssl_context.check_hostname = False |
66 |
| - ssl_context.verify_mode = CERT_REQUIRED |
67 |
| - else: |
68 |
| - ssl_context.check_hostname = True |
69 |
| - ssl_context.verify_mode = CERT_REQUIRED |
70 |
| - |
71 |
| - if self.tls_client_cert_file: |
72 |
| - ssl_context.load_cert_chain( |
73 |
| - certfile=self.tls_client_cert_file, |
74 |
| - keyfile=self.tls_client_cert_key_file, |
75 |
| - password=self.tls_client_cert_key_password, |
76 |
| - ) |
77 |
| - |
78 |
| - return ssl_context |
79 |
| - |
80 |
| - |
81 | 34 | class ResultSetQueue(ABC):
|
82 | 35 | @abstractmethod
|
83 | 36 | def next_n_rows(self, num_rows: int) -> pyarrow.Table:
|
|
0 commit comments