diff --git a/lambda-http/Cargo.toml b/lambda-http/Cargo.toml index c3ec425e..2e1bcdfa 100644 --- a/lambda-http/Cargo.toml +++ b/lambda-http/Cargo.toml @@ -21,6 +21,7 @@ apigw_rest = [] apigw_http = [] apigw_websockets = [] alb = [] +pass_through = [] [dependencies] base64 = { workspace = true } @@ -37,7 +38,7 @@ mime = "0.3" percent-encoding = "2.2" pin-project-lite = { workspace = true } serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde_json = { version = "1.0", features = ["raw_value"] } serde_urlencoded = "0.7" tokio-stream = "0.1.2" url = "2.2" diff --git a/lambda-http/src/deserializer.rs b/lambda-http/src/deserializer.rs index 7756629c..2584c0ad 100644 --- a/lambda-http/src/deserializer.rs +++ b/lambda-http/src/deserializer.rs @@ -1,43 +1,48 @@ use crate::request::LambdaRequest; +#[cfg(feature = "alb")] +use aws_lambda_events::alb::AlbTargetGroupRequest; +#[cfg(feature = "apigw_rest")] +use aws_lambda_events::apigw::ApiGatewayProxyRequest; +#[cfg(feature = "apigw_http")] +use aws_lambda_events::apigw::ApiGatewayV2httpRequest; +#[cfg(feature = "apigw_websockets")] +use aws_lambda_events::apigw::ApiGatewayWebsocketProxyRequest; use serde::{de::Error, Deserialize}; +use serde_json::value::RawValue; const ERROR_CONTEXT: &str = "this function expects a JSON payload from Amazon API Gateway, Amazon Elastic Load Balancer, or AWS Lambda Function URLs, but the data doesn't match any of those services' events"; +#[cfg(feature = "pass_through")] +const PASS_THROUGH_ENABLED: bool = true; + impl<'de> Deserialize<'de> for LambdaRequest { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { - let content = match serde::__private::de::Content::deserialize(deserializer) { - Ok(content) => content, - Err(err) => return Err(err), - }; + let raw_value: Box = Box::deserialize(deserializer)?; + let data = raw_value.get(); + #[cfg(feature = "apigw_rest")] - if let Ok(res) = aws_lambda_events::apigw::ApiGatewayProxyRequest::deserialize( - serde::__private::de::ContentRefDeserializer::::new(&content), - ) { + if let Ok(res) = serde_json::from_str::(data) { return Ok(LambdaRequest::ApiGatewayV1(res)); } #[cfg(feature = "apigw_http")] - if let Ok(res) = aws_lambda_events::apigw::ApiGatewayV2httpRequest::deserialize( - serde::__private::de::ContentRefDeserializer::::new(&content), - ) { + if let Ok(res) = serde_json::from_str::(data) { return Ok(LambdaRequest::ApiGatewayV2(res)); } #[cfg(feature = "alb")] - if let Ok(res) = - aws_lambda_events::alb::AlbTargetGroupRequest::deserialize(serde::__private::de::ContentRefDeserializer::< - D::Error, - >::new(&content)) - { + if let Ok(res) = serde_json::from_str::(data) { return Ok(LambdaRequest::Alb(res)); } #[cfg(feature = "apigw_websockets")] - if let Ok(res) = aws_lambda_events::apigw::ApiGatewayWebsocketProxyRequest::deserialize( - serde::__private::de::ContentRefDeserializer::::new(&content), - ) { + if let Ok(res) = serde_json::from_str::(data) { return Ok(LambdaRequest::WebSocket(res)); } + #[cfg(feature = "pass_through")] + if PASS_THROUGH_ENABLED { + return Ok(LambdaRequest::PassThrough(data.to_string())); + } Err(Error::custom(ERROR_CONTEXT)) } @@ -104,11 +109,4 @@ mod tests { other => panic!("unexpected request variant: {:?}", other), } } - - #[test] - fn test_deserialize_error() { - let err = serde_json::from_str::("{\"body\": {}}").unwrap_err(); - - assert_eq!(ERROR_CONTEXT, err.to_string()); - } } diff --git a/lambda-http/src/request.rs b/lambda-http/src/request.rs index ad86e5a5..c98f5c17 100644 --- a/lambda-http/src/request.rs +++ b/lambda-http/src/request.rs @@ -51,6 +51,8 @@ pub enum LambdaRequest { Alb(AlbTargetGroupRequest), #[cfg(feature = "apigw_websockets")] WebSocket(ApiGatewayWebsocketProxyRequest), + #[cfg(feature = "pass_through")] + PassThrough(String), } impl LambdaRequest { @@ -67,6 +69,8 @@ impl LambdaRequest { LambdaRequest::Alb { .. } => RequestOrigin::Alb, #[cfg(feature = "apigw_websockets")] LambdaRequest::WebSocket { .. } => RequestOrigin::WebSocket, + #[cfg(feature = "pass_through")] + LambdaRequest::PassThrough { .. } => RequestOrigin::PassThrough, #[cfg(not(any( feature = "apigw_rest", feature = "apigw_http", @@ -97,6 +101,9 @@ pub enum RequestOrigin { /// API Gateway WebSocket #[cfg(feature = "apigw_websockets")] WebSocket, + /// PassThrough request origin + #[cfg(feature = "pass_through")] + PassThrough, } #[cfg(feature = "apigw_http")] @@ -338,6 +345,26 @@ fn into_websocket_request(ag: ApiGatewayWebsocketProxyRequest) -> http::Request< req } +#[cfg(feature = "pass_through")] +fn into_pass_through_request(data: String) -> http::Request { + let mut builder = http::Request::builder(); + + let headers = builder.headers_mut().unwrap(); + headers.insert("Content-Type", "application/json".parse().unwrap()); + + update_xray_trace_id_header(headers); + + let raw_path = "/events"; + + builder + .method(http::Method::POST) + .uri(raw_path) + .extension(RawHttpPath(raw_path.to_string())) + .extension(RequestContext::PassThrough) + .body(Body::from(data)) + .expect("failed to build request") +} + #[cfg(any(feature = "apigw_rest", feature = "apigw_http", feature = "apigw_websockets"))] fn apigw_path_with_stage(stage: &Option, path: &str) -> String { if env::var("AWS_LAMBDA_HTTP_IGNORE_STAGE_IN_PATH").is_ok() { @@ -375,6 +402,9 @@ pub enum RequestContext { /// WebSocket request context #[cfg(feature = "apigw_websockets")] WebSocket(ApiGatewayWebsocketProxyRequestContext), + /// Custom request context + #[cfg(feature = "pass_through")] + PassThrough, } /// Converts LambdaRequest types into `http::Request` types @@ -389,6 +419,8 @@ impl From for http::Request { LambdaRequest::Alb(alb) => into_alb_request(alb), #[cfg(feature = "apigw_websockets")] LambdaRequest::WebSocket(ag) => into_websocket_request(ag), + #[cfg(feature = "pass_through")] + LambdaRequest::PassThrough(data) => into_pass_through_request(data), } } } diff --git a/lambda-http/src/response.rs b/lambda-http/src/response.rs index d26ef838..cc721d46 100644 --- a/lambda-http/src/response.rs +++ b/lambda-http/src/response.rs @@ -46,6 +46,8 @@ pub enum LambdaResponse { ApiGatewayV2(ApiGatewayV2httpResponse), #[cfg(feature = "alb")] Alb(AlbTargetGroupResponse), + #[cfg(feature = "pass_through")] + PassThrough(serde_json::Value), } /// Transformation from http type to internal type @@ -114,6 +116,15 @@ impl LambdaResponse { headers: headers.clone(), multi_value_headers: headers, }), + #[cfg(feature = "pass_through")] + RequestOrigin::PassThrough => { + match body { + // text body must be a valid json string + Some(Body::Text(body)) => {LambdaResponse::PassThrough(serde_json::from_str(&body).unwrap_or_default())}, + // binary body and other cases return Value::Null + _ => LambdaResponse::PassThrough(serde_json::Value::Null), + } + } #[cfg(not(any( feature = "apigw_rest", feature = "apigw_http",