@@ -6,7 +6,7 @@ let newLine = 0x0A
6
6
let headerPreamble = " codervpn "
7
7
8
8
/// A message that has the `rpc` property for recording participation in a unary RPC.
9
- protocol RPCMessage {
9
+ protocol RPCMessage : Sendable {
10
10
var rpc : Vpn_RPC { get set }
11
11
/// Returns true if `rpc` has been explicitly set.
12
12
var hasRpc : Bool { get }
@@ -49,8 +49,8 @@ struct ProtoVersion: CustomStringConvertible, Equatable, Codable {
49
49
}
50
50
}
51
51
52
- /// An abstract base class for implementations that need to communicate using the VPN protocol.
53
- class Speaker < SendMsg: RPCMessage & Message , RecvMsg: RPCMessage & Message > {
52
+ /// An actor that communicates using the VPN protocol
53
+ actor Speaker < SendMsg: RPCMessage & Message , RecvMsg: RPCMessage & Message > {
54
54
private let logger = Logger ( subsystem: " com.coder.Coder-Desktop " , category: " proto " )
55
55
private let writeFD : FileHandle
56
56
private let readFD : FileHandle
@@ -93,43 +93,6 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
93
93
try _ = await hndsh. handshake ( )
94
94
}
95
95
96
- /// Reads and handles protocol messages.
97
- func readLoop( ) async throws {
98
- for try await msg in try await receiver. messages ( ) {
99
- guard msg. hasRpc else {
100
- handleMessage ( msg)
101
- continue
102
- }
103
- guard msg. rpc. msgID == 0 else {
104
- let req = RPCRequest < SendMsg , RecvMsg > ( req: msg, sender: sender)
105
- handleRPC ( req)
106
- continue
107
- }
108
- guard msg. rpc. responseTo == 0 else {
109
- logger. debug ( " got RPC reply for msgID \( msg. rpc. responseTo) " )
110
- do throws ( RPCError) {
111
- try await self . secretary. route ( reply: msg)
112
- } catch {
113
- logger. error (
114
- " couldn't route RPC reply for \( msg. rpc. responseTo) : \( error) " )
115
- }
116
- continue
117
- }
118
- }
119
- }
120
-
121
- /// Handles a single non-RPC message. It is expected that subclasses override this method with their own handlers.
122
- func handleMessage( _ msg: RecvMsg) {
123
- // just log
124
- logger. debug ( " got non-RPC message \( msg. textFormatString ( ) ) " )
125
- }
126
-
127
- /// Handle a single RPC request. It is expected that subclasses override this method with their own handlers.
128
- func handleRPC( _ req: RPCRequest < SendMsg , RecvMsg > ) {
129
- // just log
130
- logger. debug ( " got RPC message \( req. msg. textFormatString ( ) ) " )
131
- }
132
-
133
96
/// Send a unary RPC message and handle the response
134
97
func unaryRPC( _ req: SendMsg ) async throws -> RecvMsg {
135
98
return try await withCheckedThrowingContinuation { continuation in
@@ -166,10 +129,45 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
166
129
logger. error ( " failed to close read file handle: \( error) " )
167
130
}
168
131
}
132
+
133
+ enum IncomingMessage {
134
+ case message( RecvMsg )
135
+ case RPC( RPCRequest < SendMsg , RecvMsg > )
136
+ }
137
+ }
138
+
139
+ extension Speaker : AsyncSequence , AsyncIteratorProtocol {
140
+ typealias Element = IncomingMessage
141
+
142
+ public nonisolated func makeAsyncIterator( ) -> Speaker < SendMsg , RecvMsg > {
143
+ self
144
+ }
145
+
146
+ func next( ) async throws -> IncomingMessage ? {
147
+ for try await msg in try await receiver. messages ( ) {
148
+ guard msg. hasRpc else {
149
+ return . message( msg)
150
+ }
151
+ guard msg. rpc. msgID == 0 else {
152
+ return . RPC( RPCRequest < SendMsg , RecvMsg > ( req: msg, sender: sender) )
153
+ }
154
+ guard msg. rpc. responseTo == 0 else {
155
+ logger. debug ( " got RPC reply for msgID \( msg. rpc. responseTo) " )
156
+ do throws ( RPCError) {
157
+ try await self . secretary. route ( reply: msg)
158
+ } catch {
159
+ logger. error (
160
+ " couldn't route RPC reply for \( msg. rpc. responseTo) : \( error) " )
161
+ }
162
+ continue
163
+ }
164
+ }
165
+ return nil
166
+ }
169
167
}
170
168
171
- /// A class that performs the initial VPN protocol handshake and version negotiation.
172
- class Handshaker : @ unchecked Sendable {
169
+ /// An actor performs the initial VPN protocol handshake and version negotiation.
170
+ actor Handshaker {
173
171
private let writeFD: FileHandle
174
172
private let dispatch : DispatchIO
175
173
private var theirData : Data = . init( )
@@ -193,17 +191,19 @@ class Handshaker: @unchecked Sendable {
193
191
func handshake( ) async throws -> ProtoVersion {
194
192
// kick off the read async before we try to write, synchronously, so we don't deadlock, both
195
193
// waiting to write with nobody reading.
196
- async let theirs = try withCheckedThrowingContinuation { cont in
197
- continuation = cont
198
- // send in a nil read to kick us off
199
- handleRead ( false , nil , 0 )
194
+ let readTask = Task {
195
+ try await withCheckedThrowingContinuation { cont in
196
+ self . continuation = cont
197
+ // send in a nil read to kick us off
198
+ self . handleRead ( false , nil , 0 )
199
+ }
200
200
}
201
201
202
202
let vStr = versions. map { $0. description } . joined ( separator: " , " )
203
203
let ours = String ( format: " \( headerPreamble) \( role) \( vStr) \n " )
204
204
try writeFD. write ( contentsOf: ours. data ( using: . utf8) !)
205
205
206
- let theirData = try await theirs
206
+ let theirData = try await readTask . value
207
207
guard let theirsString = String ( bytes: theirData, encoding: . utf8) else {
208
208
throw HandshakeError . invalidHeader ( " <unparsable: \( theirData) " )
209
209
}
0 commit comments