Skip to content

Commit 5a21326

Browse files
committed
chore: refactor speaker & handshaker into actors
1 parent b41d364 commit 5a21326

File tree

4 files changed

+131
-96
lines changed

4 files changed

+131
-96
lines changed

Coder Desktop/Coder Desktop.xcodeproj/project.pbxproj

+4-4
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@
660660
PRODUCT_BUNDLE_IDENTIFIER = "com.coder.Coder-Desktop";
661661
PRODUCT_NAME = "$(TARGET_NAME)";
662662
SWIFT_EMIT_LOC_STRINGS = YES;
663-
SWIFT_VERSION = 5.0;
663+
SWIFT_VERSION = 6.0;
664664
};
665665
name = Debug;
666666
};
@@ -690,7 +690,7 @@
690690
PRODUCT_BUNDLE_IDENTIFIER = "com.coder.Coder-Desktop";
691691
PRODUCT_NAME = "$(TARGET_NAME)";
692692
SWIFT_EMIT_LOC_STRINGS = YES;
693-
SWIFT_VERSION = 5.0;
693+
SWIFT_VERSION = 6.0;
694694
};
695695
name = Release;
696696
};
@@ -835,7 +835,7 @@
835835
PRODUCT_BUNDLE_IDENTIFIER = "com.coder.Coder-Desktop.ProtoTests";
836836
PRODUCT_NAME = "$(TARGET_NAME)";
837837
SWIFT_EMIT_LOC_STRINGS = NO;
838-
SWIFT_VERSION = 5.0;
838+
SWIFT_VERSION = 6.0;
839839
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/Coder Desktop.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/Coder Desktop";
840840
};
841841
name = Debug;
@@ -853,7 +853,7 @@
853853
PRODUCT_BUNDLE_IDENTIFIER = "com.coder.Coder-Desktop.ProtoTests";
854854
PRODUCT_NAME = "$(TARGET_NAME)";
855855
SWIFT_EMIT_LOC_STRINGS = NO;
856-
SWIFT_VERSION = 5.0;
856+
SWIFT_VERSION = 6.0;
857857
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/Coder Desktop.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/Coder Desktop";
858858
};
859859
name = Release;

