diff --git a/Coder Desktop/Proto/Receiver.swift b/Coder Desktop/Proto/Receiver.swift index 6a279e6..2797bad 100644 --- a/Coder Desktop/Proto/Receiver.swift +++ b/Coder Desktop/Proto/Receiver.swift @@ -57,7 +57,7 @@ actor Receiver { /// Starts reading protocol messages from the `DispatchIO` channel and returns them as an `AsyncStream` of messages. /// On read or decoding error, it logs and closes the stream. - func messages() throws -> AsyncStream { + func messages() throws(ReceiveError) -> AsyncStream { if running { throw ReceiveError.alreadyRunning } diff --git a/Coder Desktop/Proto/Speaker.swift b/Coder Desktop/Proto/Speaker.swift index ca0740d..6751aee 100644 --- a/Coder Desktop/Proto/Speaker.swift +++ b/Coder Desktop/Proto/Speaker.swift @@ -6,7 +6,7 @@ let newLine = 0x0A let headerPreamble = "codervpn" /// A message that has the `rpc` property for recording participation in a unary RPC. -protocol RPCMessage { +protocol RPCMessage: Sendable { var rpc: Vpn_RPC { get set } /// Returns true if `rpc` has been explicitly set. var hasRpc: Bool { get } @@ -49,8 +49,8 @@ struct ProtoVersion: CustomStringConvertible, Equatable, Codable { } } -/// An abstract base class for implementations that need to communicate using the VPN protocol. -class Speaker { +/// An actor that communicates using the VPN protocol +actor Speaker { private let logger = Logger(subsystem: "com.coder.Coder-Desktop", category: "proto") private let writeFD: FileHandle private let readFD: FileHandle @@ -93,43 +93,6 @@ class Speaker { try _ = await hndsh.handshake() } - /// Reads and handles protocol messages. - func readLoop() async throws { - for try await msg in try await receiver.messages() { - guard msg.hasRpc else { - handleMessage(msg) - continue - } - guard msg.rpc.msgID == 0 else { - let req = RPCRequest(req: msg, sender: sender) - handleRPC(req) - continue - } - guard msg.rpc.responseTo == 0 else { - logger.debug("got RPC reply for msgID \(msg.rpc.responseTo)") - do throws(RPCError) { - try await self.secretary.route(reply: msg) - } catch { - logger.error( - "couldn't route RPC reply for \(msg.rpc.responseTo): \(error)") - } - continue - } - } - } - - /// Handles a single non-RPC message. It is expected that subclasses override this method with their own handlers. - func handleMessage(_ msg: RecvMsg) { - // just log - logger.debug("got non-RPC message \(msg.textFormatString())") - } - - /// Handle a single RPC request. It is expected that subclasses override this method with their own handlers. - func handleRPC(_ req: RPCRequest) { - // just log - logger.debug("got RPC message \(req.msg.textFormatString())") - } - /// Send a unary RPC message and handle the response func unaryRPC(_ req: SendMsg) async throws -> RecvMsg { return try await withCheckedThrowingContinuation { continuation in @@ -166,10 +129,45 @@ class Speaker { logger.error("failed to close read file handle: \(error)") } } + + enum IncomingMessage { + case message(RecvMsg) + case RPC(RPCRequest) + } +} + +extension Speaker: AsyncSequence, AsyncIteratorProtocol { + typealias Element = IncomingMessage + + public nonisolated func makeAsyncIterator() -> Speaker { + self + } + + func next() async throws -> IncomingMessage? { + for try await msg in try await receiver.messages() { + guard msg.hasRpc else { + return .message(msg) + } + guard msg.rpc.msgID == 0 else { + return .RPC(RPCRequest(req: msg, sender: sender)) + } + guard msg.rpc.responseTo == 0 else { + logger.debug("got RPC reply for msgID \(msg.rpc.responseTo)") + do throws(RPCError) { + try await self.secretary.route(reply: msg) + } catch { + logger.error( + "couldn't route RPC reply for \(msg.rpc.responseTo): \(error)") + } + continue + } + } + return nil + } } -/// A class that performs the initial VPN protocol handshake and version negotiation. -class Handshaker: @unchecked Sendable { +/// An actor performs the initial VPN protocol handshake and version negotiation. +actor Handshaker { private let writeFD: FileHandle private let dispatch: DispatchIO private var theirData: Data = .init() @@ -193,17 +191,19 @@ class Handshaker: @unchecked Sendable { func handshake() async throws -> ProtoVersion { // kick off the read async before we try to write, synchronously, so we don't deadlock, both // waiting to write with nobody reading. - async let theirs = try withCheckedThrowingContinuation { cont in - continuation = cont - // send in a nil read to kick us off - handleRead(false, nil, 0) + let readTask = Task { + try await withCheckedThrowingContinuation { cont in + self.continuation = cont + // send in a nil read to kick us off + self.handleRead(false, nil, 0) + } } let vStr = versions.map { $0.description }.joined(separator: ",") let ours = String(format: "\(headerPreamble) \(role) \(vStr)\n") try writeFD.write(contentsOf: ours.data(using: .utf8)!) - let theirData = try await theirs + let theirData = try await readTask.value guard let theirsString = String(bytes: theirData, encoding: .utf8) else { throw HandshakeError.invalidHeader(", @unchecked Sendable { - private var msgHandler: CheckedContinuation? - override func handleMessage(_ msg: Vpn_ManagerMessage) { - msgHandler?.resume(returning: msg) - } - - /// Runs the given closure asynchronously and returns the next non-RPC message received. - func expectMessage(with closure: - @escaping @Sendable () async -> Void) async throws -> Vpn_ManagerMessage - { - return try await withCheckedThrowingContinuation { continuation in - msgHandler = continuation - Task { - await closure() - } - } - } - - private var rpcHandler: CheckedContinuation, Error>? - override func handleRPC(_ req: RPCRequest) { - rpcHandler?.resume(returning: req) - } - - /// Runs the given closure asynchronously and return the next non-RPC message received - func expectRPC(with closure: - @escaping @Sendable () async -> Void) async throws -> - RPCRequest - { - return try await withCheckedThrowingContinuation { continuation in - rpcHandler = continuation - Task { - await closure() - } - } - } -} - @Suite(.timeLimit(.minutes(1))) struct SpeakerTests: Sendable { let pipeMT = Pipe() let pipeTM = Pipe() - let uut: TestTunnel + let uut: Speaker let sender: Sender let dispatch: DispatchIO let receiver: Receiver @@ -53,7 +14,7 @@ struct SpeakerTests: Sendable { init() { let queue = DispatchQueue.global(qos: .utility) - uut = TestTunnel( + uut = Speaker( writeFD: pipeTM.fileHandleForWriting, readFD: pipeMT.fileHandleForReading ) @@ -79,39 +40,40 @@ struct SpeakerTests: Sendable { } @Test func handleSingleMessage() async throws { - async let readDone: () = try uut.readLoop() - - let got = try await uut.expectMessage { - var s = Vpn_ManagerMessage() - s.start = Vpn_StartRequest() - await #expect(throws: Never.self) { - try await sender.send(s) - } + var s = Vpn_ManagerMessage() + s.start = Vpn_StartRequest() + await #expect(throws: Never.self) { + try await sender.send(s) + } + let got = try #require(await uut.next()) + guard case let .message(msg) = got else { + Issue.record("Received unexpected message from speaker") + return } - #expect(got.msg == .start(Vpn_StartRequest())) + #expect(msg.msg == .start(Vpn_StartRequest())) try await sender.close() - try await readDone } @Test func handleRPC() async throws { - async let readDone: () = try uut.readLoop() - - let got = try await uut.expectRPC { - var s = Vpn_ManagerMessage() - s.start = Vpn_StartRequest() - s.rpc = Vpn_RPC() - s.rpc.msgID = 33 - await #expect(throws: Never.self) { - try await sender.send(s) - } + var s = Vpn_ManagerMessage() + s.start = Vpn_StartRequest() + s.rpc = Vpn_RPC() + s.rpc.msgID = 33 + await #expect(throws: Never.self) { + try await sender.send(s) + } + let got = try #require(await uut.next()) + guard case let .RPC(req) = got else { + Issue.record("Received unexpected message from speaker") + return } - #expect(got.msg.msg == .start(Vpn_StartRequest())) - #expect(got.msg.rpc.msgID == 33) + #expect(req.msg.msg == .start(Vpn_StartRequest())) + #expect(req.msg.rpc.msgID == 33) var reply = Vpn_TunnelMessage() reply.start = Vpn_StartResponse() reply.rpc.responseTo = 33 - try await got.sendReply(reply) - uut.closeWrite() + try await req.sendReply(reply) + await uut.closeWrite() var count = 0 await #expect(throws: Never.self) { @@ -122,12 +84,13 @@ struct SpeakerTests: Sendable { #expect(count == 1) } try await sender.close() - try await readDone } @Test func sendRPCs() async throws { - async let readDone: () = try uut.readLoop() - + // Speaker must be reading from the receiver for `unaryRPC` to return + let readDone = Task { + for try await _ in uut {} + } async let managerDone = Task { var count = 0 for try await req in try await receiver.messages() { @@ -148,9 +111,9 @@ struct SpeakerTests: Sendable { let got = try await uut.unaryRPC(req) #expect(got.networkSettings.errorMessage == "test \(i)") } - uut.closeWrite() + await uut.closeWrite() _ = await managerDone try await sender.close() - try await readDone + try await readDone.value } }