Skip to content

Commit bafef75

Browse files
committed
Add Databricks SQL Token Federation examples and enhance authentication with ClientCredentialsProvider
- Introduced a new script for demonstrating various token federation flows in Databricks SQL. - Implemented ClientCredentialsProvider for machine-to-machine authentication, supporting Azure and Databricks service principal flows. - Refactored token federation handling to allow integration with existing authentication methods. - Updated the DatabricksTokenFederationProvider to improve token exchange logic and error handling.
1 parent 4c5bce1 commit bafef75

File tree

5 files changed

+341
-139
lines changed

5 files changed

+341
-139
lines changed

examples/token_federation_examples.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""
2+
Databricks SQL Token Federation Examples
3+
4+
This script token federation flows:
5+
1. U2M + Account-wide federation
6+
2. U2M + Workflow-level federation
7+
3. M2M + Account-wide federation
8+
4. M2M + Workflow-level federation
9+
5. Access Token + Workflow-level federation
10+
6. Access Token + Account-wide federation
11+
12+
Token Federation Documentation:
13+
------------------------------
14+
For detailed setup instructions, refer to the official Databricks documentation:
15+
16+
- General Token Federation Overview:
17+
https://docs.databricks.com/aws/en/dev-tools/auth/oauth-federation.html
18+
19+
- Token Exchange Process:
20+
https://docs.databricks.com/aws/en/dev-tools/auth/oauth-federation-howto.html
21+
22+
- Azure OAuth Token Federation:
23+
https://learn.microsoft.com/en-us/azure/databricks/dev-tools/auth/oauth-federation
24+
25+
Environment variables required:
26+
- DATABRICKS_HOST: Databricks workspace hostname
27+
- DATABRICKS_HTTP_PATH: HTTP path for the SQL warehouse
28+
- AZURE_TENANT_ID: Azure tenant ID
29+
- AZURE_CLIENT_ID: Azure client ID for service principal
30+
- AZURE_CLIENT_SECRET: Azure client secret
31+
- DATABRICKS_SERVICE_PRINCIPAL_ID: Databricks service principal ID for workflow federation
32+
"""
33+
34+
import os
35+
from databricks import sql
36+
37+
def run_query(connection, description):
38+
cursor = connection.cursor()
39+
cursor.execute("SELECT 1+1 AS result")
40+
result = cursor.fetchall()
41+
print(f"Query result: {result[0][0]}")
42+
43+
cursor.close()
44+
45+
def demonstrate_m2m_federation(env_vars, use_workflow_federation=False):
46+
"""Demonstrate M2M (service principal) token federation"""
47+
48+
connection_params = {
49+
"server_hostname": env_vars["DATABRICKS_HOST"],
50+
"http_path": env_vars["DATABRICKS_HTTP_PATH"],
51+
"auth_type": "client-credentials",
52+
"oauth_client_id": env_vars["AZURE_CLIENT_ID"],
53+
"client_secret": env_vars["AZURE_CLIENT_SECRET"],
54+
"tenant_id": env_vars["AZURE_TENANT_ID"],
55+
"use_token_federation": True
56+
}
57+
58+
if use_workflow_federation and env_vars["DATABRICKS_SERVICE_PRINCIPAL_ID"]:
59+
connection_params["identity_federation_client_id"] = env_vars["DATABRICKS_SERVICE_PRINCIPAL_ID"]
60+
description = "M2M + Workflow-level Federation"
61+
else:
62+
description = "M2M + Account-wide Federation"
63+
64+
with sql.connect(**connection_params) as connection:
65+
run_query(connection, description)
66+
67+
68+
def demonstrate_u2m_federation(env_vars, use_workflow_federation=False):
69+
"""Demonstrate U2M (interactive) token federation"""
70+
71+
connection_params = {
72+
"server_hostname": env_vars["DATABRICKS_HOST"],
73+
"http_path": env_vars["DATABRICKS_HTTP_PATH"],
74+
"auth_type": "databricks-oauth", # Will open browser for interactive auth
75+
"use_token_federation": True
76+
}
77+
78+
if use_workflow_federation and env_vars["DATABRICKS_SERVICE_PRINCIPAL_ID"]:
79+
connection_params["identity_federation_client_id"] = env_vars["DATABRICKS_SERVICE_PRINCIPAL_ID"]
80+
description = "U2M + Workflow-level Federation (Interactive)"
81+
else:
82+
description = "U2M + Account-wide Federation (Interactive)"
83+
84+
# This will open a browser for interactive auth
85+
with sql.connect(**connection_params) as connection:
86+
run_query(connection, description)
87+
88+
def demonstrate_access_token_federation(env_vars):
89+
"""Demonstrate access token token federation"""
90+
91+
access_token = os.environ.get("ACCESS_TOKEN") # This is to demonstrate a token obtained from an identity provider
92+
93+
connection_params = {
94+
"server_hostname": env_vars["DATABRICKS_HOST"],
95+
"http_path": env_vars["DATABRICKS_HTTP_PATH"],
96+
"access_token": access_token,
97+
"use_token_federation": True
98+
}
99+
100+
# Add workflow federation if available
101+
if env_vars["DATABRICKS_SERVICE_PRINCIPAL_ID"]:
102+
connection_params["identity_federation_client_id"] = env_vars["DATABRICKS_SERVICE_PRINCIPAL_ID"]
103+
description = "Access Token + Workflow-level Federation"
104+
else:
105+
description = "Access Token + Account-wide Federation"
106+
107+
with sql.connect(**connection_params) as connection:
108+
run_query(connection, description)
109+

src/databricks/sql/auth/auth.py

Lines changed: 83 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,15 @@
55
AuthProvider,
66
AccessTokenAuthProvider,
77
ExternalAuthProvider,
8-
CredentialsProvider,
98
DatabricksOAuthProvider,
9+
ClientCredentialsProvider,
1010
)
1111

1212

1313
class AuthType(Enum):
1414
DATABRICKS_OAUTH = "databricks-oauth"
1515
AZURE_OAUTH = "azure-oauth"
16-
# TODO: Token federation should be a feature that works with different auth types,
17-
# not an auth type itself. This will be refactored in a future change.
18-
# We will add a use_token_federation flag that can be used with any auth type.
19-
TOKEN_FEDERATION = "token-federation"
20-
# other supported types (access_token) can be inferred
21-
# we can add more types as needed later
22-
16+
CLIENT_CREDENTIALS = "client-credentials"
2317

2418
class ClientContext:
2519
def __init__(
@@ -34,8 +28,10 @@ def __init__(
3428
tls_client_cert_file: Optional[str] = None,
3529
oauth_persistence=None,
3630
credentials_provider=None,
37-
identity_federation_client_id: Optional[str] = None,
31+
oauth_client_secret: Optional[str] = None,
32+
tenant_id: Optional[str] = None,
3833
use_token_federation: bool = False,
34+
identity_federation_client_id: Optional[str] = None,
3935
):
4036
self.hostname = hostname
4137
self.access_token = access_token
@@ -49,20 +45,52 @@ def __init__(
4945
self.credentials_provider = credentials_provider
5046
self.identity_federation_client_id = identity_federation_client_id
5147
self.use_token_federation = use_token_federation
48+
self.oauth_client_secret = oauth_client_secret
49+
self.tenant_id = tenant_id
50+
51+
def _create_azure_client_credentials_provider(cfg: ClientContext) -> ClientCredentialsProvider:
52+
"""Create an Azure client credentials provider."""
53+
if not cfg.oauth_client_id or not cfg.oauth_client_secret or not cfg.tenant_id:
54+
raise ValueError("Azure client credentials flow requires oauth_client_id, oauth_client_secret, and tenant_id")
55+
56+
token_endpoint = "https://login.microsoftonline.com/{}/oauth2/v2.0/token".format(cfg.tenant_id)
57+
return ClientCredentialsProvider(
58+
client_id=cfg.oauth_client_id,
59+
client_secret=cfg.oauth_client_secret,
60+
token_endpoint=token_endpoint,
61+
auth_type_value="azure-client-credentials"
62+
)
63+
64+
65+
def _create_databricks_client_credentials_provider(cfg: ClientContext) -> ClientCredentialsProvider:
66+
"""Create a Databricks client credentials provider for service principals."""
67+
if not cfg.oauth_client_id or not cfg.oauth_client_secret:
68+
raise ValueError("Databricks client credentials flow requires oauth_client_id and oauth_client_secret")
69+
70+
token_endpoint = "{}oidc/v1/token".format(cfg.hostname)
71+
return ClientCredentialsProvider(
72+
client_id=cfg.oauth_client_id,
73+
client_secret=cfg.oauth_client_secret,
74+
token_endpoint=token_endpoint,
75+
auth_type_value="client-credentials"
76+
)
5277

5378

5479
def get_auth_provider(cfg: ClientContext):
5580
"""
5681
Get an appropriate auth provider based on the provided configuration.
5782
83+
OAuth Flow Support:
84+
This function supports multiple OAuth flows:
85+
1. Interactive OAuth (databricks-oauth, azure-oauth) - for user authentication
86+
2. Client Credentials (client-credentials) - for machine-to-machine authentication
87+
3. Token Federation - implemented as a feature flag that wraps any auth type
88+
5889
Token Federation Support:
5990
-----------------------
60-
Currently, token federation is implemented as a separate auth type, but the goal is to
61-
refactor it as a feature that can work with any auth type. The current implementation
62-
is maintained for backward compatibility while the refactoring is planned.
63-
64-
Future refactoring will introduce a `use_token_federation` flag that can be combined
65-
with any auth type to enable token federation.
91+
Token federation is implemented as a feature flag (`use_token_federation=True`) that
92+
can be combined with any auth type. When enabled, it wraps the base auth provider
93+
in a DatabricksTokenFederationProvider for token exchange functionality.
6694
6795
Args:
6896
cfg: The client context containing configuration parameters
@@ -74,21 +102,31 @@ def get_auth_provider(cfg: ClientContext):
74102
RuntimeError: If no valid authentication settings are provided
75103
"""
76104
from databricks.sql.auth.token_federation import DatabricksTokenFederationProvider
105+
106+
base_provider = None
107+
77108
if cfg.credentials_provider:
78109
base_provider = ExternalAuthProvider(cfg.credentials_provider)
79110
elif cfg.access_token is not None:
80111
base_provider = AccessTokenAuthProvider(cfg.access_token)
112+
elif cfg.auth_type == AuthType.CLIENT_CREDENTIALS.value:
113+
if cfg.tenant_id:
114+
# Azure client credentials flow
115+
base_provider = _create_azure_client_credentials_provider(cfg)
116+
else:
117+
# Databricks service principal client credentials flow
118+
base_provider = _create_databricks_client_credentials_provider(cfg)
81119
elif cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
82120
assert cfg.oauth_redirect_port_range is not None
83121
assert cfg.oauth_client_id is not None
84122
assert cfg.oauth_scopes is not None
85123
base_provider = DatabricksOAuthProvider(
86-
cfg.hostname,
87-
cfg.oauth_persistence,
88-
cfg.oauth_redirect_port_range,
89-
cfg.oauth_client_id,
90-
cfg.oauth_scopes,
91-
cfg.auth_type,
124+
hostname=cfg.hostname,
125+
oauth_persistence=cfg.oauth_persistence,
126+
redirect_port_range=cfg.oauth_redirect_port_range,
127+
client_id=cfg.oauth_client_id,
128+
scopes=cfg.oauth_scopes,
129+
auth_type=cfg.auth_type,
92130
)
93131
elif cfg.use_cert_as_auth and cfg.tls_client_cert_file:
94132
base_provider = AuthProvider()
@@ -99,11 +137,11 @@ def get_auth_provider(cfg: ClientContext):
99137
and cfg.oauth_scopes is not None
100138
):
101139
base_provider = DatabricksOAuthProvider(
102-
cfg.hostname,
103-
cfg.oauth_persistence,
104-
cfg.oauth_redirect_port_range,
105-
cfg.oauth_client_id,
106-
cfg.oauth_scopes,
140+
hostname=cfg.hostname,
141+
oauth_persistence=cfg.oauth_persistence,
142+
redirect_port_range=cfg.oauth_redirect_port_range,
143+
client_id=cfg.oauth_client_id,
144+
scopes=cfg.oauth_scopes,
107145
)
108146
else:
109147
raise RuntimeError("No valid authentication settings!")
@@ -126,7 +164,7 @@ def get_auth_provider(cfg: ClientContext):
126164
def normalize_host_name(hostname: str):
127165
maybe_scheme = "https://" if not hostname.startswith("https://") else ""
128166
maybe_trailing_slash = "/" if not hostname.endswith("/") else ""
129-
return f"{maybe_scheme}{hostname}{maybe_trailing_slash}"
167+
return "{}{}{}".format(maybe_scheme, hostname, maybe_trailing_slash)
130168

131169

132170
def get_client_id_and_redirect_port(use_azure_auth: bool):
@@ -144,14 +182,25 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
144182
This function is the main entry point for authentication in the SQL connector.
145183
It processes the parameters and creates an appropriate auth provider.
146184
147-
TODO: Future refactoring needed:
148-
1. Add a use_token_federation flag that can be combined with any auth type
149-
2. Remove TOKEN_FEDERATION as an auth_type while maintaining backward compatibility
150-
3. Create a token federation wrapper that can wrap any existing auth provider
185+
Supported Authentication Methods:
186+
--------------------------------
187+
1. Access Token: Provide 'access_token' parameter
188+
2. Interactive OAuth: Set 'auth_type' to 'databricks-oauth' or 'azure-oauth'
189+
3. Client Credentials: Set 'auth_type' to 'client-credentials' with client_id, client_secret, tenant_id
190+
4. External Provider: Provide 'credentials_provider' parameter
191+
5. Token Federation: Set 'use_token_federation=True' with any of the above
151192
152193
Args:
153194
hostname: The Databricks server hostname
154-
**kwargs: Additional configuration parameters
195+
**kwargs: Additional configuration parameters including:
196+
- auth_type: Authentication type
197+
- access_token: Static access token
198+
- oauth_client_id: OAuth client ID
199+
- oauth_client_secret: OAuth client secret
200+
- tenant_id: Azure AD tenant ID (for Azure flows)
201+
- credentials_provider: External credentials provider
202+
- use_token_federation: Enable token federation
203+
- identity_federation_client_id: Federation client ID
155204
156205
Returns:
157206
An appropriate AuthProvider instance
@@ -182,6 +231,8 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
182231
else redirect_port_range,
183232
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
184233
credentials_provider=kwargs.get("credentials_provider"),
234+
oauth_client_secret=kwargs.get("oauth_client_secret"),
235+
tenant_id=kwargs.get("tenant_id"),
185236
identity_federation_client_id=kwargs.get("identity_federation_client_id"),
186237
use_token_federation=kwargs.get("use_token_federation", False),
187238
)

0 commit comments

Comments
 (0)