diff --git a/Package.swift b/Package.swift index 63d83b7b..25cbdd4f 100644 --- a/Package.swift +++ b/Package.swift @@ -23,7 +23,7 @@ let package = Package( .library(name: "AWSLambdaTesting", targets: ["AWSLambdaTesting"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", .upToNextMajor(from: "2.67.0")), + .package(url: "https://github.com/apple/swift-nio.git", .upToNextMajor(from: "2.72.0")), .package(url: "https://github.com/apple/swift-log.git", .upToNextMajor(from: "1.5.4")), .package(url: "https://github.com/apple/swift-docc-plugin.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-testing.git", branch: "swift-DEVELOPMENT-SNAPSHOT-2024-08-29-a"), diff --git a/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift b/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift index 36f88229..19b56c1f 100644 --- a/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift +++ b/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift @@ -142,8 +142,14 @@ extension LambdaRuntimeClient { static let defaultHeaders = HTTPHeaders([("user-agent", "Swift-Lambda/Unknown")]) /// These headers must be sent along an invocation or initialization error report - static let errorHeaders = HTTPHeaders([ - ("user-agent", "Swift-Lambda/Unknown"), - ("lambda-runtime-function-error-type", "Unhandled"), - ]) + static let errorHeaders: HTTPHeaders = [ + "user-agent": "Swift-Lambda/Unknown", + "lambda-runtime-function-error-type": "Unhandled", + ] + + /// These headers must be sent along an invocation or initialization error report + static let streamingHeaders: HTTPHeaders = [ + "user-agent": "Swift-Lambda/Unknown", + "transfer-encoding": "streaming", + ] } diff --git a/Sources/AWSLambdaRuntimeCore/NewLambdaRuntimeClient.swift b/Sources/AWSLambdaRuntimeCore/NewLambdaRuntimeClient.swift new file mode 100644 index 00000000..71839163 --- /dev/null +++ b/Sources/AWSLambdaRuntimeCore/NewLambdaRuntimeClient.swift @@ -0,0 +1,801 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftAWSLambdaRuntime open source project +// +// Copyright (c) 2024 Apple Inc. and the SwiftAWSLambdaRuntime project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOCore +import NIOHTTP1 +import NIOPosix + +final actor NewLambdaRuntimeClient: LambdaRuntimeClientProtocol { + nonisolated let unownedExecutor: UnownedSerialExecutor + + struct Configuration { + var ip: String + var port: Int + } + + struct Writer: LambdaRuntimeClientResponseStreamWriter { + private var runtimeClient: NewLambdaRuntimeClient + + fileprivate init(runtimeClient: NewLambdaRuntimeClient) { + self.runtimeClient = runtimeClient + } + + func write(_ buffer: NIOCore.ByteBuffer) async throws { + try await self.runtimeClient.write(buffer) + } + + func finish() async throws { + try await self.runtimeClient.writeAndFinish(nil) + } + + func writeAndFinish(_ buffer: NIOCore.ByteBuffer) async throws { + try await self.runtimeClient.writeAndFinish(buffer) + } + + func reportError(_ error: any Error) async throws { + try await self.runtimeClient.reportError(error) + } + } + + private enum ConnectionState { + case disconnected + case connecting([CheckedContinuation, any Error>]) + case connected(Channel, LambdaChannelHandler) + } + + enum LambdaState { + /// this is the "normal" state. Transitions to `waitingForNextInvocation` + case idle(previousRequestID: String?) + /// this is the state while we wait for an invocation. A next call is running. + /// Transitions to `waitingForResponse` + case waitingForNextInvocation + /// The invocation was forwarded to the handler and we wait for a response. + /// Transitions to `sendingResponse` or `sentResponse`. + case waitingForResponse(requestID: String) + case sendingResponse(requestID: String) + case sentResponse(requestID: String) + } + + enum ClosingState { + case notClosing + case closing(CheckedContinuation) + case closed + } + + private let eventLoop: any EventLoop + private let logger: Logger + private let configuration: Configuration + + private var connectionState: ConnectionState = .disconnected + private var lambdaState: LambdaState = .idle(previousRequestID: nil) + private var closingState: ClosingState = .notClosing + + // connections that are currently being closed. In the `run` method we must await all of them + // being fully closed before we can return from it. + private var closingConnections: [any Channel] = [] + + static func withRuntimeClient( + configuration: Configuration, + eventLoop: any EventLoop, + logger: Logger, + _ body: (NewLambdaRuntimeClient) async throws -> Result + ) async throws -> Result { + let runtime = NewLambdaRuntimeClient(configuration: configuration, eventLoop: eventLoop, logger: logger) + let result: Swift.Result + do { + result = .success(try await body(runtime)) + } catch { + result = .failure(error) + } + + await runtime.close() + + //try? await runtime.close() + return try result.get() + } + + private init(configuration: Configuration, eventLoop: any EventLoop, logger: Logger) { + self.unownedExecutor = eventLoop.executor.asUnownedSerialExecutor() + self.configuration = configuration + self.eventLoop = eventLoop + self.logger = logger + } + + private func close() async { + self.logger.trace("Close lambda runtime client") + + guard case .notClosing = self.closingState else { + return + } + await withCheckedContinuation { continuation in + self.closingState = .closing(continuation) + + switch self.connectionState { + case .disconnected: + if self.closingConnections.isEmpty { + return continuation.resume() + } + + case .connecting(let continuations): + for continuation in continuations { + continuation.resume(throwing: NewLambdaRuntimeError(code: .closingRuntimeClient)) + } + self.connectionState = .connecting([]) + + case .connected(let channel, _): + channel.close(mode: .all, promise: nil) + } + } + } + + func nextInvocation() async throws -> (Invocation, Writer) { + switch self.lambdaState { + case .idle: + self.lambdaState = .waitingForNextInvocation + let handler = try await self.makeOrGetConnection() + let invocation = try await handler.nextInvocation() + guard case .waitingForNextInvocation = self.lambdaState else { + fatalError("Invalid state: \(self.lambdaState)") + } + self.lambdaState = .waitingForResponse(requestID: invocation.metadata.requestID) + return (invocation, Writer(runtimeClient: self)) + + case .waitingForNextInvocation, + .waitingForResponse, + .sendingResponse, + .sentResponse: + fatalError("Invalid state: \(self.lambdaState)") + } + + } + + private func write(_ buffer: NIOCore.ByteBuffer) async throws { + switch self.lambdaState { + case .idle, .sentResponse: + throw NewLambdaRuntimeError(code: .writeAfterFinishHasBeenSent) + + case .waitingForNextInvocation: + fatalError("Invalid state: \(self.lambdaState)") + + case .waitingForResponse(let requestID): + self.lambdaState = .sendingResponse(requestID: requestID) + fallthrough + + case .sendingResponse(let requestID): + let handler = try await self.makeOrGetConnection() + guard case .sendingResponse(requestID) = self.lambdaState else { + fatalError("Invalid state: \(self.lambdaState)") + } + return try await handler.writeResponseBodyPart(buffer, requestID: requestID) + } + } + + private func writeAndFinish(_ buffer: NIOCore.ByteBuffer?) async throws { + switch self.lambdaState { + case .idle, .sentResponse: + throw NewLambdaRuntimeError(code: .finishAfterFinishHasBeenSent) + + case .waitingForNextInvocation: + fatalError("Invalid state: \(self.lambdaState)") + + case .waitingForResponse(let requestID): + fallthrough + + case .sendingResponse(let requestID): + self.lambdaState = .sentResponse(requestID: requestID) + let handler = try await self.makeOrGetConnection() + guard case .sentResponse(requestID) = self.lambdaState else { + fatalError("Invalid state: \(self.lambdaState)") + } + try await handler.finishResponseRequest(finalData: buffer, requestID: requestID) + guard case .sentResponse(requestID) = self.lambdaState else { + fatalError("Invalid state: \(self.lambdaState)") + } + self.lambdaState = .idle(previousRequestID: requestID) + } + } + + private func reportError(_ error: any Error) async throws { + switch self.lambdaState { + case .idle, .waitingForNextInvocation, .sentResponse: + fatalError("Invalid state: \(self.lambdaState)") + + case .waitingForResponse(let requestID): + fallthrough + + case .sendingResponse(let requestID): + self.lambdaState = .sentResponse(requestID: requestID) + let handler = try await self.makeOrGetConnection() + guard case .sentResponse(requestID) = self.lambdaState else { + fatalError("Invalid state: \(self.lambdaState)") + } + try await handler.reportError(error, requestID: requestID) + guard case .sentResponse(requestID) = self.lambdaState else { + fatalError("Invalid state: \(self.lambdaState)") + } + self.lambdaState = .idle(previousRequestID: requestID) + } + } + + private func channelClosed(_ channel: any Channel) { + switch (self.connectionState, self.closingState) { + case (_, .closed): + fatalError("Invalid state: \(self.connectionState), \(self.closingState)") + + case (.disconnected, .notClosing): + if let index = self.closingConnections.firstIndex(where: { $0 === channel }) { + self.closingConnections.remove(at: index) + } + + case (.disconnected, .closing(let continuation)): + if let index = self.closingConnections.firstIndex(where: { $0 === channel }) { + self.closingConnections.remove(at: index) + } + + if self.closingConnections.isEmpty { + self.closingState = .closed + continuation.resume() + } + + case (.connecting(let array), .notClosing): + self.connectionState = .disconnected + for continuation in array { + continuation.resume(throwing: NewLambdaRuntimeError(code: .lostConnectionToControlPlane)) + } + + case (.connecting(let array), .closing(let continuation)): + self.connectionState = .disconnected + precondition(array.isEmpty, "If we are closing we should have failed all connection attempts already") + if self.closingConnections.isEmpty { + self.closingState = .closed + continuation.resume() + } + + case (.connected, .notClosing): + self.connectionState = .disconnected + + case (.connected, .closing(let continuation)): + self.connectionState = .disconnected + + if self.closingConnections.isEmpty { + self.closingState = .closed + continuation.resume() + } + } + } + + private func makeOrGetConnection() async throws -> LambdaChannelHandler { + switch self.connectionState { + case .disconnected: + self.connectionState = .connecting([]) + break + case .connecting(var array): + // Since we do get sequential invocations this case normally should never be hit. + // We'll support it anyway. + return try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation, any Error>) in + array.append(continuation) + self.connectionState = .connecting(array) + } + case .connected(_, let handler): + return handler + } + + let bootstrap = ClientBootstrap(group: self.eventLoop) + .channelInitializer { channel in + do { + try channel.pipeline.syncOperations.addHTTPClientHandlers() + // Lambda quotas... An invocation payload is maximal 6MB in size: + // https://docs.aws.amazon.com/lambda/latest/dg/gettingstarted-limits.html + try channel.pipeline.syncOperations.addHandler( + NIOHTTPClientResponseAggregator(maxContentLength: 6 * 1024 * 1024) + ) + try channel.pipeline.syncOperations.addHandler( + LambdaChannelHandler(delegate: self, logger: self.logger) + ) + return channel.eventLoop.makeSucceededFuture(()) + } catch { + return channel.eventLoop.makeFailedFuture(error) + } + } + .connectTimeout(.seconds(2)) + + do { + // connect directly via socket address to avoid happy eyeballs (perf) + let address = try SocketAddress(ipAddress: self.configuration.ip, port: self.configuration.port) + let channel = try await bootstrap.connect(to: address).get() + let handler = try channel.pipeline.syncOperations.handler( + type: LambdaChannelHandler.self + ) + self.logger.trace( + "Connection to control plane created", + metadata: [ + "lambda_port": "\(self.configuration.port)", + "lambda_ip": "\(self.configuration.ip)", + ] + ) + channel.closeFuture.whenComplete { result in + self.assumeIsolated { runtimeClient in + runtimeClient.channelClosed(channel) + } + } + + switch self.connectionState { + case .disconnected, .connected: + fatalError("Unexpected state: \(self.connectionState)") + + case .connecting(let array): + self.connectionState = .connected(channel, handler) + defer { + for continuation in array { + continuation.resume(returning: handler) + } + } + return handler + } + } catch { + switch self.connectionState { + case .disconnected, .connected: + fatalError("Unexpected state: \(self.connectionState)") + + case .connecting(let array): + self.connectionState = .disconnected + defer { + for continuation in array { + continuation.resume(throwing: error) + } + } + throw error + } + } + } +} + +extension NewLambdaRuntimeClient: LambdaChannelHandlerDelegate { + nonisolated func connectionErrorHappened(_ error: any Error, channel: any Channel) { + + } + + nonisolated func connectionWillClose(channel: any Channel) { + self.assumeIsolated { isolated in + switch isolated.connectionState { + case .disconnected: + // this case should never happen. But whatever + if channel.isActive { + isolated.closingConnections.append(channel) + } + + case .connecting(let continuations): + // this case should never happen. But whatever + if channel.isActive { + isolated.closingConnections.append(channel) + } + + for continuation in continuations { + continuation.resume(throwing: NewLambdaRuntimeError(code: .connectionToControlPlaneLost)) + } + + case .connected(let stateChannel, _): + guard channel === stateChannel else { + isolated.closingConnections.append(channel) + return + } + + isolated.connectionState = .disconnected + + } + } + + } +} + +private protocol LambdaChannelHandlerDelegate { + func connectionWillClose(channel: any Channel) + func connectionErrorHappened(_ error: any Error, channel: any Channel) +} + +private final class LambdaChannelHandler { + let nextInvocationPath = Consts.invocationURLPrefix + Consts.getNextInvocationURLSuffix + + enum State { + case disconnected + case connected(ChannelHandlerContext, LambdaState) + case closing + + enum LambdaState { + /// this is the "normal" state. Transitions to `waitingForNextInvocation` + case idle + /// this is the state while we wait for an invocation. A next call is running. + /// Transitions to `waitingForResponse` + case waitingForNextInvocation(CheckedContinuation) + /// The invocation was forwarded to the handler and we wait for a response. + /// Transitions to `sendingResponse` or `sentResponse`. + case waitingForResponse + case sendingResponse + case sentResponse(CheckedContinuation) + } + } + + private var state: State = .disconnected + private var lastError: Error? + private var reusableErrorBuffer: ByteBuffer? + private let logger: Logger + private let delegate: Delegate + + init(delegate: Delegate, logger: Logger) { + self.delegate = delegate + self.logger = logger + } + + func nextInvocation(isolation: isolated (any Actor)? = #isolation) async throws -> Invocation { + switch self.state { + case .connected(let context, .idle): + return try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) in + self.state = .connected(context, .waitingForNextInvocation(continuation)) + self.sendNextRequest(context: context) + } + + case .connected(_, .sendingResponse), + .connected(_, .sentResponse), + .connected(_, .waitingForNextInvocation), + .connected(_, .waitingForResponse), + .closing: + fatalError("Invalid state: \(self.state)") + + case .disconnected: + throw NewLambdaRuntimeError(code: .connectionToControlPlaneLost) + } + } + + func reportError( + isolation: isolated (any Actor)? = #isolation, + _ error: any Error, + requestID: String + ) async throws { + switch self.state { + case .connected(_, .waitingForNextInvocation): + fatalError("Invalid state: \(self.state)") + + case .connected(let context, .waitingForResponse): + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + self.state = .connected(context, .sentResponse(continuation)) + self.sendReportErrorRequest(requestID: requestID, error: error, context: context) + } + + case .connected(let context, .sendingResponse): + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + self.state = .connected(context, .sentResponse(continuation)) + self.sendResponseStreamingFailure(error: error, context: context) + } + + case .connected(_, .idle), + .connected(_, .sentResponse): + // The final response has already been sent. The only way to report the unhandled error + // now is to log it. Normally this library never logs higher than debug, we make an + // exception here, as there is no other way of reporting the error otherwise. + self.logger.error( + "Unhandled error after stream has finished", + metadata: [ + "lambda_request_id": "\(requestID)", + "lambda_error": "\(String(describing: error))", + ] + ) + + case .disconnected: + throw NewLambdaRuntimeError(code: .connectionToControlPlaneLost) + + case .closing: + throw NewLambdaRuntimeError(code: .connectionToControlPlaneGoingAway) + } + } + + func writeResponseBodyPart( + isolation: isolated (any Actor)? = #isolation, + _ byteBuffer: ByteBuffer, + requestID: String + ) async throws { + switch self.state { + case .connected(_, .waitingForNextInvocation): + fatalError("Invalid state: \(self.state)") + + case .connected(let context, .waitingForResponse): + self.state = .connected(context, .sendingResponse) + try await self.sendResponseBodyPart(byteBuffer, sendHeadWithRequestID: requestID, context: context) + + case .connected(let context, .sendingResponse): + try await self.sendResponseBodyPart(byteBuffer, sendHeadWithRequestID: nil, context: context) + + case .connected(_, .idle), + .connected(_, .sentResponse): + throw NewLambdaRuntimeError(code: .writeAfterFinishHasBeenSent) + + case .disconnected: + throw NewLambdaRuntimeError(code: .connectionToControlPlaneLost) + + case .closing: + throw NewLambdaRuntimeError(code: .connectionToControlPlaneGoingAway) + } + } + + func finishResponseRequest( + isolation: isolated (any Actor)? = #isolation, + finalData: ByteBuffer?, + requestID: String + ) async throws { + switch self.state { + case .connected(_, .idle), + .connected(_, .waitingForNextInvocation): + fatalError("Invalid state: \(self.state)") + + case .connected(let context, .waitingForResponse): + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + self.state = .connected(context, .sentResponse(continuation)) + self.sendResponseFinish(finalData, sendHeadWithRequestID: requestID, context: context) + } + + case .connected(let context, .sendingResponse): + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + self.state = .connected(context, .sentResponse(continuation)) + self.sendResponseFinish(finalData, sendHeadWithRequestID: nil, context: context) + } + + case .connected(_, .sentResponse): + throw NewLambdaRuntimeError(code: .finishAfterFinishHasBeenSent) + + case .disconnected: + throw NewLambdaRuntimeError(code: .connectionToControlPlaneLost) + + case .closing: + throw NewLambdaRuntimeError(code: .connectionToControlPlaneGoingAway) + } + } + + private func sendResponseBodyPart( + isolation: isolated (any Actor)? = #isolation, + _ byteBuffer: ByteBuffer, + sendHeadWithRequestID: String?, + context: ChannelHandlerContext + ) async throws { + + if let requestID = sendHeadWithRequestID { + // TODO: This feels super expensive. We should be able to make this cheaper. requestIDs are fixed length + let url = Consts.invocationURLPrefix + "/" + requestID + Consts.postResponseURLSuffix + + let httpRequest = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: url, + headers: LambdaRuntimeClient.streamingHeaders + ) + + context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) + } + + let future = context.write(self.wrapOutboundOut(.body(.byteBuffer(byteBuffer)))) + context.flush() + try await future.get() + } + + private func sendResponseFinish( + isolation: isolated (any Actor)? = #isolation, + _ byteBuffer: ByteBuffer?, + sendHeadWithRequestID: String?, + context: ChannelHandlerContext + ) { + if let requestID = sendHeadWithRequestID { + // TODO: This feels quite expensive. We should be able to make this cheaper. requestIDs are fixed length + let url = "\(Consts.invocationURLPrefix)/\(requestID)\(Consts.postResponseURLSuffix)" + + // If we have less than 6MB, we don't want to use the streaming API. If we have more + // than 6MB we must use the streaming mode. + let headers: HTTPHeaders = + if byteBuffer?.readableBytes ?? 0 < 6_000_000 { + [ + "user-agent": "Swift-Lambda/Unknown", + "content-length": "\(byteBuffer?.readableBytes ?? 0)", + ] + } else { + LambdaRuntimeClient.streamingHeaders + } + + let httpRequest = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: url, + headers: headers + ) + + context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) + } + + if let byteBuffer { + context.write(self.wrapOutboundOut(.body(.byteBuffer(byteBuffer))), promise: nil) + } + + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + } + + private func sendNextRequest(context: ChannelHandlerContext) { + let httpRequest = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: self.nextInvocationPath, + headers: LambdaRuntimeClient.defaultHeaders + ) + + context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + } + + private func sendReportErrorRequest(requestID: String, error: any Error, context: ChannelHandlerContext) { + // TODO: This feels quite expensive. We should be able to make this cheaper. requestIDs are fixed length + let url = "\(Consts.invocationURLPrefix)/\(requestID)\(Consts.postErrorURLSuffix)" + + let httpRequest = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: url, + headers: LambdaRuntimeClient.errorHeaders + ) + + if self.reusableErrorBuffer == nil { + self.reusableErrorBuffer = context.channel.allocator.buffer(capacity: 1024) + } else { + self.reusableErrorBuffer!.clear() + } + + let errorResponse = ErrorResponse(errorType: Consts.functionError, errorMessage: "\(error)") + // TODO: Write this directly into our ByteBuffer + let bytes = errorResponse.toJSONBytes() + self.reusableErrorBuffer!.writeBytes(bytes) + + context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) + context.write(self.wrapOutboundOut(.body(.byteBuffer(self.reusableErrorBuffer!))), promise: nil) + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + } + + private func sendResponseStreamingFailure(error: any Error, context: ChannelHandlerContext) { + // TODO: Use base64 here + let trailers: HTTPHeaders = [ + "Lambda-Runtime-Function-Error-Type": "Unhandled", + "Lambda-Runtime-Function-Error-Body": "Requires base64", + ] + + context.write(self.wrapOutboundOut(.end(trailers)), promise: nil) + context.flush() + } + + func cancelCurrentRequestAndCloseConnection() { + fatalError("Unimplemented") + } +} + +extension LambdaChannelHandler: ChannelInboundHandler { + typealias OutboundIn = Never + typealias InboundIn = NIOHTTPClientResponseFull + typealias OutboundOut = HTTPClientRequestPart + + func handlerAdded(context: ChannelHandlerContext) { + if context.channel.isActive { + self.state = .connected(context, .idle) + } + } + + func channelActive(context: ChannelHandlerContext) { + switch self.state { + case .disconnected: + self.state = .connected(context, .idle) + case .connected: + break + case .closing: + fatalError("Invalid state: \(self.state)") + } + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let response = unwrapInboundIn(data) + + // handle response content + + switch self.state { + case .connected(let context, .waitingForNextInvocation(let continuation)): + do { + let metadata = try InvocationMetadata(headers: response.head.headers) + self.state = .connected(context, .waitingForResponse) + continuation.resume(returning: Invocation(metadata: metadata, event: response.body ?? ByteBuffer())) + } catch { + self.state = .closing + + self.delegate.connectionWillClose(channel: context.channel) + context.close(promise: nil) + continuation.resume( + throwing: NewLambdaRuntimeError(code: .invocationMissingMetadata, underlying: error) + ) + } + + case .connected(let context, .sentResponse(let continuation)): + if response.head.status == .accepted { + self.state = .connected(context, .idle) + continuation.resume() + } else { + self.state = .connected(context, .idle) + continuation.resume(throwing: NewLambdaRuntimeError(code: .unexpectedStatusCodeForRequest)) + } + + case .disconnected, .closing, .connected(_, _): + break + } + + // As defined in RFC 7230 Section 6.3: + // HTTP/1.1 defaults to the use of "persistent connections", allowing + // multiple requests and responses to be carried over a single + // connection. The "close" connection option is used to signal that a + // connection will not persist after the current request/response. HTTP + // implementations SHOULD support persistent connections. + // + // That's why we only assume the connection shall be closed if we receive + // a "connection = close" header. + let serverCloseConnection = + response.head.headers["connection"].contains(where: { $0.lowercased() == "close" }) + + let closeConnection = serverCloseConnection || response.head.version != .http1_1 + + if closeConnection { + // If we were succeeding the request promise here directly and closing the connection + // after succeeding the promise we may run into a race condition: + // + // The lambda runtime will ask for the next work item directly after a succeeded post + // response request. The desire for the next work item might be faster than the attempt + // to close the connection. This will lead to a situation where we try to the connection + // but the next request has already been scheduled on the connection that we want to + // close. For this reason we postpone succeeding the promise until the connection has + // been closed. This codepath will only be hit in the very, very unlikely event of the + // Lambda control plane demanding to close connection. (It's more or less only + // implemented to support http1.1 correctly.) This behavior is ensured with the test + // `LambdaTest.testNoKeepAliveServer`. + self.state = .closing + self.delegate.connectionWillClose(channel: context.channel) + context.close(promise: nil) + } + } + + func errorCaught(context: ChannelHandlerContext, error: Error) { + self.logger.trace( + "Channel error caught", + metadata: [ + "error": "\(error)" + ] + ) + // pending responses will fail with lastError in channelInactive since we are calling context.close + self.delegate.connectionErrorHappened(error, channel: context.channel) + + self.lastError = error + context.channel.close(promise: nil) + } + + func channelInactive(context: ChannelHandlerContext) { + // fail any pending responses with last error or assume peer disconnected + + // we don't need to forward channelInactive to the delegate, as the delegate observes the + // closeFuture + context.fireChannelInactive() + } +} + +private struct RequestCancelEvent {} diff --git a/Sources/AWSLambdaRuntimeCore/NewLambdaRuntimeError.swift b/Sources/AWSLambdaRuntimeCore/NewLambdaRuntimeError.swift new file mode 100644 index 00000000..b95a5587 --- /dev/null +++ b/Sources/AWSLambdaRuntimeCore/NewLambdaRuntimeError.swift @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftAWSLambdaRuntime open source project +// +// Copyright (c) 2024 Apple Inc. and the SwiftAWSLambdaRuntime project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +struct NewLambdaRuntimeError: Error { + enum Code { + case closingRuntimeClient + + case connectionToControlPlaneLost + case connectionToControlPlaneGoingAway + case invocationMissingMetadata + + case writeAfterFinishHasBeenSent + case finishAfterFinishHasBeenSent + case lostConnectionToControlPlane + case unexpectedStatusCodeForRequest + + } + + var code: Code + var underlying: (any Error)? + +} diff --git a/Tests/AWSLambdaRuntimeCoreTests/LambdaTest.swift b/Tests/AWSLambdaRuntimeCoreTests/LambdaTest.swift index 9379f5ee..3eb15fab 100644 --- a/Tests/AWSLambdaRuntimeCoreTests/LambdaTest.swift +++ b/Tests/AWSLambdaRuntimeCoreTests/LambdaTest.swift @@ -222,7 +222,7 @@ class LambdaTest: XCTestCase { cognitoIdentity: nil, clientContext: nil, logger: Logger(label: "test"), - eventLoop: MultiThreadedEventLoopGroup(numberOfThreads: 1).next(), + eventLoop: NIOSingletons.posixEventLoopGroup.next(), allocator: ByteBufferAllocator() ) XCTAssertGreaterThan(context.deadline, .now()) @@ -250,7 +250,7 @@ class LambdaTest: XCTestCase { cognitoIdentity: nil, clientContext: nil, logger: Logger(label: "test"), - eventLoop: MultiThreadedEventLoopGroup(numberOfThreads: 1).next(), + eventLoop: NIOSingletons.posixEventLoopGroup.next(), allocator: ByteBufferAllocator() ) XCTAssertLessThanOrEqual(context.getRemainingTime(), .seconds(1)) diff --git a/Tests/AWSLambdaRuntimeCoreTests/MockLambdaServer.swift b/Tests/AWSLambdaRuntimeCoreTests/MockLambdaServer.swift index a0859218..1d56da69 100644 --- a/Tests/AWSLambdaRuntimeCoreTests/MockLambdaServer.swift +++ b/Tests/AWSLambdaRuntimeCoreTests/MockLambdaServer.swift @@ -53,7 +53,7 @@ final class MockLambdaServer { private var shutdown = false init(behavior: LambdaServerBehavior, host: String = "127.0.0.1", port: Int = 7000, keepAlive: Bool = true) { - self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + self.group = NIOSingletons.posixEventLoopGroup self.behavior = behavior self.host = host self.port = port diff --git a/Tests/AWSLambdaRuntimeCoreTests/NewLambdaRuntimeClientTests.swift b/Tests/AWSLambdaRuntimeCoreTests/NewLambdaRuntimeClientTests.swift new file mode 100644 index 00000000..023c13a0 --- /dev/null +++ b/Tests/AWSLambdaRuntimeCoreTests/NewLambdaRuntimeClientTests.swift @@ -0,0 +1,89 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftAWSLambdaRuntime open source project +// +// Copyright (c) 2024 Apple Inc. and the SwiftAWSLambdaRuntime project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOCore +import NIOPosix +import Testing + +import struct Foundation.UUID + +@testable import AWSLambdaRuntimeCore + +@Suite +struct NewLambdaRuntimeClientTests { + + let logger = { + var logger = Logger(label: "NewLambdaClientRuntimeTest") + logger.logLevel = .trace + return logger + }() + + @Test + func testSimpleInvocations() async throws { + struct HappyBehavior: LambdaServerBehavior { + let requestId = UUID().uuidString + let event = "hello" + + func getInvocation() -> GetInvocationResult { + .success((self.requestId, self.event)) + } + + func processResponse(requestId: String, response: String?) -> Result { + #expect(self.requestId == requestId) + #expect(self.event == response) + return .success(()) + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + Issue.record("should not report error") + return .failure(.internalServerError) + } + + func processInitError(error: ErrorResponse) -> Result { + Issue.record("should not report init error") + return .failure(.internalServerError) + } + } + + try await withMockServer(behaviour: HappyBehavior()) { port in + let configuration = NewLambdaRuntimeClient.Configuration(ip: "127.0.0.1", port: port) + + try await NewLambdaRuntimeClient.withRuntimeClient( + configuration: configuration, + eventLoop: NIOSingletons.posixEventLoopGroup.next(), + logger: self.logger + ) { runtimeClient in + do { + let (invocation, writer) = try await runtimeClient.nextInvocation() + let expected = ByteBuffer(string: "hello") + #expect(invocation.event == expected) + try await writer.writeAndFinish(expected) + } + + do { + let (invocation, writer) = try await runtimeClient.nextInvocation() + let expected = ByteBuffer(string: "hello") + #expect(invocation.event == expected) + try await writer.write(ByteBuffer(string: "h")) + try await writer.write(ByteBuffer(string: "e")) + try await writer.write(ByteBuffer(string: "l")) + try await writer.write(ByteBuffer(string: "l")) + try await writer.write(ByteBuffer(string: "o")) + try await writer.finish() + } + } + } + } +} diff --git a/Tests/AWSLambdaRuntimeCoreTests/Utils.swift b/Tests/AWSLambdaRuntimeCoreTests/Utils.swift index 8bbd4730..41a2552f 100644 --- a/Tests/AWSLambdaRuntimeCoreTests/Utils.swift +++ b/Tests/AWSLambdaRuntimeCoreTests/Utils.swift @@ -62,7 +62,7 @@ func runLambda( behavior: LambdaServerBehavior, handlerProvider: @escaping (LambdaInitializationContext) async throws -> Handler ) throws { - let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let eventLoopGroup = NIOSingletons.posixEventLoopGroup.next() try runLambda( behavior: behavior, handlerProvider: { context in @@ -79,8 +79,7 @@ func runLambda( behavior: LambdaServerBehavior, handlerProvider: @escaping (LambdaInitializationContext) -> EventLoopFuture ) throws { - let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoopGroup = NIOSingletons.posixEventLoopGroup.next() let logger = Logger(label: "TestLogger") let server = MockLambdaServer(behavior: behavior, port: 0) let port = try server.start().wait()