Skip to content

Commit df01cb2

Browse files
committed
ssh: invert algorithm choices on the server
At the protocol level, SSH lets client and server specify different algorithms for the read and write half of the connection. This has never worked correctly, as Client-to-Server was always interpreted as the "write" side, even if we were the server. This has never been a problem because, apparently, there are no clients that insist on different algorithm choices running against Go SSH servers. Since the SSH package does not expose a mechanism to specify algorithms for read/write separately, there is end-to-end for this change, so add a unittest instead. Change-Id: Ie3aa781630a3bb7a3b0e3754cb67b3ce12581544 Reviewed-on: https://go-review.googlesource.com/c/crypto/+/172538 Reviewed-by: Filippo Valsorda <[email protected]> Run-TryBot: Filippo Valsorda <[email protected]> TryBot-Result: Gobot Gobot <[email protected]>
1 parent b43e412 commit df01cb2

File tree

3 files changed

+192
-9
lines changed

3 files changed

+192
-9
lines changed

Diff for: ssh/common.go

+13-7
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ func findCommon(what string, client []string, server []string) (common string, e
109109
return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server)
110110
}
111111

112+
// directionAlgorithms records algorithm choices in one direction (either read or write)
112113
type directionAlgorithms struct {
113114
Cipher string
114115
MAC string
@@ -137,7 +138,7 @@ type algorithms struct {
137138
r directionAlgorithms
138139
}
139140

140-
func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
141+
func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
141142
result := &algorithms{}
142143

143144
result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
@@ -150,32 +151,37 @@ func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algor
150151
return
151152
}
152153

153-
result.w.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
154+
stoc, ctos := &result.w, &result.r
155+
if isClient {
156+
ctos, stoc = stoc, ctos
157+
}
158+
159+
ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
154160
if err != nil {
155161
return
156162
}
157163

158-
result.r.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
164+
stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
159165
if err != nil {
160166
return
161167
}
162168

163-
result.w.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
169+
ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
164170
if err != nil {
165171
return
166172
}
167173

168-
result.r.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
174+
stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
169175
if err != nil {
170176
return
171177
}
172178

173-
result.w.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
179+
ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
174180
if err != nil {
175181
return
176182
}
177183

178-
result.r.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
184+
stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
179185
if err != nil {
180186
return
181187
}

Diff for: ssh/common_test.go

