Skip to content

Commit 6081b3c

Browse files
committed
Fixing soundness
1 parent 98df9f9 commit 6081b3c

File tree

3 files changed

+108
-87
lines changed

3 files changed

+108
-87
lines changed

Sources/AWSLambdaRuntimeCore/NewLambdaRuntimeClient.swift

Lines changed: 87 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15+
import Logging
1516
import NIOCore
1617
import NIOHTTP1
1718
import NIOPosix
18-
import Logging
1919
import _NIOBase64
2020

2121
final actor NewLambdaRuntimeClient: LambdaRuntimeClientProtocol {
@@ -36,15 +36,15 @@ final actor NewLambdaRuntimeClient: LambdaRuntimeClientProtocol {
3636
func write(_ buffer: NIOCore.ByteBuffer) async throws {
3737
try await self.runtimeClient.write(buffer)
3838
}
39-
39+
4040
func finish() async throws {
4141
try await self.runtimeClient.finish()
4242
}
43-
43+
4444
func writeAndFinish(_ buffer: NIOCore.ByteBuffer) async throws {
4545
try await self.runtimeClient.writeAndFinish(buffer)
4646
}
47-
47+
4848
func reportError(_ error: any Error) async throws {
4949
try await self.runtimeClient.reportError(error)
5050
}
@@ -120,7 +120,8 @@ final actor NewLambdaRuntimeClient: LambdaRuntimeClientProtocol {
120120
case .connecting(var array):
121121
// Since we do get sequential invocations this case normally should never be hit.
122122
// We'll support it anyway.
123-
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<LambdaChannelHandler, any Error>) in
123+
return try await withCheckedThrowingContinuation {
124+
(continuation: CheckedContinuation<LambdaChannelHandler, any Error>) in
124125
array.append(continuation)
125126
self.connectionState = .connecting(array)
126127
}
@@ -211,16 +212,17 @@ private final class LambdaChannelHandler {
211212
func nextInvocation(isolation: isolated (any Actor)? = #isolation) async throws -> Invocation {
212213
switch self.state {
213214
case .connected(let context, .idle):
214-
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Invocation, any Error>) in
215+
return try await withCheckedThrowingContinuation {
216+
(continuation: CheckedContinuation<Invocation, any Error>) in
215217
self.state = .connected(context, .waitingForNextInvocation(continuation))
216218
self.sendNextRequest(context: context)
217219
}
218220

219221
case .connected(_, .closing),
220-
.connected(_, .sendingResponse),
221-
.connected(_, .sentResponse),
222-
.connected(_, .waitingForNextInvocation),
223-
.connected(_, .waitingForResponse):
222+
.connected(_, .sendingResponse),
223+
.connected(_, .sentResponse),
224+
.connected(_, .waitingForNextInvocation),
225+
.connected(_, .waitingForResponse):
224226
fatalError()
225227

226228
case .disconnected:
@@ -232,7 +234,7 @@ private final class LambdaChannelHandler {
232234
func reportError(isolation: isolated (any Actor)? = #isolation, _ error: any Error) async throws {
233235
switch self.state {
234236
case .connected(_, .idle(.none)),
235-
.connected(_, .waitingForNextInvocation):
237+
.connected(_, .waitingForNextInvocation):
236238
fatalError("Invalid state: \(self.state)")
237239

238240
case .connected(let context, .waitingForResponse(let requestID)):
@@ -248,14 +250,17 @@ private final class LambdaChannelHandler {
248250
}
249251

250252
case .connected(_, .idle(previousRequestID: .some(let requestID))),
251-
.connected(_, .sentResponse(let requestID, _)):
253+
.connected(_, .sentResponse(let requestID, _)):
252254
// The final response has already been sent. The only way to report the unhandled error
253255
// now is to log it. Normally this library never logs higher than debug, we make an
254256
// exception here, as there is no other way of reporting the error otherwise.
255-
self.logger.error("Unhandled error after stream has finished", metadata: [
256-
"lambda_request_id": "\(requestID)",
257-
"lambda_error": "\(String(describing: error))"
258-
])
257+
self.logger.error(
258+
"Unhandled error after stream has finished",
259+
metadata: [
260+
"lambda_request_id": "\(requestID)",
261+
"lambda_error": "\(String(describing: error))",
262+
]
263+
)
259264

260265
case .disconnected, .connected(_, .closing):
261266
// TODO: throw error here
@@ -266,7 +271,7 @@ private final class LambdaChannelHandler {
266271
func writeResponseBodyPart(isolation: isolated (any Actor)? = #isolation, _ byteBuffer: ByteBuffer) async throws {
267272
switch self.state {
268273
case .connected(_, .idle(.none)),
269-
.connected(_, .waitingForNextInvocation):
274+
.connected(_, .waitingForNextInvocation):
270275
fatalError("Invalid state: \(self.state)")
271276

272277
case .connected(let context, .waitingForResponse(let requestID)):
@@ -277,11 +282,10 @@ private final class LambdaChannelHandler {
277282
try await self.sendResponseBodyPart(byteBuffer, sendHeadWithRequestID: nil, context: context)
278283

279284
case .connected(_, .idle(previousRequestID: .some(let requestID))),
280-
.connected(_, .sentResponse(let requestID, _)):
285+
.connected(_, .sentResponse(let requestID, _)):
281286
// TODO: throw error here – user tries to write after the stream has been finished
282287
fatalError()
283288

284-
285289
case .disconnected, .connected(_, .closing):
286290
// TODO: throw error here
287291
fatalError()
@@ -291,7 +295,7 @@ private final class LambdaChannelHandler {
291295
func finishResponseRequest(isolation: isolated (any Actor)? = #isolation, finalData: ByteBuffer?) async throws {
292296
switch self.state {
293297
case .connected(_, .idle(.none)),
294-
.connected(_, .waitingForNextInvocation):
298+
.connected(_, .waitingForNextInvocation):
295299
fatalError("Invalid state: \(self.state)")
296300

297301
case .connected(let context, .waitingForResponse(let requestID)):
@@ -307,11 +311,10 @@ private final class LambdaChannelHandler {
307311
}
308312

309313
case .connected(_, .idle(previousRequestID: .some(let requestID))),
310-
.connected(_, .sentResponse(let requestID, _)):
314+
.connected(_, .sentResponse(let requestID, _)):
311315
// TODO: throw error here – user tries to write after the stream has been finished
312316
fatalError()
313317

314-
315318
case .disconnected, .connected(_, .closing):
316319
// TODO: throw error here
317320
fatalError()
@@ -355,14 +358,15 @@ private final class LambdaChannelHandler {
355358
// TODO: This feels super expensive. We should be able to make this cheaper. requestIDs are fixed length
356359
let url = Consts.invocationURLPrefix + "/" + requestID + Consts.postResponseURLSuffix
357360

358-
let headers: HTTPHeaders = if byteBuffer?.readableBytes ?? 0 < 6_000_000 {
359-
[
360-
"user-agent": "Swift-Lambda/Unknown",
361-
"content-length": "\(byteBuffer?.readableBytes ?? 0)",
362-
]
363-
} else {
364-
LambdaRuntimeClient.streamingHeaders
365-
}
361+
let headers: HTTPHeaders =
362+
if byteBuffer?.readableBytes ?? 0 < 6_000_000 {
363+
[
364+
"user-agent": "Swift-Lambda/Unknown",
365+
"content-length": "\(byteBuffer?.readableBytes ?? 0)",
366+
]
367+
} else {
368+
LambdaRuntimeClient.streamingHeaders
369+
}
366370

367371
let httpRequest = HTTPRequestHead(
368372
version: .http1_1,
@@ -383,7 +387,12 @@ private final class LambdaChannelHandler {
383387
}
384388

385389
private func sendNextRequest(context: ChannelHandlerContext) {
386-
let httpRequest = HTTPRequestHead(version: .http1_1, method: .GET, uri: self.nextInvocationPath, headers: LambdaRuntimeClient.defaultHeaders)
390+
let httpRequest = HTTPRequestHead(
391+
version: .http1_1,
392+
method: .GET,
393+
uri: self.nextInvocationPath,
394+
headers: LambdaRuntimeClient.defaultHeaders
395+
)
387396

388397
context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil)
389398
context.write(self.wrapOutboundOut(.end(nil)), promise: nil)
@@ -478,40 +487,40 @@ extension LambdaChannelHandler: ChannelInboundHandler {
478487
break
479488
}
480489

481-
// // As defined in RFC 7230 Section 6.3:
482-
// // HTTP/1.1 defaults to the use of "persistent connections", allowing
483-
// // multiple requests and responses to be carried over a single
484-
// // connection. The "close" connection option is used to signal that a
485-
// // connection will not persist after the current request/response. HTTP
486-
// // implementations SHOULD support persistent connections.
487-
// //
488-
// // That's why we only assume the connection shall be closed if we receive
489-
// // a "connection = close" header.
490-
// let serverCloseConnection =
491-
// response.head.headers["connection"].contains(where: { $0.lowercased() == "close" })
492-
//
493-
// let closeConnection = serverCloseConnection || response.head.version != .http1_1
494-
//
495-
// if closeConnection {
496-
// // If we were succeeding the request promise here directly and closing the connection
497-
// // after succeeding the promise we may run into a race condition:
498-
// //
499-
// // The lambda runtime will ask for the next work item directly after a succeeded post
500-
// // response request. The desire for the next work item might be faster than the attempt
501-
// // to close the connection. This will lead to a situation where we try to the connection
502-
// // but the next request has already been scheduled on the connection that we want to
503-
// // close. For this reason we postpone succeeding the promise until the connection has
504-
// // been closed. This codepath will only be hit in the very, very unlikely event of the
505-
// // Lambda control plane demanding to close connection. (It's more or less only
506-
// // implemented to support http1.1 correctly.) This behavior is ensured with the test
507-
// // `LambdaTest.testNoKeepAliveServer`.
508-
// self.state = .waitForConnectionClose(httpResponse, promise)
509-
// _ = context.channel.close()
510-
// return
511-
// } else {
512-
// self.state = .idle
513-
// promise.succeed(httpResponse)
514-
// }
490+
// // As defined in RFC 7230 Section 6.3:
491+
// // HTTP/1.1 defaults to the use of "persistent connections", allowing
492+
// // multiple requests and responses to be carried over a single
493+
// // connection. The "close" connection option is used to signal that a
494+
// // connection will not persist after the current request/response. HTTP
495+
// // implementations SHOULD support persistent connections.
496+
// //
497+
// // That's why we only assume the connection shall be closed if we receive
498+
// // a "connection = close" header.
499+
// let serverCloseConnection =
500+
// response.head.headers["connection"].contains(where: { $0.lowercased() == "close" })
501+
//
502+
// let closeConnection = serverCloseConnection || response.head.version != .http1_1
503+
//
504+
// if closeConnection {
505+
// // If we were succeeding the request promise here directly and closing the connection
506+
// // after succeeding the promise we may run into a race condition:
507+
// //
508+
// // The lambda runtime will ask for the next work item directly after a succeeded post
509+
// // response request. The desire for the next work item might be faster than the attempt
510+
// // to close the connection. This will lead to a situation where we try to the connection
511+
// // but the next request has already been scheduled on the connection that we want to
512+
// // close. For this reason we postpone succeeding the promise until the connection has
513+
// // been closed. This codepath will only be hit in the very, very unlikely event of the
514+
// // Lambda control plane demanding to close connection. (It's more or less only
515+
// // implemented to support http1.1 correctly.) This behavior is ensured with the test
516+
// // `LambdaTest.testNoKeepAliveServer`.
517+
// self.state = .waitForConnectionClose(httpResponse, promise)
518+
// _ = context.channel.close()
519+
// return
520+
// } else {
521+
// self.state = .idle
522+
// promise.succeed(httpResponse)
523+
// }
515524
}
516525

517526
func errorCaught(context: ChannelHandlerContext, error: Error) {
@@ -524,19 +533,19 @@ extension LambdaChannelHandler: ChannelInboundHandler {
524533
// fail any pending responses with last error or assume peer disconnected
525534
context.fireChannelInactive()
526535

527-
// switch self.state {
528-
// case .idle:
529-
// break
530-
//
531-
// case .running(let promise, let timeout):
532-
// self.state = .idle
533-
// timeout?.cancel()
534-
// promise.fail(self.lastError ?? HTTPClient.Errors.connectionResetByPeer)
535-
//
536-
// case .waitForConnectionClose(let response, let promise):
537-
// self.state = .idle
538-
// promise.succeed(response)
539-
// }
536+
// switch self.state {
537+
// case .idle:
538+
// break
539+
//
540+
// case .running(let promise, let timeout):
541+
// self.state = .idle
542+
// timeout?.cancel()
543+
// promise.fail(self.lastError ?? HTTPClient.Errors.connectionResetByPeer)
544+
//
545+
// case .waitForConnectionClose(let response, let promise):
546+
// self.state = .idle
547+
// promise.succeed(response)
548+
// }
540549
}
541550
}
542551

Sources/AWSLambdaRuntimeCore/NewLambdaRuntimeError.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@ struct NewLambdaRuntimeError: Error {
1818
case finishAfterFinishHasBeenSent
1919
case lostConnectionToControlPlane
2020
case unexpectedStatusCodeForRequest
21-
21+
2222
}
2323

2424
var code: Code
2525

26-
2726
}

Tests/AWSLambdaRuntimeCoreTests/NewLambdaRuntimeClientTests.swift

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
1+
//===----------------------------------------------------------------------===//
12
//
2-
// NewLambdaRuntimeClientTests.swift
3-
// swift-aws-lambda-runtime
3+
// This source file is part of the SwiftAWSLambdaRuntime open source project
44
//
5-
// Created by Fabian Fett on 28.08.24.
5+
// Copyright (c) 2024 Apple Inc. and the SwiftAWSLambdaRuntime project authors
6+
// Licensed under Apache License v2.0
67
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
714

8-
import Testing
15+
import Logging
916
import NIOCore
1017
import NIOPosix
11-
import Logging
18+
import Testing
19+
1220
import struct Foundation.UUID
21+
1322
@testable import AWSLambdaRuntimeCore
1423

1524
@Suite
@@ -52,7 +61,8 @@ struct NewLambdaRuntimeClientTests {
5261
let configuration = NewLambdaRuntimeClient.Configuration(ip: "127.0.0.1", port: 7000)
5362

5463
try await NewLambdaRuntimeClient.withRuntimeClient(
55-
configuration: configuration, eventLoop: eventLoopGroup.next(),
64+
configuration: configuration,
65+
eventLoop: eventLoopGroup.next(),
5666
logger: self.logger
5767
) { runtimeClient in
5868
do {
@@ -77,7 +87,10 @@ struct NewLambdaRuntimeClientTests {
7787
}
7888
}
7989

80-
func withMockServer<Result>(behaviour: some LambdaServerBehavior, _ body: (MockLambdaServer, MultiThreadedEventLoopGroup) async throws -> Result) async throws -> Result {
90+
func withMockServer<Result>(
91+
behaviour: some LambdaServerBehavior,
92+
_ body: (MockLambdaServer, MultiThreadedEventLoopGroup) async throws -> Result
93+
) async throws -> Result {
8194
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
8295
let server = MockLambdaServer(behavior: behaviour)
8396
_ = try await server.start().get()

0 commit comments

Comments
 (0)