Skip to content

Commit 82e3f0e

Browse files
[proxy/authorize]: improve JWKS reliability (#9676)
While setting up some tests, I noticed that we didn't support keycloak. They make use of encryption JWKs as well as signature ones. Our current jwks crate does not support parsing encryption keys which caused the entire jwk set to fail to parse. Switching to lazy parsing fixes this. Also while setting up tests, I couldn't use localhost jwks server as we require HTTPS and we were using webpki so it was impossible to add a custom CA. Enabling native roots addresses this possibility. I saw some of our current e2e tests against our custom JWKS in s3 were taking a while to fetch. I've added a timeout + retries to address this.
1 parent 75aa19a commit 82e3f0e

File tree

7 files changed

+168
-25
lines changed

7 files changed

+168
-25
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

proxy/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ prometheus.workspace = true
6060
rand.workspace = true
6161
regex.workspace = true
6262
remote_storage = { version = "0.1", path = "../libs/remote_storage/" }
63-
reqwest.workspace = true
63+
reqwest = { workspace = true, features = ["rustls-tls-native-roots"] }
6464
reqwest-middleware = { workspace = true, features = ["json"] }
6565
reqwest-retry.workspace = true
6666
reqwest-tracing.workspace = true

proxy/src/auth/backend/jwt.rs

Lines changed: 149 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,19 @@ use arc_swap::ArcSwapOption;
77
use dashmap::DashMap;
88
use jose_jwk::crypto::KeyInfo;
99
use reqwest::{redirect, Client};
10+
use reqwest_retry::policies::ExponentialBackoff;
11+
use reqwest_retry::RetryTransientMiddleware;
1012
use serde::de::Visitor;
1113
use serde::{Deserialize, Deserializer};
14+
use serde_json::value::RawValue;
1215
use signature::Verifier;
1316
use thiserror::Error;
1417
use tokio::time::Instant;
1518

1619
use crate::auth::backend::ComputeCredentialKeys;
1720
use crate::context::RequestMonitoring;
1821
use crate::control_plane::errors::GetEndpointJwksError;
19-
use crate::http::parse_json_body_with_limit;
22+
use crate::http::read_body_with_limit;
2023
use crate::intern::RoleNameInt;
2124
use crate::types::{EndpointId, RoleName};
2225

@@ -28,6 +31,10 @@ const MAX_RENEW: Duration = Duration::from_secs(3600);
2831
const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
2932
const JWKS_USER_AGENT: &str = "neon-proxy";
3033

34+
const JWKS_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
35+
const JWKS_FETCH_TIMEOUT: Duration = Duration::from_secs(5);
36+
const JWKS_FETCH_RETRIES: u32 = 3;
37+
3138
/// How to get the JWT auth rules
3239
pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
3340
fn fetch_auth_rules(
@@ -55,7 +62,7 @@ pub(crate) struct AuthRule {
5562
}
5663

5764
pub struct JwkCache {
58-
client: reqwest::Client,
65+
client: reqwest_middleware::ClientWithMiddleware,
5966

6067
map: DashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
6168
}
@@ -117,6 +124,14 @@ impl Default for JwkCacheEntryLock {
117124
}
118125
}
119126

127+
#[derive(Deserialize)]
128+
struct JwkSet<'a> {
129+
/// we parse into raw-value because not all keys in a JWKS are ones
130+
/// we can parse directly, so we parse them lazily.
131+
#[serde(borrow)]
132+
keys: Vec<&'a RawValue>,
133+
}
134+
120135
impl JwkCacheEntryLock {
121136
async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
122137
JwkRenewalPermit::acquire_permit(self).await
@@ -130,7 +145,7 @@ impl JwkCacheEntryLock {
130145
&self,
131146
_permit: JwkRenewalPermit<'_>,
132147
ctx: &RequestMonitoring,
133-
client: &reqwest::Client,
148+
client: &reqwest_middleware::ClientWithMiddleware,
134149
endpoint: EndpointId,
135150
auth_rules: &F,
136151
) -> Result<Arc<JwkCacheEntry>, JwtError> {
@@ -154,22 +169,73 @@ impl JwkCacheEntryLock {
154169
let req = client.get(rule.jwks_url.clone());
155170
// TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
156171
// TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
157-
match req.send().await.and_then(|r| r.error_for_status()) {
172+
match req.send().await.and_then(|r| {
173+
r.error_for_status()
174+
.map_err(reqwest_middleware::Error::Reqwest)
175+
}) {
158176
// todo: should we re-insert JWKs if we want to keep this JWKs URL?
159177
// I expect these failures would be quite sparse.
160178
Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
161179
Ok(r) => {
162180
let resp: http::Response<reqwest::Body> = r.into();
163-
match parse_json_body_with_limit::<jose_jwk::JwkSet>(
164-
resp.into_body(),
165-
MAX_JWK_BODY_SIZE,
166-
)
167-
.await
181+
182+
let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE)
183+
.await
168184
{
185+
Ok(bytes) => bytes,
186+
Err(e) => {
187+
tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
188+
continue;
189+
}
190+
};
191+
192+
match serde_json::from_slice::<JwkSet>(&bytes) {
169193
Err(e) => {
170194
tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
171195
}
172196
Ok(jwks) => {
197+
// size_of::<&RawValue>() == 16
198+
// size_of::<jose_jwk::Jwk>() == 288
199+
// better to not pre-allocate this as it might be pretty large - especially if it has many
200+
// keys we don't want or need.
201+
// trivial 'attack': `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`
202+
// this would consume 8MiB just like that!
203+
let mut keys = vec![];
204+
let mut failed = 0;
205+
for key in jwks.keys {
206+
match serde_json::from_str::<jose_jwk::Jwk>(key.get()) {
207+
Ok(key) => {
208+
// if `use` (called `cls` in rust) is specified to be something other than signing,
209+
// we can skip storing it.
210+
if key
211+
.prm
212+
.cls
213+
.as_ref()
214+
.is_some_and(|c| *c != jose_jwk::Class::Signing)
215+
{
216+
continue;
217+
}
218+
219+
keys.push(key);
220+
}
221+
Err(e) => {
222+
tracing::debug!(url=?rule.jwks_url, failed=?e, "could not decode JWK");
223+
failed += 1;
224+
}
225+
}
226+
}
227+
keys.shrink_to_fit();
228+
229+
if failed > 0 {
230+
tracing::warn!(url=?rule.jwks_url, failed, "could not decode JWKs");
231+
}
232+
233+
if keys.is_empty() {
234+
tracing::warn!(url=?rule.jwks_url, "no valid JWKs found inside the response body");
235+
continue;
236+
}
237+
238+
let jwks = jose_jwk::JwkSet { keys };
173239
key_sets.insert(
174240
rule.id,
175241
KeySet {
@@ -179,7 +245,7 @@ impl JwkCacheEntryLock {
179245
},
180246
);
181247
}
182-
}
248+
};
183249
}
184250
}
185251
}
@@ -196,7 +262,7 @@ impl JwkCacheEntryLock {
196262
async fn get_or_update_jwk_cache<F: FetchAuthRules>(
197263
self: &Arc<Self>,
198264
ctx: &RequestMonitoring,
199-
client: &reqwest::Client,
265+
client: &reqwest_middleware::ClientWithMiddleware,
200266
endpoint: EndpointId,
201267
fetch: &F,
202268
) -> Result<Arc<JwkCacheEntry>, JwtError> {
@@ -250,7 +316,7 @@ impl JwkCacheEntryLock {
250316
self: &Arc<Self>,
251317
ctx: &RequestMonitoring,
252318
jwt: &str,
253-
client: &reqwest::Client,
319+
client: &reqwest_middleware::ClientWithMiddleware,
254320
endpoint: EndpointId,
255321
role_name: &RoleName,
256322
fetch: &F,
@@ -369,8 +435,19 @@ impl Default for JwkCache {
369435
let client = Client::builder()
370436
.user_agent(JWKS_USER_AGENT)
371437
.redirect(redirect::Policy::none())
438+
.tls_built_in_native_certs(true)
439+
.connect_timeout(JWKS_CONNECT_TIMEOUT)
440+
.timeout(JWKS_FETCH_TIMEOUT)
372441
.build()
373-
.expect("using &str and standard redirect::Policy");
442+
.expect("client config should be valid");
443+
444+
// Retry up to 3 times with increasing intervals between attempts.
445+
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(JWKS_FETCH_RETRIES);
446+
447+
let client = reqwest_middleware::ClientBuilder::new(client)
448+
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
449+
.build();
450+
374451
JwkCache {
375452
client,
376453
map: DashMap::default(),
@@ -1209,4 +1286,63 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
12091286
}
12101287
}
12111288
}
1289+
1290+
#[tokio::test]
1291+
async fn check_jwk_keycloak_regression() {
1292+
let (rs, valid_jwk) = new_rsa_jwk(RS1, "rs1".into());
1293+
let valid_jwk = serde_json::to_value(valid_jwk).unwrap();
1294+
1295+
// This is valid, but we cannot parse it as we have no support for encryption JWKs, only signature based ones.
1296+
// This is taken directly from keycloak.
1297+
let invalid_jwk = serde_json::json! {
1298+
{
1299+
"kid": "U-Jc9xRli84eNqRpYQoIPF-GNuRWV3ZvAIhziRW2sbQ",
1300+
"kty": "RSA",
1301+
"alg": "RSA-OAEP",
1302+
"use": "enc",
1303+
"n": "yypYWsEKmM_wWdcPnSGLSm5ytw1WG7P7EVkKSulcDRlrM6HWj3PR68YS8LySYM2D9Z-79oAdZGKhIfzutqL8rK1vS14zDuPpAM-RWY3JuQfm1O_-1DZM8-07PmVRegP5KPxsKblLf_My8ByH6sUOIa1p2rbe2q_b0dSTXYu1t0dW-cGL5VShc400YymvTwpc-5uYNsaVxZajnB7JP1OunOiuCJ48AuVp3PqsLzgoXqlXEB1ZZdch3xT3bxaTtNruGvG4xmLZY68O_T3yrwTCNH2h_jFdGPyXdyZToCMSMK2qSbytlfwfN55pT9Vv42Lz1YmoB7XRjI9aExKPc5AxFw",
1304+
"e": "AQAB",
1305+
"x5c": [
1306+
"MIICmzCCAYMCBgGS41E6azANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjQxMDMxMTYwMTQ0WhcNMzQxMDMxMTYwMzI0WjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDLKlhawQqYz/BZ1w+dIYtKbnK3DVYbs/sRWQpK6VwNGWszodaPc9HrxhLwvJJgzYP1n7v2gB1kYqEh/O62ovysrW9LXjMO4+kAz5FZjcm5B+bU7/7UNkzz7Ts+ZVF6A/ko/GwpuUt/8zLwHIfqxQ4hrWnatt7ar9vR1JNdi7W3R1b5wYvlVKFzjTRjKa9PClz7m5g2xpXFlqOcHsk/U66c6K4InjwC5Wnc+qwvOCheqVcQHVll1yHfFPdvFpO02u4a8bjGYtljrw79PfKvBMI0faH+MV0Y/Jd3JlOgIxIwrapJvK2V/B83nmlP1W/jYvPViagHtdGMj1oTEo9zkDEXAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAECYX59+Q9v6c9sb6Q0/C6IgLWG2nVCgVE1YWwIzz+68WrhlmNCRuPjY94roB+tc2tdHbj+Nh3LMzJk7L1KCQoW1+LPK6A6E8W9ad0YPcuw8csV2pUA3+H56exQMH0fUAPQAU7tXWvnQ7otcpV1XA8afn/NTMTsnxi9mSkor8MLMYQ3aeRyh1+LAchHBthWiltqsSUqXrbJF59u5p0ghquuKcWR3TXsA7klGYBgGU5KAJifr9XT87rN0bOkGvbeWAgKvnQnjZwxdnLqTfp/pRY/PiJJHhgIBYPIA7STGnMPjmJ995i34zhnbnd8WHXJA3LxrIMqLW/l8eIdvtM1w8KI="
1307+
],
1308+
"x5t": "QhfzMMnuAfkReTgZ1HtrfyOeeZs",
1309+
"x5t#S256": "cmHDUdKgLiRCEN28D5FBy9IJLFmR7QWfm77SLhGTCTU"
1310+
}
1311+
};
1312+
1313+
let jwks = serde_json::json! {{ "keys": [invalid_jwk, valid_jwk ] }};
1314+
let jwks_addr = jwks_server(move |path| match path {
1315+
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
1316+
_ => None,
1317+
})
1318+
.await;
1319+
1320+
let role_name = RoleName::from("anonymous");
1321+
let role = RoleNameInt::from(&role_name);
1322+
1323+
let rules = vec![AuthRule {
1324+
id: "foo".to_owned(),
1325+
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
1326+
audience: None,
1327+
role_names: vec![role],
1328+
}];
1329+
1330+
let fetch = Fetch(rules);
1331+
let jwk_cache = JwkCache::default();
1332+
1333+
let endpoint = EndpointId::from("ep");
1334+
1335+
let token = new_rsa_jwt("rs1".into(), rs);
1336+
1337+
jwk_cache
1338+
.check_jwt(
1339+
&RequestMonitoring::test(),
1340+
endpoint.clone(),
1341+
&role_name,
1342+
&fetch,
1343+
&token,
1344+
)
1345+
.await
1346+
.unwrap();
1347+
}
12121348
}

proxy/src/http/mod.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ pub mod health_server;
66

77
use std::time::Duration;
88

9-
use anyhow::bail;
109
use bytes::Bytes;
1110
use http::Method;
1211
use http_body_util::BodyExt;
@@ -16,7 +15,7 @@ use reqwest_middleware::RequestBuilder;
1615
pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error};
1716
pub(crate) use reqwest_retry::policies::ExponentialBackoff;
1817
pub(crate) use reqwest_retry::RetryTransientMiddleware;
19-
use serde::de::DeserializeOwned;
18+
use thiserror::Error;
2019

2120
use crate::metrics::{ConsoleRequest, Metrics};
2221
use crate::url::ApiUrl;
@@ -122,31 +121,40 @@ impl Endpoint {
122121
}
123122
}
124123

125-
pub(crate) async fn parse_json_body_with_limit<D: DeserializeOwned>(
124+
#[derive(Error, Debug)]
125+
pub(crate) enum ReadBodyError {
126+
#[error("Content length exceeds limit of {limit} bytes")]
127+
BodyTooLarge { limit: usize },
128+
129+
#[error(transparent)]
130+
Read(#[from] reqwest::Error),
131+
}
132+
133+
pub(crate) async fn read_body_with_limit(
126134
mut b: impl Body<Data = Bytes, Error = reqwest::Error> + Unpin,
127135
limit: usize,
128-
) -> anyhow::Result<D> {
136+
) -> Result<Vec<u8>, ReadBodyError> {
129137
// We could use `b.limited().collect().await.to_bytes()` here
130138
// but this ends up being slightly more efficient as far as I can tell.
131139

132140
// check the lower bound of the size hint.
133141
// in reqwest, this value is influenced by the Content-Length header.
134142
let lower_bound = match usize::try_from(b.size_hint().lower()) {
135143
Ok(bound) if bound <= limit => bound,
136-
_ => bail!("Content length exceeds limit of {limit} bytes"),
144+
_ => return Err(ReadBodyError::BodyTooLarge { limit }),
137145
};
138146
let mut bytes = Vec::with_capacity(lower_bound);
139147

140148
while let Some(frame) = b.frame().await.transpose()? {
141149
if let Ok(data) = frame.into_data() {
142150
if bytes.len() + data.len() > limit {
143-
bail!("Content length exceeds limit of {limit} bytes")
151+
return Err(ReadBodyError::BodyTooLarge { limit });
144152
}
145153
bytes.extend_from_slice(&data);
146154
}
147155
}
148156

149-
Ok(serde_json::from_slice::<D>(&bytes)?)
157+
Ok(bytes)
150158
}
151159

152160
#[cfg(test)]

proxy/src/serverless/conn_pool_lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ use super::http_conn_pool::ClientDataHttp;
1616
use super::local_conn_pool::ClientDataLocal;
1717
use crate::auth::backend::ComputeUserInfo;
1818
use crate::context::RequestMonitoring;
19-
use crate::control_plane::messages::ColdStartInfo;
20-
use crate::control_plane::messages::MetricsAuxInfo;
19+
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
2120
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
2221
use crate::types::{DbName, EndpointCacheKey, RoleName};
2322
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};

proxy/src/serverless/http_conn_pool.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use hyper::client::conn::http2;
77
use hyper_util::rt::{TokioExecutor, TokioIo};
88
use parking_lot::RwLock;
99
use rand::Rng;
10-
use std::result::Result::Ok;
1110
use tokio::net::TcpStream;
1211
use tracing::{debug, error, info, info_span, Instrument};
1312

workspace_hack/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ rand = { version = "0.8", features = ["small_rng"] }
6464
regex = { version = "1" }
6565
regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal", "unicode"] }
6666
regex-syntax = { version = "0.8" }
67-
reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "stream"] }
67+
reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "rustls-tls-native-roots", "stream"] }
6868
rustls = { version = "0.23", default-features = false, features = ["logging", "ring", "std", "tls12"] }
6969
scopeguard = { version = "1" }
7070
serde = { version = "1", features = ["alloc", "derive"] }

0 commit comments

Comments
 (0)