From 4e5d3db2216d90eca5b7ace9d4363d869c99b87c Mon Sep 17 00:00:00 2001 From: Harold Sun Date: Sun, 14 Jan 2024 17:03:03 +0800 Subject: [PATCH 1/3] Add pass-through support for non-http triggers --- lambda-http/Cargo.toml | 3 ++- lambda-http/src/deserializer.rs | 41 ++++++++++++++++++--------------- lambda-http/src/request.rs | 34 +++++++++++++++++++++++++++ lambda-http/src/response.rs | 13 +++++++++++ 4 files changed, 72 insertions(+), 19 deletions(-) diff --git a/lambda-http/Cargo.toml b/lambda-http/Cargo.toml index c3ec425e..2e1bcdfa 100644 --- a/lambda-http/Cargo.toml +++ b/lambda-http/Cargo.toml @@ -21,6 +21,7 @@ apigw_rest = [] apigw_http = [] apigw_websockets = [] alb = [] +pass_through = [] [dependencies] base64 = { workspace = true } @@ -37,7 +38,7 @@ mime = "0.3" percent-encoding = "2.2" pin-project-lite = { workspace = true } serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde_json = { version = "1.0", features = ["raw_value"] } serde_urlencoded = "0.7" tokio-stream = "0.1.2" url = "2.2" diff --git a/lambda-http/src/deserializer.rs b/lambda-http/src/deserializer.rs index 7756629c..3f5f5d34 100644 --- a/lambda-http/src/deserializer.rs +++ b/lambda-http/src/deserializer.rs @@ -1,43 +1,48 @@ use crate::request::LambdaRequest; +#[cfg(feature = "alb")] +use aws_lambda_events::alb::AlbTargetGroupRequest; +#[cfg(feature = "apigw_rest")] +use aws_lambda_events::apigw::ApiGatewayProxyRequest; +#[cfg(feature = "apigw_http")] +use aws_lambda_events::apigw::ApiGatewayV2httpRequest; +#[cfg(feature = "apigw_websockets")] +use aws_lambda_events::apigw::ApiGatewayWebsocketProxyRequest; use serde::{de::Error, Deserialize}; +use serde_json::value::RawValue; const ERROR_CONTEXT: &str = "this function expects a JSON payload from Amazon API Gateway, Amazon Elastic Load Balancer, or AWS Lambda Function URLs, but the data doesn't match any of those services' events"; +#[cfg(feature = "pass_through")] +const PASS_THROUGH_ENABLED: bool = true; + impl<'de> Deserialize<'de> for LambdaRequest { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { - let content = match serde::__private::de::Content::deserialize(deserializer) { - Ok(content) => content, - Err(err) => return Err(err), - }; + let raw_value: Box = Box::deserialize(deserializer)?; + let data = raw_value.get(); + #[cfg(feature = "apigw_rest")] - if let Ok(res) = aws_lambda_events::apigw::ApiGatewayProxyRequest::deserialize( - serde::__private::de::ContentRefDeserializer::::new(&content), - ) { + if let Ok(res) = serde_json::from_str::(data) { return Ok(LambdaRequest::ApiGatewayV1(res)); } #[cfg(feature = "apigw_http")] - if let Ok(res) = aws_lambda_events::apigw::ApiGatewayV2httpRequest::deserialize( - serde::__private::de::ContentRefDeserializer::::new(&content), - ) { + if let Ok(res) = serde_json::from_str::(data) { return Ok(LambdaRequest::ApiGatewayV2(res)); } #[cfg(feature = "alb")] - if let Ok(res) = - aws_lambda_events::alb::AlbTargetGroupRequest::deserialize(serde::__private::de::ContentRefDeserializer::< - D::Error, - >::new(&content)) - { + if let Ok(res) = serde_json::from_str::(data) { return Ok(LambdaRequest::Alb(res)); } #[cfg(feature = "apigw_websockets")] - if let Ok(res) = aws_lambda_events::apigw::ApiGatewayWebsocketProxyRequest::deserialize( - serde::__private::de::ContentRefDeserializer::::new(&content), - ) { + if let Ok(res) = serde_json::from_str::(data) { return Ok(LambdaRequest::WebSocket(res)); } + #[cfg(feature = "pass_through")] + if PASS_THROUGH_ENABLED == true { + return Ok(LambdaRequest::PassThrough(data.to_string())); + } Err(Error::custom(ERROR_CONTEXT)) } diff --git a/lambda-http/src/request.rs b/lambda-http/src/request.rs index ad86e5a5..f61061da 100644 --- a/lambda-http/src/request.rs +++ b/lambda-http/src/request.rs @@ -51,6 +51,8 @@ pub enum LambdaRequest { Alb(AlbTargetGroupRequest), #[cfg(feature = "apigw_websockets")] WebSocket(ApiGatewayWebsocketProxyRequest), + #[cfg(feature = "pass_through")] + PassThrough(String), } impl LambdaRequest { @@ -67,6 +69,8 @@ impl LambdaRequest { LambdaRequest::Alb { .. } => RequestOrigin::Alb, #[cfg(feature = "apigw_websockets")] LambdaRequest::WebSocket { .. } => RequestOrigin::WebSocket, + #[cfg(feature = "pass_through")] + LambdaRequest::PassThrough { .. } => RequestOrigin::PassThrough, #[cfg(not(any( feature = "apigw_rest", feature = "apigw_http", @@ -97,6 +101,9 @@ pub enum RequestOrigin { /// API Gateway WebSocket #[cfg(feature = "apigw_websockets")] WebSocket, + /// PassThrough request origin + #[cfg(feature = "pass_through")] + PassThrough, } #[cfg(feature = "apigw_http")] @@ -338,6 +345,28 @@ fn into_websocket_request(ag: ApiGatewayWebsocketProxyRequest) -> http::Request< req } +#[cfg(feature = "pass_through")] +fn into_pass_through_request(data: String) -> http::Request { + let mut builder = http::Request::builder(); + + let mut headers = builder.headers_mut().unwrap(); + headers.insert("Content-Type", "application/json".parse().unwrap()); + + update_xray_trace_id_header(&mut headers); + + let raw_path = "/events"; + + let req = builder + .method(http::Method::POST) + .uri(raw_path) + .extension(RawHttpPath(raw_path.to_string())) + .extension(RequestContext::PassThrough) + .body(Body::from(data)) + .expect("failed to build request"); + + req +} + #[cfg(any(feature = "apigw_rest", feature = "apigw_http", feature = "apigw_websockets"))] fn apigw_path_with_stage(stage: &Option, path: &str) -> String { if env::var("AWS_LAMBDA_HTTP_IGNORE_STAGE_IN_PATH").is_ok() { @@ -375,6 +404,9 @@ pub enum RequestContext { /// WebSocket request context #[cfg(feature = "apigw_websockets")] WebSocket(ApiGatewayWebsocketProxyRequestContext), + /// Custom request context + #[cfg(feature = "pass_through")] + PassThrough, } /// Converts LambdaRequest types into `http::Request` types @@ -389,6 +421,8 @@ impl From for http::Request { LambdaRequest::Alb(alb) => into_alb_request(alb), #[cfg(feature = "apigw_websockets")] LambdaRequest::WebSocket(ag) => into_websocket_request(ag), + #[cfg(feature = "pass_through")] + LambdaRequest::PassThrough(data) => into_pass_through_request(data), } } } diff --git a/lambda-http/src/response.rs b/lambda-http/src/response.rs index d26ef838..56c1445d 100644 --- a/lambda-http/src/response.rs +++ b/lambda-http/src/response.rs @@ -46,6 +46,8 @@ pub enum LambdaResponse { ApiGatewayV2(ApiGatewayV2httpResponse), #[cfg(feature = "alb")] Alb(AlbTargetGroupResponse), + #[cfg(feature = "pass_through")] + PassThrough(serde_json::Value), } /// Transformation from http type to internal type @@ -114,6 +116,17 @@ impl LambdaResponse { headers: headers.clone(), multi_value_headers: headers, }), + #[cfg(feature = "pass_through")] + RequestOrigin::PassThrough => { + let resp = match body { + // text body must be a valid json string + Some(Body::Text(body)) => {LambdaResponse::PassThrough(serde_json::from_str(&body).unwrap_or_default())}, + // binary body and other cases return Value::Null + _ => LambdaResponse::PassThrough(serde_json::Value::Null), + }; + + resp + } #[cfg(not(any( feature = "apigw_rest", feature = "apigw_http", From ccda3bea406ec14c848372e59b1bbe92ec6118bf Mon Sep 17 00:00:00 2001 From: Harold Sun Date: Sun, 14 Jan 2024 17:23:43 +0800 Subject: [PATCH 2/3] Fix linting warming --- lambda-http/src/deserializer.rs | 2 +- lambda-http/src/request.rs | 10 ++++------ lambda-http/src/response.rs | 6 ++---- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/lambda-http/src/deserializer.rs b/lambda-http/src/deserializer.rs index 3f5f5d34..6be79dd3 100644 --- a/lambda-http/src/deserializer.rs +++ b/lambda-http/src/deserializer.rs @@ -40,7 +40,7 @@ impl<'de> Deserialize<'de> for LambdaRequest { return Ok(LambdaRequest::WebSocket(res)); } #[cfg(feature = "pass_through")] - if PASS_THROUGH_ENABLED == true { + if PASS_THROUGH_ENABLED { return Ok(LambdaRequest::PassThrough(data.to_string())); } diff --git a/lambda-http/src/request.rs b/lambda-http/src/request.rs index f61061da..c98f5c17 100644 --- a/lambda-http/src/request.rs +++ b/lambda-http/src/request.rs @@ -349,22 +349,20 @@ fn into_websocket_request(ag: ApiGatewayWebsocketProxyRequest) -> http::Request< fn into_pass_through_request(data: String) -> http::Request { let mut builder = http::Request::builder(); - let mut headers = builder.headers_mut().unwrap(); + let headers = builder.headers_mut().unwrap(); headers.insert("Content-Type", "application/json".parse().unwrap()); - update_xray_trace_id_header(&mut headers); + update_xray_trace_id_header(headers); let raw_path = "/events"; - let req = builder + builder .method(http::Method::POST) .uri(raw_path) .extension(RawHttpPath(raw_path.to_string())) .extension(RequestContext::PassThrough) .body(Body::from(data)) - .expect("failed to build request"); - - req + .expect("failed to build request") } #[cfg(any(feature = "apigw_rest", feature = "apigw_http", feature = "apigw_websockets"))] diff --git a/lambda-http/src/response.rs b/lambda-http/src/response.rs index 56c1445d..cc721d46 100644 --- a/lambda-http/src/response.rs +++ b/lambda-http/src/response.rs @@ -118,14 +118,12 @@ impl LambdaResponse { }), #[cfg(feature = "pass_through")] RequestOrigin::PassThrough => { - let resp = match body { + match body { // text body must be a valid json string Some(Body::Text(body)) => {LambdaResponse::PassThrough(serde_json::from_str(&body).unwrap_or_default())}, // binary body and other cases return Value::Null _ => LambdaResponse::PassThrough(serde_json::Value::Null), - }; - - resp + } } #[cfg(not(any( feature = "apigw_rest", From b1b703a20bae3ec6fc0b51345635dc6eedc79a93 Mon Sep 17 00:00:00 2001 From: Harold Sun Date: Sun, 14 Jan 2024 17:31:27 +0800 Subject: [PATCH 3/3] Remove deserialize_error test This test won't fail when `pass-through` feature is enabled. --- lambda-http/src/deserializer.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/lambda-http/src/deserializer.rs b/lambda-http/src/deserializer.rs index 6be79dd3..2584c0ad 100644 --- a/lambda-http/src/deserializer.rs +++ b/lambda-http/src/deserializer.rs @@ -109,11 +109,4 @@ mod tests { other => panic!("unexpected request variant: {:?}", other), } } - - #[test] - fn test_deserialize_error() { - let err = serde_json::from_str::("{\"body\": {}}").unwrap_err(); - - assert_eq!(ERROR_CONTEXT, err.to_string()); - } }