Skip to content

chore: refactor speaker & handshaker into actors #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Coder Desktop/Proto/Receiver.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ actor Receiver<RecvMsg: Message> {

/// Starts reading protocol messages from the `DispatchIO` channel and returns them as an `AsyncStream` of messages.
/// On read or decoding error, it logs and closes the stream.
func messages() throws -> AsyncStream<RecvMsg> {
func messages() throws(ReceiveError) -> AsyncStream<RecvMsg> {
if running {
throw ReceiveError.alreadyRunning
}
Expand Down
94 changes: 47 additions & 47 deletions Coder Desktop/Proto/Speaker.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ let newLine = 0x0A
let headerPreamble = "codervpn"

/// A message that has the `rpc` property for recording participation in a unary RPC.
protocol RPCMessage {
protocol RPCMessage: Sendable {
var rpc: Vpn_RPC { get set }
/// Returns true if `rpc` has been explicitly set.
var hasRpc: Bool { get }
Expand Down Expand Up @@ -49,8 +49,8 @@ struct ProtoVersion: CustomStringConvertible, Equatable, Codable {
}
}

/// An abstract base class for implementations that need to communicate using the VPN protocol.
class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
/// An actor that communicates using the VPN protocol
actor Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
private let logger = Logger(subsystem: "com.coder.Coder-Desktop", category: "proto")
private let writeFD: FileHandle
private let readFD: FileHandle
Expand Down Expand Up @@ -93,43 +93,6 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
try _ = await hndsh.handshake()
}

/// Reads and handles protocol messages.
func readLoop() async throws {
for try await msg in try await receiver.messages() {
guard msg.hasRpc else {
handleMessage(msg)
continue
}
guard msg.rpc.msgID == 0 else {
let req = RPCRequest<SendMsg, RecvMsg>(req: msg, sender: sender)
handleRPC(req)
continue
}
guard msg.rpc.responseTo == 0 else {
logger.debug("got RPC reply for msgID \(msg.rpc.responseTo)")
do throws(RPCError) {
try await self.secretary.route(reply: msg)
} catch {
logger.error(
"couldn't route RPC reply for \(msg.rpc.responseTo): \(error)")
}
continue
}
}
}

/// Handles a single non-RPC message. It is expected that subclasses override this method with their own handlers.
func handleMessage(_ msg: RecvMsg) {
// just log
logger.debug("got non-RPC message \(msg.textFormatString())")
}

/// Handle a single RPC request. It is expected that subclasses override this method with their own handlers.
func handleRPC(_ req: RPCRequest<SendMsg, RecvMsg>) {
// just log
logger.debug("got RPC message \(req.msg.textFormatString())")
}

/// Send a unary RPC message and handle the response
func unaryRPC(_ req: SendMsg) async throws -> RecvMsg {
return try await withCheckedThrowingContinuation { continuation in
Expand Down Expand Up @@ -166,10 +129,45 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
logger.error("failed to close read file handle: \(error)")
}
}

enum IncomingMessage {
case message(RecvMsg)
case RPC(RPCRequest<SendMsg, RecvMsg>)
}
}

extension Speaker: AsyncSequence, AsyncIteratorProtocol {
typealias Element = IncomingMessage

public nonisolated func makeAsyncIterator() -> Speaker<SendMsg, RecvMsg> {
self
}

func next() async throws -> IncomingMessage? {
for try await msg in try await receiver.messages() {
guard msg.hasRpc else {
return .message(msg)
}
guard msg.rpc.msgID == 0 else {
return .RPC(RPCRequest<SendMsg, RecvMsg>(req: msg, sender: sender))
}
guard msg.rpc.responseTo == 0 else {
logger.debug("got RPC reply for msgID \(msg.rpc.responseTo)")
do throws(RPCError) {
try await self.secretary.route(reply: msg)
} catch {
logger.error(
"couldn't route RPC reply for \(msg.rpc.responseTo): \(error)")
}
continue
}
}
return nil
}
}

/// A class that performs the initial VPN protocol handshake and version negotiation.
class Handshaker: @unchecked Sendable {
/// An actor performs the initial VPN protocol handshake and version negotiation.
actor Handshaker {
private let writeFD: FileHandle
private let dispatch: DispatchIO
private var theirData: Data = .init()
Expand All @@ -193,17 +191,19 @@ class Handshaker: @unchecked Sendable {
func handshake() async throws -> ProtoVersion {
// kick off the read async before we try to write, synchronously, so we don't deadlock, both
// waiting to write with nobody reading.
async let theirs = try withCheckedThrowingContinuation { cont in
continuation = cont
// send in a nil read to kick us off
handleRead(false, nil, 0)
let readTask = Task {
try await withCheckedThrowingContinuation { cont in
self.continuation = cont
// send in a nil read to kick us off
self.handleRead(false, nil, 0)
}
}

let vStr = versions.map { $0.description }.joined(separator: ",")
let ours = String(format: "\(headerPreamble) \(role) \(vStr)\n")
try writeFD.write(contentsOf: ours.data(using: .utf8)!)

let theirData = try await theirs
let theirData = try await readTask.value
guard let theirsString = String(bytes: theirData, encoding: .utf8) else {
throw HandshakeError.invalidHeader("<unparsable: \(theirData)")
}
Expand Down
103 changes: 33 additions & 70 deletions Coder Desktop/ProtoTests/SpeakerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,19 @@
import Foundation
import Testing

/// A concrete, test class for the abstract Speaker, which overrides the handlers to send things to
/// continuations we set in the test.
class TestTunnel: Speaker<Vpn_TunnelMessage, Vpn_ManagerMessage>, @unchecked Sendable {
private var msgHandler: CheckedContinuation<Vpn_ManagerMessage, Error>?
override func handleMessage(_ msg: Vpn_ManagerMessage) {
msgHandler?.resume(returning: msg)
}

/// Runs the given closure asynchronously and returns the next non-RPC message received.
func expectMessage(with closure:
@escaping @Sendable () async -> Void) async throws -> Vpn_ManagerMessage
{
return try await withCheckedThrowingContinuation { continuation in
msgHandler = continuation
Task {
await closure()
}
}
}

private var rpcHandler: CheckedContinuation<RPCRequest<Vpn_TunnelMessage, Vpn_ManagerMessage>, Error>?
override func handleRPC(_ req: RPCRequest<Vpn_TunnelMessage, Vpn_ManagerMessage>) {
rpcHandler?.resume(returning: req)
}

/// Runs the given closure asynchronously and return the next non-RPC message received
func expectRPC(with closure:
@escaping @Sendable () async -> Void) async throws ->
RPCRequest<Vpn_TunnelMessage, Vpn_ManagerMessage>
{
return try await withCheckedThrowingContinuation { continuation in
rpcHandler = continuation
Task {
await closure()
}
}
}
}

@Suite(.timeLimit(.minutes(1)))
struct SpeakerTests: Sendable {
let pipeMT = Pipe()
let pipeTM = Pipe()
let uut: TestTunnel
let uut: Speaker<Vpn_TunnelMessage, Vpn_ManagerMessage>
let sender: Sender<Vpn_ManagerMessage>
let dispatch: DispatchIO
let receiver: Receiver<Vpn_TunnelMessage>
let handshaker: Handshaker

init() {
let queue = DispatchQueue.global(qos: .utility)
uut = TestTunnel(
uut = Speaker(
writeFD: pipeTM.fileHandleForWriting,
readFD: pipeMT.fileHandleForReading
)
Expand All @@ -79,39 +40,40 @@ struct SpeakerTests: Sendable {
}

@Test func handleSingleMessage() async throws {
async let readDone: () = try uut.readLoop()

let got = try await uut.expectMessage {
var s = Vpn_ManagerMessage()
s.start = Vpn_StartRequest()
await #expect(throws: Never.self) {
try await sender.send(s)
}
var s = Vpn_ManagerMessage()
s.start = Vpn_StartRequest()
await #expect(throws: Never.self) {
try await sender.send(s)
}
let got = try #require(await uut.next())
guard case let .message(msg) = got else {
Issue.record("Received unexpected message from speaker")
return
}
#expect(got.msg == .start(Vpn_StartRequest()))
#expect(msg.msg == .start(Vpn_StartRequest()))
try await sender.close()
try await readDone
}

@Test func handleRPC() async throws {
async let readDone: () = try uut.readLoop()

let got = try await uut.expectRPC {
var s = Vpn_ManagerMessage()
s.start = Vpn_StartRequest()
s.rpc = Vpn_RPC()
s.rpc.msgID = 33
await #expect(throws: Never.self) {
try await sender.send(s)
}
var s = Vpn_ManagerMessage()
s.start = Vpn_StartRequest()
s.rpc = Vpn_RPC()
s.rpc.msgID = 33
await #expect(throws: Never.self) {
try await sender.send(s)
}
let got = try #require(await uut.next())
guard case let .RPC(req) = got else {
Issue.record("Received unexpected message from speaker")
return
}
#expect(got.msg.msg == .start(Vpn_StartRequest()))
#expect(got.msg.rpc.msgID == 33)
#expect(req.msg.msg == .start(Vpn_StartRequest()))
#expect(req.msg.rpc.msgID == 33)
var reply = Vpn_TunnelMessage()
reply.start = Vpn_StartResponse()
reply.rpc.responseTo = 33
try await got.sendReply(reply)
uut.closeWrite()
try await req.sendReply(reply)
await uut.closeWrite()

var count = 0
await #expect(throws: Never.self) {
Expand All @@ -122,12 +84,13 @@ struct SpeakerTests: Sendable {
#expect(count == 1)
}
try await sender.close()
try await readDone
}

@Test func sendRPCs() async throws {
async let readDone: () = try uut.readLoop()

// Speaker must be reading from the receiver for `unaryRPC` to return
let readDone = Task {
for try await _ in uut {}
}
async let managerDone = Task {
var count = 0
for try await req in try await receiver.messages() {
Expand All @@ -148,9 +111,9 @@ struct SpeakerTests: Sendable {
let got = try await uut.unaryRPC(req)
#expect(got.networkSettings.errorMessage == "test \(i)")
}
uut.closeWrite()
await uut.closeWrite()
_ = await managerDone
try await sender.close()
try await readDone
try await readDone.value
}
}