From 7a6a6102546203e88af8ebf897a3707ce2c6d28a Mon Sep 17 00:00:00 2001 From: David Calavera Date: Tue, 20 Jun 2023 03:02:45 -0700 Subject: [PATCH] Implement custom deserializer for LambdaRequest This deserializer gives us full control over the error message that we return for invalid payloads. The default message that Serde returns is usually very confusing, and it's been reported many times as something people don't understand. This code is a copy of the code that Serde generates when it expands the Deserialize macro. Signed-off-by: David Calavera --- lambda-events/Cargo.toml | 2 +- lambda-events/src/event/alb/mod.rs | 1 + lambda-events/src/event/apigw/mod.rs | 23 ++- .../src/fixtures/example-apigw-request.json | 145 +++++++++++------- lambda-http/Cargo.toml | 2 +- lambda-http/src/deserializer.rs | 117 ++++++++++++++ lambda-http/src/lib.rs | 1 + lambda-http/src/request.rs | 5 +- 8 files changed, 238 insertions(+), 58 deletions(-) create mode 100644 lambda-http/src/deserializer.rs diff --git a/lambda-events/Cargo.toml b/lambda-events/Cargo.toml index b1108c63..28df6b4a 100644 --- a/lambda-events/Cargo.toml +++ b/lambda-events/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws_lambda_events" -version = "0.10.0" +version = "0.11.0" description = "AWS Lambda event definitions" authors = [ "Christian Legnitto ", diff --git a/lambda-events/src/event/alb/mod.rs b/lambda-events/src/event/alb/mod.rs index 259dce23..7bb1eb7f 100644 --- a/lambda-events/src/event/alb/mod.rs +++ b/lambda-events/src/event/alb/mod.rs @@ -9,6 +9,7 @@ use serde::{Deserialize, Serialize}; /// `AlbTargetGroupRequest` contains data originating from the ALB Lambda target group integration #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] #[serde(rename_all = "camelCase")] +#[serde(deny_unknown_fields)] pub struct AlbTargetGroupRequest { #[serde(with = "http_method")] pub http_method: Method, diff --git a/lambda-events/src/event/apigw/mod.rs b/lambda-events/src/event/apigw/mod.rs index 917f06aa..b595d825 100644 --- a/lambda-events/src/event/apigw/mod.rs +++ b/lambda-events/src/event/apigw/mod.rs @@ -13,6 +13,7 @@ use std::collections::HashMap; /// `ApiGatewayProxyRequest` contains data coming from the API Gateway proxy #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] #[serde(rename_all = "camelCase")] +#[serde(deny_unknown_fields)] pub struct ApiGatewayProxyRequest where T1: DeserializeOwned, @@ -118,12 +119,25 @@ where /// `ApiGatewayV2httpRequest` contains data coming from the new HTTP API Gateway #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] #[serde(rename_all = "camelCase")] +#[serde(deny_unknown_fields)] pub struct ApiGatewayV2httpRequest { + #[serde(default, rename = "type")] + pub kind: Option, + #[serde(default)] + pub method_arn: Option, + #[serde(with = "http_method", default = "default_http_method")] + pub http_method: Method, + #[serde(default)] + pub identity_source: Option, + #[serde(default)] + pub authorization_token: Option, + #[serde(default)] + pub resource: Option, #[serde(default)] pub version: Option, #[serde(default)] pub route_key: Option, - #[serde(default)] + #[serde(default, alias = "path")] pub raw_path: Option, #[serde(default)] pub raw_query_string: Option, @@ -319,6 +333,7 @@ pub struct ApiGatewayRequestIdentity { /// `ApiGatewayWebsocketProxyRequest` contains data coming from the API Gateway proxy #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] #[serde(rename_all = "camelCase")] +#[serde(deny_unknown_fields)] pub struct ApiGatewayWebsocketProxyRequest where T1: DeserializeOwned, @@ -747,6 +762,10 @@ pub struct IamPolicyStatement { pub resource: Vec, } +fn default_http_method() -> Method { + Method::GET +} + #[cfg(test)] mod test { use super::*; @@ -901,6 +920,8 @@ mod test { let output: String = serde_json::to_string(&parsed).unwrap(); let reparsed: ApiGatewayV2httpRequest = serde_json::from_slice(output.as_bytes()).unwrap(); assert_eq!(parsed, reparsed); + assert_eq!("REQUEST", parsed.kind.unwrap()); + assert_eq!(Method::GET, parsed.http_method); } #[test] diff --git a/lambda-events/src/fixtures/example-apigw-request.json b/lambda-events/src/fixtures/example-apigw-request.json index 570f785b..d91e9609 100644 --- a/lambda-events/src/fixtures/example-apigw-request.json +++ b/lambda-events/src/fixtures/example-apigw-request.json @@ -1,55 +1,95 @@ { "resource": "/{proxy+}", - "path": "/hello/world", - "httpMethod": "POST", - "headers": { - "Accept": "*/*", - "Accept-Encoding": "gzip, deflate", - "cache-control": "no-cache", - "CloudFront-Forwarded-Proto": "https", - "CloudFront-Is-Desktop-Viewer": "true", - "CloudFront-Is-Mobile-Viewer": "false", - "CloudFront-Is-SmartTV-Viewer": "false", - "CloudFront-Is-Tablet-Viewer": "false", - "CloudFront-Viewer-Country": "US", - "Content-Type": "application/json", - "headerName": "headerValue", - "Host": "gy415nuibc.execute-api.us-east-1.amazonaws.com", - "Postman-Token": "9f583ef0-ed83-4a38-aef3-eb9ce3f7a57f", - "User-Agent": "PostmanRuntime/2.4.5", - "Via": "1.1 d98420743a69852491bbdea73f7680bd.cloudfront.net (CloudFront)", - "X-Amz-Cf-Id": "pn-PWIJc6thYnZm5P0NMgOUglL1DYtl0gdeJky8tqsg8iS_sgsKD1A==", - "X-Forwarded-For": "54.240.196.186, 54.182.214.83", - "X-Forwarded-Port": "443", - "X-Forwarded-Proto": "https" - }, - "multiValueHeaders": { - "Accept": ["*/*"], - "Accept-Encoding": ["gzip, deflate"], - "cache-control": ["no-cache"], - "CloudFront-Forwarded-Proto": ["https"], - "CloudFront-Is-Desktop-Viewer": ["true"], - "CloudFront-Is-Mobile-Viewer": ["false"], - "CloudFront-Is-SmartTV-Viewer": ["false"], - "CloudFront-Is-Tablet-Viewer": ["false"], - "CloudFront-Viewer-Country": ["US"], - "Content-Type": ["application/json"], - "headerName": ["headerValue"], - "Host": ["gy415nuibc.execute-api.us-east-1.amazonaws.com"], - "Postman-Token": ["9f583ef0-ed83-4a38-aef3-eb9ce3f7a57f"], - "User-Agent": ["PostmanRuntime/2.4.5"], - "Via": ["1.1 d98420743a69852491bbdea73f7680bd.cloudfront.net (CloudFront)"], - "X-Amz-Cf-Id": ["pn-PWIJc6thYnZm5P0NMgOUglL1DYtl0gdeJky8tqsg8iS_sgsKD1A=="], - "X-Forwarded-For": ["54.240.196.186, 54.182.214.83"], - "X-Forwarded-Port": ["443"], - "X-Forwarded-Proto": ["https"] - }, + "path": "/hello/world", + "httpMethod": "POST", + "headers": { + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate", + "cache-control": "no-cache", + "CloudFront-Forwarded-Proto": "https", + "CloudFront-Is-Desktop-Viewer": "true", + "CloudFront-Is-Mobile-Viewer": "false", + "CloudFront-Is-SmartTV-Viewer": "false", + "CloudFront-Is-Tablet-Viewer": "false", + "CloudFront-Viewer-Country": "US", + "Content-Type": "application/json", + "headerName": "headerValue", + "Host": "gy415nuibc.execute-api.us-east-1.amazonaws.com", + "Postman-Token": "9f583ef0-ed83-4a38-aef3-eb9ce3f7a57f", + "User-Agent": "PostmanRuntime/2.4.5", + "Via": "1.1 d98420743a69852491bbdea73f7680bd.cloudfront.net (CloudFront)", + "X-Amz-Cf-Id": "pn-PWIJc6thYnZm5P0NMgOUglL1DYtl0gdeJky8tqsg8iS_sgsKD1A==", + "X-Forwarded-For": "54.240.196.186, 54.182.214.83", + "X-Forwarded-Port": "443", + "X-Forwarded-Proto": "https" + }, + "multiValueHeaders": { + "Accept": [ + "*/*" + ], + "Accept-Encoding": [ + "gzip, deflate" + ], + "cache-control": [ + "no-cache" + ], + "CloudFront-Forwarded-Proto": [ + "https" + ], + "CloudFront-Is-Desktop-Viewer": [ + "true" + ], + "CloudFront-Is-Mobile-Viewer": [ + "false" + ], + "CloudFront-Is-SmartTV-Viewer": [ + "false" + ], + "CloudFront-Is-Tablet-Viewer": [ + "false" + ], + "CloudFront-Viewer-Country": [ + "US" + ], + "Content-Type": [ + "application/json" + ], + "headerName": [ + "headerValue" + ], + "Host": [ + "gy415nuibc.execute-api.us-east-1.amazonaws.com" + ], + "Postman-Token": [ + "9f583ef0-ed83-4a38-aef3-eb9ce3f7a57f" + ], + "User-Agent": [ + "PostmanRuntime/2.4.5" + ], + "Via": [ + "1.1 d98420743a69852491bbdea73f7680bd.cloudfront.net (CloudFront)" + ], + "X-Amz-Cf-Id": [ + "pn-PWIJc6thYnZm5P0NMgOUglL1DYtl0gdeJky8tqsg8iS_sgsKD1A==" + ], + "X-Forwarded-For": [ + "54.240.196.186, 54.182.214.83" + ], + "X-Forwarded-Port": [ + "443" + ], + "X-Forwarded-Proto": [ + "https" + ] + }, "queryStringParameters": { "name": "me" - }, - "multiValueQueryStringParameters": { - "name": ["me"] - }, + }, + "multiValueQueryStringParameters": { + "name": [ + "me" + ] + }, "pathParameters": { "proxy": "hello/world" }, @@ -70,9 +110,9 @@ "accountId": "theAccountId", "cognitoIdentityId": "theCognitoIdentityId", "caller": "theCaller", - "apiKey": "theApiKey", - "apiKeyId": "theApiKeyId", - "accessKey": "ANEXAMPLEOFACCESSKEY", + "apiKey": "theApiKey", + "apiKeyId": "theApiKeyId", + "accessKey": "ANEXAMPLEOFACCESSKEY", "sourceIp": "192.168.196.186", "cognitoAuthenticationType": "theCognitoAuthenticationType", "cognitoAuthenticationProvider": "theCognitoAuthenticationProvider", @@ -92,5 +132,4 @@ "apiId": "gy415nuibc" }, "body": "{\r\n\t\"a\": 1\r\n}" -} - +} \ No newline at end of file diff --git a/lambda-http/Cargo.toml b/lambda-http/Cargo.toml index edc68650..be111092 100644 --- a/lambda-http/Cargo.toml +++ b/lambda-http/Cargo.toml @@ -40,7 +40,7 @@ percent-encoding = "2.2" [dependencies.aws_lambda_events] path = "../lambda-events" -version = "0.10.0" +version = "0.11.0" default-features = false features = ["alb", "apigw"] diff --git a/lambda-http/src/deserializer.rs b/lambda-http/src/deserializer.rs new file mode 100644 index 00000000..1771ea7b --- /dev/null +++ b/lambda-http/src/deserializer.rs @@ -0,0 +1,117 @@ +use crate::request::LambdaRequest; +use aws_lambda_events::{ + alb::AlbTargetGroupRequest, + apigw::{ApiGatewayProxyRequest, ApiGatewayV2httpRequest, ApiGatewayWebsocketProxyRequest}, +}; +use serde::{de::Error, Deserialize}; + +const ERROR_CONTEXT: &str = "this function expects a JSON payload from Amazon API Gateway, Amazon Elastic Load Balancer, or AWS Lambda Function URLs, but the data doesn't match any of those services' events"; + +impl<'de> Deserialize<'de> for LambdaRequest { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let content = match serde::__private::de::Content::deserialize(deserializer) { + Ok(content) => content, + Err(err) => return Err(err), + }; + #[cfg(feature = "apigw_rest")] + if let Ok(res) = + ApiGatewayProxyRequest::deserialize(serde::__private::de::ContentRefDeserializer::::new(&content)) + { + return Ok(LambdaRequest::ApiGatewayV1(res)); + } + #[cfg(feature = "apigw_http")] + if let Ok(res) = ApiGatewayV2httpRequest::deserialize( + serde::__private::de::ContentRefDeserializer::::new(&content), + ) { + return Ok(LambdaRequest::ApiGatewayV2(res)); + } + #[cfg(feature = "alb")] + if let Ok(res) = + AlbTargetGroupRequest::deserialize(serde::__private::de::ContentRefDeserializer::::new(&content)) + { + return Ok(LambdaRequest::Alb(res)); + } + #[cfg(feature = "apigw_websockets")] + if let Ok(res) = ApiGatewayWebsocketProxyRequest::deserialize(serde::__private::de::ContentRefDeserializer::< + D::Error, + >::new(&content)) + { + return Ok(LambdaRequest::WebSocket(res)); + } + + Err(Error::custom(ERROR_CONTEXT)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_deserialize_apigw_rest() { + let data = include_bytes!("../../lambda-events/src/fixtures/example-apigw-request.json"); + + let req: LambdaRequest = serde_json::from_slice(data).expect("failed to deserialze apigw rest data"); + match req { + LambdaRequest::ApiGatewayV1(req) => { + assert_eq!("12345678912", req.request_context.account_id.unwrap()); + } + other => panic!("unexpected request variant: {:?}", other), + } + } + + #[test] + fn test_deserialize_apigw_http() { + let data = include_bytes!("../../lambda-events/src/fixtures/example-apigw-v2-request-iam.json"); + + let req: LambdaRequest = serde_json::from_slice(data).expect("failed to deserialze apigw http data"); + match req { + LambdaRequest::ApiGatewayV2(req) => { + assert_eq!("123456789012", req.request_context.account_id.unwrap()); + } + other => panic!("unexpected request variant: {:?}", other), + } + } + + #[test] + fn test_deserialize_alb() { + let data = include_bytes!( + "../../lambda-events/src/fixtures/example-alb-lambda-target-request-multivalue-headers.json" + ); + + let req: LambdaRequest = serde_json::from_slice(data).expect("failed to deserialze alb rest data"); + match req { + LambdaRequest::Alb(req) => { + assert_eq!( + "arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/lambda-target/abcdefgh", + req.request_context.elb.target_group_arn.unwrap() + ); + } + other => panic!("unexpected request variant: {:?}", other), + } + } + + #[test] + fn test_deserialize_apigw_websocket() { + let data = + include_bytes!("../../lambda-events/src/fixtures/example-apigw-websocket-request-without-method.json"); + + let req: LambdaRequest = serde_json::from_slice(data).expect("failed to deserialze apigw websocket data"); + match req { + LambdaRequest::WebSocket(req) => { + assert_eq!("CONNECT", req.request_context.event_type.unwrap()); + } + other => panic!("unexpected request variant: {:?}", other), + } + } + + #[test] + fn test_deserialize_error() { + let err = serde_json::from_str::("{\"command\": \"hi\"}").unwrap_err(); + + assert_eq!(ERROR_CONTEXT, err.to_string()); + } +} diff --git a/lambda-http/src/lib.rs b/lambda-http/src/lib.rs index 37c167a0..bc9e753d 100644 --- a/lambda-http/src/lib.rs +++ b/lambda-http/src/lib.rs @@ -70,6 +70,7 @@ pub use lambda_runtime::{self, service_fn, tower, Context, Error, Service}; use request::RequestFuture; use response::ResponseFuture; +mod deserializer; pub mod ext; pub mod request; mod response; diff --git a/lambda-http/src/request.rs b/lambda-http/src/request.rs index 5ed3effe..ea418595 100644 --- a/lambda-http/src/request.rs +++ b/lambda-http/src/request.rs @@ -20,8 +20,10 @@ use aws_lambda_events::apigw::{ApiGatewayWebsocketProxyRequest, ApiGatewayWebsoc use aws_lambda_events::{encodings::Body, query_map::QueryMap}; use http::header::HeaderName; use http::{HeaderMap, HeaderValue}; + use serde::{Deserialize, Serialize}; use serde_json::error::Error as JsonError; + use std::future::Future; use std::pin::Pin; use std::{env, io::Read, mem}; @@ -33,8 +35,7 @@ use url::Url; /// This is not intended to be a type consumed by crate users directly. The order /// of the variants are notable. Serde will try to deserialize in this order. #[doc(hidden)] -#[derive(Deserialize, Debug)] -#[serde(untagged)] +#[derive(Debug)] pub enum LambdaRequest { #[cfg(feature = "apigw_rest")] ApiGatewayV1(ApiGatewayProxyRequest),