@@ -7,16 +7,19 @@ use arc_swap::ArcSwapOption;
7
7
use dashmap:: DashMap ;
8
8
use jose_jwk:: crypto:: KeyInfo ;
9
9
use reqwest:: { redirect, Client } ;
10
+ use reqwest_retry:: policies:: ExponentialBackoff ;
11
+ use reqwest_retry:: RetryTransientMiddleware ;
10
12
use serde:: de:: Visitor ;
11
13
use serde:: { Deserialize , Deserializer } ;
14
+ use serde_json:: value:: RawValue ;
12
15
use signature:: Verifier ;
13
16
use thiserror:: Error ;
14
17
use tokio:: time:: Instant ;
15
18
16
19
use crate :: auth:: backend:: ComputeCredentialKeys ;
17
20
use crate :: context:: RequestMonitoring ;
18
21
use crate :: control_plane:: errors:: GetEndpointJwksError ;
19
- use crate :: http:: parse_json_body_with_limit ;
22
+ use crate :: http:: read_body_with_limit ;
20
23
use crate :: intern:: RoleNameInt ;
21
24
use crate :: types:: { EndpointId , RoleName } ;
22
25
@@ -28,6 +31,10 @@ const MAX_RENEW: Duration = Duration::from_secs(3600);
28
31
const MAX_JWK_BODY_SIZE : usize = 64 * 1024 ;
29
32
const JWKS_USER_AGENT : & str = "neon-proxy" ;
30
33
34
+ const JWKS_CONNECT_TIMEOUT : Duration = Duration :: from_secs ( 2 ) ;
35
+ const JWKS_FETCH_TIMEOUT : Duration = Duration :: from_secs ( 5 ) ;
36
+ const JWKS_FETCH_RETRIES : u32 = 3 ;
37
+
31
38
/// How to get the JWT auth rules
32
39
pub ( crate ) trait FetchAuthRules : Clone + Send + Sync + ' static {
33
40
fn fetch_auth_rules (
@@ -55,7 +62,7 @@ pub(crate) struct AuthRule {
55
62
}
56
63
57
64
pub struct JwkCache {
58
- client : reqwest :: Client ,
65
+ client : reqwest_middleware :: ClientWithMiddleware ,
59
66
60
67
map : DashMap < ( EndpointId , RoleName ) , Arc < JwkCacheEntryLock > > ,
61
68
}
@@ -117,6 +124,14 @@ impl Default for JwkCacheEntryLock {
117
124
}
118
125
}
119
126
127
+ #[ derive( Deserialize ) ]
128
+ struct JwkSet < ' a > {
129
+ /// we parse into raw-value because not all keys in a JWKS are ones
130
+ /// we can parse directly, so we parse them lazily.
131
+ #[ serde( borrow) ]
132
+ keys : Vec < & ' a RawValue > ,
133
+ }
134
+
120
135
impl JwkCacheEntryLock {
121
136
async fn acquire_permit < ' a > ( self : & ' a Arc < Self > ) -> JwkRenewalPermit < ' a > {
122
137
JwkRenewalPermit :: acquire_permit ( self ) . await
@@ -130,7 +145,7 @@ impl JwkCacheEntryLock {
130
145
& self ,
131
146
_permit : JwkRenewalPermit < ' _ > ,
132
147
ctx : & RequestMonitoring ,
133
- client : & reqwest :: Client ,
148
+ client : & reqwest_middleware :: ClientWithMiddleware ,
134
149
endpoint : EndpointId ,
135
150
auth_rules : & F ,
136
151
) -> Result < Arc < JwkCacheEntry > , JwtError > {
@@ -154,22 +169,73 @@ impl JwkCacheEntryLock {
154
169
let req = client. get ( rule. jwks_url . clone ( ) ) ;
155
170
// TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
156
171
// TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
157
- match req. send ( ) . await . and_then ( |r| r. error_for_status ( ) ) {
172
+ match req. send ( ) . await . and_then ( |r| {
173
+ r. error_for_status ( )
174
+ . map_err ( reqwest_middleware:: Error :: Reqwest )
175
+ } ) {
158
176
// todo: should we re-insert JWKs if we want to keep this JWKs URL?
159
177
// I expect these failures would be quite sparse.
160
178
Err ( e) => tracing:: warn!( url=?rule. jwks_url, error=?e, "could not fetch JWKs" ) ,
161
179
Ok ( r) => {
162
180
let resp: http:: Response < reqwest:: Body > = r. into ( ) ;
163
- match parse_json_body_with_limit :: < jose_jwk:: JwkSet > (
164
- resp. into_body ( ) ,
165
- MAX_JWK_BODY_SIZE ,
166
- )
167
- . await
181
+
182
+ let bytes = match read_body_with_limit ( resp. into_body ( ) , MAX_JWK_BODY_SIZE )
183
+ . await
168
184
{
185
+ Ok ( bytes) => bytes,
186
+ Err ( e) => {
187
+ tracing:: warn!( url=?rule. jwks_url, error=?e, "could not decode JWKs" ) ;
188
+ continue ;
189
+ }
190
+ } ;
191
+
192
+ match serde_json:: from_slice :: < JwkSet > ( & bytes) {
169
193
Err ( e) => {
170
194
tracing:: warn!( url=?rule. jwks_url, error=?e, "could not decode JWKs" ) ;
171
195
}
172
196
Ok ( jwks) => {
197
+ // size_of::<&RawValue>() == 16
198
+ // size_of::<jose_jwk::Jwk>() == 288
199
+ // better to not pre-allocate this as it might be pretty large - especially if it has many
200
+ // keys we don't want or need.
201
+ // trivial 'attack': `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`
202
+ // this would consume 8MiB just like that!
203
+ let mut keys = vec ! [ ] ;
204
+ let mut failed = 0 ;
205
+ for key in jwks. keys {
206
+ match serde_json:: from_str :: < jose_jwk:: Jwk > ( key. get ( ) ) {
207
+ Ok ( key) => {
208
+ // if `use` (called `cls` in rust) is specified to be something other than signing,
209
+ // we can skip storing it.
210
+ if key
211
+ . prm
212
+ . cls
213
+ . as_ref ( )
214
+ . is_some_and ( |c| * c != jose_jwk:: Class :: Signing )
215
+ {
216
+ continue ;
217
+ }
218
+
219
+ keys. push ( key) ;
220
+ }
221
+ Err ( e) => {
222
+ tracing:: debug!( url=?rule. jwks_url, failed=?e, "could not decode JWK" ) ;
223
+ failed += 1 ;
224
+ }
225
+ }
226
+ }
227
+ keys. shrink_to_fit ( ) ;
228
+
229
+ if failed > 0 {
230
+ tracing:: warn!( url=?rule. jwks_url, failed, "could not decode JWKs" ) ;
231
+ }
232
+
233
+ if keys. is_empty ( ) {
234
+ tracing:: warn!( url=?rule. jwks_url, "no valid JWKs found inside the response body" ) ;
235
+ continue ;
236
+ }
237
+
238
+ let jwks = jose_jwk:: JwkSet { keys } ;
173
239
key_sets. insert (
174
240
rule. id ,
175
241
KeySet {
@@ -179,7 +245,7 @@ impl JwkCacheEntryLock {
179
245
} ,
180
246
) ;
181
247
}
182
- }
248
+ } ;
183
249
}
184
250
}
185
251
}
@@ -196,7 +262,7 @@ impl JwkCacheEntryLock {
196
262
async fn get_or_update_jwk_cache < F : FetchAuthRules > (
197
263
self : & Arc < Self > ,
198
264
ctx : & RequestMonitoring ,
199
- client : & reqwest :: Client ,
265
+ client : & reqwest_middleware :: ClientWithMiddleware ,
200
266
endpoint : EndpointId ,
201
267
fetch : & F ,
202
268
) -> Result < Arc < JwkCacheEntry > , JwtError > {
@@ -250,7 +316,7 @@ impl JwkCacheEntryLock {
250
316
self : & Arc < Self > ,
251
317
ctx : & RequestMonitoring ,
252
318
jwt : & str ,
253
- client : & reqwest :: Client ,
319
+ client : & reqwest_middleware :: ClientWithMiddleware ,
254
320
endpoint : EndpointId ,
255
321
role_name : & RoleName ,
256
322
fetch : & F ,
@@ -369,8 +435,19 @@ impl Default for JwkCache {
369
435
let client = Client :: builder ( )
370
436
. user_agent ( JWKS_USER_AGENT )
371
437
. redirect ( redirect:: Policy :: none ( ) )
438
+ . tls_built_in_native_certs ( true )
439
+ . connect_timeout ( JWKS_CONNECT_TIMEOUT )
440
+ . timeout ( JWKS_FETCH_TIMEOUT )
372
441
. build ( )
373
- . expect ( "using &str and standard redirect::Policy" ) ;
442
+ . expect ( "client config should be valid" ) ;
443
+
444
+ // Retry up to 3 times with increasing intervals between attempts.
445
+ let retry_policy = ExponentialBackoff :: builder ( ) . build_with_max_retries ( JWKS_FETCH_RETRIES ) ;
446
+
447
+ let client = reqwest_middleware:: ClientBuilder :: new ( client)
448
+ . with ( RetryTransientMiddleware :: new_with_policy ( retry_policy) )
449
+ . build ( ) ;
450
+
374
451
JwkCache {
375
452
client,
376
453
map : DashMap :: default ( ) ,
@@ -1209,4 +1286,63 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
1209
1286
}
1210
1287
}
1211
1288
}
1289
+
1290
+ #[ tokio:: test]
1291
+ async fn check_jwk_keycloak_regression ( ) {
1292
+ let ( rs, valid_jwk) = new_rsa_jwk ( RS1 , "rs1" . into ( ) ) ;
1293
+ let valid_jwk = serde_json:: to_value ( valid_jwk) . unwrap ( ) ;
1294
+
1295
+ // This is valid, but we cannot parse it as we have no support for encryption JWKs, only signature based ones.
1296
+ // This is taken directly from keycloak.
1297
+ let invalid_jwk = serde_json:: json! {
1298
+ {
1299
+ "kid" : "U-Jc9xRli84eNqRpYQoIPF-GNuRWV3ZvAIhziRW2sbQ" ,
1300
+ "kty" : "RSA" ,
1301
+ "alg" : "RSA-OAEP" ,
1302
+ "use" : "enc" ,
1303
+ "n" : "yypYWsEKmM_wWdcPnSGLSm5ytw1WG7P7EVkKSulcDRlrM6HWj3PR68YS8LySYM2D9Z-79oAdZGKhIfzutqL8rK1vS14zDuPpAM-RWY3JuQfm1O_-1DZM8-07PmVRegP5KPxsKblLf_My8ByH6sUOIa1p2rbe2q_b0dSTXYu1t0dW-cGL5VShc400YymvTwpc-5uYNsaVxZajnB7JP1OunOiuCJ48AuVp3PqsLzgoXqlXEB1ZZdch3xT3bxaTtNruGvG4xmLZY68O_T3yrwTCNH2h_jFdGPyXdyZToCMSMK2qSbytlfwfN55pT9Vv42Lz1YmoB7XRjI9aExKPc5AxFw" ,
1304
+ "e" : "AQAB" ,
1305
+ "x5c" : [
1306
+ "MIICmzCCAYMCBgGS41E6azANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjQxMDMxMTYwMTQ0WhcNMzQxMDMxMTYwMzI0WjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDLKlhawQqYz/BZ1w+dIYtKbnK3DVYbs/sRWQpK6VwNGWszodaPc9HrxhLwvJJgzYP1n7v2gB1kYqEh/O62ovysrW9LXjMO4+kAz5FZjcm5B+bU7/7UNkzz7Ts+ZVF6A/ko/GwpuUt/8zLwHIfqxQ4hrWnatt7ar9vR1JNdi7W3R1b5wYvlVKFzjTRjKa9PClz7m5g2xpXFlqOcHsk/U66c6K4InjwC5Wnc+qwvOCheqVcQHVll1yHfFPdvFpO02u4a8bjGYtljrw79PfKvBMI0faH+MV0Y/Jd3JlOgIxIwrapJvK2V/B83nmlP1W/jYvPViagHtdGMj1oTEo9zkDEXAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAECYX59+Q9v6c9sb6Q0/C6IgLWG2nVCgVE1YWwIzz+68WrhlmNCRuPjY94roB+tc2tdHbj+Nh3LMzJk7L1KCQoW1+LPK6A6E8W9ad0YPcuw8csV2pUA3+H56exQMH0fUAPQAU7tXWvnQ7otcpV1XA8afn/NTMTsnxi9mSkor8MLMYQ3aeRyh1+LAchHBthWiltqsSUqXrbJF59u5p0ghquuKcWR3TXsA7klGYBgGU5KAJifr9XT87rN0bOkGvbeWAgKvnQnjZwxdnLqTfp/pRY/PiJJHhgIBYPIA7STGnMPjmJ995i34zhnbnd8WHXJA3LxrIMqLW/l8eIdvtM1w8KI="
1307
+ ] ,
1308
+ "x5t" : "QhfzMMnuAfkReTgZ1HtrfyOeeZs" ,
1309
+ "x5t#S256" : "cmHDUdKgLiRCEN28D5FBy9IJLFmR7QWfm77SLhGTCTU"
1310
+ }
1311
+ } ;
1312
+
1313
+ let jwks = serde_json:: json! { { "keys" : [ invalid_jwk, valid_jwk ] } } ;
1314
+ let jwks_addr = jwks_server ( move |path| match path {
1315
+ "/" => Some ( serde_json:: to_vec ( & jwks) . unwrap ( ) ) ,
1316
+ _ => None ,
1317
+ } )
1318
+ . await ;
1319
+
1320
+ let role_name = RoleName :: from ( "anonymous" ) ;
1321
+ let role = RoleNameInt :: from ( & role_name) ;
1322
+
1323
+ let rules = vec ! [ AuthRule {
1324
+ id: "foo" . to_owned( ) ,
1325
+ jwks_url: format!( "http://{jwks_addr}/" ) . parse( ) . unwrap( ) ,
1326
+ audience: None ,
1327
+ role_names: vec![ role] ,
1328
+ } ] ;
1329
+
1330
+ let fetch = Fetch ( rules) ;
1331
+ let jwk_cache = JwkCache :: default ( ) ;
1332
+
1333
+ let endpoint = EndpointId :: from ( "ep" ) ;
1334
+
1335
+ let token = new_rsa_jwt ( "rs1" . into ( ) , rs) ;
1336
+
1337
+ jwk_cache
1338
+ . check_jwt (
1339
+ & RequestMonitoring :: test ( ) ,
1340
+ endpoint. clone ( ) ,
1341
+ & role_name,
1342
+ & fetch,
1343
+ & token,
1344
+ )
1345
+ . await
1346
+ . unwrap ( ) ;
1347
+ }
1212
1348
}
0 commit comments