forked from databricks/databricks-sql-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathauth.py
executable file
·129 lines (112 loc) · 4.66 KB
/
auth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from enum import Enum
from typing import Optional, List
from databricks.sql.auth.authenticators import (
AuthProvider,
AccessTokenAuthProvider,
ExternalAuthProvider,
DatabricksOAuthProvider,
)
class AuthType(Enum):
DATABRICKS_OAUTH = "databricks-oauth"
AZURE_OAUTH = "azure-oauth"
# other supported types (access_token) can be inferred
# we can add more types as needed later
class ClientContext:
def __init__(
self,
hostname: str,
access_token: Optional[str] = None,
auth_type: Optional[str] = None,
oauth_scopes: Optional[List[str]] = None,
oauth_client_id: Optional[str] = None,
oauth_redirect_port_range: Optional[List[int]] = None,
use_cert_as_auth: Optional[str] = None,
tls_client_cert_file: Optional[str] = None,
oauth_persistence=None,
credentials_provider=None,
):
self.hostname = hostname
self.access_token = access_token
self.auth_type = auth_type
self.oauth_scopes = oauth_scopes
self.oauth_client_id = oauth_client_id
self.oauth_redirect_port_range = oauth_redirect_port_range
self.use_cert_as_auth = use_cert_as_auth
self.tls_client_cert_file = tls_client_cert_file
self.oauth_persistence = oauth_persistence
self.credentials_provider = credentials_provider
def get_auth_provider(cfg: ClientContext):
if cfg.credentials_provider:
return ExternalAuthProvider(cfg.credentials_provider)
if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
assert cfg.oauth_redirect_port_range is not None
assert cfg.oauth_client_id is not None
assert cfg.oauth_scopes is not None
return DatabricksOAuthProvider(
cfg.hostname,
cfg.oauth_persistence,
cfg.oauth_redirect_port_range,
cfg.oauth_client_id,
cfg.oauth_scopes,
cfg.auth_type,
)
elif cfg.access_token is not None:
return AccessTokenAuthProvider(cfg.access_token)
elif cfg.use_cert_as_auth and cfg.tls_client_cert_file:
# no op authenticator. authentication is performed using ssl certificate outside of headers
return AuthProvider()
else:
if (
cfg.oauth_redirect_port_range is not None
and cfg.oauth_client_id is not None
and cfg.oauth_scopes is not None
):
return DatabricksOAuthProvider(
cfg.hostname,
cfg.oauth_persistence,
cfg.oauth_redirect_port_range,
cfg.oauth_client_id,
cfg.oauth_scopes,
)
else:
raise RuntimeError("No valid authentication settings!")
PYSQL_OAUTH_SCOPES = ["sql", "offline_access"]
PYSQL_OAUTH_CLIENT_ID = "databricks-sql-python"
PYSQL_OAUTH_AZURE_CLIENT_ID = "96eecda7-19ea-49cc-abb5-240097d554f5"
PYSQL_OAUTH_REDIRECT_PORT_RANGE = list(range(8020, 8025))
PYSQL_OAUTH_AZURE_REDIRECT_PORT_RANGE = [8030]
def normalize_host_name(hostname: str):
maybe_scheme = "https://" if not hostname.startswith("https://") else ""
maybe_trailing_slash = "/" if not hostname.endswith("/") else ""
return f"{maybe_scheme}{hostname}{maybe_trailing_slash}"
def get_client_id_and_redirect_port(use_azure_auth: bool):
return (
(PYSQL_OAUTH_CLIENT_ID, PYSQL_OAUTH_REDIRECT_PORT_RANGE)
if not use_azure_auth
else (PYSQL_OAUTH_AZURE_CLIENT_ID, PYSQL_OAUTH_AZURE_REDIRECT_PORT_RANGE)
)
def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
auth_type = kwargs.get("auth_type")
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
auth_type == AuthType.AZURE_OAUTH.value
)
if kwargs.get("username") or kwargs.get("password"):
raise ValueError(
"Username/password authentication is no longer supported. "
"Please use OAuth or access token instead."
)
cfg = ClientContext(
hostname=normalize_host_name(hostname),
auth_type=auth_type,
access_token=kwargs.get("access_token"),
use_cert_as_auth=kwargs.get("_use_cert_as_auth"),
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
oauth_scopes=PYSQL_OAUTH_SCOPES,
oauth_client_id=kwargs.get("oauth_client_id") or client_id,
oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
else redirect_port_range,
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
credentials_provider=kwargs.get("credentials_provider"),
)
return get_auth_provider(cfg)