Skip to content

Commit ae28649

Browse files
committed
remove idp detection
1 parent aeeca66 commit ae28649

File tree

2 files changed

+108
-98
lines changed

2 files changed

+108
-98
lines changed

src/databricks/sql/auth/token_federation.py

+31-45
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,15 @@ def __init__(
3939
self.access_token = access_token
4040
self.token_type = token_type
4141
self.refresh_token = refresh_token
42-
self.expiry = expiry or datetime.now(tz=timezone.utc)
42+
43+
# Ensure expiry is timezone-aware
44+
if expiry is None:
45+
self.expiry = datetime.now(tz=timezone.utc)
46+
elif expiry.tzinfo is None:
47+
# Convert naive datetime to aware datetime
48+
self.expiry = expiry.replace(tzinfo=timezone.utc)
49+
else:
50+
self.expiry = expiry
4351

4452
def is_expired(self) -> bool:
4553
"""Check if the token is expired."""
@@ -129,7 +137,9 @@ def get_headers() -> Dict[str, str]:
129137
and self.last_exchanged_token.needs_refresh()
130138
):
131139
# The token is approaching expiry, try to refresh
132-
logger.debug("Exchanged token approaching expiry, refreshing...")
140+
logger.info(
141+
"Exchanged token approaching expiry, refreshing with fresh external token..."
142+
)
133143
return self._refresh_token(access_token, token_type)
134144

135145
# Parse the JWT to get claims
@@ -138,14 +148,16 @@ def get_headers() -> Dict[str, str]:
138148
# Check if token needs to be exchanged
139149
if self._is_same_host(token_claims.get("iss", ""), self.hostname):
140150
# Token is from the same host, no need to exchange
151+
logger.debug("Token from same host, no exchange needed")
141152
return self.external_provider_headers
142153
else:
143154
# Token is from a different host, need to exchange
155+
logger.debug("Token from different host, attempting exchange")
144156
return self._try_token_exchange_or_fallback(
145157
access_token, token_type
146158
)
147159
except Exception as e:
148-
logger.error(f"Failed to process token: {str(e)}")
160+
logger.error(f"Error processing token: {str(e)}")
149161
# Fall back to original headers in case of error
150162
return self.external_provider_headers
151163

@@ -238,25 +250,6 @@ def _parse_jwt_claims(self, token: str) -> Dict[str, Any]:
238250
logger.error(f"Failed to parse JWT: {str(e)}")
239251
raise
240252

241-
def _detect_idp_from_claims(self, token_claims: Dict[str, Any]) -> str:
242-
"""
243-
Detect the identity provider type from token claims.
244-
245-
This can be used to adjust token exchange parameters based on the IdP.
246-
"""
247-
issuer = token_claims.get("iss", "")
248-
249-
if "login.microsoftonline.com" in issuer or "sts.windows.net" in issuer:
250-
return "azure"
251-
elif "token.actions.githubusercontent.com" in issuer:
252-
return "github"
253-
elif "accounts.google.com" in issuer:
254-
return "google"
255-
elif "cognito-idp" in issuer and "amazonaws.com" in issuer:
256-
return "aws"
257-
else:
258-
return "unknown"
259-
260253
def _is_same_host(self, url1: str, url2: str) -> bool:
261254
"""Check if two URLs have the same host."""
262255
try:
@@ -283,7 +276,9 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
283276
The headers with the fresh token
284277
"""
285278
try:
286-
logger.info("Refreshing expired token by getting a new external token")
279+
logger.info(
280+
"Refreshing token using proactive approach (getting fresh external token first)"
281+
)
287282

288283
# Get a fresh token from the underlying credentials provider
289284
# instead of reusing the same access_token
@@ -303,14 +298,14 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
303298
fresh_token_type = parts[0]
304299
fresh_access_token = parts[1]
305300

306-
logger.debug("Got fresh external token")
307-
308-
# Now process the fresh token
309-
token_claims = self._parse_jwt_claims(fresh_access_token)
310-
idp_type = self._detect_idp_from_claims(token_claims)
301+
# Check if we got the same token back
302+
if fresh_access_token == access_token:
303+
logger.warning(
304+
"Credentials provider returned the same token during refresh"
305+
)
311306

312307
# Perform a new token exchange with the fresh token
313-
refreshed_token = self._exchange_token(fresh_access_token, idp_type)
308+
refreshed_token = self._exchange_token(fresh_access_token)
314309

315310
# Update the stored token
316311
self.last_exchanged_token = refreshed_token
@@ -321,6 +316,10 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
321316
headers[
322317
"Authorization"
323318
] = f"{refreshed_token.token_type} {refreshed_token.access_token}"
319+
320+
logger.info(
321+
f"Successfully refreshed token, new expiry: {refreshed_token.expiry}"
322+
)
324323
return headers
325324
except Exception as e:
326325
logger.error(
@@ -334,12 +333,8 @@ def _try_token_exchange_or_fallback(
334333
) -> Dict[str, str]:
335334
"""Try to exchange the token or fall back to the original token."""
336335
try:
337-
# Parse the token to get claims for IdP-specific adjustments
338-
token_claims = self._parse_jwt_claims(access_token)
339-
idp_type = self._detect_idp_from_claims(token_claims)
340-
341336
# Exchange the token
342-
exchanged_token = self._exchange_token(access_token, idp_type)
337+
exchanged_token = self._exchange_token(access_token)
343338

344339
# Store the exchanged token for potential refresh later
345340
self.last_exchanged_token = exchanged_token
@@ -358,13 +353,12 @@ def _try_token_exchange_or_fallback(
358353
# Fall back to original headers
359354
return self.external_provider_headers
360355

361-
def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token:
356+
def _exchange_token(self, access_token: str) -> Token:
362357
"""
363358
Exchange an external token for a Databricks token.
364359
365360
Args:
366361
access_token: The external token to exchange
367-
idp_type: The detected identity provider type (azure, github, etc.)
368362
369363
Returns:
370364
A Token object containing the exchanged token
@@ -384,14 +378,6 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token
384378
if self.identity_federation_client_id:
385379
params["client_id"] = self.identity_federation_client_id
386380

387-
# Make IdP-specific adjustments
388-
if idp_type == "azure":
389-
# For Azure AD, add special handling if needed
390-
pass
391-
elif idp_type == "github":
392-
# For GitHub Actions, add special handling if needed
393-
pass
394-
395381
# Set up headers
396382
headers = {"Accept": "*/*", "Content-Type": "application/x-www-form-urlencoded"}
397383

@@ -441,7 +427,7 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token
441427
return token
442428
except RequestException as e:
443429
logger.error(f"Failed to perform token exchange: {str(e)}")
444-
raise
430+
raise ValueError(f"Request error during token exchange: {str(e)}")
445431

446432

447433
class SimpleCredentialsProvider(CredentialsProvider):

0 commit comments

Comments
 (0)