@@ -6,7 +6,7 @@ let newLine = 0x0A
6
6
let headerPreamble = " codervpn "
7
7
8
8
/// A message that has the `rpc` property for recording participation in a unary RPC.
9
- protocol RPCMessage {
9
+ protocol RPCMessage : Sendable {
10
10
var rpc : Vpn_RPC { get set }
11
11
/// Returns true if `rpc` has been explicitly set.
12
12
var hasRpc : Bool { get }
@@ -49,8 +49,8 @@ struct ProtoVersion: CustomStringConvertible, Equatable, Codable {
49
49
}
50
50
}
51
51
52
- /// An abstract base class for implementations that need to communicate using the VPN protocol.
53
- class Speaker < SendMsg: RPCMessage & Message , RecvMsg: RPCMessage & Message > {
52
+ /// An actor that communicates using the VPN protocol
53
+ actor Speaker < SendMsg: RPCMessage & Message , RecvMsg: RPCMessage & Message > {
54
54
private let logger = Logger ( subsystem: " com.coder.Coder-Desktop " , category: " proto " )
55
55
private let writeFD : FileHandle
56
56
private let readFD : FileHandle
@@ -59,6 +59,8 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
59
59
private let sender : Sender < SendMsg >
60
60
private let receiver : Receiver < RecvMsg >
61
61
private let secretary = RPCSecretary < RecvMsg > ( )
62
+ private var messageBuffer : MessageBuffer = . init( )
63
+ private var readLoopTask : Task < Void , any Error > ?
62
64
let role : ProtoRole
63
65
64
66
/// Creates an instance that communicates over the provided file handles.
@@ -93,41 +95,45 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
93
95
try _ = await hndsh. handshake ( )
94
96
}
95
97
96
- /// Reads and handles protocol messages.
97
- func readLoop( ) async throws {
98
- for try await msg in try await receiver. messages ( ) {
99
- guard msg. hasRpc else {
100
- handleMessage ( msg)
101
- continue
102
- }
103
- guard msg. rpc. msgID == 0 else {
104
- let req = RPCRequest < SendMsg , RecvMsg > ( req: msg, sender: sender)
105
- handleRPC ( req)
106
- continue
107
- }
108
- guard msg. rpc. responseTo == 0 else {
109
- logger. debug ( " got RPC reply for msgID \( msg. rpc. responseTo) " )
110
- do throws ( RPCError) {
111
- try await self . secretary. route ( reply: msg)
112
- } catch {
113
- logger. error (
114
- " couldn't route RPC reply for \( msg. rpc. responseTo) : \( error) " )
98
+ public func start( ) {
99
+ guard readLoopTask == nil else {
100
+ logger. error ( " speaker is already running " )
101
+ return
102
+ }
103
+ readLoopTask = Task {
104
+ do throws ( ReceiveError) {
105
+ for try await msg in try await self . receiver. messages ( ) {
106
+ guard msg. hasRpc else {
107
+ await messageBuffer. push ( . message( msg) )
108
+ continue
109
+ }
110
+ guard msg. rpc. msgID == 0 else {
111
+ let req = RPCRequest < SendMsg , RecvMsg > ( req: msg, sender: self . sender)
112
+ await messageBuffer. push ( . RPC( req) )
113
+ continue
114
+ }
115
+ guard msg. rpc. responseTo == 0 else {
116
+ self . logger. debug ( " got RPC reply for msgID \( msg. rpc. responseTo) " )
117
+ do throws ( RPCError) {
118
+ try await self . secretary. route ( reply: msg)
119
+ } catch {
120
+ self . logger. error (
121
+ " couldn't route RPC reply for \( msg. rpc. responseTo) : \( error) " )
122
+ }
123
+ continue
124
+ }
115
125
}
116
- continue
126
+ } catch {
127
+ self . logger. error ( " failed to receive messages: \( error) " )
117
128
}
118
129
}
119
130
}
120
131
121
- /// Handles a single non-RPC message. It is expected that subclasses override this method with their own handlers.
122
- func handleMessage( _ msg: RecvMsg) {
123
- // just log
124
- logger. debug ( " got non-RPC message \( msg. textFormatString ( ) ) " )
125
- }
126
-
127
- /// Handle a single RPC request. It is expected that subclasses override this method with their own handlers.
128
- func handleRPC( _ req: RPCRequest < SendMsg , RecvMsg > ) {
129
- // just log
130
- logger. debug ( " got RPC message \( req. msg. textFormatString ( ) ) " )
132
+ func wait( ) async throws {
133
+ guard let task = readLoopTask else {
134
+ return
135
+ }
136
+ try await task. value
131
137
}
132
138
133
139
/// Send a unary RPC message and handle the response
@@ -166,10 +172,51 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
166
172
logger. error ( " failed to close read file handle: \( error) " )
167
173
}
168
174
}
175
+
176
+ enum IncomingMessage {
177
+ case message( RecvMsg )
178
+ case RPC( RPCRequest < SendMsg , RecvMsg > )
179
+ }
180
+
181
+ private actor MessageBuffer {
182
+ private var messages : [ IncomingMessage ] = [ ]
183
+ private var continuations : [ CheckedContinuation < IncomingMessage ? , Never > ] = [ ]
184
+
185
+ func push( _ message: IncomingMessage ? ) {
186
+ if let continuation = continuations. first {
187
+ continuations. removeFirst ( )
188
+ continuation. resume ( returning: message)
189
+ } else if let message = message {
190
+ messages. append ( message)
191
+ }
192
+ }
193
+
194
+ func next( ) async -> IncomingMessage ? {
195
+ if let message = messages. first {
196
+ messages. removeFirst ( )
197
+ return message
198
+ }
199
+ return await withCheckedContinuation { continuation in
200
+ continuations. append ( continuation)
201
+ }
202
+ }
203
+ }
169
204
}
170
205
171
- /// A class that performs the initial VPN protocol handshake and version negotiation.
172
- class Handshaker : @unchecked Sendable {
206
+ extension Speaker : AsyncSequence , AsyncIteratorProtocol {
207
+ typealias Element = IncomingMessage
208
+
209
+ public nonisolated func makeAsyncIterator( ) -> Speaker < SendMsg , RecvMsg > {
210
+ self
211
+ }
212
+
213
+ func next( ) async throws -> IncomingMessage ? {
214
+ return await messageBuffer. next ( )
215
+ }
216
+ }
217
+
218
+ /// An actor performs the initial VPN protocol handshake and version negotiation.
219
+ actor Handshaker {
173
220
private let writeFD : FileHandle
174
221
private let dispatch : DispatchIO
175
222
private var theirData : Data = . init( )
@@ -193,17 +240,19 @@ class Handshaker: @unchecked Sendable {
193
240
func handshake( ) async throws -> ProtoVersion {
194
241
// kick off the read async before we try to write, synchronously, so we don't deadlock, both
195
242
// waiting to write with nobody reading.
196
- async let theirs = try withCheckedThrowingContinuation { cont in
197
- continuation = cont
198
- // send in a nil read to kick us off
199
- handleRead ( false , nil , 0 )
243
+ let readTask = Task {
244
+ try await withCheckedThrowingContinuation { cont in
245
+ self . continuation = cont
246
+ // send in a nil read to kick us off
247
+ self . handleRead ( false , nil , 0 )
248
+ }
200
249
}
201
250
202
251
let vStr = versions. map { $0. description } . joined ( separator: " , " )
203
252
let ours = String ( format: " \( headerPreamble) \( role) \( vStr) \n " )
204
253
try writeFD. write ( contentsOf: ours. data ( using: . utf8) !)
205
254
206
- let theirData = try await theirs
255
+ let theirData = try await readTask . value
207
256
guard let theirsString = String ( bytes: theirData, encoding: . utf8) else {
208
257
throw HandshakeError . invalidHeader ( " <unparsable: \( theirData) " )
209
258
}
0 commit comments