@@ -8,11 +8,13 @@ import (
8
8
"bytes"
9
9
"crypto/rand"
10
10
"errors"
11
+ "io"
11
12
"net"
12
13
"os"
13
14
"os/exec"
14
15
"path/filepath"
15
16
"strconv"
17
+ "sync"
16
18
"testing"
17
19
"time"
18
20
@@ -173,6 +175,63 @@ func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Ce
173
175
174
176
}
175
177
178
+ func TestMalformedRequests (t * testing.T ) {
179
+ keyringAgent := NewKeyring ()
180
+ listener , err := netListener ()
181
+ if err != nil {
182
+ t .Fatalf ("netListener: %v" , err )
183
+ }
184
+ defer listener .Close ()
185
+
186
+ testCase := func (t * testing.T , requestBytes []byte , wantServerErr bool ) {
187
+ var wg sync.WaitGroup
188
+ wg .Add (1 )
189
+ go func () {
190
+ defer wg .Done ()
191
+ c , err := listener .Accept ()
192
+ if err != nil {
193
+ t .Errorf ("listener.Accept: %v" , err )
194
+ return
195
+ }
196
+ defer c .Close ()
197
+
198
+ err = ServeAgent (keyringAgent , c )
199
+ if err == nil {
200
+ t .Error ("ServeAgent should have returned an error to malformed input" )
201
+ } else {
202
+ if (err != io .EOF ) != wantServerErr {
203
+ t .Errorf ("ServeAgent returned expected error: %v" , err )
204
+ }
205
+ }
206
+ }()
207
+
208
+ c , err := net .Dial ("tcp" , listener .Addr ().String ())
209
+ if err != nil {
210
+ t .Fatalf ("net.Dial: %v" , err )
211
+ }
212
+ _ , err = c .Write (requestBytes )
213
+ if err != nil {
214
+ t .Errorf ("Unexpected error writing raw bytes on connection: %v" , err )
215
+ }
216
+ c .Close ()
217
+ wg .Wait ()
218
+ }
219
+
220
+ var testCases = []struct {
221
+ name string
222
+ requestBytes []byte
223
+ wantServerErr bool
224
+ }{
225
+ {"Empty request" , []byte {}, false },
226
+ {"Short header" , []byte {0x00 }, true },
227
+ {"Empty body" , []byte {0x00 , 0x00 , 0x00 , 0x00 }, true },
228
+ {"Short body" , []byte {0x00 , 0x00 , 0x00 , 0x01 }, false },
229
+ }
230
+ for _ , tc := range testCases {
231
+ t .Run (tc .name , func (t * testing.T ) { testCase (t , tc .requestBytes , tc .wantServerErr ) })
232
+ }
233
+ }
234
+
176
235
func TestAgent (t * testing.T ) {
177
236
for _ , keyType := range []string {"rsa" , "dsa" , "ecdsa" , "ed25519" } {
178
237
testOpenSSHAgent (t , testPrivateKeys [keyType ], nil , 0 )
@@ -192,17 +251,26 @@ func TestCert(t *testing.T) {
192
251
testKeyringAgent (t , testPrivateKeys ["rsa" ], cert , 0 )
193
252
}
194
253
195
- // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
196
- // therefore is buffered (net.Pipe deadlocks if both sides start with
197
- // a write.)
198
- func netPipe () (net.Conn , net.Conn , error ) {
254
+ // netListener creates a localhost network listener.
255
+ func netListener () (net.Listener , error ) {
199
256
listener , err := net .Listen ("tcp" , "127.0.0.1:0" )
200
257
if err != nil {
201
258
listener , err = net .Listen ("tcp" , "[::1]:0" )
202
259
if err != nil {
203
- return nil , nil , err
260
+ return nil , err
204
261
}
205
262
}
263
+ return listener , nil
264
+ }
265
+
266
+ // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
267
+ // therefore is buffered (net.Pipe deadlocks if both sides start with
268
+ // a write.)
269
+ func netPipe () (net.Conn , net.Conn , error ) {
270
+ listener , err := netListener ()
271
+ if err != nil {
272
+ return nil , nil , err
273
+ }
206
274
defer listener .Close ()
207
275
c1 , err := net .Dial ("tcp" , listener .Addr ().String ())
208
276
if err != nil {
0 commit comments