Skip to content

Commit cd0370b

Browse files
committed
[PECO-1411] Support OAuth InHouse on GCP
Signed-off-by: Jacky Hu <[email protected]>
1 parent 3f6834c commit cd0370b

File tree

3 files changed

+15
-7
lines changed

3 files changed

+15
-7
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,10 @@ def normalize_host_name(hostname: str):
8888

8989

9090
def get_client_id_and_redirect_port(hostname: str):
91+
cloud_type = infer_cloud_from_host(hostname)
9192
return (
9293
(PYSQL_OAUTH_CLIENT_ID, PYSQL_OAUTH_REDIRECT_PORT_RANGE)
93-
if infer_cloud_from_host(hostname) == CloudType.AWS
94+
if cloud_type == CloudType.AWS or cloud_type == CloudType.GCP
9495
else (PYSQL_OAUTH_AZURE_CLIENT_ID, PYSQL_OAUTH_AZURE_REDIRECT_PORT_RANGE)
9596
)
9697

src/databricks/sql/auth/endpoint.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class OAuthScope:
2121
class CloudType(Enum):
2222
AWS = "aws"
2323
AZURE = "azure"
24+
GCP = "gcp"
2425

2526

2627
DATABRICKS_AWS_DOMAINS = [
@@ -34,6 +35,9 @@ class CloudType(Enum):
3435
".databricks.azure.cn",
3536
".databricks.azure.us",
3637
]
38+
DATABRICKS_GCP_DOMAINS = [
39+
".gcp.databricks.com"
40+
]
3741

3842

3943
# Infer cloud type from Databricks SQL instance hostname
@@ -45,6 +49,8 @@ def infer_cloud_from_host(hostname: str) -> Optional[CloudType]:
4549
return CloudType.AZURE
4650
elif any(e for e in DATABRICKS_AWS_DOMAINS if host.endswith(e)):
4751
return CloudType.AWS
52+
elif any(e for e in DATABRICKS_GCP_DOMAINS if host.endswith(e)):
53+
return CloudType.GCP
4854
else:
4955
return None
5056

@@ -94,7 +100,7 @@ def get_openid_config_url(self, hostname: str):
94100
return "https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration"
95101

96102

97-
class AwsOAuthEndpointCollection(OAuthEndpointCollection):
103+
class InHouseOAuthEndpointCollection(OAuthEndpointCollection):
98104
def get_scopes_mapping(self, scopes: List[str]) -> List[str]:
99105
# No scope mapping in AWS
100106
return scopes.copy()
@@ -109,8 +115,8 @@ def get_openid_config_url(self, hostname: str):
109115

110116

111117
def get_oauth_endpoints(cloud: CloudType) -> Optional[OAuthEndpointCollection]:
112-
if cloud == CloudType.AWS:
113-
return AwsOAuthEndpointCollection()
118+
if cloud == CloudType.AWS or cloud == CloudType.GCP:
119+
return InHouseOAuthEndpointCollection()
114120
elif cloud == CloudType.AZURE:
115121
return AzureOAuthEndpointCollection()
116122
else:

tests/unit/test_auth.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
88
from databricks.sql.auth.oauth import OAuthManager
99
from databricks.sql.auth.authenticators import DatabricksOAuthProvider
10-
from databricks.sql.auth.endpoint import CloudType, AwsOAuthEndpointCollection, AzureOAuthEndpointCollection
10+
from databricks.sql.auth.endpoint import CloudType, InHouseOAuthEndpointCollection, AzureOAuthEndpointCollection
1111
from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory
1212
from databricks.sql.experimental.oauth_persistence import OAuthPersistenceCache
1313

@@ -55,9 +55,10 @@ def test_oauth_auth_provider(self, mock_get_tokens, mock_check_and_refresh):
5555
mock_get_tokens.return_value = (access_token, refresh_token)
5656
mock_check_and_refresh.return_value = (access_token, refresh_token, False)
5757

58-
params = [(CloudType.AWS, "foo.cloud.databricks.com", AwsOAuthEndpointCollection, "offline_access sql"),
58+
params = [(CloudType.AWS, "foo.cloud.databricks.com", InHouseOAuthEndpointCollection, "offline_access sql"),
5959
(CloudType.AZURE, "foo.1.azuredatabricks.net", AzureOAuthEndpointCollection,
60-
f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation offline_access")]
60+
f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation offline_access"),
61+
(CloudType.GCP, "foo.gcp.databricks.com", InHouseOAuthEndpointCollection, "offline_access sql")]
6162

6263
for cloud_type, host, expected_endpoint_type, expected_scopes in params:
6364
with self.subTest(cloud_type.value):

0 commit comments

Comments
 (0)