Skip to content

Commit 19dc0b1

Browse files
committed
formatted
1 parent bafef75 commit 19dc0b1

File tree

4 files changed

+87
-59
lines changed

4 files changed

+87
-59
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class AuthType(Enum):
1515
AZURE_OAUTH = "azure-oauth"
1616
CLIENT_CREDENTIALS = "client-credentials"
1717

18+
1819
class ClientContext:
1920
def __init__(
2021
self,
@@ -48,31 +49,42 @@ def __init__(
4849
self.oauth_client_secret = oauth_client_secret
4950
self.tenant_id = tenant_id
5051

51-
def _create_azure_client_credentials_provider(cfg: ClientContext) -> ClientCredentialsProvider:
52+
53+
def _create_azure_client_credentials_provider(
54+
cfg: ClientContext,
55+
) -> ClientCredentialsProvider:
5256
"""Create an Azure client credentials provider."""
5357
if not cfg.oauth_client_id or not cfg.oauth_client_secret or not cfg.tenant_id:
54-
raise ValueError("Azure client credentials flow requires oauth_client_id, oauth_client_secret, and tenant_id")
55-
56-
token_endpoint = "https://login.microsoftonline.com/{}/oauth2/v2.0/token".format(cfg.tenant_id)
58+
raise ValueError(
59+
"Azure client credentials flow requires oauth_client_id, oauth_client_secret, and tenant_id"
60+
)
61+
62+
token_endpoint = "https://login.microsoftonline.com/{}/oauth2/v2.0/token".format(
63+
cfg.tenant_id
64+
)
5765
return ClientCredentialsProvider(
5866
client_id=cfg.oauth_client_id,
5967
client_secret=cfg.oauth_client_secret,
6068
token_endpoint=token_endpoint,
61-
auth_type_value="azure-client-credentials"
69+
auth_type_value="azure-client-credentials",
6270
)
6371

6472

65-
def _create_databricks_client_credentials_provider(cfg: ClientContext) -> ClientCredentialsProvider:
73+
def _create_databricks_client_credentials_provider(
74+
cfg: ClientContext,
75+
) -> ClientCredentialsProvider:
6676
"""Create a Databricks client credentials provider for service principals."""
6777
if not cfg.oauth_client_id or not cfg.oauth_client_secret:
68-
raise ValueError("Databricks client credentials flow requires oauth_client_id and oauth_client_secret")
69-
78+
raise ValueError(
79+
"Databricks client credentials flow requires oauth_client_id and oauth_client_secret"
80+
)
81+
7082
token_endpoint = "{}oidc/v1/token".format(cfg.hostname)
7183
return ClientCredentialsProvider(
7284
client_id=cfg.oauth_client_id,
7385
client_secret=cfg.oauth_client_secret,
7486
token_endpoint=token_endpoint,
75-
auth_type_value="client-credentials"
87+
auth_type_value="client-credentials",
7688
)
7789

7890

@@ -102,9 +114,9 @@ def get_auth_provider(cfg: ClientContext):
102114
RuntimeError: If no valid authentication settings are provided
103115
"""
104116
from databricks.sql.auth.token_federation import DatabricksTokenFederationProvider
105-
117+
106118
base_provider = None
107-
119+
108120
if cfg.credentials_provider:
109121
base_provider = ExternalAuthProvider(cfg.credentials_provider)
110122
elif cfg.access_token is not None:

src/databricks/sql/auth/authenticators.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
# Private API: this is an evolving interface and it will change in the future.
1212
# Please must not depend on it in your applications.
1313
from databricks.sql.experimental.oauth_persistence import OAuthToken, OAuthPersistence
14-
from databricks.sql.auth.endpoint import AzureOAuthEndpointCollection, InHouseOAuthEndpointCollection
14+
from databricks.sql.auth.endpoint import (
15+
AzureOAuthEndpointCollection,
16+
InHouseOAuthEndpointCollection,
17+
)
18+
1519

1620
class AuthProvider:
1721
def add_headers(self, request_headers: Dict[str, str]):
@@ -56,8 +60,10 @@ def auth_type(self) -> str:
5660
def __call__(self, *args, **kwargs) -> HeaderFactory:
5761
def get_headers():
5862
return {"Authorization": self.__authorization_header_value}
63+
5964
return get_headers
6065

66+
6167
# Private API: this is an evolving interface and it will change in the future.
6268
# Please must not depend on it in your applications.
6369
class DatabricksOAuthProvider(AuthProvider, CredentialsProvider):
@@ -81,11 +87,8 @@ def __init__(
8187

8288
idp_endpoint = get_oauth_endpoints(hostname, auth_type == "azure-oauth")
8389
if not idp_endpoint:
84-
raise NotImplementedError(
85-
f"OAuth is not supported for host ${hostname}"
86-
)
90+
raise NotImplementedError(f"OAuth is not supported for host ${hostname}")
8791

88-
8992
cloud_scopes = idp_endpoint.get_scopes_mapping(scopes)
9093
self._scopes_as_str = self.SCOPE_DELIM.join(cloud_scopes)
9194

@@ -107,6 +110,7 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
107110
def get_headers():
108111
self._update_token_if_expired()
109112
return {"Authorization": "Bearer {}".format(self._access_token)}
113+
110114
return get_headers
111115

112116
def _initial_get_token(self):
@@ -170,25 +174,24 @@ def __init__(
170174
client_id: str,
171175
client_secret: str,
172176
token_endpoint: str,
173-
auth_type_value: str = "client-credentials"
177+
auth_type_value: str = "client-credentials",
174178
):
175179
"""
176180
Initialize a ClientCredentialsProvider.
177-
181+
178182
Args:
179183
client_id: OAuth client ID
180-
client_secret: OAuth client secret
184+
client_secret: OAuth client secret
181185
token_endpoint: OAuth token endpoint URL
182186
auth_type_value: Auth type identifier
183187
"""
184188
self.client_id = client_id
185189
self.client_secret = client_secret
186190
self.token_endpoint = token_endpoint
187191
self.auth_type_value = auth_type_value
188-
192+
189193
self._cached_token = None
190194
self._token_expires_at = None
191-
192195

193196
def auth_type(self) -> str:
194197
return self.auth_type_value
@@ -197,50 +200,54 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
197200
def get_headers() -> Dict[str, str]:
198201
token = self._get_access_token()
199202
return {"Authorization": "Bearer {}".format(token)}
203+
200204
return get_headers
201-
205+
202206
def add_headers(self, request_headers: Dict[str, str]):
203207
token = self._get_access_token()
204208
request_headers["Authorization"] = "Bearer {}".format(token)
205209

206210
def _get_access_token(self) -> str:
207211
"""Get a valid access token using client credentials flow, with caching."""
208212
# Check if we have a valid cached token (with 40 second buffer since azure doesn't respect a token with less than 30s expiry)
209-
if (self._cached_token and self._token_expires_at and
210-
time.time() < self._token_expires_at - 40):
213+
if (
214+
self._cached_token
215+
and self._token_expires_at
216+
and time.time() < self._token_expires_at - 40
217+
):
211218
return self._cached_token
212-
219+
213220
# Get new token using client credentials flow
214221
token_data = self._request_token()
215-
216-
self._cached_token = token_data['access_token']
222+
223+
self._cached_token = token_data["access_token"]
217224
# expires_in is in seconds, convert to absolute time
218-
self._token_expires_at = time.time() + token_data.get('expires_in', 3600)
219-
225+
self._token_expires_at = time.time() + token_data.get("expires_in", 3600)
226+
220227
return self._cached_token
221228

222229
def _request_token(self) -> dict:
223230
"""Request a new token using OAuth client credentials flow."""
224231
data = {
225-
'grant_type': 'client_credentials',
226-
'client_id': self.client_id,
227-
'client_secret': self.client_secret,
228-
'scope': self.AZURE_DATABRICKS_SCOPE,
232+
"grant_type": "client_credentials",
233+
"client_id": self.client_id,
234+
"client_secret": self.client_secret,
235+
"scope": self.AZURE_DATABRICKS_SCOPE,
229236
}
230-
231-
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
232-
237+
238+
headers = {"Content-Type": "application/x-www-form-urlencoded"}
239+
233240
try:
234241
response = requests.post(self.token_endpoint, data=data, headers=headers)
235242
response.raise_for_status()
236-
243+
237244
token_data = response.json()
238-
239-
if 'access_token' not in token_data:
245+
246+
if "access_token" not in token_data:
240247
raise ValueError("No access_token in response: {}".format(token_data))
241-
248+
242249
return token_data
243-
250+
244251
except requests.exceptions.RequestException as e:
245252
raise RuntimeError("Token request failed: {}".format(e)) from e
246253
except ValueError as e:

src/databricks/sql/auth/token_federation.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
class DatabricksTokenFederationProvider(CredentialsProvider):
2020
"""
2121
Token federation provider that exchanges external tokens for Databricks tokens.
22-
22+
2323
This implementation follows the JDBC pattern:
2424
1. Try token exchange without HTTP Basic authentication (per RFC 8693)
2525
2. Fall back to using external token directly if exchange fails
@@ -118,7 +118,9 @@ def _parse_jwt_claims(self, token: str) -> Dict[str, Any]:
118118
Dict[str, Any]: Parsed JWT claims
119119
"""
120120
try:
121-
return jwt.decode(token, options={"verify_signature": False, "verify_aud": False})
121+
return jwt.decode(
122+
token, options={"verify_signature": False, "verify_aud": False}
123+
)
122124
except Exception as e:
123125
logger.debug("Failed to parse JWT: %s", str(e))
124126
return {}
@@ -143,7 +145,6 @@ def _get_expiry_from_jwt(self, token: str) -> Optional[datetime]:
143145
except (ValueError, TypeError) as e:
144146
logger.warning("Invalid JWT expiry value: %s", e)
145147

146-
147148
def refresh_token(self) -> Token:
148149
"""
149150
Refresh the token and return the new Token object.
@@ -184,7 +185,9 @@ def refresh_token(self) -> Token:
184185
self.current_token = new_token
185186
return new_token
186187
except Exception as e:
187-
logger.debug("Token exchange failed: %s. Using external token as fallback.", e)
188+
logger.debug(
189+
"Token exchange failed: %s. Using external token as fallback.", e
190+
)
188191
expiry = self._get_expiry_from_jwt(access_token)
189192
fallback_token = Token(access_token, token_type, "", expiry)
190193
self.current_token = fallback_token
@@ -223,7 +226,9 @@ def get_auth_headers(self) -> Dict[str, str]:
223226
# Always get the latest headers from the credentials provider
224227
header_factory = self.credentials_provider()
225228
headers = dict(header_factory()) if header_factory else {}
226-
headers["Authorization"] = "{} {}".format(token.token_type, token.access_token)
229+
headers["Authorization"] = "{} {}".format(
230+
token.token_type, token.access_token
231+
)
227232
return headers
228233
except Exception as e:
229234
return dict(self.external_headers) if self.external_headers else {}
@@ -233,7 +238,7 @@ def _send_token_exchange_request(
233238
) -> Dict[str, Any]:
234239
"""
235240
Send the token exchange request to the token endpoint.
236-
241+
237242
For M2M flows, this should include HTTP Basic authentication using client credentials.
238243
For U2M flows, token exchange is validated purely based on the JWT token and federation policies.
239244
@@ -250,23 +255,29 @@ def _send_token_exchange_request(
250255
raise ValueError("Token endpoint not initialized")
251256

252257
auth = None
253-
if hasattr(self.credentials_provider, 'client_id') and hasattr(self.credentials_provider, 'client_secret'):
258+
if hasattr(self.credentials_provider, "client_id") and hasattr(
259+
self.credentials_provider, "client_secret"
260+
):
254261
client_id = self.credentials_provider.client_id
255262
client_secret = self.credentials_provider.client_secret
256263
auth = (client_id, client_secret)
257264
else:
258-
logger.debug("No client credentials available, sending request without authentication")
259-
265+
logger.debug(
266+
"No client credentials available, sending request without authentication"
267+
)
268+
260269
response = requests.post(
261-
self.token_endpoint,
262-
data=token_exchange_data,
270+
self.token_endpoint,
271+
data=token_exchange_data,
263272
headers=self.EXCHANGE_HEADERS,
264-
auth=auth
273+
auth=auth,
265274
)
266275

267276
if response.status_code != 200:
268277
raise requests.HTTPError(
269-
"Token exchange failed with status code {}: {}".format(response.status_code, response.text),
278+
"Token exchange failed with status code {}: {}".format(
279+
response.status_code, response.text
280+
),
270281
response=response,
271282
)
272283

@@ -289,7 +300,7 @@ def _exchange_token(self, access_token: str) -> Token:
289300
self.token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint(
290301
self.hostname
291302
)
292-
303+
293304
# Prepare the request data according to RFC 8693
294305
token_exchange_data = dict(self.TOKEN_EXCHANGE_PARAMS)
295306
token_exchange_data["subject_token"] = access_token
@@ -319,4 +330,4 @@ def add_headers(self, request_headers: Dict[str, str]):
319330
"""
320331
headers = self.get_auth_headers()
321332
for k, v in headers.items():
322-
request_headers[k] = v
333+
request_headers[k] = v

tests/unit/test_token_federation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
import jwt
55

66
from databricks.sql.auth.token import Token
7-
from databricks.sql.auth.token_federation import (
8-
DatabricksTokenFederationProvider
9-
)
7+
from databricks.sql.auth.token_federation import DatabricksTokenFederationProvider
108
from databricks.sql.auth.oidc_utils import OIDCDiscoveryUtil
119

1210

0 commit comments

Comments
 (0)