|
4 | 4 | import unittest
|
5 | 5 | from unittest.mock import patch, MagicMock, Mock
|
6 | 6 | from ssl import CERT_NONE, CERT_REQUIRED
|
| 7 | +from urllib3 import HTTPSConnectionPool |
7 | 8 |
|
8 | 9 | import pyarrow
|
9 | 10 |
|
@@ -208,6 +209,61 @@ def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_
|
208 | 209 | self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED)
|
209 | 210 | self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options)
|
210 | 211 |
|
| 212 | + @patch("databricks.sql.types.create_default_context") |
| 213 | + def test_tls_cert_args_are_used_by_http_client(self, mock_create_default_context): |
| 214 | + from databricks.sql.auth.thrift_http_client import THttpClient |
| 215 | + |
| 216 | + mock_cert_key_file = Mock() |
| 217 | + mock_cert_key_password = Mock() |
| 218 | + mock_trusted_ca_file = Mock() |
| 219 | + mock_cert_file = Mock() |
| 220 | + |
| 221 | + mock_ssl_options = SSLOptions( |
| 222 | + tls_verify=True, |
| 223 | + tls_client_cert_file=mock_cert_file, |
| 224 | + tls_client_cert_key_file=mock_cert_key_file, |
| 225 | + tls_client_cert_key_password=mock_cert_key_password, |
| 226 | + tls_trusted_ca_file=mock_trusted_ca_file, |
| 227 | + ) |
| 228 | + |
| 229 | + http_client = THttpClient( |
| 230 | + auth_provider=None, |
| 231 | + uri_or_host="https://example.com", |
| 232 | + ssl_options=mock_ssl_options, |
| 233 | + ) |
| 234 | + |
| 235 | + self.assertEqual(http_client.scheme, 'https') |
| 236 | + self.assertEqual(http_client.certfile, mock_ssl_options.tls_client_cert_file) |
| 237 | + self.assertEqual(http_client.keyfile, mock_ssl_options.tls_client_cert_key_file) |
| 238 | + self.assertIsNotNone(http_client.certfile) |
| 239 | + mock_create_default_context.assert_called() |
| 240 | + |
| 241 | + http_client.open() |
| 242 | + |
| 243 | + conn_pool = http_client._THttpClient__pool |
| 244 | + self.assertIsInstance(conn_pool, HTTPSConnectionPool) |
| 245 | + self.assertEqual(conn_pool.cert_reqs, CERT_REQUIRED) |
| 246 | + self.assertEqual(conn_pool.ca_certs, mock_ssl_options.tls_trusted_ca_file) |
| 247 | + self.assertEqual(conn_pool.cert_file, mock_ssl_options.tls_client_cert_file) |
| 248 | + self.assertEqual(conn_pool.key_file, mock_ssl_options.tls_client_cert_key_file) |
| 249 | + self.assertEqual(conn_pool.key_password, mock_ssl_options.tls_client_cert_key_password) |
| 250 | + |
| 251 | + def test_tls_no_verify_is_respected_by_http_client(self): |
| 252 | + from databricks.sql.auth.thrift_http_client import THttpClient |
| 253 | + |
| 254 | + http_client = THttpClient( |
| 255 | + auth_provider=None, |
| 256 | + uri_or_host="https://example.com", |
| 257 | + ssl_options=SSLOptions(tls_verify=False), |
| 258 | + ) |
| 259 | + self.assertEqual(http_client.scheme, 'https') |
| 260 | + |
| 261 | + http_client.open() |
| 262 | + |
| 263 | + conn_pool = http_client._THttpClient__pool |
| 264 | + self.assertIsInstance(conn_pool, HTTPSConnectionPool) |
| 265 | + self.assertEqual(conn_pool.cert_reqs, CERT_NONE) |
| 266 | + |
211 | 267 | @patch("databricks.sql.auth.thrift_http_client.THttpClient")
|
212 | 268 | @patch("databricks.sql.types.create_default_context")
|
213 | 269 | def test_tls_no_verify_is_respected(self, mock_create_default_context, t_http_client_class):
|
|
0 commit comments