diff --git a/Package.swift b/Package.swift index d2c92fdc..96068884 100644 --- a/Package.swift +++ b/Package.swift @@ -17,7 +17,7 @@ let package = Package( .library(name: "AWSLambdaTesting", targets: ["AWSLambdaTesting"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.76.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.77.0"), .package(url: "https://github.com/apple/swift-log.git", from: "1.5.4"), ], targets: [ @@ -89,11 +89,11 @@ let package = Package( .executableTarget( name: "MockServer", dependencies: [ + .product(name: "Logging", package: "swift-log"), .product(name: "NIOHTTP1", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), - ], - swiftSettings: [.swiftLanguageMode(.v5)] + ] ), ] ) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift new file mode 100644 index 00000000..0849e325 --- /dev/null +++ b/Sources/MockServer/MockHTTPServer.swift @@ -0,0 +1,298 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftAWSLambdaRuntime open source project +// +// Copyright (c) 2017-2025 Apple Inc. and the SwiftAWSLambdaRuntime project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOCore +import NIOHTTP1 +import NIOPosix +import Synchronization + +// for UUID and Date +#if canImport(FoundationEssentials) +import FoundationEssentials +#else +import Foundation +#endif + +@main +struct HttpServer { + /// The server's host. (default: 127.0.0.1) + private let host: String + /// The server's port. (default: 7000) + private let port: Int + /// The server's event loop group. (default: MultiThreadedEventLoopGroup.singleton) + private let eventLoopGroup: MultiThreadedEventLoopGroup + /// the mode. Are we mocking a server for a Lambda function that expects a String or a JSON document? (default: string) + private let mode: Mode + /// the number of connections this server must accept before shutting down (default: 1) + private let maxInvocations: Int + /// the logger (control verbosity with LOG_LEVEL environment variable) + private let logger: Logger + + static func main() async throws { + var log = Logger(label: "MockServer") + log.logLevel = env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info + + let server = HttpServer( + host: env("HOST") ?? "127.0.0.1", + port: env("PORT").flatMap(Int.init) ?? 7000, + eventLoopGroup: .singleton, + mode: env("MODE").flatMap(Mode.init) ?? .string, + maxInvocations: env("MAX_INVOCATIONS").flatMap(Int.init) ?? 1, + logger: log + ) + try await server.run() + } + + /// This method starts the server and handles one unique incoming connections + /// The Lambda function will send two HTTP requests over this connection: one for the next invocation and one for the response. + private func run() async throws { + let channel = try await ServerBootstrap(group: self.eventLoopGroup) + .serverChannelOption(.backlog, value: 256) + .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) + .childChannelOption(.maxMessagesPerRead, value: 1) + .bind( + host: self.host, + port: self.port + ) { channel in + channel.eventLoop.makeCompletedFuture { + + try channel.pipeline.syncOperations.configureHTTPServerPipeline( + withErrorHandling: true + ) + + return try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: NIOAsyncChannel.Configuration( + inboundType: HTTPServerRequestPart.self, + outboundType: HTTPServerResponsePart.self + ) + ) + } + } + + logger.info( + "Server started and listening", + metadata: [ + "host": "\(channel.channel.localAddress?.ipAddress?.debugDescription ?? "")", + "port": "\(channel.channel.localAddress?.port ?? 0)", + "maxInvocations": "\(self.maxInvocations)", + ] + ) + + // This counter is used to track the number of incoming connections. + // This mock servers accepts n TCP connection then shutdowns + let connectionCounter = SharedCounter(maxValue: self.maxInvocations) + + // We are handling each incoming connection in a separate child task. It is important + // to use a discarding task group here which automatically discards finished child tasks. + // A normal task group retains all child tasks and their outputs in memory until they are + // consumed by iterating the group or by exiting the group. Since, we are never consuming + // the results of the group we need the group to automatically discard them; otherwise, this + // would result in a memory leak over time. + try await withThrowingDiscardingTaskGroup { group in + try await channel.executeThenClose { inbound in + for try await connectionChannel in inbound { + + let counter = connectionCounter.current() + logger.trace("Handling new connection", metadata: ["connectionNumber": "\(counter)"]) + + group.addTask { + await self.handleConnection(channel: connectionChannel) + logger.trace("Done handling connection", metadata: ["connectionNumber": "\(counter)"]) + } + + if connectionCounter.increment() { + logger.info( + "Maximum number of connections reached, shutting down after current connection", + metadata: ["maxConnections": "\(self.maxInvocations)"] + ) + break // this causes the server to shutdown after handling the connection + } + } + } + } + logger.info("Server shutting down") + } + + /// This method handles a single connection by responsing hard coded value to a Lambda function request. + /// It handles two requests: one for the next invocation and one for the response. + /// when the maximum number of requests is reached, it closes the connection. + private func handleConnection( + channel: NIOAsyncChannel + ) async { + + var requestHead: HTTPRequestHead! + var requestBody: ByteBuffer? + + // each Lambda invocation results in TWO HTTP requests (next and response) + let requestCount = SharedCounter(maxValue: 2) + + // Note that this method is non-throwing and we are catching any error. + // We do this since we don't want to tear down the whole server when a single connection + // encounters an error. + do { + try await channel.executeThenClose { inbound, outbound in + for try await inboundData in inbound { + let requestNumber = requestCount.current() + logger.trace("Handling request", metadata: ["requestNumber": "\(requestNumber)"]) + + if case .head(let head) = inboundData { + logger.trace("Received request head", metadata: ["head": "\(head)"]) + requestHead = head + } + if case .body(let body) = inboundData { + logger.trace("Received request body", metadata: ["body": "\(body)"]) + requestBody = body + } + if case .end(let end) = inboundData { + logger.trace("Received request end", metadata: ["end": "\(String(describing: end))"]) + + precondition(requestHead != nil, "Received .end without .head") + let (responseStatus, responseHeaders, responseBody) = self.processRequest( + requestHead: requestHead, + requestBody: requestBody + ) + + try await self.sendResponse( + responseStatus: responseStatus, + responseHeaders: responseHeaders, + responseBody: responseBody, + outbound: outbound + ) + + requestHead = nil + + if requestCount.increment() { + logger.info( + "Maximum number of requests reached, closing this connection", + metadata: ["maxRequest": "2"] + ) + break // this finishes handiling request on this connection + } + } + } + } + } catch { + logger.error("Hit error: \(error)") + } + } + /// This function process the requests and return an hard-coded response (string or JSON depending on the mode). + /// We ignore the requestBody. + private func processRequest( + requestHead: HTTPRequestHead, + requestBody: ByteBuffer? + ) -> (HTTPResponseStatus, [(String, String)], String) { + var responseStatus: HTTPResponseStatus = .ok + var responseBody: String = "" + var responseHeaders: [(String, String)] = [] + + logger.trace( + "Processing request", + metadata: ["VERB": "\(requestHead.method)", "URI": "\(requestHead.uri)"] + ) + + if requestHead.uri.hasSuffix("/next") { + responseStatus = .accepted + + let requestId = UUID().uuidString + switch self.mode { + case .string: + responseBody = "\"Seb\"" // must be a valid JSON document + case .json: + responseBody = "{ \"name\": \"Seb\", \"age\" : 52 }" + } + let deadline = Int64(Date(timeIntervalSinceNow: 60).timeIntervalSince1970 * 1000) + responseHeaders = [ + (AmazonHeaders.requestID, requestId), + (AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:123456789012:function:custom-runtime"), + (AmazonHeaders.traceID, "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419;Sampled=1"), + (AmazonHeaders.deadline, String(deadline)), + ] + } else if requestHead.uri.hasSuffix("/response") { + responseStatus = .accepted + } else if requestHead.uri.hasSuffix("/error") { + responseStatus = .ok + } else { + responseStatus = .notFound + } + logger.trace("Returning response: \(responseStatus), \(responseHeaders), \(responseBody)") + return (responseStatus, responseHeaders, responseBody) + } + + private func sendResponse( + responseStatus: HTTPResponseStatus, + responseHeaders: [(String, String)], + responseBody: String, + outbound: NIOAsyncChannelOutboundWriter + ) async throws { + var headers = HTTPHeaders(responseHeaders) + headers.add(name: "Content-Length", value: "\(responseBody.utf8.count)") + headers.add(name: "KeepAlive", value: "timeout=1, max=2") + + logger.trace("Writing response head") + try await outbound.write( + HTTPServerResponsePart.head( + HTTPResponseHead( + version: .init(major: 1, minor: 1), // use HTTP 1.1 it keeps connection alive between requests + status: responseStatus, + headers: headers + ) + ) + ) + logger.trace("Writing response body") + try await outbound.write(HTTPServerResponsePart.body(.byteBuffer(ByteBuffer(string: responseBody)))) + logger.trace("Writing response end") + try await outbound.write(HTTPServerResponsePart.end(nil)) + } + + private enum Mode: String { + case string + case json + } + + private static func env(_ name: String) -> String? { + guard let value = getenv(name) else { + return nil + } + return String(cString: value) + } + + private enum AmazonHeaders { + static let requestID = "Lambda-Runtime-Aws-Request-Id" + static let traceID = "Lambda-Runtime-Trace-Id" + static let clientContext = "X-Amz-Client-Context" + static let cognitoIdentity = "X-Amz-Cognito-Identity" + static let deadline = "Lambda-Runtime-Deadline-Ms" + static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" + } + + private final class SharedCounter: Sendable { + private let counterMutex = Mutex(0) + private let maxValue: Int + + init(maxValue: Int) { + self.maxValue = maxValue + } + func current() -> Int { + counterMutex.withLock { $0 } + } + func increment() -> Bool { + counterMutex.withLock { + $0 += 1 + return $0 >= maxValue + } + } + } +} diff --git a/Sources/MockServer/main.swift b/Sources/MockServer/main.swift deleted file mode 100644 index 1b8466f9..00000000 --- a/Sources/MockServer/main.swift +++ /dev/null @@ -1,177 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftAWSLambdaRuntime open source project -// -// Copyright (c) 2017-2018 Apple Inc. and the SwiftAWSLambdaRuntime project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -import Dispatch -import NIOCore -import NIOHTTP1 -import NIOPosix - -#if canImport(FoundationEssentials) -import FoundationEssentials -#else -import Foundation -#endif - -struct MockServer { - private let group: EventLoopGroup - private let host: String - private let port: Int - private let mode: Mode - - public init() { - self.group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount) - self.host = env("HOST") ?? "127.0.0.1" - self.port = env("PORT").flatMap(Int.init) ?? 7000 - self.mode = env("MODE").flatMap(Mode.init) ?? .string - } - - func start() throws { - let bootstrap = ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { _ in - channel.pipeline.addHandler(HTTPHandler(mode: self.mode)) - } - } - try bootstrap.bind(host: self.host, port: self.port).flatMap { channel -> EventLoopFuture in - guard let localAddress = channel.localAddress else { - return channel.eventLoop.makeFailedFuture(ServerError.cantBind) - } - print("\(self) started and listening on \(localAddress)") - return channel.eventLoop.makeSucceededFuture(()) - }.wait() - } -} - -final class HTTPHandler: ChannelInboundHandler { - public typealias InboundIn = HTTPServerRequestPart - public typealias OutboundOut = HTTPServerResponsePart - - private let mode: Mode - - private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>() - - public init(mode: Mode) { - self.mode = mode - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let requestPart = unwrapInboundIn(data) - - switch requestPart { - case .head(let head): - self.pending.append((head: head, body: nil)) - case .body(var buffer): - var request = self.pending.removeFirst() - if request.body == nil { - request.body = buffer - } else { - request.body!.writeBuffer(&buffer) - } - self.pending.prepend(request) - case .end: - let request = self.pending.removeFirst() - self.processRequest(context: context, request: request) - } - } - - func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) { - var responseStatus: HTTPResponseStatus - var responseBody: String? - var responseHeaders: [(String, String)]? - - if request.head.uri.hasSuffix("/next") { - let requestId = UUID().uuidString - responseStatus = .ok - switch self.mode { - case .string: - responseBody = requestId - case .json: - responseBody = "{ \"body\": \"\(requestId)\" }" - } - let deadline = Int64(Date(timeIntervalSinceNow: 60).timeIntervalSince1970 * 1000) - responseHeaders = [ - (AmazonHeaders.requestID, requestId), - (AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:123456789012:function:custom-runtime"), - (AmazonHeaders.traceID, "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419;Sampled=1"), - (AmazonHeaders.deadline, String(deadline)), - ] - } else if request.head.uri.hasSuffix("/response") { - responseStatus = .accepted - } else { - responseStatus = .notFound - } - self.writeResponse(context: context, status: responseStatus, headers: responseHeaders, body: responseBody) - } - - func writeResponse( - context: ChannelHandlerContext, - status: HTTPResponseStatus, - headers: [(String, String)]? = nil, - body: String? = nil - ) { - var headers = HTTPHeaders(headers ?? []) - headers.add(name: "content-length", value: "\(body?.utf8.count ?? 0)") - let head = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: status, headers: headers) - - context.write(wrapOutboundOut(.head(head))).whenFailure { error in - print("\(self) write error \(error)") - } - - if let b = body { - var buffer = context.channel.allocator.buffer(capacity: b.utf8.count) - buffer.writeString(b) - context.write(wrapOutboundOut(.body(.byteBuffer(buffer)))).whenFailure { error in - print("\(self) write error \(error)") - } - } - - context.writeAndFlush(wrapOutboundOut(.end(nil))).whenComplete { result in - if case .failure(let error) = result { - print("\(self) write error \(error)") - } - } - } -} - -enum ServerError: Error { - case notReady - case cantBind -} - -enum AmazonHeaders { - static let requestID = "Lambda-Runtime-Aws-Request-Id" - static let traceID = "Lambda-Runtime-Trace-Id" - static let clientContext = "X-Amz-Client-Context" - static let cognitoIdentity = "X-Amz-Cognito-Identity" - static let deadline = "Lambda-Runtime-Deadline-Ms" - static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" -} - -enum Mode: String { - case string - case json -} - -func env(_ name: String) -> String? { - guard let value = getenv(name) else { - return nil - } - return String(cString: value) -} - -// main -let server = MockServer() -try! server.start() -dispatchMain()