Skip to content

Commit cf72bb0

Browse files
authored
Refactor Lambda response streaming. (awslabs#696)
* Refactor Lambda response streaming. Remove the separate streaming.rs from lambda-runtime crate. Merge into the `run` method. Added FunctionResponse enum to capture both buffered response and streaming response. Added IntoFunctionResponse trait to convert `Serialize` response into FunctionResponse::BufferedResponse, and convert `Stream` response into FunctionResponse::StreamingResponse. Existing handler functions should continue to work. Improved error handling in response streaming. Return trailers to report errors instead of panic. * Add comments for reporting midstream errors using error trailers * Remove "pub" from internal run method
1 parent e2d51ad commit cf72bb0

File tree

9 files changed

+266
-308
lines changed

9 files changed

+266
-308
lines changed

examples/basic-streaming-response/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
2. Build the function with `cargo lambda build --release`
77
3. Deploy the function to AWS Lambda with `cargo lambda deploy --enable-function-url --iam-role YOUR_ROLE`
88
4. Enable Lambda streaming response on Lambda console: change the function url's invoke mode to `RESPONSE_STREAM`
9-
5. Verify the function works: `curl <function-url>`. The results should be streamed back with 0.5 second pause between each word.
9+
5. Verify the function works: `curl -v -N <function-url>`. The results should be streamed back with 0.5 second pause between each word.
1010

1111
## Build for ARM 64
1212

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
use hyper::{body::Body, Response};
2-
use lambda_runtime::{service_fn, Error, LambdaEvent};
1+
use hyper::body::Body;
2+
use lambda_runtime::{service_fn, Error, LambdaEvent, StreamResponse};
33
use serde_json::Value;
44
use std::{thread, time::Duration};
55

6-
async fn func(_event: LambdaEvent<Value>) -> Result<Response<Body>, Error> {
6+
async fn func(_event: LambdaEvent<Value>) -> Result<StreamResponse<Body>, Error> {
77
let messages = vec!["Hello", "world", "from", "Lambda!"];
88

99
let (mut tx, rx) = Body::channel();
@@ -15,12 +15,10 @@ async fn func(_event: LambdaEvent<Value>) -> Result<Response<Body>, Error> {
1515
}
1616
});
1717

18-
let resp = Response::builder()
19-
.header("content-type", "text/html")
20-
.header("CustomHeader", "outerspace")
21-
.body(rx)?;
22-
23-
Ok(resp)
18+
Ok(StreamResponse {
19+
metadata_prelude: Default::default(),
20+
stream: rx,
21+
})
2422
}
2523

2624
#[tokio::main]
@@ -34,6 +32,6 @@ async fn main() -> Result<(), Error> {
3432
.without_time()
3533
.init();
3634

37-
lambda_runtime::run_with_streaming_response(service_fn(func)).await?;
35+
lambda_runtime::run(service_fn(func)).await?;
3836
Ok(())
3937
}

lambda-http/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ lambda_runtime = { path = "../lambda-runtime", version = "0.8" }
3333
serde = { version = "1.0", features = ["derive"] }
3434
serde_json = "1.0"
3535
serde_urlencoded = "0.7"
36+
tokio-stream = "0.1.2"
3637
mime = "0.3"
3738
encoding_rs = "0.8"
3839
url = "2.2"

lambda-http/src/streaming.rs

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1+
use crate::http::header::SET_COOKIE;
12
use crate::tower::ServiceBuilder;
23
use crate::Request;
34
use crate::{request::LambdaRequest, RequestExt};
45
pub use aws_lambda_events::encodings::Body as LambdaEventBody;
56
use bytes::Bytes;
67
pub use http::{self, Response};
78
use http_body::Body;
8-
use lambda_runtime::LambdaEvent;
9-
pub use lambda_runtime::{self, service_fn, tower, Context, Error, Service};
9+
pub use lambda_runtime::{
10+
self, service_fn, tower, tower::ServiceExt, Error, FunctionResponse, LambdaEvent, MetadataPrelude, Service,
11+
StreamResponse,
12+
};
1013
use std::fmt::{Debug, Display};
14+
use std::pin::Pin;
15+
use std::task::{Context, Poll};
16+
use tokio_stream::Stream;
1117

