Skip to content

Commit 64f00d2

Browse files
ssh/agent: add checking for empty SSH requests
Previously empty SSH requests would cause a panic.
1 parent e363607 commit 64f00d2

File tree

2 files changed

+76
-5
lines changed

2 files changed

+76
-5
lines changed

Diff for: ssh/agent/client_test.go

+73-5
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@ import (
88
"bytes"
99
"crypto/rand"
1010
"errors"
11+
"io"
1112
"net"
1213
"os"
1314
"os/exec"
1415
"path/filepath"
1516
"strconv"
17+
"sync"
1618
"testing"
1719
"time"
1820

@@ -173,6 +175,63 @@ func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Ce
173175

174176
}
175177

178+
func TestMalformedRequests(t *testing.T) {
179+
keyringAgent := NewKeyring()
180+
listener, err := netListener()
181+
if err != nil {
182+
t.Fatalf("netListener: %v", err)
183+
}
184+
defer listener.Close()
185+
186+
testCase := func(t *testing.T, requestBytes []byte, wantServerErr bool) {
187+
var wg sync.WaitGroup
188+
wg.Add(1)
189+
go func() {
190+
defer wg.Done()
191+
c, err := listener.Accept()
192+
if err != nil {
193+
t.Errorf("listener.Accept: %v", err)
194+
return
195+
}
196+
defer c.Close()
197+
198+
err = ServeAgent(keyringAgent, c)
199+
if err == nil {
200+
t.Error("ServeAgent should have returned an error to malformed input")
201+
} else {
202+
if (err != io.EOF) != wantServerErr {
203+
t.Errorf("ServeAgent returned expected error: %v", err)
204+
}
205+
}
206+
}()
207+
208+
c, err := net.Dial("tcp", listener.Addr().String())
209+
if err != nil {
210+
t.Fatalf("net.Dial: %v", err)
211+
}
212+
_, err = c.Write(requestBytes)
213+
if err != nil {
214+
t.Errorf("Unexpected error writing raw bytes on connection: %v", err)
215+
}
216+
c.Close()
217+
wg.Wait()
218+
}
219+
220+
var testCases = []struct {
221+
name string
222+
requestBytes []byte
223+
wantServerErr bool
224+
}{
225+
{"Empty request", []byte{}, false},
226+
{"Short header", []byte{0x00}, true},
227+
{"Empty body", []byte{0x00, 0x00, 0x00, 0x00}, true},
228+
{"Short body", []byte{0x00, 0x00, 0x00, 0x01}, false},
229+
}
230+
for _, tc := range testCases {
231+
t.Run(tc.name, func(t *testing.T) { testCase(t, tc.requestBytes, tc.wantServerErr) })
232+
}
233+
}
234+
176235
func TestAgent(t *testing.T) {
177236
for _, keyType := range []string{"rsa", "dsa", "ecdsa", "ed25519"} {
178237
testOpenSSHAgent(t, testPrivateKeys[keyType], nil, 0)
@@ -192,17 +251,26 @@ func TestCert(t *testing.T) {
192251
testKeyringAgent(t, testPrivateKeys["rsa"], cert, 0)
193252
}
194253

195-
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
196-
// therefore is buffered (net.Pipe deadlocks if both sides start with
197-
// a write.)
198-
func netPipe() (net.Conn, net.Conn, error) {
254+
// netListener creates a localhost network listener.
255+
func netListener() (net.Listener, error) {
199256
listener, err := net.Listen("tcp", "127.0.0.1:0")
200257
if err != nil {
201258
listener, err = net.Listen("tcp", "[::1]:0")
202259
if err != nil {
203-
return nil, nil, err
260+
return nil, err
204261
}
205262
}
263+
return listener, nil
264+
}
265+
266+
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
267+
// therefore is buffered (net.Pipe deadlocks if both sides start with
268+
// a write.)
269+
func netPipe() (net.Conn, net.Conn, error) {
270+
listener, err := netListener()
271+
if err != nil {
272+
return nil, nil, err
273+
}
206274
defer listener.Close()
207275
c1, err := net.Dial("tcp", listener.Addr().String())
208276
if err != nil {

Diff for: ssh/agent/server.go

+3
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,9 @@ func ServeAgent(agent Agent, c io.ReadWriter) error {
497497
return err
498498
}
499499
l := binary.BigEndian.Uint32(length[:])
500+
if l == 0 {
501+
return fmt.Errorf("agent: request size is 0")
502+
}
500503
if l > maxAgentResponseBytes {
501504
// We also cap requests.
502505
return fmt.Errorf("agent: request too large: %d", l)

0 commit comments

Comments
 (0)