diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c15f2f9..18b2208c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Release History +# 3.1.0 (TBD) + +- Fix: `server_hostname` URIs that included `https://` would raise an exception + ## 3.0.1 (2023-12-01) - Other: updated docstring comment about default parameterization approach (#287) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 288c3e10..69ac760a 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -141,9 +141,11 @@ def __init__( if kwargs.get("_connection_uri"): uri = kwargs.get("_connection_uri") elif server_hostname and http_path: - uri = "https://{host}:{port}/{path}".format( - host=server_hostname, port=port, path=http_path.lstrip("/") + uri = "{host}:{port}/{path}".format( + host=server_hostname.rstrip("/"), port=port, path=http_path.lstrip("/") ) + if not uri.startswith("https://"): + uri = "https://" + uri else: raise ValueError("No valid connection settings.") diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index e641ed21..92c664a0 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -212,6 +212,18 @@ def test_port_and_host_are_respected(self, t_http_client_class): self.assertEqual(t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value") + @patch("databricks.sql.auth.thrift_http_client.THttpClient") + def test_host_with_https_does_not_duplicate(self, t_http_client_class): + ThriftBackend("https://hostname", 123, "path_value", [], auth_provider=AuthProvider()) + self.assertEqual(t_http_client_class.call_args[1]["uri_or_host"], + "https://hostname:123/path_value") + + @patch("databricks.sql.auth.thrift_http_client.THttpClient") + def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_class): + ThriftBackend("https://hostname/", 123, "path_value", [], auth_provider=AuthProvider()) + self.assertEqual(t_http_client_class.call_args[1]["uri_or_host"], + "https://hostname:123/path_value") + @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_socket_timeout_is_propagated(self, t_http_client_class): ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=129)