1218
/// Starts the Lambda Rust runtime and stream response back [Configure Lambda
1319
/// Streaming Response](https://docs.aws.amazon.com/lambda/latest/dg/configuration-response-streaming.html).
@@ -28,7 +34,60 @@ where
2834
let event: Request = req.payload.into();
2935
event.with_lambda_context(req.context)
3036
})
31-
.service(handler);
37+
.service(handler)
38+
.map_response(|res| {
39+
let (parts, body) = res.into_parts();
3240

33-
lambda_runtime::run_with_streaming_response(svc).await
41+
let mut prelude_headers = parts.headers;
42+
43+
let cookies = prelude_headers.get_all(SET_COOKIE);
44+
let cookies = cookies
45+
.iter()
46+
.map(|c| String::from_utf8_lossy(c.as_bytes()).to_string())
47+
.collect::<Vec<String>>();
48+
49+
prelude_headers.remove(SET_COOKIE);
50+
51+
let metadata_prelude = MetadataPrelude {
52+
headers: prelude_headers,
53+
status_code: parts.status,
54+
cookies,
55+
};
56+
57+
StreamResponse {
58+
metadata_prelude,
59+
stream: BodyStream { body },
60+
}
61+
});
62+
63+
lambda_runtime::run(svc).await
64+
}
65+
66+
pub struct BodyStream<B> {
67+
pub(crate) body: B,
68+
}
69+
70+
impl<B> BodyStream<B>
71+
where
72+
B: Body + Unpin + Send + 'static,
73+
B::Data: Into<Bytes> + Send,
74+
B::Error: Into<Error> + Send + Debug,
75+
{
76+
fn project(self: Pin<&mut Self>) -> Pin<&mut B> {
77+
unsafe { self.map_unchecked_mut(|s| &mut s.body) }
78+
}
79+
}
80+
81+
impl<B> Stream for BodyStream<B>
82+
where
83+
B: Body + Unpin + Send + 'static,
84+
B::Data: Into<Bytes> + Send,
85+
B::Error: Into<Error> + Send + Debug,
86+
{
87+
type Item = Result<B::Data, B::Error>;
88+
89+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
90+
let body = self.project();
91+
body.poll_data(cx)
92+
}
3493
}

lambda-runtime/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,5 @@ tokio-stream = "0.1.2"
4343
lambda_runtime_api_client = { version = "0.8", path = "../lambda-runtime-api-client" }
4444
serde_path_to_error = "0.1.11"
4545
http-serde = "1.1.3"
46+
base64 = "0.20.0"
47+
http-body = "0.4"

lambda-runtime/src/lib.rs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//! Create a type that conforms to the [`tower::Service`] trait. This type can
88
//! then be passed to the the `lambda_runtime::run` function, which launches
99
//! and runs the Lambda runtime.
10+
use bytes::Bytes;
1011
use futures::FutureExt;
1112
use hyper::{
1213
client::{connect::Connection, HttpConnector},
@@ -20,6 +21,7 @@ use std::{
2021
env,
2122
fmt::{self, Debug, Display},
2223
future::Future,
24+
marker::PhantomData,
2325
panic,
2426
};
2527
use tokio::io::{AsyncRead, AsyncWrite};
@@ -35,11 +37,8 @@ mod simulated;
3537
/// Types available to a Lambda function.
3638
mod types;
3739

38-
mod streaming;
39-
pub use streaming::run_with_streaming_response;
40-
4140
use requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest};
42-
pub use types::{Context, LambdaEvent};
41+
pub use types::{Context, FunctionResponse, IntoFunctionResponse, LambdaEvent, MetadataPrelude, StreamResponse};
4342

