@@ -8,11 +8,13 @@ import (
8
8
"bytes"
9
9
"crypto/rand"
10
10
"errors"
11
+ "io"
11
12
"net"
12
13
"os"
13
14
"os/exec"
14
15
"path/filepath"
15
16
"strconv"
17
+ "sync"
16
18
"testing"
17
19
"time"
18
20
@@ -196,6 +198,63 @@ func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert
196
198
197
199
}
198
200
201
+ func TestMalformedRequests (t * testing.T ) {
202
+ keyringAgent := NewKeyring ()
203
+ listener , err := netListener ()
204
+ if err != nil {
205
+ t .Fatalf ("netListener: %v" , err )
206
+ }
207
+ defer listener .Close ()
208
+
209
+ testCase := func (t * testing.T , requestBytes []byte , wantServerErr bool ) {
210
+ var wg sync.WaitGroup
211
+ wg .Add (1 )
212
+ go func () {
213
+ defer wg .Done ()
214
+ c , err := listener .Accept ()
215
+ if err != nil {
216
+ t .Errorf ("listener.Accept: %v" , err )
217
+ return
218
+ }
219
+ defer c .Close ()
220
+
221
+ err = ServeAgent (keyringAgent , c )
222
+ if err == nil {
223
+ t .Error ("ServeAgent should have returned an error to malformed input" )
224
+ } else {
225
+ if (err != io .EOF ) != wantServerErr {
226
+ t .Errorf ("ServeAgent returned expected error: %v" , err )
227
+ }
228
+ }
229
+ }()
230
+
231
+ c , err := net .Dial ("tcp" , listener .Addr ().String ())
232
+ if err != nil {
233
+ t .Fatalf ("net.Dial: %v" , err )
234
+ }
235
+ _ , err = c .Write (requestBytes )
236
+ if err != nil {
237
+ t .Errorf ("Unexpected error writing raw bytes on connection: %v" , err )
238
+ }
239
+ c .Close ()
240
+ wg .Wait ()
241
+ }
242
+
243
+ var testCases = []struct {
244
+ name string
245
+ requestBytes []byte
246
+ wantServerErr bool
247
+ }{
248
+ {"Empty request" , []byte {}, false },
249
+ {"Short header" , []byte {0x00 }, true },
250
+ {"Empty body" , []byte {0x00 , 0x00 , 0x00 , 0x00 }, true },
251
+ {"Short body" , []byte {0x00 , 0x00 , 0x00 , 0x01 }, false },
252
+ }
253
+ for _ , tc := range testCases {
254
+ t .Run (tc .name , func (t * testing.T ) { testCase (t , tc .requestBytes , tc .wantServerErr ) })
255
+ }
256
+ }
257
+
199
258
func TestAgent (t * testing.T ) {
200
259
for _ , keyType := range []string {"rsa" , "dsa" , "ecdsa" , "ed25519" } {
201
260
testOpenSSHAgent (t , testPrivateKeys [keyType ], nil , 0 )
@@ -215,17 +274,26 @@ func TestCert(t *testing.T) {
215
274
testKeyringAgent (t , testPrivateKeys ["rsa" ], cert , 0 )
216
275
}
217
276
218
- // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
219
- // therefore is buffered (net.Pipe deadlocks if both sides start with
220
- // a write.)
221
- func netPipe () (net.Conn , net.Conn , error ) {
277
+ // netListener creates a localhost network listener.
278
+ func netListener () (net.Listener , error ) {
222
279
listener , err := net .Listen ("tcp" , "127.0.0.1:0" )
223
280
if err != nil {
224
281
listener , err = net .Listen ("tcp" , "[::1]:0" )
225
282
if err != nil {
226
- return nil , nil , err
283
+ return nil , err
227
284
}
228
285
}
286
+ return listener , nil
287
+ }
288
+
289
+ // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
290
+ // therefore is buffered (net.Pipe deadlocks if both sides start with
291
+ // a write.)
292
+ func netPipe () (net.Conn , net.Conn , error ) {
293
+ listener , err := netListener ()
294
+ if err != nil {
295
+ return nil , nil , err
296
+ }
229
297
defer listener .Close ()
230
298
c1 , err := net .Dial ("tcp" , listener .Addr ().String ())
231
299
if err != nil {
0 commit comments