Skip to content

Commit fbc4a84

Browse files
committed
chore: refactor speaker & handshaker into actors
1 parent ae65c20 commit fbc4a84

File tree

3 files changed

+124
-108
lines changed

3 files changed

+124
-108
lines changed

Coder Desktop/Proto/Receiver.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ actor Receiver<RecvMsg: Message> {
5757

5858
/// Starts reading protocol messages from the `DispatchIO` channel and returns them as an `AsyncStream` of messages.
5959
/// On read or decoding error, it logs and closes the stream.
60-
func messages() throws -> AsyncStream<RecvMsg> {
60+
func messages() throws(ReceiveError) -> AsyncStream<RecvMsg> {
6161
if running {
6262
throw ReceiveError.alreadyRunning
6363
}

Coder Desktop/Proto/Speaker.swift

+89-40
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ let newLine = 0x0A
66
let headerPreamble = "codervpn"
77

88
/// A message that has the `rpc` property for recording participation in a unary RPC.
9-
protocol RPCMessage {
9+
protocol RPCMessage: Sendable {
1010
var rpc: Vpn_RPC { get set }
1111
/// Returns true if `rpc` has been explicitly set.
1212
var hasRpc: Bool { get }
@@ -49,8 +49,8 @@ struct ProtoVersion: CustomStringConvertible, Equatable, Codable {
4949
}
5050
}
5151

52-
/// An abstract base class for implementations that need to communicate using the VPN protocol.
53-
class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
52+
/// An actor that communicates using the VPN protocol
53+
actor Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
5454
private let logger = Logger(subsystem: "com.coder.Coder-Desktop", category: "proto")
5555
private let writeFD: FileHandle
5656
private let readFD: FileHandle
@@ -59,6 +59,8 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
5959
private let sender: Sender<SendMsg>
6060
private let receiver: Receiver<RecvMsg>
6161
private let secretary = RPCSecretary<RecvMsg>()
62+
private var messageBuffer: MessageBuffer = .init()
63+
private var readLoopTask: Task<Void, any Error>?
6264
let role: ProtoRole
6365

6466
/// Creates an instance that communicates over the provided file handles.
@@ -93,41 +95,45 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
9395
try _ = await hndsh.handshake()
9496
}
9597

96-
/// Reads and handles protocol messages.
97-
func readLoop() async throws {
98-
for try await msg in try await receiver.messages() {
99-
guard msg.hasRpc else {
100-
handleMessage(msg)
101-
continue
102-
}
103-
guard msg.rpc.msgID == 0 else {
104-
let req = RPCRequest<SendMsg, RecvMsg>(req: msg, sender: sender)
105-
handleRPC(req)
106-
continue
107-
}
108-
guard msg.rpc.responseTo == 0 else {
109-
logger.debug("got RPC reply for msgID \(msg.rpc.responseTo)")
110-
do throws(RPCError) {
111-
try await self.secretary.route(reply: msg)
112-
} catch {
113-
logger.error(
114-
"couldn't route RPC reply for \(msg.rpc.responseTo): \(error)")
98+
public func start() {
99+
guard readLoopTask == nil else {
100+
logger.error("speaker is already running")
101+
return
102+
}
103+
readLoopTask = Task {
104+
do throws(ReceiveError) {
105+
for try await msg in try await self.receiver.messages() {
106+
guard msg.hasRpc else {
107+
await messageBuffer.push(.message(msg))
108+
continue
109+
}
110+
guard msg.rpc.msgID == 0 else {
111+
let req = RPCRequest<SendMsg, RecvMsg>(req: msg, sender: self.sender)
112+
await messageBuffer.push(.RPC(req))
113+
continue
114+
}
115+
guard msg.rpc.responseTo == 0 else {
116+
self.logger.debug("got RPC reply for msgID \(msg.rpc.responseTo)")
117+
do throws(RPCError) {
118+
try await self.secretary.route(reply: msg)
119+
} catch {
120+
self.logger.error(
121+
"couldn't route RPC reply for \(msg.rpc.responseTo): \(error)")
122+
}
123+
continue
124+
}
115125
}
116-
continue
126+
} catch {
127+
self.logger.error("failed to receive messages: \(error)")
117128
}
118129
}
119130
}
120131

121-
/// Handles a single non-RPC message. It is expected that subclasses override this method with their own handlers.
122-
func handleMessage(_ msg: RecvMsg) {
123-
// just log
124-
logger.debug("got non-RPC message \(msg.textFormatString())")
125-
}
126-
127-
/// Handle a single RPC request. It is expected that subclasses override this method with their own handlers.
128-
func handleRPC(_ req: RPCRequest<SendMsg, RecvMsg>) {
129-
// just log
130-
logger.debug("got RPC message \(req.msg.textFormatString())")
132+
func wait() async throws {
133+
guard let task = readLoopTask else {
134+
return
135+
}
136+
try await task.value
131137
}
132138

133139
/// Send a unary RPC message and handle the response
@@ -166,10 +172,51 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
166172
logger.error("failed to close read file handle: \(error)")
167173
}
168174
}
175+
176+
enum IncomingMessage {
177+
case message(RecvMsg)
178+
case RPC(RPCRequest<SendMsg, RecvMsg>)
179+
}
180+
181+
private actor MessageBuffer {
182+
private var messages: [IncomingMessage] = []
183+
private var continuations: [CheckedContinuation<IncomingMessage?, Never>] = []
184+
185+
func push(_ message: IncomingMessage?) {
186+
if let continuation = continuations.first {
187+
continuations.removeFirst()
188+
continuation.resume(returning: message)
189+
} else if let message = message {
190+
messages.append(message)
191+
}
192+
}
193+
194+
func next() async -> IncomingMessage? {
195+
if let message = messages.first {
196+
messages.removeFirst()
197+
return message
198+
}
199+
return await withCheckedContinuation { continuation in
200+
continuations.append(continuation)
201+
}
202+
}
203+
}
169204
}
170205

171-
/// A class that performs the initial VPN protocol handshake and version negotiation.
172-
class Handshaker: @unchecked Sendable {
206+
extension Speaker: AsyncSequence, AsyncIteratorProtocol {
207+
typealias Element = IncomingMessage
208+
209+
public nonisolated func makeAsyncIterator() -> Speaker<SendMsg, RecvMsg> {
210+
self
211+
}
212+
213+
func next() async throws -> IncomingMessage? {
214+
return await messageBuffer.next()
215+
}
216+
}
217+
218+
/// An actor performs the initial VPN protocol handshake and version negotiation.
219+
actor Handshaker {
173220
private let writeFD: FileHandle
174221
private let dispatch: DispatchIO
175222
private var theirData: Data = .init()
@@ -193,17 +240,19 @@ class Handshaker: @unchecked Sendable {
193240
func handshake() async throws -> ProtoVersion {
194241
// kick off the read async before we try to write, synchronously, so we don't deadlock, both
195242
// waiting to write with nobody reading.
196-
async let theirs = try withCheckedThrowingContinuation { cont in
197-
continuation = cont
198-
// send in a nil read to kick us off
199-
handleRead(false, nil, 0)
243+
let readTask = Task {
244+
try await withCheckedThrowingContinuation { cont in
245+
self.continuation = cont
246+
// send in a nil read to kick us off
247+
self.handleRead(false, nil, 0)
248+
}
200249
}
201250

202251
let vStr = versions.map { $0.description }.joined(separator: ",")
203252
let ours = String(format: "\(headerPreamble) \(role) \(vStr)\n")
204253
try writeFD.write(contentsOf: ours.data(using: .utf8)!)
205254

206-
let theirData = try await theirs
255+
let theirData = try await readTask.value
207256
guard let theirsString = String(bytes: theirData, encoding: .utf8) else {
208257
throw HandshakeError.invalidHeader("<unparsable: \(theirData)")
209258
}

Coder Desktop/ProtoTests/SpeakerTests.swift

+34-67
Original file line numberDiff line numberDiff line change
@@ -2,58 +2,19 @@
22
import Foundation
33
import Testing
44

5-
/// A concrete, test class for the abstract Speaker, which overrides the handlers to send things to
6-
/// continuations we set in the test.
7-
class TestTunnel: Speaker<Vpn_TunnelMessage, Vpn_ManagerMessage>, @unchecked Sendable {
8-
private var msgHandler: CheckedContinuation<Vpn_ManagerMessage, Error>?
9-
override func handleMessage(_ msg: Vpn_ManagerMessage) {
10-
msgHandler?.resume(returning: msg)
11-
}
12-
13-
/// Runs the given closure asynchronously and returns the next non-RPC message received.
14-
func expectMessage(with closure:
15-
@escaping @Sendable () async -> Void) async throws -> Vpn_ManagerMessage
16-
{
17-
return try await withCheckedThrowingContinuation { continuation in
18-
msgHandler = continuation
19-
Task {
20-
await closure()
21-
}
22-
}
23-
}
24-
25-
private var rpcHandler: CheckedContinuation<RPCRequest<Vpn_TunnelMessage, Vpn_ManagerMessage>, Error>?
26-
override func handleRPC(_ req: RPCRequest<Vpn_TunnelMessage, Vpn_ManagerMessage>) {
27-
rpcHandler?.resume(returning: req)
28-
}
29-
30-
/// Runs the given closure asynchronously and return the next non-RPC message received
31-
func expectRPC(with closure:
32-
@escaping @Sendable () async -> Void) async throws ->
33-
RPCRequest<Vpn_TunnelMessage, Vpn_ManagerMessage>
34-
{
35-
return try await withCheckedThrowingContinuation { continuation in
36-
rpcHandler = continuation
37-
Task {
38-
await closure()
39-
}
40-
}
41-
}
42-
}
43-
445
@Suite(.timeLimit(.minutes(1)))
456
struct SpeakerTests: Sendable {
467
let pipeMT = Pipe()
478
let pipeTM = Pipe()
48-
let uut: TestTunnel
9+
let uut: Speaker<Vpn_TunnelMessage, Vpn_ManagerMessage>
4910
let sender: Sender<Vpn_ManagerMessage>
5011
let dispatch: DispatchIO
5112
let receiver: Receiver<Vpn_TunnelMessage>
5213
let handshaker: Handshaker
5314

5415
init() {
5516
let queue = DispatchQueue.global(qos: .utility)
56-
uut = TestTunnel(
17+
uut = Speaker(
5718
writeFD: pipeTM.fileHandleForWriting,
5819
readFD: pipeMT.fileHandleForReading
5920
)
@@ -79,39 +40,45 @@ struct SpeakerTests: Sendable {
7940
}
8041

8142
@Test func handleSingleMessage() async throws {
82-
async let readDone: () = try uut.readLoop()
43+
await uut.start()
8344

84-
let got = try await uut.expectMessage {
85-
var s = Vpn_ManagerMessage()
86-
s.start = Vpn_StartRequest()
87-
await #expect(throws: Never.self) {
88-
try await sender.send(s)
89-
}
45+
var s = Vpn_ManagerMessage()
46+
s.start = Vpn_StartRequest()
47+
await #expect(throws: Never.self) {
48+
try await sender.send(s)
9049
}
91-
#expect(got.msg == .start(Vpn_StartRequest()))
50+
let got = try #require(await uut.next())
51+
guard case let .message(msg) = got else {
52+
Issue.record("Received unexpected message from speaker")
53+
return
54+
}
55+
#expect(msg.msg == .start(Vpn_StartRequest()))
9256
try await sender.close()
93-
try await readDone
57+
try await uut.wait()
9458
}
9559

9660
@Test func handleRPC() async throws {
97-
async let readDone: () = try uut.readLoop()
61+
await uut.start()
9862

99-
let got = try await uut.expectRPC {
100-
var s = Vpn_ManagerMessage()
101-
s.start = Vpn_StartRequest()
102-
s.rpc = Vpn_RPC()
103-
s.rpc.msgID = 33
104-
await #expect(throws: Never.self) {
105-
try await sender.send(s)
106-
}
63+
var s = Vpn_ManagerMessage()
64+
s.start = Vpn_StartRequest()
65+
s.rpc = Vpn_RPC()
66+
s.rpc.msgID = 33
67+
await #expect(throws: Never.self) {
68+
try await sender.send(s)
69+
}
70+
let got = try #require(await uut.next())
71+
guard case let .RPC(req) = got else {
72+
Issue.record("Received unexpected message from speaker")
73+
return
10774
}
108-
#expect(got.msg.msg == .start(Vpn_StartRequest()))
109-
#expect(got.msg.rpc.msgID == 33)
75+
#expect(req.msg.msg == .start(Vpn_StartRequest()))
76+
#expect(req.msg.rpc.msgID == 33)
11077
var reply = Vpn_TunnelMessage()
11178
reply.start = Vpn_StartResponse()
11279
reply.rpc.responseTo = 33
113-
try await got.sendReply(reply)
114-
uut.closeWrite()
80+
try await req.sendReply(reply)
81+
await uut.closeWrite()
11582

11683
var count = 0
11784
await #expect(throws: Never.self) {
@@ -122,11 +89,11 @@ struct SpeakerTests: Sendable {
12289
#expect(count == 1)
12390
}
12491
try await sender.close()
125-
try await readDone
92+
try await uut.wait()
12693
}
12794

12895
@Test func sendRPCs() async throws {
129-
async let readDone: () = try uut.readLoop()
96+
await uut.start()
13097

13198
async let managerDone = Task {
13299
var count = 0
@@ -148,9 +115,9 @@ struct SpeakerTests: Sendable {
148115
let got = try await uut.unaryRPC(req)
149116
#expect(got.networkSettings.errorMessage == "test \(i)")
150117
}
151-
uut.closeWrite()
118+
await uut.closeWrite()
152119
_ = await managerDone
153120
try await sender.close()
154-
try await readDone
121+
try await uut.wait()
155122
}
156123
}

0 commit comments

Comments
 (0)