forked from databricks/databricks-sql-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathendpoint.py
121 lines (93 loc) · 3.93 KB
/
endpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#
# It implements all the cloud specific OAuth configuration/metadata
#
# Azure: It uses AAD
# AWS: It uses Databricks internal IdP
# GCP: Not support yet
#
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional, List
import os
OIDC_REDIRECTOR_PATH = "oidc"
class OAuthScope:
OFFLINE_ACCESS = "offline_access"
SQL = "sql"
class CloudType(Enum):
AWS = "aws"
AZURE = "azure"
GCP = "gcp"
DATABRICKS_AWS_DOMAINS = [
".cloud.databricks.com",
".cloud.databricks.us",
".dev.databricks.com",
]
DATABRICKS_AZURE_DOMAINS = [
".azuredatabricks.net",
".databricks.azure.cn",
".databricks.azure.us",
]
DATABRICKS_GCP_DOMAINS = [".gcp.databricks.com"]
# Infer cloud type from Databricks SQL instance hostname
def infer_cloud_from_host(hostname: str) -> Optional[CloudType]:
# normalize
host = hostname.lower().replace("https://", "").split("/")[0]
if any(e for e in DATABRICKS_AZURE_DOMAINS if host.endswith(e)):
return CloudType.AZURE
elif any(e for e in DATABRICKS_AWS_DOMAINS if host.endswith(e)):
return CloudType.AWS
elif any(e for e in DATABRICKS_GCP_DOMAINS if host.endswith(e)):
return CloudType.GCP
else:
return None
def get_databricks_oidc_url(hostname: str):
maybe_scheme = "https://" if not hostname.startswith("https://") else ""
maybe_trailing_slash = "/" if not hostname.endswith("/") else ""
return f"{maybe_scheme}{hostname}{maybe_trailing_slash}{OIDC_REDIRECTOR_PATH}"
class OAuthEndpointCollection(ABC):
@abstractmethod
def get_scopes_mapping(self, scopes: List[str]) -> List[str]:
raise NotImplementedError()
# Endpoint for oauth2 authorization e.g https://idp.example.com/oauth2/v2.0/authorize
@abstractmethod
def get_authorization_url(self, hostname: str) -> str:
raise NotImplementedError()
# Endpoint for well-known openid configuration e.g https://idp.example.com/oauth2/.well-known/openid-configuration
@abstractmethod
def get_openid_config_url(self, hostname: str) -> str:
raise NotImplementedError()
class AzureOAuthEndpointCollection(OAuthEndpointCollection):
DATATRICKS_AZURE_APP = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
def get_scopes_mapping(self, scopes: List[str]) -> List[str]:
# There is no corresponding scopes in Azure, instead, access control will be delegated to Databricks
tenant_id = os.getenv(
"DATABRICKS_AZURE_TENANT_ID",
AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP,
)
azure_scope = f"{tenant_id}/user_impersonation"
mapped_scopes = [azure_scope]
if OAuthScope.OFFLINE_ACCESS in scopes:
mapped_scopes.append(OAuthScope.OFFLINE_ACCESS)
return mapped_scopes
def get_authorization_url(self, hostname: str):
# We need get account specific url, which can be redirected by databricks unified oidc endpoint
return f"{get_databricks_oidc_url(hostname)}/oauth2/v2.0/authorize"
def get_openid_config_url(self, hostname: str):
return "https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration"
class InHouseOAuthEndpointCollection(OAuthEndpointCollection):
def get_scopes_mapping(self, scopes: List[str]) -> List[str]:
# No scope mapping in AWS
return scopes.copy()
def get_authorization_url(self, hostname: str):
idp_url = get_databricks_oidc_url(hostname)
return f"{idp_url}/oauth2/v2.0/authorize"
def get_openid_config_url(self, hostname: str):
idp_url = get_databricks_oidc_url(hostname)
return f"{idp_url}/.well-known/oauth-authorization-server"
def get_oauth_endpoints(cloud: CloudType) -> Optional[OAuthEndpointCollection]:
if cloud == CloudType.AWS or cloud == CloudType.GCP:
return InHouseOAuthEndpointCollection()
elif cloud == CloudType.AZURE:
return AzureOAuthEndpointCollection()
else:
return None