13
13
//===----------------------------------------------------------------------===//
14
14
15
15
#if DEBUG
16
+ import DequeModule
16
17
import Dispatch
17
18
import Logging
18
- import NIOConcurrencyHelpers
19
19
import NIOCore
20
20
import NIOHTTP1
21
21
import NIOPosix
@@ -47,24 +47,15 @@ extension Lambda {
47
47
/// - note: This API is designed strictly for local testing and is behind a DEBUG flag
48
48
static func withLocalServer(
49
49
invocationEndpoint: String ? = nil ,
50
- _ body: @escaping ( ) async throws -> Void
50
+ _ body: sending @escaping ( ) async throws -> Void
51
51
) async throws {
52
+ var logger = Logger ( label: " LocalServer " )
53
+ logger. logLevel = Lambda . env ( " LOG_LEVEL " ) . flatMap ( Logger . Level. init) ?? . info
52
54
53
- // launch the local server and wait for it to be started before running the body
54
- try await withThrowingTaskGroup ( of: Void . self) { group in
55
- // this call will return when the server calls continuation.resume()
56
- try await withCheckedThrowingContinuation { ( continuation: CheckedContinuation < Void , any Error > ) in
57
- group. addTask {
58
- do {
59
- try await LambdaHttpServer ( invocationEndpoint: invocationEndpoint) . start (
60
- continuation: continuation
61
- )
62
- } catch {
63
- continuation. resume ( throwing: error)
64
- }
65
- }
66
- }
67
- // now that server is started, run the Lambda function itself
55
+ try await LambdaHTTPServer . withLocalServer (
56
+ invocationEndpoint: invocationEndpoint,
57
+ logger: logger
58
+ ) {
68
59
try await body ( )
69
60
}
70
61
}
@@ -84,34 +75,46 @@ extension Lambda {
84
75
/// 1. POST /invoke - the client posts the event to the lambda function
85
76
///
86
77
/// This server passes the data received from /invoke POST request to the lambda function (GET /next) and then forwards the response back to the client.
87
- private struct LambdaHttpServer {
88
- private let logger : Logger
89
- private let group : EventLoopGroup
90
- private let host : String
91
- private let port : Int
78
+ private struct LambdaHTTPServer {
92
79
private let invocationEndpoint : String
93
80
94
81
private let invocationPool = Pool < LocalServerInvocation > ( )
95
82
private let responsePool = Pool < LocalServerResponse > ( )
96
83
97
- init ( invocationEndpoint: String ? ) {
98
- var logger = Logger ( label: " LocalServer " )
99
- logger. logLevel = Lambda . env ( " LOG_LEVEL " ) . flatMap ( Logger . Level. init) ?? . info
100
- self . logger = logger
101
- self . group = MultiThreadedEventLoopGroup . singleton
102
- self . host = " 127.0.0.1 "
103
- self . port = 7000
84
+ private init (
85
+ invocationEndpoint: String ?
86
+ ) {
104
87
self . invocationEndpoint = invocationEndpoint ?? " /invoke "
105
88
}
106
89
107
- func start( continuation: CheckedContinuation < Void , any Error > ) async throws {
108
- let channel = try await ServerBootstrap ( group: self . group)
90
+ private enum TaskResult < Result: Sendable > : Sendable {
91
+ case closureResult( Swift . Result < Result , any Error > )
92
+ case serverReturned( Swift . Result < Void , any Error > )
93
+ }
94
+
95
+ struct UnsafeTransferBox < Value> : @unchecked Sendable {
96
+ let value : Value
97
+
98
+ init ( value: sending Value) {
99
+ self . value = value
100
+ }
101
+ }
102
+
103
+ static func withLocalServer< Result: Sendable > (
104
+ invocationEndpoint: String ? ,
105
+ host: String = " 127.0.0.1 " ,
106
+ port: Int = 7000 ,
107
+ eventLoopGroup: MultiThreadedEventLoopGroup = . singleton,
108
+ logger: Logger ,
109
+ _ closure: sending @escaping ( ) async throws -> Result
110
+ ) async throws -> Result {
111
+ let channel = try await ServerBootstrap ( group: eventLoopGroup)
109
112
. serverChannelOption ( . backlog, value: 256 )
110
113
. serverChannelOption ( . socketOption( . so_reuseaddr) , value: 1 )
111
114
. childChannelOption ( . maxMessagesPerRead, value: 1 )
112
115
. bind (
113
- host: self . host,
114
- port: self . port
116
+ host: host,
117
+ port: port
115
118
) { channel in
116
119
channel. eventLoop. makeCompletedFuture {
117
120
@@ -129,8 +132,6 @@ private struct LambdaHttpServer {
129
132
}
130
133
}
131
134
132
- // notify the caller that the server is started
133
- continuation. resume ( )
134
135
logger. info (
135
136
" Server started and listening " ,
136
137
metadata: [
@@ -139,30 +140,87 @@ private struct LambdaHttpServer {
139
140
]
140
141
)
141
142
142
- // We are handling each incoming connection in a separate child task. It is important
143
- // to use a discarding task group here which automatically discards finished child tasks.
144
- // A normal task group retains all child tasks and their outputs in memory until they are
145
- // consumed by iterating the group or by exiting the group. Since, we are never consuming
146
- // the results of the group we need the group to automatically discard them; otherwise, this
147
- // would result in a memory leak over time.
148
- try await withThrowingDiscardingTaskGroup { group in
149
- try await channel. executeThenClose { inbound in
150
- for try await connectionChannel in inbound {
151
-
152
- group. addTask {
153
- logger. trace ( " Handling a new connection " )
154
- await self . handleConnection ( channel: connectionChannel)
155
- logger. trace ( " Done handling the connection " )
143
+ let server = LambdaHTTPServer ( invocationEndpoint: invocationEndpoint)
144
+
145
+ // Sadly the Swift compiler does not understand that the passed in closure will only be
146
+ // invoked once. Because of this we need an unsafe transfer box here. Buuuh!
147
+ let closureBox = UnsafeTransferBox ( value: closure)
148
+ let result = await withTaskGroup ( of: TaskResult< Result> . self , returning: Swift . Result < Result , any Error > . self) {
149
+ group in
150
+ group. addTask {
151
+ let c = closureBox. value
152
+ do {
153
+ let result = try await c ( )
154
+ return . closureResult( . success( result) )
155
+ } catch {
156
+ return . closureResult( . failure( error) )
157
+ }
158
+ }
159
+
160
+ group. addTask {
161
+ do {
162
+ // We are handling each incoming connection in a separate child task. It is important
163
+ // to use a discarding task group here which automatically discards finished child tasks.
164
+ // A normal task group retains all child tasks and their outputs in memory until they are
165
+ // consumed by iterating the group or by exiting the group. Since, we are never consuming
166
+ // the results of the group we need the group to automatically discard them; otherwise, this
167
+ // would result in a memory leak over time.
168
+ try await withThrowingDiscardingTaskGroup { taskGroup in
169
+ try await channel. executeThenClose { inbound in
170
+ for try await connectionChannel in inbound {
171
+
172
+ taskGroup. addTask {
173
+ logger. trace ( " Handling a new connection " )
174
+ await server. handleConnection ( channel: connectionChannel, logger: logger)
175
+ logger. trace ( " Done handling the connection " )
176
+ }
177
+ }
178
+ }
156
179
}
180
+ return . serverReturned( . success( ( ) ) )
181
+ } catch {
182
+ return . serverReturned( . failure( error) )
183
+ }
184
+ }
185
+
186
+ // Now that the local HTTP server and LambdaHandler tasks are started, wait for the
187
+ // first of the two that will terminate.
188
+ // When the first task terminates, cancel the group and collect the result of the
189
+ // second task.
190
+
191
+ // collect and return the result of the LambdaHandler
192
+ let serverOrHandlerResult1 = await group. next ( ) !
193
+ group. cancelAll ( )
194
+
195
+ switch serverOrHandlerResult1 {
196
+ case . closureResult( let result) :
197
+ return result
198
+
199
+ case . serverReturned( let result) :
200
+ logger. error (
201
+ " Server shutdown before closure completed " ,
202
+ metadata: [
203
+ " error " : " \( result. maybeError != nil ? " \( result. maybeError!) " : " none " ) "
204
+ ]
205
+ )
206
+ switch await group. next ( ) ! {
207
+ case . closureResult( let result) :
208
+ return result
209
+
210
+ case . serverReturned:
211
+ fatalError ( " Only one task is a server, and only one can return `serverReturned` " )
157
212
}
158
213
}
159
214
}
215
+
160
216
logger. info ( " Server shutting down " )
217
+ return try result. get ( )
161
218
}
162
219
163
220
/// This method handles individual TCP connections
164
221
private func handleConnection(
165
- channel: NIOAsyncChannel < HTTPServerRequestPart , HTTPServerResponsePart >
222
+ channel: NIOAsyncChannel < HTTPServerRequestPart , HTTPServerResponsePart > ,
223
+ logger: Logger
166
224
) async {
167
225
168
226
var requestHead : HTTPRequestHead !
@@ -186,12 +244,14 @@ private struct LambdaHttpServer {
186
244
// process the request
187
245
let response = try await self . processRequest (
188
246
head: requestHead,
189
- body: requestBody
247
+ body: requestBody,
248
+ logger: logger
190
249
)
191
250
// send the responses
192
251
try await self . sendResponse (
193
252
response: response,
194
- outbound: outbound
253
+ outbound: outbound,
254
+ logger: logger
195
255
)
196
256
197
257
requestHead = nil
@@ -214,15 +274,19 @@ private struct LambdaHttpServer {
214
274
/// - body: the HTTP request body
215
275
/// - Throws:
216
276
/// - Returns: the response to send back to the client or the Lambda function
217
- private func processRequest( head: HTTPRequestHead , body: ByteBuffer ? ) async throws -> LocalServerResponse {
277
+ private func processRequest(
278
+ head: HTTPRequestHead ,
279
+ body: ByteBuffer ? ,
280
+ logger: Logger
281
+ ) async throws -> LocalServerResponse {
218
282
219
283
if let body {
220
- self . logger. trace (
284
+ logger. trace (
221
285
" Processing request " ,
222
286
metadata: [ " URI " : " \( head. method) \( head. uri) " , " Body " : " \( String ( buffer: body) ) " ]
223
287
)
224
288
} else {
225
- self . logger. trace ( " Processing request " , metadata: [ " URI " : " \( head. method) \( head. uri) " ] )
289
+ logger. trace ( " Processing request " , metadata: [ " URI " : " \( head. method) \( head. uri) " ] )
226
290
}
227
291
228
292
switch ( head. method, head. uri) {
@@ -237,7 +301,9 @@ private struct LambdaHttpServer {
237
301
}
238
302
// we always accept the /invoke request and push them to the pool
239
303
let requestId = " \( DispatchTime . now ( ) . uptimeNanoseconds) "
240
- logger. trace ( " /invoke received invocation " , metadata: [ " requestId " : " \( requestId) " ] )
304
+ var logger = logger
305
+ logger [ metadataKey: " requestID " ] = " \( requestId) "
306
+ logger. trace ( " /invoke received invocation " )
241
307
await self . invocationPool. push ( LocalServerInvocation ( requestId: requestId, request: body) )
242
308
243
309
// wait for the lambda function to process the request
@@ -273,9 +339,9 @@ private struct LambdaHttpServer {
273
339
case ( . GET, let url) where url. hasSuffix ( Consts . getNextInvocationURLSuffix) :
274
340
275
341
// pop the tasks from the queue
276
- self . logger. trace ( " /next waiting for /invoke " )
342
+ logger. trace ( " /next waiting for /invoke " )
277
343
for try await invocation in self . invocationPool {
278
- self . logger. trace ( " /next retrieved invocation " , metadata: [ " requestId " : " \( invocation. requestId) " ] )
344
+ logger. trace ( " /next retrieved invocation " , metadata: [ " requestId " : " \( invocation. requestId) " ] )
279
345
// this call also stores the invocation requestId into the response
280
346
return invocation. makeResponse ( status: . accepted)
281
347
}
@@ -322,12 +388,13 @@ private struct LambdaHttpServer {
322
388
323
389
private func sendResponse(
324
390
response: LocalServerResponse ,
325
- outbound: NIOAsyncChannelOutboundWriter < HTTPServerResponsePart >
391
+ outbound: NIOAsyncChannelOutboundWriter < HTTPServerResponsePart > ,
392
+ logger: Logger
326
393
) async throws {
327
394
var headers = HTTPHeaders ( response. headers ?? [ ] )
328
395
headers. add ( name: " Content-Length " , value: " \( response. body? . readableBytes ?? 0 ) " )
329
396
330
- self . logger. trace ( " Writing response " , metadata: [ " requestId " : " \( response. requestId ?? " " ) " ] )
397
+ logger. trace ( " Writing response " , metadata: [ " requestId " : " \( response. requestId ?? " " ) " ] )
331
398
try await outbound. write (
332
399
HTTPServerResponsePart . head (
333
400
HTTPResponseHead (
@@ -350,44 +417,59 @@ private struct LambdaHttpServer {
350
417
private final class Pool < T> : AsyncSequence , AsyncIteratorProtocol , Sendable where T: Sendable {
351
418
typealias Element = T
352
419
353
- private let _buffer = Mutex < CircularBuffer < T > > ( . init( ) )
354
- private let _continuation = Mutex < CheckedContinuation < T , any Error > ? > ( nil )
355
-
356
- /// retrieve the first element from the buffer
357
- public func popFirst( ) async -> T ? {
358
- self . _buffer. withLock { $0. popFirst ( ) }
420
+ enum State : ~ Copyable {
421
+ case buffer( Deque < T > )
422
+ case continuation( CheckedContinuation < T , any Error > ? )
359
423
}
360
424
425
+ private let lock = Mutex < State > ( . buffer( [ ] ) )
426
+
361
427
/// enqueue an element, or give it back immediately to the iterator if it is waiting for an element
362
428
public func push( _ invocation: T ) async {
363
429
// if the iterator is waiting for an element, give it to it
364
430
// otherwise, enqueue the element
365
- if let continuation = self . _continuation. withLock ( { $0 } ) {
366
- self . _continuation. withLock { $0 = nil }
367
- continuation. resume ( returning: invocation)
368
- } else {
369
- self . _buffer. withLock { $0. append ( invocation) }
431
+ let maybeContinuation = self . lock. withLock { state -> CheckedContinuation < T , any Error > ? in
432
+ switch consume state {
433
+ case . continuation( let continuation) :
434
+ state = . buffer( [ ] )
435
+ return continuation
436
+
437
+ case . buffer( var buffer) :
438
+ buffer. append ( invocation)
439
+ state = . buffer( buffer)
440
+ return nil
441
+ }
370
442
}
443
+
444
+ maybeContinuation? . resume ( returning: invocation)
371
445
}
372
446
373
447
func next( ) async throws -> T ? {
374
-
375
448
// exit the async for loop if the task is cancelled
376
449
guard !Task. isCancelled else {
377
450
return nil
378
451
}
379
452
380
- if let element = await self . popFirst ( ) {
381
- return element
382
- } else {
383
- // we can't return nil if there is nothing to dequeue otherwise the async for loop will stop
384
- // wait for an element to be enqueued
385
- return try await withCheckedThrowingContinuation { ( continuation: CheckedContinuation < T , any Error > ) in
386
- // store the continuation for later, when an element is enqueued
387
- self . _continuation. withLock {
388
- $0 = continuation
453
+ return try await withCheckedThrowingContinuation { ( continuation: CheckedContinuation < T , any Error > ) in
454
+ let nextAction = self . lock. withLock { state -> T ? in
455
+ switch consume state {
456
+ case . buffer( var buffer) :
457
+ if let first = buffer. popFirst ( ) {
458
+ state = . buffer( buffer)
459
+ return first
460
+ } else {
461
+ state = . continuation( continuation)
462
+ return nil
463
+ }
464
+
465
+ case . continuation:
466
+ fatalError ( " Concurrent invocations to next(). This is illegal. " )
389
467
}
390
468
}
469
+
470
+ guard let nextAction else { return }
471
+
472
+ continuation. resume ( returning: nextAction)
391
473
}
392
474
}
393
475
@@ -432,3 +514,14 @@ private struct LambdaHttpServer {
432
514
}
433
515
}
434
516
#endif
517
+
518
+ extension Result {
519
+ var maybeError : Failure ? {
520
+ switch self {
521
+ case . success:
522
+ return nil
523
+ case . failure( let error) :
524
+ return error
525
+ }
526
+ }
527
+ }
0 commit comments