Skip to content

Commit e9f5c6f

Browse files
chore: refactor speaker & handshaker into actors (#15)
Instead of relying on class inheritance, the new Speaker can composed into whatever would like to speak the CoderVPN protocol, and messages can be handled by iterating over the speaker itself e.g: ```swift enum IncomingMessage { case message(RecvMsg) case RPC(RPCRequest<SendMsg, RecvMsg>) } ``` ```swift for try await msg in speaker { switch msg { case let .message(msg): // Handle message that doesn't require a response case let .RPC(req): // Handle incoming RPC } } ```
1 parent ae65c20 commit e9f5c6f

File tree

3 files changed

+81
-118
lines changed

3 files changed

+81
-118
lines changed

Coder Desktop/Proto/Receiver.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ actor Receiver<RecvMsg: Message> {
5757

5858
/// Starts reading protocol messages from the `DispatchIO` channel and returns them as an `AsyncStream` of messages.
5959
/// On read or decoding error, it logs and closes the stream.
60-
func messages() throws -> AsyncStream<RecvMsg> {
60+
func messages() throws(ReceiveError) -> AsyncStream<RecvMsg> {
6161
if running {
6262
throw ReceiveError.alreadyRunning
6363
}

Coder Desktop/Proto/Speaker.swift

+47-47
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ let newLine = 0x0A
66
let headerPreamble = "codervpn"
77

88
/// A message that has the `rpc` property for recording participation in a unary RPC.
9-
protocol RPCMessage {
9+
protocol RPCMessage: Sendable {
1010
var rpc: Vpn_RPC { get set }
1111
/// Returns true if `rpc` has been explicitly set.
1212
var hasRpc: Bool { get }
@@ -49,8 +49,8 @@ struct ProtoVersion: CustomStringConvertible, Equatable, Codable {
4949
}
5050
}
5151

52-
/// An abstract base class for implementations that need to communicate using the VPN protocol.
53-
class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
52+
/// An actor that communicates using the VPN protocol
53+
actor Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
5454
private let logger = Logger(subsystem: "com.coder.Coder-Desktop", category: "proto")
5555
private let writeFD: FileHandle
5656
private let readFD: FileHandle
@@ -93,43 +93,6 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
9393
try _ = await hndsh.handshake()
9494
}
9595

96-
/// Reads and handles protocol messages.
97-
func readLoop() async throws {
98-
for try await msg in try await receiver.messages() {
99-
guard msg.hasRpc else {
100-
handleMessage(msg)
101-
continue
102-
}
103-
guard msg.rpc.msgID == 0 else {
104-
let req = RPCRequest<SendMsg, RecvMsg>(req: msg, sender: sender)
105-
handleRPC(req)
106-
continue
107-
}
108-
guard msg.rpc.responseTo == 0 else {
109-
logger.debug("got RPC reply for msgID \(msg.rpc.responseTo)")
110-
do throws(RPCError) {
111-
try await self.secretary.route(reply: msg)
112-
} catch {
113-
logger.error(
114-
"couldn't route RPC reply for \(msg.rpc.responseTo): \(error)")
115-
}
116-
continue
117-
}
118-
}
119-
}
120-
121-
/// Handles a single non-RPC message. It is expected that subclasses override this method with their own handlers.
122-
func handleMessage(_ msg: RecvMsg) {
123-
// just log
124-
logger.debug("got non-RPC message \(msg.textFormatString())")
125-
}
126-
127-
/// Handle a single RPC request. It is expected that subclasses override this method with their own handlers.
128-
func handleRPC(_ req: RPCRequest<SendMsg, RecvMsg>) {
129-
// just log
130-
logger.debug("got RPC message \(req.msg.textFormatString())")
131-
}
132-
13396
/// Send a unary RPC message and handle the response
13497
func unaryRPC(_ req: SendMsg) async throws -> RecvMsg {
13598
return try await withCheckedThrowingContinuation { continuation in
@@ -166,10 +129,45 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
166129
logger.error("failed to close read file handle: \(error)")
167130
}
168131
}
132+
133+
enum IncomingMessage {
134+
case message(RecvMsg)
135+
case RPC(RPCRequest<SendMsg, RecvMsg>)
136+
}
137+
}
138+
139+
extension Speaker: AsyncSequence, AsyncIteratorProtocol {
140+
typealias Element = IncomingMessage
141+
142+
public nonisolated func makeAsyncIterator() -> Speaker<SendMsg, RecvMsg> {
143+
self
144+
}
145+
146+
func next() async throws -> IncomingMessage? {
147+
for try await msg in try await receiver.messages() {
148+
guard msg.hasRpc else {
149+
return .message(msg)
150+
}
151+
guard msg.rpc.msgID == 0 else {
152+
return .RPC(RPCRequest<SendMsg, RecvMsg>(req: msg, sender: sender))
153+
}
154+
guard msg.rpc.responseTo == 0 else {
155+
logger.debug("got RPC reply for msgID \(msg.rpc.responseTo)")
156+
do throws(RPCError) {
157+
try await self.secretary.route(reply: msg)
158+
} catch {
159+
logger.error(
160+
"couldn't route RPC reply for \(msg.rpc.responseTo): \(error)")
161+
}
162+
continue
163+
}
164+
}
165+
return nil
166+
}
169167
}
170168

171-
/// A class that performs the initial VPN protocol handshake and version negotiation.
172-
class Handshaker: @unchecked Sendable {
169+
/// An actor performs the initial VPN protocol handshake and version negotiation.
170+
actor Handshaker {
173171
private let writeFD: FileHandle
174172
private let dispatch: DispatchIO
175173
private var theirData: Data = .init()
@@ -193,17 +191,19 @@ class Handshaker: @unchecked Sendable {
193191
func handshake() async throws -> ProtoVersion {
194192
// kick off the read async before we try to write, synchronously, so we don't deadlock, both
195193
// waiting to write with nobody reading.
196-
async let theirs = try withCheckedThrowingContinuation { cont in
197-
continuation = cont
198-
// send in a nil read to kick us off
199-
handleRead(false, nil, 0)
194+
let readTask = Task {
195+
try await withCheckedThrowingContinuation { cont in
196+
self.continuation = cont
197+
// send in a nil read to kick us off
198+
self.handleRead(false, nil, 0)
199+
}
200200
}
201201

202202
let vStr = versions.map { $0.description }.joined(separator: ",")
203203
let ours = String(format: "\(headerPreamble) \(role) \(vStr)\n")
204204
try writeFD.write(contentsOf: ours.data(using: .utf8)!)
205205

206-
let theirData = try await theirs
206+
let theirData = try await readTask.value
207207
guard let theirsString = String(bytes: theirData, encoding: .utf8) else {
208208
throw HandshakeError.invalidHeader("<unparsable: \(theirData)")
209209
}

Coder Desktop/ProtoTests/SpeakerTests.swift

+33-70
Original file line numberDiff line numberDiff line change
@@ -2,58 +2,19 @@
22
import Foundation
33
import Testing
44

5-
/// A concrete, test class for the abstract Speaker, which overrides the handlers to send things to
6-
/// continuations we set in the test.
7-
class TestTunnel: Speaker<Vpn_TunnelMessage, Vpn_ManagerMessage>, @unchecked Sendable {
8-
private var msgHandler: CheckedContinuation<Vpn_ManagerMessage, Error>?
9-
override func handleMessage(_ msg: Vpn_ManagerMessage) {
10-
msgHandler?.resume(returning: msg)
11-
}
12-
13-
/// Runs the given closure asynchronously and returns the next non-RPC message received.
14-
func expectMessage(with closure:
15-
@escaping @Sendable () async -> Void) async throws -> Vpn_ManagerMessage
16-
{
17-
return try await withCheckedThrowingContinuation { continuation in
18-
msgHandler = continuation
19-
Task {
20-
await closure()
21-
}
22-
}
23-
}
24-
25-
private var rpcHandler: CheckedContinuation<RPCRequest<Vpn_TunnelMessage, Vpn_ManagerMessage>, Error>?
26-
override func handleRPC(_ req: RPCRequest<Vpn_TunnelMessage, Vpn_ManagerMessage>) {
27-
rpcHandler?.resume(returning: req)
28-
}
29-
30-
/// Runs the given closure asynchronously and return the next non-RPC message received
31-
func expectRPC(with closure:
32-
@escaping @Sendable () async -> Void) async throws ->
33-
RPCRequest<Vpn_TunnelMessage, Vpn_ManagerMessage>
34-
{
35-
return try await withCheckedThrowingContinuation { continuation in
36-
rpcHandler = continuation
37-
Task {
38-
await closure()
39-
}
40-
}
41-
}
42-
}
43-
445
@Suite(.timeLimit(.minutes(1)))
456
struct SpeakerTests: Sendable {
467
let pipeMT = Pipe()
478
let pipeTM = Pipe()
48-
let uut: TestTunnel
9+
let uut: Speaker<Vpn_TunnelMessage, Vpn_ManagerMessage>
4910
let sender: Sender<Vpn_ManagerMessage>
5011
let dispatch: DispatchIO
5112
let receiver: Receiver<Vpn_TunnelMessage>
5213
let handshaker: Handshaker
5314

5415
init() {
5516
let queue = DispatchQueue.global(qos: .utility)
56-
uut = TestTunnel(
17+
uut = Speaker(
5718
writeFD: pipeTM.fileHandleForWriting,
5819
readFD: pipeMT.fileHandleForReading
5920
)
@@ -79,39 +40,40 @@ struct SpeakerTests: Sendable {
7940
}
8041

8142
@Test func handleSingleMessage() async throws {
82-
async let readDone: () = try uut.readLoop()
83-
84-
let got = try await uut.expectMessage {
85-
var s = Vpn_ManagerMessage()
86-
s.start = Vpn_StartRequest()
87-
await #expect(throws: Never.self) {
88-
try await sender.send(s)
89-
}
43+
var s = Vpn_ManagerMessage()
44+
s.start = Vpn_StartRequest()
45+
await #expect(throws: Never.self) {
46+
try await sender.send(s)
47+
}
48+
let got = try #require(await uut.next())
49+
guard case let .message(msg) = got else {
50+
Issue.record("Received unexpected message from speaker")
51+
return
9052
}
91-
#expect(got.msg == .start(Vpn_StartRequest()))
53+
#expect(msg.msg == .start(Vpn_StartRequest()))
9254
try await sender.close()
93-
try await readDone
9455
}
9556

9657
@Test func handleRPC() async throws {
97-
async let readDone: () = try uut.readLoop()
98-
99-
let got = try await uut.expectRPC {
100-
var s = Vpn_ManagerMessage()
101-
s.start = Vpn_StartRequest()
102-
s.rpc = Vpn_RPC()
103-
s.rpc.msgID = 33
104-
await #expect(throws: Never.self) {
105-
try await sender.send(s)
106-
}
58+
var s = Vpn_ManagerMessage()
59+
s.start = Vpn_StartRequest()
60+
s.rpc = Vpn_RPC()
61+
s.rpc.msgID = 33
62+
await #expect(throws: Never.self) {
63+
try await sender.send(s)
64+
}
65+
let got = try #require(await uut.next())
66+
guard case let .RPC(req) = got else {
67+
Issue.record("Received unexpected message from speaker")
68+
return
10769
}
108-
#expect(got.msg.msg == .start(Vpn_StartRequest()))
109-
#expect(got.msg.rpc.msgID == 33)
70+
#expect(req.msg.msg == .start(Vpn_StartRequest()))
71+
#expect(req.msg.rpc.msgID == 33)
11072
var reply = Vpn_TunnelMessage()
11173
reply.start = Vpn_StartResponse()
11274
reply.rpc.responseTo = 33
113-
try await got.sendReply(reply)
114-
uut.closeWrite()
75+
try await req.sendReply(reply)
76+
await uut.closeWrite()
11577

11678
var count = 0
11779
await #expect(throws: Never.self) {
@@ -122,12 +84,13 @@ struct SpeakerTests: Sendable {
12284
#expect(count == 1)
12385
}
12486
try await sender.close()
125-
try await readDone
12687
}
12788

12889
@Test func sendRPCs() async throws {
129-
async let readDone: () = try uut.readLoop()
130-
90+
// Speaker must be reading from the receiver for `unaryRPC` to return
91+
let readDone = Task {
92+
for try await _ in uut {}
93+
}
13194
async let managerDone = Task {
13295
var count = 0
13396
for try await req in try await receiver.messages() {
@@ -148,9 +111,9 @@ struct SpeakerTests: Sendable {
148111
let got = try await uut.unaryRPC(req)
149112
#expect(got.networkSettings.errorMessage == "test \(i)")
150113
}
151-
uut.closeWrite()
114+
await uut.closeWrite()
152115
_ = await managerDone
153116
try await sender.close()
154-
try await readDone
117+
try await readDone.value
155118
}
156119
}

0 commit comments

Comments
 (0)