+176
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
// Copyright 2019 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package ssh
6+
7+
import (
8+
"reflect"
9+
"testing"
10+
)
11+
12+
func TestFindAgreedAlgorithms(t *testing.T) {
13+
initKex := func(k *kexInitMsg) {
14+
if k.KexAlgos == nil {
15+
k.KexAlgos = []string{"kex1"}
16+
}
17+
if k.ServerHostKeyAlgos == nil {
18+
k.ServerHostKeyAlgos = []string{"hostkey1"}
19+
}
20+
if k.CiphersClientServer == nil {
21+
k.CiphersClientServer = []string{"cipher1"}
22+
23+
}
24+
if k.CiphersServerClient == nil {
25+
k.CiphersServerClient = []string{"cipher1"}
26+
27+
}
28+
if k.MACsClientServer == nil {
29+
k.MACsClientServer = []string{"mac1"}
30+
31+
}
32+
if k.MACsServerClient == nil {
33+
k.MACsServerClient = []string{"mac1"}
34+
35+
}
36+
if k.CompressionClientServer == nil {
37+
k.CompressionClientServer = []string{"compression1"}
38+
39+
}
40+
if k.CompressionServerClient == nil {
41+
k.CompressionServerClient = []string{"compression1"}
42+
43+
}
44+
if k.LanguagesClientServer == nil {
45+
k.LanguagesClientServer = []string{"language1"}
46+
47+
}
48+
if k.LanguagesServerClient == nil {
49+
k.LanguagesServerClient = []string{"language1"}
50+
51+
}
52+
}
53+
54+
initDirAlgs := func(a *directionAlgorithms) {
55+
if a.Cipher == "" {
56+
a.Cipher = "cipher1"
57+
}
58+
if a.MAC == "" {
59+
a.MAC = "mac1"
60+
}
61+
if a.Compression == "" {
62+
a.Compression = "compression1"
63+
}
64+
}
65+
66+
initAlgs := func(a *algorithms) {
67+
if a.kex == "" {
68+
a.kex = "kex1"
69+
}
70+
if a.hostKey == "" {
71+
a.hostKey = "hostkey1"
72+
}
73+
initDirAlgs(&a.r)
74+
initDirAlgs(&a.w)
75+
}
76+
77+
type testcase struct {
78+
name string
79+
clientIn, serverIn kexInitMsg
80+
wantClient, wantServer algorithms
81+
wantErr bool
82+
}
83+
84+
cases := []testcase{
85+
testcase{
86+
name: "standard",
87+
},
88+
89+
testcase{
90+
name: "no common hostkey",
91+
serverIn: kexInitMsg{
92+
ServerHostKeyAlgos: []string{"hostkey2"},
93+
},
94+
wantErr: true,
95+
},
96+
97+
testcase{
98+
name: "no common kex",
99+
serverIn: kexInitMsg{
100+
KexAlgos: []string{"kex2"},
101+
},
102+
wantErr: true,
103+
},
104+
105+
testcase{
106+
name: "no common cipher",
107+
serverIn: kexInitMsg{
108+
CiphersClientServer: []string{"cipher2"},
109+
},
110+
wantErr: true,
111+
},
112+
113+
testcase{
114+
name: "client decides cipher",
115+
serverIn: kexInitMsg{
116+
CiphersClientServer: []string{"cipher1", "cipher2"},
117+
CiphersServerClient: []string{"cipher2", "cipher3"},
118+
},
119+
clientIn: kexInitMsg{
120+
CiphersClientServer: []string{"cipher2", "cipher1"},
121+
CiphersServerClient: []string{"cipher3", "cipher2"},
122+
},
123+
wantClient: algorithms{
124+
r: directionAlgorithms{
125+
Cipher: "cipher3",
126+
},
127+
w: directionAlgorithms{
128+
Cipher: "cipher2",
129+
},
130+
},
131+
wantServer: algorithms{
132+
w: directionAlgorithms{
133+
Cipher: "cipher3",
134+
},
135+
r: directionAlgorithms{
136+
Cipher: "cipher2",
137+
},
138+
},
139+
},
140+
141+
// TODO(hanwen): fix and add tests for AEAD ignoring
142+
// the MACs field
143+
}
144+
145+
for i := range cases {
146+
initKex(&cases[i].clientIn)
147+
initKex(&cases[i].serverIn)
148+
initAlgs(&cases[i].wantClient)
149+
initAlgs(&cases[i].wantServer)
150+
}
151+
152+
for _, c := range cases {
153+
t.Run(c.name, func(t *testing.T) {
154+
serverAlgs, serverErr := findAgreedAlgorithms(false, &c.clientIn, &c.serverIn)
155+
clientAlgs, clientErr := findAgreedAlgorithms(true, &c.clientIn, &c.serverIn)
156+
157+
serverHasErr := serverErr != nil
158+
clientHasErr := clientErr != nil
159+
if c.wantErr != serverHasErr || c.wantErr != clientHasErr {
160+
t.Fatalf("got client/server error (%v, %v), want hasError %v",
161+
clientErr, serverErr, c.wantErr)
162+
163+
}
164+
if c.wantErr {
165+
return
166+
}
167+
168+
if !reflect.DeepEqual(serverAlgs, &c.wantServer) {
169+
t.Errorf("server: got algs %#v, want %#v", serverAlgs, &c.wantServer)
170+
}
171+
if !reflect.DeepEqual(clientAlgs, &c.wantClient) {
172+
t.Errorf("server: got algs %#v, want %#v", clientAlgs, &c.wantClient)
173+
}
174+
})
175+
}
176+
}

Diff for: ssh/handshake.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -543,15 +543,16 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
543543

544544
clientInit := otherInit
545545
serverInit := t.sentInitMsg
546-
if len(t.hostKeys) == 0 {
546+
isClient := len(t.hostKeys) == 0
547+
if isClient {
547548
clientInit, serverInit = serverInit, clientInit
548549

549550
magics.clientKexInit = t.sentInitPacket
550551
magics.serverKexInit = otherInitPacket
551552
}
552553

553554
var err error
554-
t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit)
555+
t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit)
555556
if err != nil {
556557
return err
557558
}

0 commit comments

Comments
 (0)