4443
/// Error type that lambdas may result in
4544
pub type Error = lambda_runtime_api_client::Error;
@@ -97,17 +96,21 @@ where
9796
C::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
9897
C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
9998
{
100-
async fn run<F, A, B>(
99+
async fn run<F, A, R, B, S, D, E>(
101100
&self,
102101
incoming: impl Stream<Item = Result<http::Response<hyper::Body>, Error>> + Send,
103102
mut handler: F,
104103
) -> Result<(), Error>
105104
where
106105
F: Service<LambdaEvent<A>>,
107-
F::Future: Future<Output = Result<B, F::Error>>,
106+
F::Future: Future<Output = Result<R, F::Error>>,
108107
F::Error: fmt::Debug + fmt::Display,
109108
A: for<'de> Deserialize<'de>,
109+
R: IntoFunctionResponse<B, S>,
110110
B: Serialize,
111+
S: Stream<Item = Result<D, E>> + Unpin + Send + 'static,
112+
D: Into<Bytes> + Send,
113+
E: Into<Error> + Send + Debug,
111114
{
112115
let client = &self.client;
113116
tokio::pin!(incoming);
@@ -177,6 +180,8 @@ where
177180
EventCompletionRequest {
178181
request_id,
179182
body: response,
183+
_unused_b: PhantomData,
184+
_unused_s: PhantomData,
180185
}
181186
.into_req()
182187
}
@@ -243,13 +248,17 @@ where
243248
/// Ok(event.payload)
244249
/// }
245250
/// ```
246-
pub async fn run<A, B, F>(handler: F) -> Result<(), Error>
251+
pub async fn run<A, F, R, B, S, D, E>(handler: F) -> Result<(), Error>
247252
where
248253
F: Service<LambdaEvent<A>>,
249-
F::Future: Future<Output = Result<B, F::Error>>,
254+
F::Future: Future<Output = Result<R, F::Error>>,
250255
F::Error: fmt::Debug + fmt::Display,
251256
A: for<'de> Deserialize<'de>,
257+
R: IntoFunctionResponse<B, S>,
252258
B: Serialize,
259+
S: Stream<Item = Result<D, E>> + Unpin + Send + 'static,
260+
D: Into<Bytes> + Send,
261+
E: Into<Error> + Send + Debug,
253262
{
254263
trace!("Loading config from env");
255264
let config = Config::from_env()?;
@@ -293,7 +302,7 @@ mod endpoint_tests {
293302
use lambda_runtime_api_client::Client;
294303
use serde_json::json;
295304
use simulated::DuplexStreamWrapper;
296-
use std::{convert::TryFrom, env};
305+
use std::{convert::TryFrom, env, marker::PhantomData};
297306
use tokio::{
298307
io::{self, AsyncRead, AsyncWrite},
299308
select,
@@ -430,6 +439,8 @@ mod endpoint_tests {
430439
let req = EventCompletionRequest {
431440
request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9",
432441
body: "done",
442+
_unused_b: PhantomData::<&str>,
443+
_unused_s: PhantomData::<Body>,
433444
};
434445
let req = req.into_req()?;
435446

lambda-runtime/src/requests.rs

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1-
use crate::{types::Diagnostic, Error};
1+
use crate::types::ToStreamErrorTrailer;
2+
use crate::{types::Diagnostic, Error, FunctionResponse, IntoFunctionResponse};
3+
use bytes::Bytes;
4+
use http::header::CONTENT_TYPE;
25
use http::{Method, Request, Response, Uri};
36
use hyper::Body;
47
use lambda_runtime_api_client::build_request;
58
use serde::Serialize;
9+
use std::fmt::Debug;
10+
use std::marker::PhantomData;
611
use std::str::FromStr;
12+
use tokio_stream::{Stream, StreamExt};
713

814
pub(crate) trait IntoRequest {
915
fn into_req(self) -> Result<Request<Body>, Error>;
@@ -65,23 +71,87 @@ fn test_next_event_request() {
6571
}
6672

6773
// /runtime/invocation/{AwsRequestId}/response
68-
pub(crate) struct EventCompletionRequest<'a, T> {
74+
pub(crate) struct EventCompletionRequest<'a, R, B, S, D, E>
75+
where
76+
R: IntoFunctionResponse<B, S>,
77+
B: Serialize,
78+
S: Stream<Item = Result<D, E>> + Unpin + Send + 'static,
79+
D: Into<Bytes> + Send,
80+
E: Into<Error> + Send + Debug,
81+
{
6982
pub(crate) request_id: &'a str,
70-
pub(crate) body: T,
83+
pub(crate) body: R,
84+
pub(crate) _unused_b: PhantomData<B>,
85+
pub(crate) _unused_s: PhantomData<S>,
7186
}
7287

73-
impl<'a, T> IntoRequest for EventCompletionRequest<'a, T>
88+
impl<'a, R, B, S, D, E> IntoRequest for EventCompletionRequest<'a, R, B, S, D, E>
7489
where
75-
T: for<'serialize> Serialize,
90+
R: IntoFunctionResponse<B, S>,
91+
B: Serialize,
92+
S: Stream<Item = Result<D, E>> + Unpin + Send + 'static,
93+
D: Into<Bytes> + Send,
94+
E: Into<Error> + Send + Debug,
7695
{
7796
fn into_req(self) -> Result<Request<Body>, Error> {
78-
let uri = format!("/2018-06-01/runtime/invocation/{}/response", self.request_id);
79-
let uri = Uri::from_str(&uri)?;
80-
let body = serde_json::to_vec(&self.body)?;
81-
let body = Body::from(body);
97+
match self.body.into_response() {
98+
FunctionResponse::BufferedResponse(body) => {
99+
let uri = format!("/2018-06-01/runtime/invocation/{}/response", self.request_id);
100+
let uri = Uri::from_str(&uri)?;
82101

83-
let req = build_request().method(Method::POST).uri(uri).body(body)?;
84-
Ok(req)
102+
let body = serde_json::to_vec(&body)?;
103+
let body = Body::from(body);
104+
105+
let req = build_request().method(Method::POST).uri(uri).body(body)?;
106+
Ok(req)
107+
}
108+
FunctionResponse::StreamingResponse(mut response) => {
109+
let uri = format!("/2018-06-01/runtime/invocation/{}/response", self.request_id);
110+
let uri = Uri::from_str(&uri)?;
111+
112+
let mut builder = build_request().method(Method::POST).uri(uri);
113+
let req_headers = builder.headers_mut().unwrap();
114+
115+
req_headers.insert("Transfer-Encoding", "chunked".parse()?);
116+
req_headers.insert("Lambda-Runtime-Function-Response-Mode", "streaming".parse()?);
117+
// Report midstream errors using error trailers.
118+
// See the details in Lambda Developer Doc: https://docs.aws.amazon.com/lambda/latest/dg/runtimes-custom.html#runtimes-custom-response-streaming
119+
req_headers.append("Trailer", "Lambda-Runtime-Function-Error-Type".parse()?);
120+
req_headers.append("Trailer", "Lambda-Runtime-Function-Error-Body".parse()?);
121+
req_headers.insert(
122+
"Content-Type",
123+
"application/vnd.awslambda.http-integration-response".parse()?,
124+
);
125+
126+
// default Content-Type
127+
let preloud_headers = &mut response.metadata_prelude.headers;
128+
preloud_headers
129+
.entry(CONTENT_TYPE)
130+
.or_insert("application/octet-stream".parse()?);
131+
132+
let metadata_prelude = serde_json::to_string(&response.metadata_prelude)?;
133+
134+
tracing::trace!(?metadata_prelude);
135+
136+
let (mut tx, rx) = Body::channel();
137+
138+
tokio::spawn(async move {
139+
tx.send_data(metadata_prelude.into()).await.unwrap();
140+
tx.send_data("\u{0}".repeat(8).into()).await.unwrap();
141+
142+
while let Some(chunk) = response.stream.next().await {
143+
let chunk = match chunk {
144+
Ok(chunk) => chunk.into(),
145+
Err(err) => err.into().to_tailer().into(),
146+
};
147+
tx.send_data(chunk).await.unwrap();
148+
}
149+
});
150+
151+
let req = builder.body(rx)?;
152+
Ok(req)
153+
}
154+
}
85155
}
86156
}
87157

@@ -90,6 +160,8 @@ fn test_event_completion_request() {
90160
let req = EventCompletionRequest {
91161
request_id: "id",
92162
body: "hello, world!",
163+
_unused_b: PhantomData::<&str>,
164+
_unused_s: PhantomData::<Body>,
93165
};
94166
let req = req.into_req().unwrap();
95167
let expected = Uri::from_static("/2018-06-01/runtime/invocation/id/response");

0 commit comments

Comments
 (0)