Skip to content

Commit 49eab2a

Browse files
committed
fmt
1 parent 541e82f commit 49eab2a

File tree

3 files changed

+100
-53
lines changed

3 files changed

+100
-53
lines changed

src/databricks/sql/auth/token_federation.py

+40-28
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,7 @@ def get_headers() -> Dict[str, str]:
150150
else:
151151
# Token is from a different host, need to exchange
152152
logger.debug("Token from different host, attempting exchange")
153-
return self._try_token_exchange_or_fallback(
154-
access_token, token_type
155-
)
153+
return self._try_token_exchange_or_fallback(access_token, token_type)
156154
except Exception as e:
157155
logger.error(f"Error processing token: {str(e)}")
158156
# Fall back to original headers in case of error
@@ -172,9 +170,7 @@ def _init_oidc_discovery(self):
172170

173171
if idp_endpoints:
174172
# Get the OpenID configuration URL
175-
openid_config_url = idp_endpoints.get_openid_config_url(
176-
self.hostname
177-
)
173+
openid_config_url = idp_endpoints.get_openid_config_url(self.hostname)
178174

179175
# Fetch the OpenID configuration
180176
response = requests.get(openid_config_url)
@@ -185,7 +181,8 @@ def _init_oidc_discovery(self):
185181
logger.info(f"Discovered token endpoint: {self.token_endpoint}")
186182
else:
187183
logger.warning(
188-
f"Failed to fetch OpenID configuration from {openid_config_url}: {response.status_code}"
184+
f"Failed to fetch OpenID configuration from {openid_config_url}: "
185+
f"{response.status_code}"
189186
)
190187
except Exception as e:
191188
logger.warning(
@@ -282,9 +279,15 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
282279
self.last_external_token = access_token
283280

284281
# Update the headers with the new token
285-
return {"Authorization": f"{exchanged_token.token_type} {exchanged_token.access_token}"}
282+
return {
283+
"Authorization": (
284+
f"{exchanged_token.token_type} {exchanged_token.access_token}"
285+
)
286+
}
286287
except Exception as e:
287-
logger.error(f"Token refresh failed: {str(e)}, falling back to original token")
288+
logger.error(
289+
f"Token refresh failed: {str(e)}, falling back to original token"
290+
)
288291
return self.external_provider_headers
289292

290293
def _try_token_exchange_or_fallback(
@@ -305,12 +308,20 @@ def _try_token_exchange_or_fallback(
305308
self.last_exchanged_token = exchanged_token
306309
self.last_external_token = access_token
307310

308-
return {"Authorization": f"{exchanged_token.token_type} {exchanged_token.access_token}"}
311+
return {
312+
"Authorization": (
313+
f"{exchanged_token.token_type} {exchanged_token.access_token}"
314+
)
315+
}
309316
except Exception as e:
310-
logger.warning(f"Token exchange failed: {str(e)}, falling back to original token")
317+
logger.warning(
318+
f"Token exchange failed: {str(e)}, falling back to original token"
319+
)
311320
return self.external_provider_headers
312321

313-
def _send_token_exchange_request(self, token_exchange_data: Dict[str, str]) -> Dict[str, Any]:
322+
def _send_token_exchange_request(
323+
self, token_exchange_data: Dict[str, str]
324+
) -> Dict[str, Any]:
314325
"""
315326
Send the token exchange request to the token endpoint.
316327
@@ -325,20 +336,19 @@ def _send_token_exchange_request(self, token_exchange_data: Dict[str, str]) -> D
325336
"""
326337
if not self.token_endpoint:
327338
raise ValueError("Token endpoint not initialized")
328-
339+
329340
headers = {"Accept": "*/*", "Content-Type": "application/x-www-form-urlencoded"}
330-
341+
331342
response = requests.post(
332-
self.token_endpoint,
333-
data=token_exchange_data,
334-
headers=headers
343+
self.token_endpoint, data=token_exchange_data, headers=headers
335344
)
336-
345+
337346
if response.status_code != 200:
338347
raise ValueError(
339-
f"Token exchange failed with status code {response.status_code}: {response.text}"
348+
f"Token exchange failed with status code {response.status_code}: "
349+
f"{response.text}"
340350
)
341-
351+
342352
return response.json()
343353

344354
def _exchange_token(self, access_token: str) -> Token:
@@ -365,26 +375,28 @@ def _exchange_token(self, access_token: str) -> Token:
365375
try:
366376
# Send the token exchange request
367377
resp_data = self._send_token_exchange_request(token_exchange_data)
368-
378+
369379
# Extract token information
370380
new_access_token = resp_data.get("access_token")
371381
if not new_access_token:
372382
raise ValueError("No access token in exchange response")
373-
383+
374384
token_type = resp_data.get("token_type", "Bearer")
375385
refresh_token = resp_data.get("refresh_token", "")
376-
386+
377387
# Parse expiry time from token claims if possible
378388
expiry = datetime.now(tz=timezone.utc)
379-
389+
380390
# First try to get expiry from the response's expires_in field
381391
if "expires_in" in resp_data and resp_data["expires_in"]:
382392
try:
383393
expires_in = int(resp_data["expires_in"])
384-
expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in)
394+
expiry = datetime.now(tz=timezone.utc) + timedelta(
395+
seconds=expires_in
396+
)
385397
except (ValueError, TypeError) as e:
386398
logger.warning(f"Invalid expires_in value: {str(e)}")
387-
399+
388400
# If that didn't work, try to parse JWT claims for expiry
389401
if expiry == datetime.now(tz=timezone.utc):
390402
token_claims = self._parse_jwt_claims(new_access_token)
@@ -394,9 +406,9 @@ def _exchange_token(self, access_token: str) -> Token:
394406
expiry = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
395407
except (ValueError, TypeError) as e:
396408
logger.warning(f"Invalid exp claim in token: {str(e)}")
397-
409+
398410
return Token(new_access_token, token_type, refresh_token, expiry)
399-
411+
400412
except Exception as e:
401413
logger.error(f"Token exchange failed: {str(e)}")
402414
raise

tests/token_federation/github_oidc_test.py

+51-21
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,17 @@
1414
import base64
1515
import logging
1616
from databricks import sql
17-
import jwt
1817

18+
try:
19+
import jwt
20+
21+
HAS_JWT_LIBRARY = True
22+
except ImportError:
23+
HAS_JWT_LIBRARY = False
1924

2025

2126
logging.basicConfig(
22-
level=logging.INFO,
23-
format="%(asctime)s - %(levelname)s - %(message)s"
27+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
2428
)
2529
logger = logging.getLogger(__name__)
2630

@@ -35,10 +39,29 @@ def decode_jwt(token):
3539
Returns:
3640
dict: The decoded token claims or None if decoding fails
3741
"""
42+
if HAS_JWT_LIBRARY:
43+
try:
44+
# Using PyJWT library (preferred method)
45+
# Note: we're not verifying the signature as this is just for debugging
46+
return jwt.decode(token, options={"verify_signature": False})
47+
except Exception as e:
48+
logger.error(f"Failed to decode token with PyJWT: {str(e)}")
49+
50+
# Fallback to manual decoding
3851
try:
39-
return jwt.decode(token, options={"verify_signature": False})
52+
parts = token.split(".")
53+
if len(parts) != 3:
54+
raise ValueError("Invalid JWT format")
55+
56+
payload = parts[1]
57+
# Add padding if needed
58+
padding = "=" * (4 - len(payload) % 4)
59+
payload += padding
60+
61+
decoded = base64.b64decode(payload)
62+
return json.loads(decoded)
4063
except Exception as e:
41-
logger.error(f"Failed to decode token with PyJWT: {str(e)}")
64+
logger.error(f"Failed to decode token: {str(e)}")
4265
return {}
4366

4467

@@ -53,7 +76,7 @@ def get_environment_variables():
5376
host = os.environ.get("DATABRICKS_HOST_FOR_TF")
5477
http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF")
5578
identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID")
56-
79+
5780
return github_token, host, http_path, identity_federation_client_id
5881

5982

@@ -62,7 +85,7 @@ def display_token_info(claims):
6285
if not claims:
6386
logger.warning("No token claims available to display")
6487
return
65-
88+
6689
logger.info("=== GitHub OIDC Token Claims ===")
6790
logger.info(f"Token issuer: {claims.get('iss')}")
6891
logger.info(f"Token subject: {claims.get('sub')}")
@@ -74,7 +97,9 @@ def display_token_info(claims):
7497
logger.info("===============================")
7598

7699

77-
def test_databricks_connection(host, http_path, github_token, identity_federation_client_id):
100+
def test_databricks_connection(
101+
host, http_path, github_token, identity_federation_client_id
102+
):
78103
"""
79104
Test connection to Databricks using token federation.
80105
@@ -90,30 +115,30 @@ def test_databricks_connection(host, http_path, github_token, identity_federatio
90115
logger.info("=== Testing Connection via Connector ===")
91116
logger.info(f"Connecting to Databricks at {host}{http_path}")
92117
logger.info(f"Using client ID: {identity_federation_client_id}")
93-
118+
94119
connection_params = {
95120
"server_hostname": host,
96121
"http_path": http_path,
97122
"access_token": github_token,
98123
"auth_type": "token-federation",
99124
"identity_federation_client_id": identity_federation_client_id,
100125
}
101-
126+
102127
try:
103128
with sql.connect(**connection_params) as connection:
104129
logger.info("Connection established successfully")
105-
130+
106131
# Execute a simple query
107132
cursor = connection.cursor()
108133
cursor.execute("SELECT 1 + 1 as result")
109134
result = cursor.fetchall()
110135
logger.info(f"Query result: {result[0][0]}")
111-
136+
112137
# Show current user
113138
cursor.execute("SELECT current_user() as user")
114139
result = cursor.fetchall()
115140
logger.info(f"Connected as user: {result[0][0]}")
116-
141+
117142
logger.info("Token federation test successful!")
118143
return True
119144
except Exception as e:
@@ -125,29 +150,34 @@ def main():
125150
"""Main entry point for the test script."""
126151
try:
127152
# Get environment variables
128-
github_token, host, http_path, identity_federation_client_id = get_environment_variables()
129-
153+
github_token, host, http_path, identity_federation_client_id = (
154+
get_environment_variables()
155+
)
156+
130157
if not github_token:
131158
logger.error("Missing GitHub OIDC token (OIDC_TOKEN)")
132159
sys.exit(1)
133-
160+
134161
if not host or not http_path:
135-
logger.error("Missing Databricks connection parameters (DATABRICKS_HOST_FOR_TF, DATABRICKS_HTTP_PATH_FOR_TF)")
162+
logger.error(
163+
"Missing Databricks connection parameters "
164+
"(DATABRICKS_HOST_FOR_TF, DATABRICKS_HTTP_PATH_FOR_TF)"
165+
)
136166
sys.exit(1)
137-
167+
138168
# Display token claims
139169
claims = decode_jwt(github_token)
140170
display_token_info(claims)
141-
171+
142172
# Test Databricks connection
143173
success = test_databricks_connection(
144174
host, http_path, github_token, identity_federation_client_id
145175
)
146-
176+
147177
if not success:
148178
logger.error("Token federation test failed")
149179
sys.exit(1)
150-
180+
151181
except Exception as e:
152182
logger.error(f"Unexpected error: {str(e)}")
153183
sys.exit(1)

tests/unit/test_token_federation.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
Token,
1414
DatabricksTokenFederationProvider,
1515
SimpleCredentialsProvider,
16-
create_token_federation_provider,
1716
TOKEN_REFRESH_BUFFER_SECONDS,
1817
)
1918

@@ -136,19 +135,25 @@ def test_init_oidc_discovery(mock_request_get, mock_get_oauth_endpoints):
136135

137136
@pytest.fixture
138137
def mock_parse_jwt_claims():
139-
with patch("databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims") as mock:
138+
with patch(
139+
"databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims"
140+
) as mock:
140141
yield mock
141142

142143

143144
@pytest.fixture
144145
def mock_exchange_token():
145-
with patch("databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token") as mock:
146+
with patch(
147+
"databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token"
148+
) as mock:
146149
yield mock
147150

148151

149152
@pytest.fixture
150153
def mock_is_same_host():
151-
with patch("databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host") as mock:
154+
with patch(
155+
"databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host"
156+
) as mock:
152157
yield mock
153158

154159

0 commit comments

Comments
 (0)