Skip to content

Commit 92ebff2

Browse files
authored
Remove generics from authentication code (#9504)
This prepares for moving some of the implementation into axum extractors, which only get access to `Parts`.
1 parent 202cbb5 commit 92ebff2

File tree

5 files changed

+43
-41
lines changed

5 files changed

+43
-41
lines changed

src/auth.rs

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use crate::util::errors::{
1111
use crate::util::token::HashedToken;
1212
use chrono::Utc;
1313
use http::header;
14+
use http::request::Parts;
1415

1516
#[derive(Debug, Clone)]
1617
pub struct AuthCheck {
@@ -57,18 +58,14 @@ impl AuthCheck {
5758
}
5859

5960
#[instrument(name = "auth.check", skip_all)]
60-
pub fn check<T: RequestPartsExt>(
61-
&self,
62-
request: &T,
63-
conn: &mut impl Conn,
64-
) -> AppResult<Authentication> {
65-
let auth = authenticate(request, conn)?;
61+
pub fn check(&self, parts: &Parts, conn: &mut impl Conn) -> AppResult<Authentication> {
62+
let auth = authenticate(parts, conn)?;
6663

6764
if let Some(token) = auth.api_token() {
6865
if !self.allow_token {
6966
let error_message =
7067
"API Token authentication was explicitly disallowed for this API";
71-
request.request_log().add("cause", error_message);
68+
parts.request_log().add("cause", error_message);
7269

7370
return Err(forbidden(
7471
"this action can only be performed on the crates.io website",
@@ -77,7 +74,7 @@ impl AuthCheck {
7774

7875
if !self.endpoint_scope_matches(token.endpoint_scopes.as_ref()) {
7976
let error_message = "Endpoint scope mismatch";
80-
request.request_log().add("cause", error_message);
77+
parts.request_log().add("cause", error_message);
8178

8279
return Err(forbidden(
8380
"this token does not have the required permissions to perform this action",
@@ -86,7 +83,7 @@ impl AuthCheck {
8683

8784
if !self.crate_scope_matches(token.crate_scopes.as_ref()) {
8885
let error_message = "Crate scope mismatch";
89-
request.request_log().add("cause", error_message);
86+
parts.request_log().add("cause", error_message);
9087

9188
return Err(forbidden(
9289
"this token does not have the required permissions to perform this action",
@@ -171,11 +168,11 @@ impl Authentication {
171168
}
172169

173170
#[instrument(skip_all)]
174-
fn authenticate_via_cookie<T: RequestPartsExt>(
175-
req: &T,
171+
fn authenticate_via_cookie(
172+
parts: &Parts,
176173
conn: &mut impl Conn,
177174
) -> AppResult<Option<CookieAuthentication>> {
178-
let user_id_from_session = req
175+
let user_id_from_session = parts
179176
.session()
180177
.get("user_id")
181178
.and_then(|s| s.parse::<i32>().ok());
@@ -185,23 +182,23 @@ fn authenticate_via_cookie<T: RequestPartsExt>(
185182
};
186183

187184
let user = User::find(conn, id).map_err(|err| {
188-
req.request_log().add("cause", err);
185+
parts.request_log().add("cause", err);
189186
internal("user_id from cookie not found in database")
190187
})?;
191188

192189
ensure_not_locked(&user)?;
193190

194-
req.request_log().add("uid", id);
191+
parts.request_log().add("uid", id);
195192

196193
Ok(Some(CookieAuthentication { user }))
197194
}
198195

199196
#[instrument(skip_all)]
200-
fn authenticate_via_token<T: RequestPartsExt>(
201-
req: &T,
197+
fn authenticate_via_token(
198+
parts: &Parts,
202199
conn: &mut impl Conn,
203200
) -> AppResult<Option<TokenAuthentication>> {
204-
let maybe_authorization = req
201+
let maybe_authorization = parts
205202
.headers()
206203
.get(header::AUTHORIZATION)
207204
.and_then(|h| h.to_str().ok());
@@ -215,43 +212,43 @@ fn authenticate_via_token<T: RequestPartsExt>(
215212

216213
let token = ApiToken::find_by_api_token(conn, &token).map_err(|e| {
217214
let cause = format!("invalid token caused by {e}");
218-
req.request_log().add("cause", cause);
215+
parts.request_log().add("cause", cause);
219216

220217
forbidden("authentication failed")
221218
})?;
222219

223220
let user = User::find(conn, token.user_id).map_err(|err| {
224-
req.request_log().add("cause", err);
221+
parts.request_log().add("cause", err);
225222
internal("user_id from token not found in database")
226223
})?;
227224

228225
ensure_not_locked(&user)?;
229226

230-
req.request_log().add("uid", token.user_id);
231-
req.request_log().add("tokenid", token.id);
227+
parts.request_log().add("uid", token.user_id);
228+
parts.request_log().add("tokenid", token.id);
232229

233230
Ok(Some(TokenAuthentication { user, token }))
234231
}
235232

236233
#[instrument(skip_all)]
237-
fn authenticate<T: RequestPartsExt>(req: &T, conn: &mut impl Conn) -> AppResult<Authentication> {
238-
controllers::util::verify_origin(req)?;
234+
fn authenticate(parts: &Parts, conn: &mut impl Conn) -> AppResult<Authentication> {
235+
controllers::util::verify_origin(parts)?;
239236

240-
match authenticate_via_cookie(req, conn) {
237+
match authenticate_via_cookie(parts, conn) {
241238
Ok(None) => {}
242239
Ok(Some(auth)) => return Ok(Authentication::Cookie(auth)),
243240
Err(err) => return Err(err),
244241
}
245242

246-
match authenticate_via_token(req, conn) {
243+
match authenticate_via_token(parts, conn) {
247244
Ok(None) => {}
248245
Ok(Some(auth)) => return Ok(Authentication::Token(auth)),
249246
Err(err) => return Err(err),
250247
}
251248

252249
// Unable to authenticate the user
253250
let cause = "no cookie session or auth header found";
254-
req.request_log().add("cause", cause);
251+
parts.request_log().add("cause", cause);
255252

256253
return Err(forbidden("this action requires authentication"));
257254
}

src/controllers/crate_owner_invitation.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,16 +275,18 @@ struct OwnerInvitation {
275275

276276
/// Handles the `PUT /api/v1/me/crate_owner_invitations/:crate_id` route.
277277
pub async fn handle_invite(state: AppState, req: BytesRequest) -> AppResult<Json<Value>> {
278+
let (parts, body) = req.0.into_parts();
279+
278280
let crate_invite: OwnerInvitation =
279-
serde_json::from_slice(req.body()).map_err(|_| bad_request("invalid json request"))?;
281+
serde_json::from_slice(&body).map_err(|_| bad_request("invalid json request"))?;
280282

281283
let crate_invite = crate_invite.crate_owner_invite;
282284

283285
let conn = state.db_write().await?;
284286
spawn_blocking(move || {
285287
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
286288

287-
let auth = AuthCheck::default().check(&req, conn)?;
289+
let auth = AuthCheck::default().check(&parts, conn)?;
288290
let user_id = auth.user_id();
289291

290292
let config = &state.config;

src/controllers/token.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ pub async fn list(
6767

6868
/// Handles the `PUT /me/tokens` route.
6969
pub async fn new(app: AppState, req: BytesRequest) -> AppResult<Json<Value>> {
70+
let (parts, body) = req.0.into_parts();
71+
7072
let conn = app.db_write().await?;
7173
spawn_blocking(move || {
7274
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
@@ -87,15 +89,15 @@ pub async fn new(app: AppState, req: BytesRequest) -> AppResult<Json<Value>> {
8789
api_token: NewApiToken,
8890
}
8991

90-
let new: NewApiTokenRequest = json::from_slice(req.body())
92+
let new: NewApiTokenRequest = json::from_slice(&body)
9193
.map_err(|e| bad_request(format!("invalid new token request: {e:?}")))?;
9294

9395
let name = &new.api_token.name;
9496
if name.is_empty() {
9597
return Err(bad_request("name must have a value"));
9698
}
9799

98-
let auth = AuthCheck::default().check(&req, conn)?;
100+
let auth = AuthCheck::default().check(&parts, conn)?;
99101
if auth.api_token_id().is_some() {
100102
return Err(bad_request(
101103
"cannot use an API token to create a new API token",

src/controllers/user/me.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,26 +129,27 @@ pub async fn confirm_user_email(state: AppState, Path(token): Path<String>) -> A
129129

130130
/// Handles `PUT /me/email_notifications` route
131131
pub async fn update_email_notifications(app: AppState, req: BytesRequest) -> AppResult<Response> {
132+
let (parts, body) = req.0.into_parts();
133+
132134
#[derive(Deserialize)]
133135
struct CrateEmailNotifications {
134136
id: i32,
135137
email_notifications: bool,
136138
}
137139

138-
let updates: HashMap<i32, bool> =
139-
serde_json::from_slice::<Vec<CrateEmailNotifications>>(req.body())
140-
.map_err(|_| bad_request("invalid json request"))?
141-
.iter()
142-
.map(|c| (c.id, c.email_notifications))
143-
.collect();
140+
let updates: HashMap<i32, bool> = serde_json::from_slice::<Vec<CrateEmailNotifications>>(&body)
141+
.map_err(|_| bad_request("invalid json request"))?
142+
.iter()
143+
.map(|c| (c.id, c.email_notifications))
144+
.collect();
144145

145146
let conn = app.db_write().await?;
146147
spawn_blocking(move || {
147148
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
148149

149150
use diesel::pg::upsert::excluded;
150151

151-
let user_id = AuthCheck::default().check(&req, conn)?.user_id();
152+
let user_id = AuthCheck::default().check(&parts, conn)?.user_id();
152153

153154
// Build inserts from existing crates belonging to the current user
154155
let to_insert = CrateOwner::by_owner_kind(OwnerKind::User)

src/controllers/util.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ use http::{header, Extensions, HeaderMap, HeaderValue, Method, Request, Uri, Ver
1010
/// We don't want to accept authenticated requests that originated from other sites, so this
1111
/// function returns an error if the Origin header doesn't match what we expect "this site" to
1212
/// be: <https://crates.io> in production, or <http://localhost:port/> in development.
13-
pub fn verify_origin<T: RequestPartsExt>(req: &T) -> AppResult<()> {
14-
let headers = req.headers();
15-
let allowed_origins = &req.app().config.allowed_origins;
13+
pub fn verify_origin(parts: &Parts) -> AppResult<()> {
14+
let headers = parts.headers();
15+
let allowed_origins = &parts.app().config.allowed_origins;
1616

1717
let bad_origin = headers
1818
.get_all(header::ORIGIN)
@@ -23,7 +23,7 @@ pub fn verify_origin<T: RequestPartsExt>(req: &T) -> AppResult<()> {
2323
let error_message =
2424
format!("only same-origin requests can be authenticated. got {bad_origin:?}");
2525

26-
req.request_log().add("cause", error_message);
26+
parts.request_log().add("cause", error_message);
2727

2828
return Err(forbidden("invalid origin header"));
2929
}

0 commit comments

Comments
 (0)