Skip to content

[PECO-1411] Support Databricks OAuth on GCP #338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# 3.0.3 (TBD)

- Add support in-house OAuth on GCP (#338)
- Revised docstrings and examples for OAuth (#339)

# 3.0.2 (2024-01-25)
Expand Down
3 changes: 2 additions & 1 deletion src/databricks/sql/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ def normalize_host_name(hostname: str):


def get_client_id_and_redirect_port(hostname: str):
cloud_type = infer_cloud_from_host(hostname)
return (
(PYSQL_OAUTH_CLIENT_ID, PYSQL_OAUTH_REDIRECT_PORT_RANGE)
if infer_cloud_from_host(hostname) == CloudType.AWS
if cloud_type == CloudType.AWS or cloud_type == CloudType.GCP
else (PYSQL_OAUTH_AZURE_CLIENT_ID, PYSQL_OAUTH_AZURE_REDIRECT_PORT_RANGE)
)

Expand Down
10 changes: 7 additions & 3 deletions src/databricks/sql/auth/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class OAuthScope:
class CloudType(Enum):
AWS = "aws"
AZURE = "azure"
GCP = "gcp"


DATABRICKS_AWS_DOMAINS = [
Expand All @@ -34,6 +35,7 @@ class CloudType(Enum):
".databricks.azure.cn",
".databricks.azure.us",
]
DATABRICKS_GCP_DOMAINS = [".gcp.databricks.com"]


# Infer cloud type from Databricks SQL instance hostname
Expand All @@ -45,6 +47,8 @@ def infer_cloud_from_host(hostname: str) -> Optional[CloudType]:
return CloudType.AZURE
elif any(e for e in DATABRICKS_AWS_DOMAINS if host.endswith(e)):
return CloudType.AWS
elif any(e for e in DATABRICKS_GCP_DOMAINS if host.endswith(e)):
return CloudType.GCP
else:
return None

Expand Down Expand Up @@ -94,7 +98,7 @@ def get_openid_config_url(self, hostname: str):
return "https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration"


class AwsOAuthEndpointCollection(OAuthEndpointCollection):
class InHouseOAuthEndpointCollection(OAuthEndpointCollection):
def get_scopes_mapping(self, scopes: List[str]) -> List[str]:
# No scope mapping in AWS
return scopes.copy()
Expand All @@ -109,8 +113,8 @@ def get_openid_config_url(self, hostname: str):


def get_oauth_endpoints(cloud: CloudType) -> Optional[OAuthEndpointCollection]:
if cloud == CloudType.AWS:
return AwsOAuthEndpointCollection()
if cloud == CloudType.AWS or cloud == CloudType.GCP:
return InHouseOAuthEndpointCollection()
elif cloud == CloudType.AZURE:
return AzureOAuthEndpointCollection()
else:
Expand Down
7 changes: 4 additions & 3 deletions tests/unit/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.auth.oauth import OAuthManager
from databricks.sql.auth.authenticators import DatabricksOAuthProvider
from databricks.sql.auth.endpoint import CloudType, AwsOAuthEndpointCollection, AzureOAuthEndpointCollection
from databricks.sql.auth.endpoint import CloudType, InHouseOAuthEndpointCollection, AzureOAuthEndpointCollection
from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory
from databricks.sql.experimental.oauth_persistence import OAuthPersistenceCache

Expand Down Expand Up @@ -55,9 +55,10 @@ def test_oauth_auth_provider(self, mock_get_tokens, mock_check_and_refresh):
mock_get_tokens.return_value = (access_token, refresh_token)
mock_check_and_refresh.return_value = (access_token, refresh_token, False)

params = [(CloudType.AWS, "foo.cloud.databricks.com", AwsOAuthEndpointCollection, "offline_access sql"),
params = [(CloudType.AWS, "foo.cloud.databricks.com", InHouseOAuthEndpointCollection, "offline_access sql"),
(CloudType.AZURE, "foo.1.azuredatabricks.net", AzureOAuthEndpointCollection,
f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation offline_access")]
f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation offline_access"),
(CloudType.GCP, "foo.gcp.databricks.com", InHouseOAuthEndpointCollection, "offline_access sql")]

for cloud_type, host, expected_endpoint_type, expected_scopes in params:
with self.subTest(cloud_type.value):
Expand Down