From 0c66d4019d6fb23e246d1637d7919b5a059d11a3 Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Mon, 13 Jan 2025 12:47:54 +0100 Subject: [PATCH 1/9] update mock server for swft 6 compliance --- Package.swift | 6 +- Sources/MockServer/MockHTTPServer.swift | 285 ++++++++++++++++++++++++ Sources/MockServer/main.swift | 177 --------------- 3 files changed, 288 insertions(+), 180 deletions(-) create mode 100644 Sources/MockServer/MockHTTPServer.swift delete mode 100644 Sources/MockServer/main.swift diff --git a/Package.swift b/Package.swift index d2c92fdc..96068884 100644 --- a/Package.swift +++ b/Package.swift @@ -17,7 +17,7 @@ let package = Package( .library(name: "AWSLambdaTesting", targets: ["AWSLambdaTesting"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.76.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.77.0"), .package(url: "https://github.com/apple/swift-log.git", from: "1.5.4"), ], targets: [ @@ -89,11 +89,11 @@ let package = Package( .executableTarget( name: "MockServer", dependencies: [ + .product(name: "Logging", package: "swift-log"), .product(name: "NIOHTTP1", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), - ], - swiftSettings: [.swiftLanguageMode(.v5)] + ] ), ] ) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift new file mode 100644 index 00000000..a730de11 --- /dev/null +++ b/Sources/MockServer/MockHTTPServer.swift @@ -0,0 +1,285 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftAWSLambdaRuntime open source project +// +// Copyright (c) 2017-2025 Apple Inc. and the SwiftAWSLambdaRuntime project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOCore +import NIOHTTP1 +import NIOPosix + +// for UUID and Date +#if canImport(FoundationEssentials) +import FoundationEssentials +#else +import Foundation +#endif + +@main +public class MockHttpServer { + + public static func main() throws { + let server = MockHttpServer() + try server.start() + } + + private func start() throws { + let host = env("HOST") ?? "127.0.0.1" + let port = env("PORT").flatMap(Int.init) ?? 7000 + let mode = env("MODE").flatMap(Mode.init) ?? .string + var log = Logger(label: "MockServer") + log.logLevel = env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info + let logger = log + + let socketBootstrap = ServerBootstrap(group: MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)) + // Specify backlog and enable SO_REUSEADDR for the server itself + // .serverChannelOption(.backlog, value: 256) + .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) + // .childChannelOption(.maxMessagesPerRead, value: 1) + + // Set the handlers that are applied to the accepted Channels + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { + channel.pipeline.addHandler(HTTPHandler(mode: mode, logger: logger)) + } + } + + let channel = try socketBootstrap.bind(host: host, port: port).wait() + logger.debug("Server started and listening on \(host):\(port)") + + // This will never return as we don't close the ServerChannel + try channel.closeFuture.wait() + } +} + +private final class HTTPHandler: ChannelInboundHandler { + public typealias InboundIn = HTTPServerRequestPart + public typealias OutboundOut = HTTPServerResponsePart + + private enum State { + case idle + case waitingForRequestBody + case sendingResponse + + mutating func requestReceived() { + precondition(self == .idle, "Invalid state for request received: \(self)") + self = .waitingForRequestBody + } + + mutating func requestComplete() { + precondition( + self == .waitingForRequestBody, + "Invalid state for request complete: \(self)" + ) + self = .sendingResponse + } + + mutating func responseComplete() { + precondition(self == .sendingResponse, "Invalid state for response complete: \(self)") + self = .idle + } + } + + private let logger: Logger + private let mode: Mode + + private var buffer: ByteBuffer! = nil + private var state: HTTPHandler.State = .idle + private var keepAlive = false + + private var requestHead: HTTPRequestHead? + private var requestBodyBytes: Int = 0 + + init(mode: Mode, logger: Logger) { + self.mode = mode + self.logger = logger + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let reqPart = Self.unwrapInboundIn(data) + handle(context: context, request: reqPart) + } + + func channelReadComplete(context: ChannelHandlerContext) { + context.flush() + self.buffer.clear() + } + + func handlerAdded(context: ChannelHandlerContext) { + self.buffer = context.channel.allocator.buffer(capacity: 0) + } + + private func handle(context: ChannelHandlerContext, request: HTTPServerRequestPart) { + switch request { + case .head(let request): + logger.trace("Received request .head") + self.requestHead = request + self.requestBodyBytes = 0 + self.keepAlive = request.isKeepAlive + self.state.requestReceived() + case .body(buffer: var buf): + logger.trace("Received request .body") + self.requestBodyBytes += buf.readableBytes + self.buffer.writeBuffer(&buf) + case .end: + logger.trace("Received request .end") + self.state.requestComplete() + + precondition(requestHead != nil, "Received .end without .head") + let (responseStatus, responseHeaders, responseBody) = self.processRequest( + requestHead: self.requestHead!, + requestBody: self.buffer + ) + + self.buffer.clear() + self.buffer.writeString(responseBody) + + var headers = HTTPHeaders(responseHeaders) + headers.add(name: "Content-Length", value: "\(responseBody.utf8.count)") + + // write the response + context.write( + Self.wrapOutboundOut( + .head( + httpResponseHead( + request: self.requestHead!, + status: responseStatus, + headers: headers + ) + ) + ), + promise: nil + ) + context.write(Self.wrapOutboundOut(.body(.byteBuffer(self.buffer))), promise: nil) + self.completeResponse(context, trailers: nil, promise: nil) + } + } + + private func processRequest( + requestHead: HTTPRequestHead, + requestBody: ByteBuffer + ) -> (HTTPResponseStatus, [(String, String)], String) { + var responseStatus: HTTPResponseStatus = .ok + var responseBody: String = "" + var responseHeaders: [(String, String)] = [] + + logger.trace("Processing request for : \(requestHead) - \(requestBody.getString(at: 0, length: self.requestBodyBytes) ?? "")") + + if requestHead.uri.hasSuffix("/next") { + logger.trace("URI /next") + + responseStatus = .accepted + + let requestId = UUID().uuidString + switch self.mode { + case .string: + responseBody = "\"\(requestId)\"" // must be a valid JSON string + case .json: + responseBody = "{ \"body\": \"\(requestId)\" }" + } + let deadline = Int64(Date(timeIntervalSinceNow: 60).timeIntervalSince1970 * 1000) + responseHeaders = [ + // ("Connection", "close"), + (AmazonHeaders.requestID, requestId), + (AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:123456789012:function:custom-runtime"), + (AmazonHeaders.traceID, "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419;Sampled=1"), + (AmazonHeaders.deadline, String(deadline)), + ] + } else if requestHead.uri.hasSuffix("/response") { + logger.trace("URI /response") + responseStatus = .accepted + } else if requestHead.uri.hasSuffix("/error") { + logger.trace("URI /error") + responseStatus = .ok + } else { + logger.trace("Unknown URI : \(requestHead)") + responseStatus = .notFound + } + logger.trace("Returning response: \(responseStatus), \(responseHeaders), \(responseBody)") + return (responseStatus, responseHeaders, responseBody) + } + + private func completeResponse( + _ context: ChannelHandlerContext, + trailers: HTTPHeaders?, + promise: EventLoopPromise? + ) { + self.state.responseComplete() + + let eventLoop = context.eventLoop + let loopBoundContext = NIOLoopBound(context, eventLoop: eventLoop) + + let promise = self.keepAlive ? promise : (promise ?? context.eventLoop.makePromise()) + if !self.keepAlive { + promise!.futureResult.whenComplete { (_: Result) in + let context = loopBoundContext.value + context.close(promise: nil) + } + } + + context.writeAndFlush(Self.wrapOutboundOut(.end(trailers)), promise: promise) + } + + private func httpResponseHead( + request: HTTPRequestHead, + status: HTTPResponseStatus, + headers: HTTPHeaders = HTTPHeaders() + ) -> HTTPResponseHead { + var head = HTTPResponseHead(version: request.version, status: status, headers: headers) + let connectionHeaders: [String] = head.headers[canonicalForm: "connection"].map { + $0.lowercased() + } + + if !connectionHeaders.contains("keep-alive") && !connectionHeaders.contains("close") { + // the user hasn't pre-set either 'keep-alive' or 'close', so we might need to add headers + + switch (request.isKeepAlive, request.version.major, request.version.minor) { + case (true, 1, 0): + // HTTP/1.0 and the request has 'Connection: keep-alive', we should mirror that + head.headers.add(name: "Connection", value: "keep-alive") + case (false, 1, let n) where n >= 1: + // HTTP/1.1 (or treated as such) and the request has 'Connection: close', we should mirror that + head.headers.add(name: "Connection", value: "close") + default: + // we should match the default or are dealing with some HTTP that we don't support, let's leave as is + () + } + } + return head + } + + private enum ServerError: Error { + case notReady + case cantBind + } + + private enum AmazonHeaders { + static let requestID = "Lambda-Runtime-Aws-Request-Id" + static let traceID = "Lambda-Runtime-Trace-Id" + static let clientContext = "X-Amz-Client-Context" + static let cognitoIdentity = "X-Amz-Cognito-Identity" + static let deadline = "Lambda-Runtime-Deadline-Ms" + static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" + } +} + +private enum Mode: String { + case string + case json +} + +private func env(_ name: String) -> String? { + guard let value = getenv(name) else { + return nil + } + return String(cString: value) +} diff --git a/Sources/MockServer/main.swift b/Sources/MockServer/main.swift deleted file mode 100644 index 1b8466f9..00000000 --- a/Sources/MockServer/main.swift +++ /dev/null @@ -1,177 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftAWSLambdaRuntime open source project -// -// Copyright (c) 2017-2018 Apple Inc. and the SwiftAWSLambdaRuntime project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -import Dispatch -import NIOCore -import NIOHTTP1 -import NIOPosix - -#if canImport(FoundationEssentials) -import FoundationEssentials -#else -import Foundation -#endif - -struct MockServer { - private let group: EventLoopGroup - private let host: String - private let port: Int - private let mode: Mode - - public init() { - self.group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount) - self.host = env("HOST") ?? "127.0.0.1" - self.port = env("PORT").flatMap(Int.init) ?? 7000 - self.mode = env("MODE").flatMap(Mode.init) ?? .string - } - - func start() throws { - let bootstrap = ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { _ in - channel.pipeline.addHandler(HTTPHandler(mode: self.mode)) - } - } - try bootstrap.bind(host: self.host, port: self.port).flatMap { channel -> EventLoopFuture in - guard let localAddress = channel.localAddress else { - return channel.eventLoop.makeFailedFuture(ServerError.cantBind) - } - print("\(self) started and listening on \(localAddress)") - return channel.eventLoop.makeSucceededFuture(()) - }.wait() - } -} - -final class HTTPHandler: ChannelInboundHandler { - public typealias InboundIn = HTTPServerRequestPart - public typealias OutboundOut = HTTPServerResponsePart - - private let mode: Mode - - private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>() - - public init(mode: Mode) { - self.mode = mode - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let requestPart = unwrapInboundIn(data) - - switch requestPart { - case .head(let head): - self.pending.append((head: head, body: nil)) - case .body(var buffer): - var request = self.pending.removeFirst() - if request.body == nil { - request.body = buffer - } else { - request.body!.writeBuffer(&buffer) - } - self.pending.prepend(request) - case .end: - let request = self.pending.removeFirst() - self.processRequest(context: context, request: request) - } - } - - func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) { - var responseStatus: HTTPResponseStatus - var responseBody: String? - var responseHeaders: [(String, String)]? - - if request.head.uri.hasSuffix("/next") { - let requestId = UUID().uuidString - responseStatus = .ok - switch self.mode { - case .string: - responseBody = requestId - case .json: - responseBody = "{ \"body\": \"\(requestId)\" }" - } - let deadline = Int64(Date(timeIntervalSinceNow: 60).timeIntervalSince1970 * 1000) - responseHeaders = [ - (AmazonHeaders.requestID, requestId), - (AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:123456789012:function:custom-runtime"), - (AmazonHeaders.traceID, "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419;Sampled=1"), - (AmazonHeaders.deadline, String(deadline)), - ] - } else if request.head.uri.hasSuffix("/response") { - responseStatus = .accepted - } else { - responseStatus = .notFound - } - self.writeResponse(context: context, status: responseStatus, headers: responseHeaders, body: responseBody) - } - - func writeResponse( - context: ChannelHandlerContext, - status: HTTPResponseStatus, - headers: [(String, String)]? = nil, - body: String? = nil - ) { - var headers = HTTPHeaders(headers ?? []) - headers.add(name: "content-length", value: "\(body?.utf8.count ?? 0)") - let head = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: status, headers: headers) - - context.write(wrapOutboundOut(.head(head))).whenFailure { error in - print("\(self) write error \(error)") - } - - if let b = body { - var buffer = context.channel.allocator.buffer(capacity: b.utf8.count) - buffer.writeString(b) - context.write(wrapOutboundOut(.body(.byteBuffer(buffer)))).whenFailure { error in - print("\(self) write error \(error)") - } - } - - context.writeAndFlush(wrapOutboundOut(.end(nil))).whenComplete { result in - if case .failure(let error) = result { - print("\(self) write error \(error)") - } - } - } -} - -enum ServerError: Error { - case notReady - case cantBind -} - -enum AmazonHeaders { - static let requestID = "Lambda-Runtime-Aws-Request-Id" - static let traceID = "Lambda-Runtime-Trace-Id" - static let clientContext = "X-Amz-Client-Context" - static let cognitoIdentity = "X-Amz-Cognito-Identity" - static let deadline = "Lambda-Runtime-Deadline-Ms" - static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" -} - -enum Mode: String { - case string - case json -} - -func env(_ name: String) -> String? { - guard let value = getenv(name) else { - return nil - } - return String(cString: value) -} - -// main -let server = MockServer() -try! server.start() -dispatchMain() From 888ed77ab5bae76acec6c2a55899f309d2210bbc Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Mon, 13 Jan 2025 18:36:31 +0100 Subject: [PATCH 2/9] apply swift format --- Sources/MockServer/MockHTTPServer.swift | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index a730de11..0de58d0c 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -27,7 +27,7 @@ import Foundation @main public class MockHttpServer { - public static func main() throws { + public static func main() throws { let server = MockHttpServer() try server.start() } @@ -172,7 +172,9 @@ private final class HTTPHandler: ChannelInboundHandler { var responseBody: String = "" var responseHeaders: [(String, String)] = [] - logger.trace("Processing request for : \(requestHead) - \(requestBody.getString(at: 0, length: self.requestBodyBytes) ?? "")") + logger.trace( + "Processing request for : \(requestHead) - \(requestBody.getString(at: 0, length: self.requestBodyBytes) ?? "")" + ) if requestHead.uri.hasSuffix("/next") { logger.trace("URI /next") @@ -182,7 +184,7 @@ private final class HTTPHandler: ChannelInboundHandler { let requestId = UUID().uuidString switch self.mode { case .string: - responseBody = "\"\(requestId)\"" // must be a valid JSON string + responseBody = "\"\(requestId)\"" // must be a valid JSON string case .json: responseBody = "{ \"body\": \"\(requestId)\" }" } @@ -194,7 +196,7 @@ private final class HTTPHandler: ChannelInboundHandler { (AmazonHeaders.traceID, "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419;Sampled=1"), (AmazonHeaders.deadline, String(deadline)), ] - } else if requestHead.uri.hasSuffix("/response") { + } else if requestHead.uri.hasSuffix("/response") { logger.trace("URI /response") responseStatus = .accepted } else if requestHead.uri.hasSuffix("/error") { From e7b7e6ccfee7abdf26c237a109f8355e5e6f8a2c Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Mon, 13 Jan 2025 18:50:45 +0100 Subject: [PATCH 3/9] simplify ByteBuffer to String --- Sources/MockServer/MockHTTPServer.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index 0de58d0c..8e0e56fd 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -173,7 +173,7 @@ private final class HTTPHandler: ChannelInboundHandler { var responseHeaders: [(String, String)] = [] logger.trace( - "Processing request for : \(requestHead) - \(requestBody.getString(at: 0, length: self.requestBodyBytes) ?? "")" + "Processing request for : \(requestHead) - \(String(requestBody))" ) if requestHead.uri.hasSuffix("/next") { From 61ab17bafd8532dd5e3731905059e049a095a065 Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Mon, 13 Jan 2025 18:56:51 +0100 Subject: [PATCH 4/9] fix --- Sources/MockServer/MockHTTPServer.swift | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index 8e0e56fd..63fe3e72 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -42,9 +42,9 @@ public class MockHttpServer { let socketBootstrap = ServerBootstrap(group: MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)) // Specify backlog and enable SO_REUSEADDR for the server itself - // .serverChannelOption(.backlog, value: 256) + .serverChannelOption(.backlog, value: 256) .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) - // .childChannelOption(.maxMessagesPerRead, value: 1) + .childChannelOption(.maxMessagesPerRead, value: 1) // Set the handlers that are applied to the accepted Channels .childChannelInitializer { channel in @@ -173,7 +173,7 @@ private final class HTTPHandler: ChannelInboundHandler { var responseHeaders: [(String, String)] = [] logger.trace( - "Processing request for : \(requestHead) - \(String(requestBody))" + "Processing request for : \(requestHead) - \(String(buffer: requestBody))" ) if requestHead.uri.hasSuffix("/next") { From bc3a34d301f40bffce4c0ff501d863d0b73b82d8 Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Mon, 13 Jan 2025 19:18:37 +0100 Subject: [PATCH 5/9] remove unused code --- Sources/MockServer/MockHTTPServer.swift | 30 ------------------------- 1 file changed, 30 deletions(-) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index 63fe3e72..a8a1663d 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -65,35 +65,10 @@ private final class HTTPHandler: ChannelInboundHandler { public typealias InboundIn = HTTPServerRequestPart public typealias OutboundOut = HTTPServerResponsePart - private enum State { - case idle - case waitingForRequestBody - case sendingResponse - - mutating func requestReceived() { - precondition(self == .idle, "Invalid state for request received: \(self)") - self = .waitingForRequestBody - } - - mutating func requestComplete() { - precondition( - self == .waitingForRequestBody, - "Invalid state for request complete: \(self)" - ) - self = .sendingResponse - } - - mutating func responseComplete() { - precondition(self == .sendingResponse, "Invalid state for response complete: \(self)") - self = .idle - } - } - private let logger: Logger private let mode: Mode private var buffer: ByteBuffer! = nil - private var state: HTTPHandler.State = .idle private var keepAlive = false private var requestHead: HTTPRequestHead? @@ -125,14 +100,12 @@ private final class HTTPHandler: ChannelInboundHandler { self.requestHead = request self.requestBodyBytes = 0 self.keepAlive = request.isKeepAlive - self.state.requestReceived() case .body(buffer: var buf): logger.trace("Received request .body") self.requestBodyBytes += buf.readableBytes self.buffer.writeBuffer(&buf) case .end: logger.trace("Received request .end") - self.state.requestComplete() precondition(requestHead != nil, "Received .end without .head") let (responseStatus, responseHeaders, responseBody) = self.processRequest( @@ -190,7 +163,6 @@ private final class HTTPHandler: ChannelInboundHandler { } let deadline = Int64(Date(timeIntervalSinceNow: 60).timeIntervalSince1970 * 1000) responseHeaders = [ - // ("Connection", "close"), (AmazonHeaders.requestID, requestId), (AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:123456789012:function:custom-runtime"), (AmazonHeaders.traceID, "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419;Sampled=1"), @@ -215,8 +187,6 @@ private final class HTTPHandler: ChannelInboundHandler { trailers: HTTPHeaders?, promise: EventLoopPromise? ) { - self.state.responseComplete() - let eventLoop = context.eventLoop let loopBoundContext = NIOLoopBound(context, eventLoop: eventLoop) From ccdb45a2ac5c17526407099dcc8bbf86618f377a Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Tue, 14 Jan 2025 06:39:13 +0100 Subject: [PATCH 6/9] adjust payload to new examples --- Sources/MockServer/MockHTTPServer.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index a8a1663d..55468e67 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -157,9 +157,9 @@ private final class HTTPHandler: ChannelInboundHandler { let requestId = UUID().uuidString switch self.mode { case .string: - responseBody = "\"\(requestId)\"" // must be a valid JSON string + responseBody = "\"Seb\"" // must be a valid JSON document case .json: - responseBody = "{ \"body\": \"\(requestId)\" }" + responseBody = "{ \"name\": \"Seb\", \"age\" : 52 }" } let deadline = Int64(Date(timeIntervalSinceNow: 60).timeIntervalSince1970 * 1000) responseHeaders = [ From eb1608b55b90a5f280d44e23f99e4ed0f618a3fc Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Wed, 15 Jan 2025 10:21:27 +0100 Subject: [PATCH 7/9] [wip] use NIOAsyncChannel --- Sources/MockServer/MockHTTPServer.swift | 330 +++++++++++++----------- 1 file changed, 180 insertions(+), 150 deletions(-) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index 55468e67..34a923d1 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -16,6 +16,7 @@ import Logging import NIOCore import NIOHTTP1 import NIOPosix +import Synchronization // for UUID and Date #if canImport(FoundationEssentials) @@ -25,133 +26,169 @@ import Foundation #endif @main -public class MockHttpServer { - - public static func main() throws { - let server = MockHttpServer() - try server.start() - } +struct HttpServer { + /// The server's host. (default: 127.0.0.1) + private let host: String + /// The server's port. (default: 7000) + private let port: Int + /// The server's event loop group. (default: MultiThreadedEventLoopGroup.singleton) + private let eventLoopGroup: MultiThreadedEventLoopGroup + /// the mode. Are we mocking a server for a Lambda function that expects a String or a JSON document? (default: string) + private let mode: Mode + /// the number of connections this server must accept before shutting down (default: 1) + private let maxInvocations: Int + /// the logger (control verbosity with LOG_LEVEL environment variable) + private let logger: Logger - private func start() throws { - let host = env("HOST") ?? "127.0.0.1" - let port = env("PORT").flatMap(Int.init) ?? 7000 - let mode = env("MODE").flatMap(Mode.init) ?? .string + static func main() async throws { var log = Logger(label: "MockServer") log.logLevel = env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info - let logger = log - let socketBootstrap = ServerBootstrap(group: MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)) - // Specify backlog and enable SO_REUSEADDR for the server itself + let server = HttpServer( + host: env("HOST") ?? "127.0.0.1", + port: env("PORT").flatMap(Int.init) ?? 7000, + eventLoopGroup: .singleton, + mode: env("MODE").flatMap(Mode.init) ?? .string, + maxInvocations: env("MAX_INVOCATIONS").flatMap(Int.init) ?? 1, + logger: log + ) + try await server.run() + } + + /// This method starts the server and handles incoming connections. + private func run() async throws { + let channel = try await ServerBootstrap(group: self.eventLoopGroup) .serverChannelOption(.backlog, value: 256) .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) .childChannelOption(.maxMessagesPerRead, value: 1) + .bind( + host: self.host, + port: self.port + ) { channel in + channel.eventLoop.makeCompletedFuture { + + try channel.pipeline.syncOperations.configureHTTPServerPipeline( + withErrorHandling: true + ) - // Set the handlers that are applied to the accepted Channels - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { - channel.pipeline.addHandler(HTTPHandler(mode: mode, logger: logger)) + return try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: NIOAsyncChannel.Configuration( + inboundType: HTTPServerRequestPart.self, + outboundType: HTTPServerResponsePart.self + ) + ) } } - let channel = try socketBootstrap.bind(host: host, port: port).wait() - logger.debug("Server started and listening on \(host):\(port)") - - // This will never return as we don't close the ServerChannel - try channel.closeFuture.wait() - } -} - -private final class HTTPHandler: ChannelInboundHandler { - public typealias InboundIn = HTTPServerRequestPart - public typealias OutboundOut = HTTPServerResponsePart - - private let logger: Logger - private let mode: Mode - - private var buffer: ByteBuffer! = nil - private var keepAlive = false - - private var requestHead: HTTPRequestHead? - private var requestBodyBytes: Int = 0 - - init(mode: Mode, logger: Logger) { - self.mode = mode - self.logger = logger - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let reqPart = Self.unwrapInboundIn(data) - handle(context: context, request: reqPart) - } - - func channelReadComplete(context: ChannelHandlerContext) { - context.flush() - self.buffer.clear() - } + logger.info( + "Server started and listening", + metadata: [ + "host": "\(channel.channel.localAddress?.ipAddress?.debugDescription ?? "")", + "port": "\(channel.channel.localAddress?.port ?? 0)", + ] + ) - func handlerAdded(context: ChannelHandlerContext) { - self.buffer = context.channel.allocator.buffer(capacity: 0) + // We are handling each incoming connection in a separate child task. It is important + // to use a discarding task group here which automatically discards finished child tasks. + // A normal task group retains all child tasks and their outputs in memory until they are + // consumed by iterating the group or by exiting the group. Since, we are never consuming + // the results of the group we need the group to automatically discard them; otherwise, this + // would result in a memory leak over time. + try await withThrowingDiscardingTaskGroup { group in + try await channel.executeThenClose { inbound in + for try await connectionChannel in inbound { + logger.trace("Handling new connection") + logger.info( + "This mock server accepts only one connection, it will shutdown the server after handling the current connection." + ) + group.addTask { + await self.handleConnection(channel: connectionChannel) + logger.trace("Done handling connection") + } + break + } + } + } + logger.info("Server shutting down") } - private func handle(context: ChannelHandlerContext, request: HTTPServerRequestPart) { - switch request { - case .head(let request): - logger.trace("Received request .head") - self.requestHead = request - self.requestBodyBytes = 0 - self.keepAlive = request.isKeepAlive - case .body(buffer: var buf): - logger.trace("Received request .body") - self.requestBodyBytes += buf.readableBytes - self.buffer.writeBuffer(&buf) - case .end: - logger.trace("Received request .end") - - precondition(requestHead != nil, "Received .end without .head") - let (responseStatus, responseHeaders, responseBody) = self.processRequest( - requestHead: self.requestHead!, - requestBody: self.buffer - ) + /// This method handles a single connection by echoing back all inbound data. + private func handleConnection( + channel: NIOAsyncChannel + ) async { + + var requestHead: HTTPRequestHead! + var requestBody: ByteBuffer? + + // each Lambda invocation results in TWO HTTP requests (next and response) + let requestCount = RequestCounter(maxRequest: self.maxInvocations * 2) + + // Note that this method is non-throwing and we are catching any error. + // We do this since we don't want to tear down the whole server when a single connection + // encounters an error. + do { + try await channel.executeThenClose { inbound, outbound in + for try await inboundData in inbound { + let requestNumber = requestCount.current() + logger.trace("Handling request", metadata: ["requestNumber": "\(requestNumber)"]) + + if case .head(let head) = inboundData { + logger.trace("Received request head", metadata: ["head": "\(head)"]) + requestHead = head + } + if case .body(let body) = inboundData { + logger.trace("Received request body", metadata: ["body": "\(body)"]) + requestBody = body + } + if case .end(let end) = inboundData { + logger.trace("Received request end", metadata: ["end": "\(String(describing: end))"]) + + precondition(requestHead != nil, "Received .end without .head") + let (responseStatus, responseHeaders, responseBody) = self.processRequest( + requestHead: requestHead, + requestBody: requestBody + ) - self.buffer.clear() - self.buffer.writeString(responseBody) + try await self.sendResponse( + responseStatus: responseStatus, + responseHeaders: responseHeaders, + responseBody: responseBody, + outbound: outbound + ) - var headers = HTTPHeaders(responseHeaders) - headers.add(name: "Content-Length", value: "\(responseBody.utf8.count)") + requestHead = nil - // write the response - context.write( - Self.wrapOutboundOut( - .head( - httpResponseHead( - request: self.requestHead!, - status: responseStatus, - headers: headers - ) - ) - ), - promise: nil - ) - context.write(Self.wrapOutboundOut(.body(.byteBuffer(self.buffer))), promise: nil) - self.completeResponse(context, trailers: nil, promise: nil) + if requestCount.increment() { + logger.info( + "Maximum number of invocations reached, closing this connection", + metadata: ["maxInvocations": "\(self.maxInvocations)"] + ) + break + } + } + } + } + } catch { + logger.error("Hit error: \(error)") } } - + /// This function process the requests and return an hard-coded response (string or JSON depending on the mode). + /// We ignore the requestBody. private func processRequest( requestHead: HTTPRequestHead, - requestBody: ByteBuffer + requestBody: ByteBuffer? ) -> (HTTPResponseStatus, [(String, String)], String) { var responseStatus: HTTPResponseStatus = .ok var responseBody: String = "" var responseHeaders: [(String, String)] = [] logger.trace( - "Processing request for : \(requestHead) - \(String(buffer: requestBody))" + "Processing request", + metadata: ["VERB": "\(requestHead.method)", "URI": "\(requestHead.uri)"] ) if requestHead.uri.hasSuffix("/next") { - logger.trace("URI /next") - responseStatus = .accepted let requestId = UUID().uuidString @@ -169,64 +206,51 @@ private final class HTTPHandler: ChannelInboundHandler { (AmazonHeaders.deadline, String(deadline)), ] } else if requestHead.uri.hasSuffix("/response") { - logger.trace("URI /response") responseStatus = .accepted } else if requestHead.uri.hasSuffix("/error") { - logger.trace("URI /error") responseStatus = .ok } else { - logger.trace("Unknown URI : \(requestHead)") responseStatus = .notFound } logger.trace("Returning response: \(responseStatus), \(responseHeaders), \(responseBody)") return (responseStatus, responseHeaders, responseBody) } - private func completeResponse( - _ context: ChannelHandlerContext, - trailers: HTTPHeaders?, - promise: EventLoopPromise? - ) { - let eventLoop = context.eventLoop - let loopBoundContext = NIOLoopBound(context, eventLoop: eventLoop) - - let promise = self.keepAlive ? promise : (promise ?? context.eventLoop.makePromise()) - if !self.keepAlive { - promise!.futureResult.whenComplete { (_: Result) in - let context = loopBoundContext.value - context.close(promise: nil) - } - } - - context.writeAndFlush(Self.wrapOutboundOut(.end(trailers)), promise: promise) + private func sendResponse( + responseStatus: HTTPResponseStatus, + responseHeaders: [(String, String)], + responseBody: String, + outbound: NIOAsyncChannelOutboundWriter + ) async throws { + var headers = HTTPHeaders(responseHeaders) + headers.add(name: "Content-Length", value: "\(responseBody.utf8.count)") + + logger.trace("Writing response head") + try await outbound.write( + HTTPServerResponsePart.head( + HTTPResponseHead( + version: .init(major: 1, minor: 1), + status: responseStatus, + headers: headers + ) + ) + ) + logger.trace("Writing response body") + try await outbound.write(HTTPServerResponsePart.body(.byteBuffer(ByteBuffer(string: responseBody)))) + logger.trace("Writing response end") + try await outbound.write(HTTPServerResponsePart.end(nil)) } - private func httpResponseHead( - request: HTTPRequestHead, - status: HTTPResponseStatus, - headers: HTTPHeaders = HTTPHeaders() - ) -> HTTPResponseHead { - var head = HTTPResponseHead(version: request.version, status: status, headers: headers) - let connectionHeaders: [String] = head.headers[canonicalForm: "connection"].map { - $0.lowercased() - } - - if !connectionHeaders.contains("keep-alive") && !connectionHeaders.contains("close") { - // the user hasn't pre-set either 'keep-alive' or 'close', so we might need to add headers + private enum Mode: String { + case string + case json + } - switch (request.isKeepAlive, request.version.major, request.version.minor) { - case (true, 1, 0): - // HTTP/1.0 and the request has 'Connection: keep-alive', we should mirror that - head.headers.add(name: "Connection", value: "keep-alive") - case (false, 1, let n) where n >= 1: - // HTTP/1.1 (or treated as such) and the request has 'Connection: close', we should mirror that - head.headers.add(name: "Connection", value: "close") - default: - // we should match the default or are dealing with some HTTP that we don't support, let's leave as is - () - } + private static func env(_ name: String) -> String? { + guard let value = getenv(name) else { + return nil } - return head + return String(cString: value) } private enum ServerError: Error { @@ -242,16 +266,22 @@ private final class HTTPHandler: ChannelInboundHandler { static let deadline = "Lambda-Runtime-Deadline-Ms" static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" } -} -private enum Mode: String { - case string - case json -} + private final class RequestCounter: Sendable { + private let counterMutex = Mutex(0) + private let maxRequest: Int -private func env(_ name: String) -> String? { - guard let value = getenv(name) else { - return nil + init(maxRequest: Int) { + self.maxRequest = maxRequest + } + func current() -> Int { + counterMutex.withLock { $0 } + } + func increment() -> Bool { + counterMutex.withLock { + $0 += 1 + return $0 >= maxRequest + } + } } - return String(cString: value) } From 32dace6948727731871ca52d39f51a65e153c471 Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Wed, 15 Jan 2025 11:38:59 +0100 Subject: [PATCH 8/9] manage max number of connections and max number of request per connection --- Sources/MockServer/MockHTTPServer.swift | 52 ++++++++++++++++--------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index 34a923d1..b8f18998 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -55,7 +55,8 @@ struct HttpServer { try await server.run() } - /// This method starts the server and handles incoming connections. + /// This method starts the server and handles one unique incoming connections + /// The Lambda function will send two HTTP requests over this connection: one for the next invocation and one for the response. private func run() async throws { let channel = try await ServerBootstrap(group: self.eventLoopGroup) .serverChannelOption(.backlog, value: 256) @@ -86,9 +87,14 @@ struct HttpServer { metadata: [ "host": "\(channel.channel.localAddress?.ipAddress?.debugDescription ?? "")", "port": "\(channel.channel.localAddress?.port ?? 0)", + "maxInvocations": "\(self.maxInvocations)", ] ) + // This counter is used to track the number of incoming connections. + // This mock servers accepts n TCP connection then shutdowns + let connectionCounter = SharedCounter(maxValue: self.maxInvocations) + // We are handling each incoming connection in a separate child task. It is important // to use a discarding task group here which automatically discards finished child tasks. // A normal task group retains all child tasks and their outputs in memory until they are @@ -98,22 +104,31 @@ struct HttpServer { try await withThrowingDiscardingTaskGroup { group in try await channel.executeThenClose { inbound in for try await connectionChannel in inbound { - logger.trace("Handling new connection") - logger.info( - "This mock server accepts only one connection, it will shutdown the server after handling the current connection." - ) + + let counter = connectionCounter.current() + logger.trace("Handling new connection", metadata: ["connectionNumber": "\(counter)"]) + group.addTask { await self.handleConnection(channel: connectionChannel) - logger.trace("Done handling connection") + logger.trace("Done handling connection", metadata: ["connectionNumber": "\(counter)"]) + } + + if connectionCounter.increment() { + logger.info( + "Maximum number of connections reached, shutting down after current connection", + metadata: ["maxConnections": "\(self.maxInvocations)"] + ) + break // this causes the server to shutdown after handling the connection } - break } } } logger.info("Server shutting down") } - /// This method handles a single connection by echoing back all inbound data. + /// This method handles a single connection by responsing hard coded value to a Lambda function request. + /// It handles two requests: one for the next invocation and one for the response. + /// when the maximum number of requests is reached, it closes the connection. private func handleConnection( channel: NIOAsyncChannel ) async { @@ -122,7 +137,7 @@ struct HttpServer { var requestBody: ByteBuffer? // each Lambda invocation results in TWO HTTP requests (next and response) - let requestCount = RequestCounter(maxRequest: self.maxInvocations * 2) + let requestCount = SharedCounter(maxValue: 2) // Note that this method is non-throwing and we are catching any error. // We do this since we don't want to tear down the whole server when a single connection @@ -161,10 +176,10 @@ struct HttpServer { if requestCount.increment() { logger.info( - "Maximum number of invocations reached, closing this connection", - metadata: ["maxInvocations": "\(self.maxInvocations)"] + "Maximum number of requests reached, closing this connection", + metadata: ["maxRequest": "2"] ) - break + break // this finishes handiling request on this connection } } } @@ -224,12 +239,13 @@ struct HttpServer { ) async throws { var headers = HTTPHeaders(responseHeaders) headers.add(name: "Content-Length", value: "\(responseBody.utf8.count)") + headers.add(name: "KeepAlive", value: "timeout=1, max=2") logger.trace("Writing response head") try await outbound.write( HTTPServerResponsePart.head( HTTPResponseHead( - version: .init(major: 1, minor: 1), + version: .init(major: 1, minor: 1), // use HTTP 1.1 it keeps connection alive between requests status: responseStatus, headers: headers ) @@ -267,12 +283,12 @@ struct HttpServer { static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" } - private final class RequestCounter: Sendable { + private final class SharedCounter: Sendable { private let counterMutex = Mutex(0) - private let maxRequest: Int + private let maxValue: Int - init(maxRequest: Int) { - self.maxRequest = maxRequest + init(maxValue: Int) { + self.maxValue = maxValue } func current() -> Int { counterMutex.withLock { $0 } @@ -280,7 +296,7 @@ struct HttpServer { func increment() -> Bool { counterMutex.withLock { $0 += 1 - return $0 >= maxRequest + return $0 >= maxValue } } } From ca742d96abe7b6b6733c6dced045619dc37e2e25 Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Tue, 21 Jan 2025 09:12:10 +0100 Subject: [PATCH 9/9] remove unused code --- Sources/MockServer/MockHTTPServer.swift | 5 ----- 1 file changed, 5 deletions(-) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index b8f18998..0849e325 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -269,11 +269,6 @@ struct HttpServer { return String(cString: value) } - private enum ServerError: Error { - case notReady - case cantBind - } - private enum AmazonHeaders { static let requestID = "Lambda-Runtime-Aws-Request-Id" static let traceID = "Lambda-Runtime-Trace-Id"