@@ -39,7 +39,15 @@ def __init__(
39
39
self .access_token = access_token
40
40
self .token_type = token_type
41
41
self .refresh_token = refresh_token
42
- self .expiry = expiry or datetime .now (tz = timezone .utc )
42
+
43
+ # Ensure expiry is timezone-aware
44
+ if expiry is None :
45
+ self .expiry = datetime .now (tz = timezone .utc )
46
+ elif expiry .tzinfo is None :
47
+ # Convert naive datetime to aware datetime
48
+ self .expiry = expiry .replace (tzinfo = timezone .utc )
49
+ else :
50
+ self .expiry = expiry
43
51
44
52
def is_expired (self ) -> bool :
45
53
"""Check if the token is expired."""
@@ -129,7 +137,9 @@ def get_headers() -> Dict[str, str]:
129
137
and self .last_exchanged_token .needs_refresh ()
130
138
):
131
139
# The token is approaching expiry, try to refresh
132
- logger .debug ("Exchanged token approaching expiry, refreshing..." )
140
+ logger .info (
141
+ "Exchanged token approaching expiry, refreshing with fresh external token..."
142
+ )
133
143
return self ._refresh_token (access_token , token_type )
134
144
135
145
# Parse the JWT to get claims
@@ -138,14 +148,16 @@ def get_headers() -> Dict[str, str]:
138
148
# Check if token needs to be exchanged
139
149
if self ._is_same_host (token_claims .get ("iss" , "" ), self .hostname ):
140
150
# Token is from the same host, no need to exchange
151
+ logger .debug ("Token from same host, no exchange needed" )
141
152
return self .external_provider_headers
142
153
else :
143
154
# Token is from a different host, need to exchange
155
+ logger .debug ("Token from different host, attempting exchange" )
144
156
return self ._try_token_exchange_or_fallback (
145
157
access_token , token_type
146
158
)
147
159
except Exception as e :
148
- logger .error (f"Failed to process token: { str (e )} " )
160
+ logger .error (f"Error processing token: { str (e )} " )
149
161
# Fall back to original headers in case of error
150
162
return self .external_provider_headers
151
163
@@ -238,25 +250,6 @@ def _parse_jwt_claims(self, token: str) -> Dict[str, Any]:
238
250
logger .error (f"Failed to parse JWT: { str (e )} " )
239
251
raise
240
252
241
- def _detect_idp_from_claims (self , token_claims : Dict [str , Any ]) -> str :
242
- """
243
- Detect the identity provider type from token claims.
244
-
245
- This can be used to adjust token exchange parameters based on the IdP.
246
- """
247
- issuer = token_claims .get ("iss" , "" )
248
-
249
- if "login.microsoftonline.com" in issuer or "sts.windows.net" in issuer :
250
- return "azure"
251
- elif "token.actions.githubusercontent.com" in issuer :
252
- return "github"
253
- elif "accounts.google.com" in issuer :
254
- return "google"
255
- elif "cognito-idp" in issuer and "amazonaws.com" in issuer :
256
- return "aws"
257
- else :
258
- return "unknown"
259
-
260
253
def _is_same_host (self , url1 : str , url2 : str ) -> bool :
261
254
"""Check if two URLs have the same host."""
262
255
try :
@@ -283,7 +276,9 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
283
276
The headers with the fresh token
284
277
"""
285
278
try :
286
- logger .info ("Refreshing expired token by getting a new external token" )
279
+ logger .info (
280
+ "Refreshing token using proactive approach (getting fresh external token first)"
281
+ )
287
282
288
283
# Get a fresh token from the underlying credentials provider
289
284
# instead of reusing the same access_token
@@ -303,14 +298,14 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
303
298
fresh_token_type = parts [0 ]
304
299
fresh_access_token = parts [1 ]
305
300
306
- logger . debug ( "Got fresh external token" )
307
-
308
- # Now process the fresh token
309
- token_claims = self . _parse_jwt_claims ( fresh_access_token )
310
- idp_type = self . _detect_idp_from_claims ( token_claims )
301
+ # Check if we got the same token back
302
+ if fresh_access_token == access_token :
303
+ logger . warning (
304
+ "Credentials provider returned the same token during refresh"
305
+ )
311
306
312
307
# Perform a new token exchange with the fresh token
313
- refreshed_token = self ._exchange_token (fresh_access_token , idp_type )
308
+ refreshed_token = self ._exchange_token (fresh_access_token )
314
309
315
310
# Update the stored token
316
311
self .last_exchanged_token = refreshed_token
@@ -321,6 +316,10 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
321
316
headers [
322
317
"Authorization"
323
318
] = f"{ refreshed_token .token_type } { refreshed_token .access_token } "
319
+
320
+ logger .info (
321
+ f"Successfully refreshed token, new expiry: { refreshed_token .expiry } "
322
+ )
324
323
return headers
325
324
except Exception as e :
326
325
logger .error (
@@ -334,12 +333,8 @@ def _try_token_exchange_or_fallback(
334
333
) -> Dict [str , str ]:
335
334
"""Try to exchange the token or fall back to the original token."""
336
335
try :
337
- # Parse the token to get claims for IdP-specific adjustments
338
- token_claims = self ._parse_jwt_claims (access_token )
339
- idp_type = self ._detect_idp_from_claims (token_claims )
340
-
341
336
# Exchange the token
342
- exchanged_token = self ._exchange_token (access_token , idp_type )
337
+ exchanged_token = self ._exchange_token (access_token )
343
338
344
339
# Store the exchanged token for potential refresh later
345
340
self .last_exchanged_token = exchanged_token
@@ -358,13 +353,12 @@ def _try_token_exchange_or_fallback(
358
353
# Fall back to original headers
359
354
return self .external_provider_headers
360
355
361
- def _exchange_token (self , access_token : str , idp_type : str = "unknown" ) -> Token :
356
+ def _exchange_token (self , access_token : str ) -> Token :
362
357
"""
363
358
Exchange an external token for a Databricks token.
364
359
365
360
Args:
366
361
access_token: The external token to exchange
367
- idp_type: The detected identity provider type (azure, github, etc.)
368
362
369
363
Returns:
370
364
A Token object containing the exchanged token
@@ -384,14 +378,6 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token
384
378
if self .identity_federation_client_id :
385
379
params ["client_id" ] = self .identity_federation_client_id
386
380
387
- # Make IdP-specific adjustments
388
- if idp_type == "azure" :
389
- # For Azure AD, add special handling if needed
390
- pass
391
- elif idp_type == "github" :
392
- # For GitHub Actions, add special handling if needed
393
- pass
394
-
395
381
# Set up headers
396
382
headers = {"Accept" : "*/*" , "Content-Type" : "application/x-www-form-urlencoded" }
397
383
@@ -441,7 +427,7 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token
441
427
return token
442
428
except RequestException as e :
443
429
logger .error (f"Failed to perform token exchange: { str (e )} " )
444
- raise
430
+ raise ValueError ( f"Request error during token exchange: { str ( e ) } " )
445
431
446
432
447
433
class SimpleCredentialsProvider (CredentialsProvider ):
0 commit comments