Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 385a85f

Browse files
authoredFeb 27, 2025··
Merge branch 'main' into cancel-next-invocation
2 parents 36171e1 + 5de00c9 commit 385a85f

File tree

2 files changed

+176
-82
lines changed

2 files changed

+176
-82
lines changed
 

‎Package.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ let package = Package(
1919
dependencies: [
2020
.package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"),
2121
.package(url: "https://github.com/apple/swift-log.git", from: "1.5.4"),
22+
.package(url: "https://github.com/apple/swift-collections.git", from: "1.1.4"),
2223
],
2324
targets: [
2425
.target(
@@ -31,10 +32,10 @@ let package = Package(
3132
.target(
3233
name: "AWSLambdaRuntimeCore",
3334
dependencies: [
35+
.product(name: "DequeModule", package: "swift-collections"),
3436
.product(name: "Logging", package: "swift-log"),
3537
.product(name: "NIOHTTP1", package: "swift-nio"),
3638
.product(name: "NIOCore", package: "swift-nio"),
37-
.product(name: "NIOConcurrencyHelpers", package: "swift-nio"),
3839
.product(name: "NIOPosix", package: "swift-nio"),
3940
]
4041
),

‎Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift

Lines changed: 174 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#if DEBUG
16+
import DequeModule
1617
import Dispatch
1718
import Logging
18-
import NIOConcurrencyHelpers
1919
import NIOCore
2020
import NIOHTTP1
2121
import NIOPosix
@@ -47,24 +47,15 @@ extension Lambda {
4747
/// - note: This API is designed strictly for local testing and is behind a DEBUG flag
4848
static func withLocalServer(
4949
invocationEndpoint: String? = nil,
50-
_ body: @escaping () async throws -> Void
50+
_ body: sending @escaping () async throws -> Void
5151
) async throws {
52+
var logger = Logger(label: "LocalServer")
53+
logger.logLevel = Lambda.env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info
5254

53-
// launch the local server and wait for it to be started before running the body
54-
try await withThrowingTaskGroup(of: Void.self) { group in
55-
// this call will return when the server calls continuation.resume()
56-
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
57-
group.addTask {
58-
do {
59-
try await LambdaHttpServer(invocationEndpoint: invocationEndpoint).start(
60-
continuation: continuation
61-
)
62-
} catch {
63-
continuation.resume(throwing: error)
64-
}
65-
}
66-
}
67-
// now that server is started, run the Lambda function itself
55+
try await LambdaHTTPServer.withLocalServer(
56+
invocationEndpoint: invocationEndpoint,
57+
logger: logger
58+
) {
6859
try await body()
6960
}
7061
}
@@ -84,34 +75,46 @@ extension Lambda {
8475
/// 1. POST /invoke - the client posts the event to the lambda function
8576
///
8677
/// This server passes the data received from /invoke POST request to the lambda function (GET /next) and then forwards the response back to the client.
87-
private struct LambdaHttpServer {
88-
private let logger: Logger
89-
private let group: EventLoopGroup
90-
private let host: String
91-
private let port: Int
78+
private struct LambdaHTTPServer {
9279
private let invocationEndpoint: String
9380

9481
private let invocationPool = Pool<LocalServerInvocation>()
9582
private let responsePool = Pool<LocalServerResponse>()
9683

97-
init(invocationEndpoint: String?) {
98-
var logger = Logger(label: "LocalServer")
99-
logger.logLevel = Lambda.env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info
100-
self.logger = logger
101-
self.group = MultiThreadedEventLoopGroup.singleton
102-
self.host = "127.0.0.1"
103-
self.port = 7000
84+
private init(
85+
invocationEndpoint: String?
86+
) {
10487
self.invocationEndpoint = invocationEndpoint ?? "/invoke"
10588
}
10689

107-
func start(continuation: CheckedContinuation<Void, any Error>) async throws {
108-
let channel = try await ServerBootstrap(group: self.group)
90+
private enum TaskResult<Result: Sendable>: Sendable {
91+
case closureResult(Swift.Result<Result, any Error>)
92+
case serverReturned(Swift.Result<Void, any Error>)
93+
}
94+
95+
struct UnsafeTransferBox<Value>: @unchecked Sendable {
96+
let value: Value
97+
98+
init(value: sending Value) {
99+
self.value = value
100+
}
101+
}
102+
103+
static func withLocalServer<Result: Sendable>(
104+
invocationEndpoint: String?,
105+
host: String = "127.0.0.1",
106+
port: Int = 7000,
107+
eventLoopGroup: MultiThreadedEventLoopGroup = .singleton,
108+
logger: Logger,
109+
_ closure: sending @escaping () async throws -> Result
110+
) async throws -> Result {
111+
let channel = try await ServerBootstrap(group: eventLoopGroup)
109112
.serverChannelOption(.backlog, value: 256)
110113
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
111114
.childChannelOption(.maxMessagesPerRead, value: 1)
112115
.bind(
113-
host: self.host,
114-
port: self.port
116+
host: host,
117+
port: port
115118
) { channel in
116119
channel.eventLoop.makeCompletedFuture {
117120

@@ -129,8 +132,6 @@ private struct LambdaHttpServer {
129132
}
130133
}
131134

132-
// notify the caller that the server is started
133-
continuation.resume()
134135
logger.info(
135136
"Server started and listening",
136137
metadata: [
@@ -139,30 +140,87 @@ private struct LambdaHttpServer {
139140
]
140141
)
141142

142-
// We are handling each incoming connection in a separate child task. It is important
143-
// to use a discarding task group here which automatically discards finished child tasks.
144-
// A normal task group retains all child tasks and their outputs in memory until they are
145-
// consumed by iterating the group or by exiting the group. Since, we are never consuming
146-
// the results of the group we need the group to automatically discard them; otherwise, this
147-
// would result in a memory leak over time.
148-
try await withThrowingDiscardingTaskGroup { group in
149-
try await channel.executeThenClose { inbound in
150-
for try await connectionChannel in inbound {
151-
152-
group.addTask {
153-
logger.trace("Handling a new connection")
154-
await self.handleConnection(channel: connectionChannel)
155-
logger.trace("Done handling the connection")
143+
let server = LambdaHTTPServer(invocationEndpoint: invocationEndpoint)
144+
145+
// Sadly the Swift compiler does not understand that the passed in closure will only be
146+
// invoked once. Because of this we need an unsafe transfer box here. Buuuh!
147+
let closureBox = UnsafeTransferBox(value: closure)
148+
let result = await withTaskGroup(of: TaskResult<Result>.self, returning: Swift.Result<Result, any Error>.self) {
149+
group in
150+
group.addTask {
151+
let c = closureBox.value
152+
do {
153+
let result = try await c()
154+
return .closureResult(.success(result))
155+
} catch {
156+
return .closureResult(.failure(error))
157+
}
158+
}
159+
160+
group.addTask {
161+
do {
162+
// We are handling each incoming connection in a separate child task. It is important
163+
// to use a discarding task group here which automatically discards finished child tasks.
164+
// A normal task group retains all child tasks and their outputs in memory until they are
165+
// consumed by iterating the group or by exiting the group. Since, we are never consuming
166+
// the results of the group we need the group to automatically discard them; otherwise, this
167+
// would result in a memory leak over time.
168+
try await withThrowingDiscardingTaskGroup { taskGroup in
169+
try await channel.executeThenClose { inbound in
170+
for try await connectionChannel in inbound {
171+
172+
taskGroup.addTask {
173+
logger.trace("Handling a new connection")
174+
await server.handleConnection(channel: connectionChannel, logger: logger)
175+
logger.trace("Done handling the connection")
176+
}
177+
}
178+
}
156179
}
180+
return .serverReturned(.success(()))
181+
} catch {
182+
return .serverReturned(.failure(error))
183+
}
184+
}
185+
186+
// Now that the local HTTP server and LambdaHandler tasks are started, wait for the
187+
// first of the two that will terminate.
188+
// When the first task terminates, cancel the group and collect the result of the
189+
// second task.
190+
191+
// collect and return the result of the LambdaHandler
192+
let serverOrHandlerResult1 = await group.next()!
193+
group.cancelAll()
194+
195+
switch serverOrHandlerResult1 {
196+
case .closureResult(let result):
197+
return result
198+
199+
case .serverReturned(let result):
200+
logger.error(
201+
"Server shutdown before closure completed",
202+
metadata: [
203+
"error": "\(result.maybeError != nil ? "\(result.maybeError!)" : "none")"
204+
]
205+
)
206+
switch await group.next()! {
207+
case .closureResult(let result):
208+
return result
209+
210+
case .serverReturned:
211+
fatalError("Only one task is a server, and only one can return `serverReturned`")
157212
}
158213
}
159214
}
215+
160216
logger.info("Server shutting down")
217+
return try result.get()
161218
}
162219

163220
/// This method handles individual TCP connections
164221
private func handleConnection(
165-
channel: NIOAsyncChannel<HTTPServerRequestPart, HTTPServerResponsePart>
222+
channel: NIOAsyncChannel<HTTPServerRequestPart, HTTPServerResponsePart>,
223+
logger: Logger
166224
) async {
167225

168226
var requestHead: HTTPRequestHead!
@@ -186,12 +244,14 @@ private struct LambdaHttpServer {
186244
// process the request
187245
let response = try await self.processRequest(
188246
head: requestHead,
189-
body: requestBody
247+
body: requestBody,
248+
logger: logger
190249
)
191250
// send the responses
192251
try await self.sendResponse(
193252
response: response,
194-
outbound: outbound
253+
outbound: outbound,
254+
logger: logger
195255
)
196256

197257
requestHead = nil
@@ -214,15 +274,19 @@ private struct LambdaHttpServer {
214274
/// - body: the HTTP request body
215275
/// - Throws:
216276
/// - Returns: the response to send back to the client or the Lambda function
217-
private func processRequest(head: HTTPRequestHead, body: ByteBuffer?) async throws -> LocalServerResponse {
277+
private func processRequest(
278+
head: HTTPRequestHead,
279+
body: ByteBuffer?,
280+
logger: Logger
281+
) async throws -> LocalServerResponse {
218282

219283
if let body {
220-
self.logger.trace(
284+
logger.trace(
221285
"Processing request",
222286
metadata: ["URI": "\(head.method) \(head.uri)", "Body": "\(String(buffer: body))"]
223287
)
224288
} else {
225-
self.logger.trace("Processing request", metadata: ["URI": "\(head.method) \(head.uri)"])
289+
logger.trace("Processing request", metadata: ["URI": "\(head.method) \(head.uri)"])
226290
}
227291

228292
switch (head.method, head.uri) {
@@ -237,7 +301,9 @@ private struct LambdaHttpServer {
237301
}
238302
// we always accept the /invoke request and push them to the pool
239303
let requestId = "\(DispatchTime.now().uptimeNanoseconds)"
240-
logger.trace("/invoke received invocation", metadata: ["requestId": "\(requestId)"])
304+
var logger = logger
305+
logger[metadataKey: "requestID"] = "\(requestId)"
306+
logger.trace("/invoke received invocation")
241307
await self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body))
242308

243309
// wait for the lambda function to process the request
@@ -273,9 +339,9 @@ private struct LambdaHttpServer {
273339
case (.GET, let url) where url.hasSuffix(Consts.getNextInvocationURLSuffix):
274340

275341
// pop the tasks from the queue
276-
self.logger.trace("/next waiting for /invoke")
342+
logger.trace("/next waiting for /invoke")
277343
for try await invocation in self.invocationPool {
278-
self.logger.trace("/next retrieved invocation", metadata: ["requestId": "\(invocation.requestId)"])
344+
logger.trace("/next retrieved invocation", metadata: ["requestId": "\(invocation.requestId)"])
279345
// this call also stores the invocation requestId into the response
280346
return invocation.makeResponse(status: .accepted)
281347
}
@@ -322,12 +388,13 @@ private struct LambdaHttpServer {
322388

323389
private func sendResponse(
324390
response: LocalServerResponse,
325-
outbound: NIOAsyncChannelOutboundWriter<HTTPServerResponsePart>
391+
outbound: NIOAsyncChannelOutboundWriter<HTTPServerResponsePart>,
392+
logger: Logger
326393
) async throws {
327394
var headers = HTTPHeaders(response.headers ?? [])
328395
headers.add(name: "Content-Length", value: "\(response.body?.readableBytes ?? 0)")
329396

330-
self.logger.trace("Writing response", metadata: ["requestId": "\(response.requestId ?? "")"])
397+
logger.trace("Writing response", metadata: ["requestId": "\(response.requestId ?? "")"])
331398
try await outbound.write(
332399
HTTPServerResponsePart.head(
333400
HTTPResponseHead(
@@ -350,44 +417,59 @@ private struct LambdaHttpServer {
350417
private final class Pool<T>: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable {
351418
typealias Element = T
352419

353-
private let _buffer = Mutex<CircularBuffer<T>>(.init())
354-
private let _continuation = Mutex<CheckedContinuation<T, any Error>?>(nil)
355-
356-
/// retrieve the first element from the buffer
357-
public func popFirst() async -> T? {
358-
self._buffer.withLock { $0.popFirst() }
420+
enum State: ~Copyable {
421+
case buffer(Deque<T>)
422+
case continuation(CheckedContinuation<T, any Error>?)
359423
}
360424

425+
private let lock = Mutex<State>(.buffer([]))
426+
361427
/// enqueue an element, or give it back immediately to the iterator if it is waiting for an element
362428
public func push(_ invocation: T) async {
363429
// if the iterator is waiting for an element, give it to it
364430
// otherwise, enqueue the element
365-
if let continuation = self._continuation.withLock({ $0 }) {
366-
self._continuation.withLock { $0 = nil }
367-
continuation.resume(returning: invocation)
368-
} else {
369-
self._buffer.withLock { $0.append(invocation) }
431+
let maybeContinuation = self.lock.withLock { state -> CheckedContinuation<T, any Error>? in
432+
switch consume state {
433+
case .continuation(let continuation):
434+
state = .buffer([])
435+
return continuation
436+
437+
case .buffer(var buffer):
438+
buffer.append(invocation)
439+
state = .buffer(buffer)
440+
return nil
441+
}
370442
}
443+
444+
maybeContinuation?.resume(returning: invocation)
371445
}
372446

373447
func next() async throws -> T? {
374-
375448
// exit the async for loop if the task is cancelled
376449
guard !Task.isCancelled else {
377450
return nil
378451
}
379452

380-
if let element = await self.popFirst() {
381-
return element
382-
} else {
383-
// we can't return nil if there is nothing to dequeue otherwise the async for loop will stop
384-
// wait for an element to be enqueued
385-
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in
386-
// store the continuation for later, when an element is enqueued
387-
self._continuation.withLock {
388-
$0 = continuation
453+
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in
454+
let nextAction = self.lock.withLock { state -> T? in
455+
switch consume state {
456+
case .buffer(var buffer):
457+
if let first = buffer.popFirst() {
458+
state = .buffer(buffer)
459+
return first
460+
} else {
461+
state = .continuation(continuation)
462+
return nil
463+
}
464+
465+
case .continuation:
466+
fatalError("Concurrent invocations to next(). This is illegal.")
389467
}
390468
}
469+
470+
guard let nextAction else { return }
471+
472+
continuation.resume(returning: nextAction)
391473
}
392474
}
393475

@@ -432,3 +514,14 @@ private struct LambdaHttpServer {
432514
}
433515
}
434516
#endif
517+
518+
extension Result {
519+
var maybeError: Failure? {
520+
switch self {
521+
case .success:
522+
return nil
523+
case .failure(let error):
524+
return error
525+
}
526+
}
527+
}

0 commit comments

Comments
 (0)
Please sign in to comment.