Coder Desktop/Proto/Receiver.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ actor Receiver<RecvMsg: Message> {
2222
dispatch.read(offset: 0, length: 4, queue: queue) { done, data, error in
2323
guard error == 0 else {
2424
let errStrPtr = strerror(error)
25-
let errStr = String(validatingUTF8: errStrPtr!)!
25+
let errStr = String(validatingCString: errStrPtr!)!
2626
continuation.resume(throwing: ReceiveError.readError(errStr))
2727
return
2828
}
@@ -42,7 +42,7 @@ actor Receiver<RecvMsg: Message> {
4242
dispatch.read(offset: 0, length: Int(length), queue: queue) { done, data, error in
4343
guard error == 0 else {
4444
let errStrPtr = strerror(error)
45-
let errStr = String(validatingUTF8: errStrPtr!)!
45+
let errStr = String(validatingCString: errStrPtr!)!
4646
continuation.resume(throwing: ReceiveError.readError(errStr))
4747
return
4848
}
@@ -57,7 +57,7 @@ actor Receiver<RecvMsg: Message> {
5757

5858
/// Starts reading protocol messages from the `DispatchIO` channel and returns them as an `AsyncStream` of messages.
5959
/// On read or decoding error, it logs and closes the stream.
60-
func messages() throws -> AsyncStream<RecvMsg> {
60+
func messages() throws(ReceiveError) -> AsyncStream<RecvMsg> {
6161
if running {
6262
throw ReceiveError.alreadyRunning
6363
}

Coder Desktop/Proto/Speaker.swift

+90-41
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ let newLine = 0x0A
66
let headerPreamble = "codervpn"
77

88
/// A message that has the `rpc` property for recording participation in a unary RPC.
9-
protocol RPCMessage {
9+
protocol RPCMessage: Sendable {
1010
var rpc: Vpn_RPC { get set }
1111
/// Returns true if `rpc` has been explicitly set.
1212
var hasRpc: Bool { get }
@@ -49,8 +49,8 @@ struct ProtoVersion: CustomStringConvertible, Equatable, Codable {
4949
}
5050
}
5151

52-
/// An abstract base class for implementations that need to communicate using the VPN protocol.
53-
class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
52+
/// An actor that communicates using the VPN protocol
53+
actor Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
5454
private let logger = Logger(subsystem: "com.coder.Coder-Desktop", category: "proto")
5555
private let writeFD: FileHandle
5656
private let readFD: FileHandle
@@ -59,6 +59,8 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
5959
private let sender: Sender<SendMsg>
6060
private let receiver: Receiver<RecvMsg>
6161
private let secretary = RPCSecretary<RecvMsg>()
62+
private var messageBuffer: MessageBuffer = .init()
63+
private var readLoopTask: Task<Void, any Error>?
6264
let role: ProtoRole
6365

6466
/// Creates an instance that communicates over the provided file handles.
@@ -93,41 +95,45 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
9395
try _ = await hndsh.handshake()
9496
}
9597

96-
/// Reads and handles protocol messages.
97-
func readLoop() async throws {
98-
for try await msg in try await receiver.messages() {
99-
guard msg.hasRpc else {
100-
handleMessage(msg)
101-
continue
102-
}
103-
guard msg.rpc.msgID == 0 else {
104-
let req = RPCRequest<SendMsg, RecvMsg>(req: msg, sender: sender)
105-
handleRPC(req)
106-
continue
107-
}
108-
guard msg.rpc.responseTo == 0 else {
109-
logger.debug("got RPC reply for msgID \(msg.rpc.responseTo)")
110-
do throws(RPCError) {
111-
try await self.secretary.route(reply: msg)
112-
} catch {
113-
logger.error(
114-
"couldn't route RPC reply for \(msg.rpc.responseTo): \(error)")
98+
public func start() {
99+
guard readLoopTask == nil else {
100+
logger.error("speaker is already running")
101+
return
102+
}
103+
readLoopTask = Task {
104+
do throws(ReceiveError) {
105+
for try await msg in try await self.receiver.messages() {
106+
guard msg.hasRpc else {
107+
await messageBuffer.push(.message(msg))
108+
continue
109+
}
110+
guard msg.rpc.msgID == 0 else {
111+
let req = RPCRequest<SendMsg, RecvMsg>(req: msg, sender: self.sender)
112+
await messageBuffer.push(.RPC(req))
113+
continue
114+
}
115+
guard msg.rpc.responseTo == 0 else {
116+
self.logger.debug("got RPC reply for msgID \(msg.rpc.responseTo)")
117+
do throws(RPCError) {
118+
try await self.secretary.route(reply: msg)
119+
} catch {
120+
self.logger.error(
121+
"couldn't route RPC reply for \(msg.rpc.responseTo): \(error)")
122+
}
123+
continue
124+
}
115125
}
116-
continue
126+
} catch {
127+
self.logger.error("failed to receive messages: \(error)")
117128
}
118129
}
119130
}
120131

121-
/// Handles a single non-RPC message. It is expected that subclasses override this method with their own handlers.
122-
func handleMessage(_ msg: RecvMsg) {
123-
// just log
124-
logger.debug("got non-RPC message \(msg.textFormatString())")
125-
}
126-
127-
/// Handle a single RPC request. It is expected that subclasses override this method with their own handlers.
128-
func handleRPC(_ req: RPCRequest<SendMsg, RecvMsg>) {
129-
// just log
130-
logger.debug("got RPC message \(req.msg.textFormatString())")
132+
func wait() async throws {
133+
guard let task = readLoopTask else {
134+
return
135+
}
136+
try await task.value
131137
}
132138

133139
/// Send a unary RPC message and handle the response
@@ -166,10 +172,51 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
166172
logger.error("failed to close read file handle: \(error)")
167173
}
168174
}
175+
176+
enum IncomingMessage {
177+
case message(RecvMsg)
178+
case RPC(RPCRequest<SendMsg, RecvMsg>)
179+
}
180+
181+
private actor MessageBuffer {
182+
private var messages: [IncomingMessage] = []
183+
private var continuations: [CheckedContinuation<IncomingMessage?, Never>] = []
184+
185+
func push(_ message: IncomingMessage?) {
186+
if let continuation = continuations.first {
187+
continuations.removeFirst()
188+
continuation.resume(returning: message)
189+
} else if let message = message {
190+
messages.append(message)
191+
}
192+
}
193+
194+
func next() async -> IncomingMessage? {
195+
if let message = messages.first {
196+
messages.removeFirst()
197+
return message
198+
}
199+
return await withCheckedContinuation { continuation in
200+
continuations.append(continuation)
201+
}
202+
}
203+
}
169204
}
170205

171-
/// A class that performs the initial VPN protocol handshake and version negotiation.
172-
class Handshaker {
206+
extension Speaker: AsyncSequence, AsyncIteratorProtocol {
207+
typealias Element = IncomingMessage
208+
209+
public nonisolated func makeAsyncIterator() -> Speaker<SendMsg, RecvMsg> {
210+
self
211+
}
212+
213+
func next() async throws -> IncomingMessage? {
214+
return await messageBuffer.next()
215+
}
216+
}
217+
218+
/// An actor performs the initial VPN protocol handshake and version negotiation.
219+
actor Handshaker {
173220
private let writeFD: FileHandle
174221
private let dispatch: DispatchIO
175222
private var theirData: Data = .init()
@@ -193,17 +240,19 @@ class Handshaker {
193240
func handshake() async throws -> ProtoVersion {
194241
// kick off the read async before we try to write, synchronously, so we don't deadlock, both
195242
// waiting to write with nobody reading.
196-
async let theirs = try withCheckedThrowingContinuation { cont in
197-
continuation = cont
198-
// send in a nil read to kick us off
199-
handleRead(false, nil, 0)
243+
let readTask = Task {
244+
try await withCheckedThrowingContinuation { cont in
245+
self.continuation = cont
246+
// send in a nil read to kick us off
247+
self.handleRead(false, nil, 0)
248+
}
200249
}
201250

202251
let vStr = versions.map { $0.description }.joined(separator: ",")
203252
let ours = String(format: "\(headerPreamble) \(role) \(vStr)\n")
204253
try writeFD.write(contentsOf: ours.data(using: .utf8)!)
205254

206-
let theirData = try await theirs
255+
let theirData = try await readTask.value
207256
guard let theirsString = String(bytes: theirData, encoding: .utf8) else {
208257
throw HandshakeError.invalidHeader("<unparsable: \(theirData)")
209258
}
@@ -219,7 +268,7 @@ class Handshaker {
219268
private func handleRead(_: Bool, _ data: DispatchData?, _ error: Int32) {
220269
guard error == 0 else {
221270
let errStrPtr = strerror(error)
222-
let errStr = String(validatingUTF8: errStrPtr!)!
271+
let errStr = String(validatingCString: errStrPtr!)!
223272
continuation?.resume(throwing: HandshakeError.readError(errStr))
224273
return
225274
}

Coder Desktop/ProtoTests/SpeakerTests.swift

+34-48
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,19 @@
22
import Foundation
33
import Testing
44

5-
/// A concrete, test class for the abstract Speaker, which overrides the handlers to send things to
6-
/// continuations we set in the test.
7-
class TestTunnel: Speaker<Vpn_TunnelMessage, Vpn_ManagerMessage> {
8-
var msgHandler: CheckedContinuation<Vpn_ManagerMessage, Error>?
9-
override func handleMessage(_ msg: Vpn_ManagerMessage) {
10-
msgHandler?.resume(returning: msg)
11-
}
12-
13-
var rpcHandler: CheckedContinuation<RPCRequest<Vpn_TunnelMessage, Vpn_ManagerMessage>, Error>?
14-
override func handleRPC(_ req: RPCRequest<Vpn_TunnelMessage, Vpn_ManagerMessage>) {
15-
rpcHandler?.resume(returning: req)
16-
}
17-
}
18-
195
@Suite(.timeLimit(.minutes(1)))
206
struct SpeakerTests {
217
let pipeMT = Pipe()
228
let pipeTM = Pipe()
23-
let uut: TestTunnel
9+
let uut: Speaker<Vpn_TunnelMessage, Vpn_ManagerMessage>
2410
let sender: Sender<Vpn_ManagerMessage>
2511
let dispatch: DispatchIO
2612
let receiver: Receiver<Vpn_TunnelMessage>
2713
let handshaker: Handshaker
2814

2915
init() {
3016
let queue = DispatchQueue.global(qos: .utility)
31-
uut = TestTunnel(
17+
uut = Speaker(
3218
writeFD: pipeTM.fileHandleForWriting,
3319
readFD: pipeMT.fileHandleForReading
3420
)
@@ -54,45 +40,45 @@ struct SpeakerTests {
5440
}
5541

5642
@Test func handleSingleMessage() async throws {
57-
async let readDone: () = try uut.readLoop()
43+
await uut.start()
5844

59-
let got = try await withCheckedThrowingContinuation { continuation in
60-
uut.msgHandler = continuation
61-
Task {
62-
var s = Vpn_ManagerMessage()
63-
s.start = Vpn_StartRequest()
64-
await #expect(throws: Never.self) {
65-
try await sender.send(s)
66-
}
67-
}
45+
var s = Vpn_ManagerMessage()
46+
s.start = Vpn_StartRequest()
47+
await #expect(throws: Never.self) {
48+
try await sender.send(s)
49+
}
50+
let got = try #require(await uut.next())
51+
guard case let .message(msg) = got else {
52+
Issue.record("Received unexpected message from speaker")
53+
return
6854
}
69-
#expect(got.msg == .start(Vpn_StartRequest()))
55+
#expect(msg.msg == .start(Vpn_StartRequest()))
7056
try await sender.close()
71-
try await readDone
57+
try await uut.wait()
7258
}
7359

7460
@Test func handleRPC() async throws {
75-
async let readDone: () = try uut.readLoop()
61+
await uut.start()
7662

77-
let got = try await withCheckedThrowingContinuation { continuation in
78-
uut.rpcHandler = continuation
79-
Task {
80-
var s = Vpn_ManagerMessage()
81-
s.start = Vpn_StartRequest()
82-
s.rpc = Vpn_RPC()
83-
s.rpc.msgID = 33
84-
await #expect(throws: Never.self) {
85-
try await sender.send(s)
86-
}
87-
}
63+
var s = Vpn_ManagerMessage()
64+
s.start = Vpn_StartRequest()
65+
s.rpc = Vpn_RPC()
66+
s.rpc.msgID = 33
67+
await #expect(throws: Never.self) {
68+
try await sender.send(s)
69+
}
70+
let got = try #require(await uut.next())
71+
guard case let .RPC(req) = got else {
72+
Issue.record("Received unexpected message from speaker")
73+
return
8874
}
89-
#expect(got.msg.msg == .start(Vpn_StartRequest()))
90-
#expect(got.msg.rpc.msgID == 33)
75+
#expect(req.msg.msg == .start(Vpn_StartRequest()))
76+
#expect(req.msg.rpc.msgID == 33)
9177
var reply = Vpn_TunnelMessage()
9278
reply.start = Vpn_StartResponse()
9379
reply.rpc.responseTo = 33
94-
try await got.sendReply(reply)
95-
uut.closeWrite()
80+
try await req.sendReply(reply)
81+
await uut.closeWrite()
9682

9783
var count = 0
9884
await #expect(throws: Never.self) {
@@ -103,11 +89,11 @@ struct SpeakerTests {
10389
#expect(count == 1)
10490
}
10591
try await sender.close()
106-
try await readDone
92+
try await uut.wait()
10793
}
10894

10995
@Test func sendRPCs() async throws {
110-
async let readDone: () = try uut.readLoop()
96+
await uut.start()
11197

11298
async let managerDone = Task {
11399
var count = 0
@@ -129,9 +115,9 @@ struct SpeakerTests {
129115
let got = try await uut.unaryRPC(req)
130116
#expect(got.networkSettings.errorMessage == "test \(i)")
131117
}
132-
uut.closeWrite()
118+
await uut.closeWrite()
133119
_ = await managerDone
134120
try await sender.close()
135-
try await readDone
121+
try await uut.wait()
136122
}
137123
}

0 commit comments

Comments
 (0)