Skip to content

Commit 9cd0187

Browse files
FiloSottilegopherbot
authored andcommitted
curve25519: use crypto/ecdh on Go 1.20
For golang/go#52221 Change-Id: I27e867d4cc89cd52c8d510f0dbab4e89b7cd4763 Reviewed-on: https://go-review.googlesource.com/c/crypto/+/451115 Auto-Submit: Filippo Valsorda <[email protected]> Reviewed-by: Cherry Mui <[email protected]> TryBot-Result: Gopher Robot <[email protected]> Run-TryBot: Filippo Valsorda <[email protected]> Reviewed-by: Roland Shoemaker <[email protected]>
1 parent c6a20f9 commit 9cd0187

File tree

5 files changed

+173
-107
lines changed

5 files changed

+173
-107
lines changed

curve25519/curve25519.go

+6-93
Original file line numberDiff line numberDiff line change
@@ -5,71 +5,18 @@
55
// Package curve25519 provides an implementation of the X25519 function, which
66
// performs scalar multiplication on the elliptic curve known as Curve25519.
77
// See RFC 7748.
8+
//
9+
// Starting in Go 1.20, this package is a wrapper for the X25519 implementation
10+
// in the crypto/ecdh package.
811
package curve25519 // import "golang.org/x/crypto/curve25519"
912

10-
import (
11-
"crypto/subtle"
12-
"errors"
13-
"strconv"
14-
15-
"golang.org/x/crypto/curve25519/internal/field"
16-
)
17-
1813
// ScalarMult sets dst to the product scalar * point.
1914
//
2015
// Deprecated: when provided a low-order point, ScalarMult will set dst to all
2116
// zeroes, irrespective of the scalar. Instead, use the X25519 function, which
2217
// will return an error.
2318
func ScalarMult(dst, scalar, point *[32]byte) {
24-
var e [32]byte
25-
26-
copy(e[:], scalar[:])
27-
e[0] &= 248
28-
e[31] &= 127
29-
e[31] |= 64
30-
31-
var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element
32-
x1.SetBytes(point[:])
33-
x2.One()
34-
x3.Set(&x1)
35-
z3.One()
36-
37-
swap := 0
38-
for pos := 254; pos >= 0; pos-- {
39-
b := e[pos/8] >> uint(pos&7)
40-
b &= 1
41-
swap ^= int(b)
42-
x2.Swap(&x3, swap)
43-
z2.Swap(&z3, swap)
44-
swap = int(b)
45-
46-
tmp0.Subtract(&x3, &z3)
47-
tmp1.Subtract(&x2, &z2)
48-
x2.Add(&x2, &z2)
49-
z2.Add(&x3, &z3)
50-
z3.Multiply(&tmp0, &x2)
51-
z2.Multiply(&z2, &tmp1)
52-
tmp0.Square(&tmp1)
53-
tmp1.Square(&x2)
54-
x3.Add(&z3, &z2)
55-
z2.Subtract(&z3, &z2)
56-
x2.Multiply(&tmp1, &tmp0)
57-
tmp1.Subtract(&tmp1, &tmp0)
58-
z2.Square(&z2)
59-
60-
z3.Mult32(&tmp1, 121666)
61-
x3.Square(&x3)
62-
tmp0.Add(&tmp0, &z3)
63-
z3.Multiply(&x1, &z2)
64-
z2.Multiply(&tmp1, &tmp0)
65-
}
66-
67-
x2.Swap(&x3, swap)
68-
z2.Swap(&z3, swap)
69-
70-
z2.Invert(&z2)
71-
x2.Multiply(&x2, &z2)
72-
copy(dst[:], x2.Bytes())
19+
scalarMult(dst, scalar, point)
7320
}
7421

7522
// ScalarBaseMult sets dst to the product scalar * base where base is the
@@ -78,7 +25,7 @@ func ScalarMult(dst, scalar, point *[32]byte) {
7825
// It is recommended to use the X25519 function with Basepoint instead, as
7926
// copying into fixed size arrays can lead to unexpected bugs.
8027
func ScalarBaseMult(dst, scalar *[32]byte) {
81-
ScalarMult(dst, scalar, &basePoint)
28+
scalarBaseMult(dst, scalar)
8229
}
8330

8431
const (
@@ -91,21 +38,10 @@ const (
9138
// Basepoint is the canonical Curve25519 generator.
9239
var Basepoint []byte
9340

94-
var basePoint = [32]byte{9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
41+
var basePoint = [32]byte{9}
9542

9643
func init() { Basepoint = basePoint[:] }
9744

98-
func checkBasepoint() {
99-
if subtle.ConstantTimeCompare(Basepoint, []byte{
100-
0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
101-
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
102-
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
103-
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
104-
}) != 1 {
105-
panic("curve25519: global Basepoint value was modified")
106-
}
107-
}
108-
10945
// X25519 returns the result of the scalar multiplication (scalar * point),
11046
// according to RFC 7748, Section 5. scalar, point and the return value are
11147
// slices of 32 bytes.
@@ -121,26 +57,3 @@ func X25519(scalar, point []byte) ([]byte, error) {
12157
var dst [32]byte
12258
return x25519(&dst, scalar, point)
12359
}
124-
125-
func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
126-
var in [32]byte
127-
if l := len(scalar); l != 32 {
128-
return nil, errors.New("bad scalar length: " + strconv.Itoa(l) + ", expected 32")
129-
}
130-
if l := len(point); l != 32 {
131-
return nil, errors.New("bad point length: " + strconv.Itoa(l) + ", expected 32")
132-
}
133-
copy(in[:], scalar)
134-
if &point[0] == &Basepoint[0] {
135-
checkBasepoint()
136-
ScalarBaseMult(dst, &in)
137-
} else {
138-
var base, zero [32]byte
139-
copy(base[:], point)
140-
ScalarMult(dst, &in, &base)
141-
if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 {
142-
return nil, errors.New("bad input point: low order point")
143-
}
144-
}
145-
return dst[:], nil
146-
}

curve25519/curve25519_compat.go

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// Copyright 2019 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
//go:build !go1.20
6+
7+
package curve25519
8+
9+
import (
10+
"crypto/subtle"
11+
"errors"
12+
"strconv"
13+
14+
"golang.org/x/crypto/curve25519/internal/field"
15+
)
16+
17+
func scalarMult(dst, scalar, point *[32]byte) {
18+
var e [32]byte
19+
20+
copy(e[:], scalar[:])
21+
e[0] &= 248
22+
e[31] &= 127
23+
e[31] |= 64
24+
25+
var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element
26+
x1.SetBytes(point[:])
27+
x2.One()
28+
x3.Set(&x1)
29+
z3.One()
30+
31+
swap := 0
32+
for pos := 254; pos >= 0; pos-- {
33+
b := e[pos/8] >> uint(pos&7)
34+
b &= 1
35+
swap ^= int(b)
36+
x2.Swap(&x3, swap)
37+
z2.Swap(&z3, swap)
38+
swap = int(b)
39+
40+
tmp0.Subtract(&x3, &z3)
41+
tmp1.Subtract(&x2, &z2)
42+
x2.Add(&x2, &z2)
43+
z2.Add(&x3, &z3)
44+
z3.Multiply(&tmp0, &x2)
45+
z2.Multiply(&z2, &tmp1)
46+
tmp0.Square(&tmp1)
47+
tmp1.Square(&x2)
48+
x3.Add(&z3, &z2)
49+
z2.Subtract(&z3, &z2)
50+
x2.Multiply(&tmp1, &tmp0)
51+
tmp1.Subtract(&tmp1, &tmp0)
52+
z2.Square(&z2)
53+
54+
z3.Mult32(&tmp1, 121666)
55+
x3.Square(&x3)
56+
tmp0.Add(&tmp0, &z3)
57+
z3.Multiply(&x1, &z2)
58+
z2.Multiply(&tmp1, &tmp0)
59+
}
60+
61+
x2.Swap(&x3, swap)
62+
z2.Swap(&z3, swap)
63+
64+
z2.Invert(&z2)
65+
x2.Multiply(&x2, &z2)
66+
copy(dst[:], x2.Bytes())
67+
}
68+
69+
func scalarBaseMult(dst, scalar *[32]byte) {
70+
checkBasepoint()
71+
scalarMult(dst, scalar, &basePoint)
72+
}
73+
74+
func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
75+
var in [32]byte
76+
if l := len(scalar); l != 32 {
77+
return nil, errors.New("bad scalar length: " + strconv.Itoa(l) + ", expected 32")
78+
}
79+
if l := len(point); l != 32 {
80+
return nil, errors.New("bad point length: " + strconv.Itoa(l) + ", expected 32")
81+
}
82+
copy(in[:], scalar)
83+
if &point[0] == &Basepoint[0] {
84+
scalarBaseMult(dst, &in)
85+
} else {
86+
var base, zero [32]byte
87+
copy(base[:], point)
88+
scalarMult(dst, &in, &base)
89+
if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 {
90+
return nil, errors.New("bad input point: low order point")
91+
}
92+
}
93+
return dst[:], nil
94+
}
95+
96+
func checkBasepoint() {
97+
if subtle.ConstantTimeCompare(Basepoint, []byte{
98+
0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
99+
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
100+
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
101+
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
102+
}) != 1 {
103+
panic("curve25519: global Basepoint value was modified")
104+
}
105+
}

curve25519/curve25519_go120.go

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright 2022 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
//go:build go1.20
6+
7+
package curve25519
8+
9+
import "crypto/ecdh"
10+
11+
func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
12+
curve := ecdh.X25519()
13+
pub, err := curve.NewPublicKey(point)
14+
if err != nil {
15+
return nil, err
16+
}
17+
priv, err := curve.NewPrivateKey(scalar)
18+
if err != nil {
19+
return nil, err
20+
}
21+
out, err := priv.ECDH(pub)
22+
if err != nil {
23+
return nil, err
24+
}
25+
copy(dst[:], out)
26+
return dst[:], nil
27+
}
28+
29+
func scalarMult(dst, scalar, point *[32]byte) {
30+
if _, err := x25519(dst, scalar[:], point[:]); err != nil {
31+
// The only error condition for x25519 when the inputs are 32 bytes long
32+
// is if the output would have been the all-zero value.
33+
for i := range dst {
34+
dst[i] = 0
35+
}
36+
}
37+
}
38+
39+
func scalarBaseMult(dst, scalar *[32]byte) {
40+
curve := ecdh.X25519()
41+
priv, err := curve.NewPrivateKey(scalar[:])
42+
if err != nil {
43+
panic("curve25519: internal error: scalarBaseMult was not 32 bytes")
44+
}
45+
copy(dst[:], priv.PublicKey().Bytes())
46+
}

curve25519/curve25519_test.go

+15-13
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
// Use of this source code is governed by a BSD-style
33
// license that can be found in the LICENSE file.
44

5-
package curve25519
5+
package curve25519_test
66

77
import (
88
"bytes"
99
"crypto/rand"
1010
"encoding/hex"
1111
"testing"
12+
13+
"golang.org/x/crypto/curve25519"
1214
)
1315

1416
const expectedHex = "89161fde887b2b53de549af483940106ecc114d6982daa98256de23bdf77661a"
@@ -19,7 +21,7 @@ func TestX25519Basepoint(t *testing.T) {
1921

2022
for i := 0; i < 200; i++ {
2123
var err error
22-
x, err = X25519(x, Basepoint)
24+
x, err = curve25519.X25519(x, curve25519.Basepoint)
2325
if err != nil {
2426
t.Fatal(err)
2527
}
@@ -32,12 +34,12 @@ func TestX25519Basepoint(t *testing.T) {
3234
}
3335

3436
func TestLowOrderPoints(t *testing.T) {
35-
scalar := make([]byte, ScalarSize)
37+
scalar := make([]byte, curve25519.ScalarSize)
3638
if _, err := rand.Read(scalar); err != nil {
3739
t.Fatal(err)
3840
}
3941
for i, p := range lowOrderPoints {
40-
out, err := X25519(scalar, p)
42+
out, err := curve25519.X25519(scalar, p)
4143
if err == nil {
4244
t.Errorf("%d: expected error, got nil", i)
4345
}
@@ -48,10 +50,10 @@ func TestLowOrderPoints(t *testing.T) {
4850
}
4951

5052
func TestTestVectors(t *testing.T) {
51-
t.Run("Legacy", func(t *testing.T) { testTestVectors(t, ScalarMult) })
53+
t.Run("Legacy", func(t *testing.T) { testTestVectors(t, curve25519.ScalarMult) })
5254
t.Run("X25519", func(t *testing.T) {
5355
testTestVectors(t, func(dst, scalar, point *[32]byte) {
54-
out, err := X25519(scalar[:], point[:])
56+
out, err := curve25519.X25519(scalar[:], point[:])
5557
if err != nil {
5658
t.Fatal(err)
5759
}
@@ -88,10 +90,10 @@ func TestHighBitIgnored(t *testing.T) {
8890
var hi0, hi1 [32]byte
8991

9092
u[31] &= 0x7f
91-
ScalarMult(&hi0, &s, &u)
93+
curve25519.ScalarMult(&hi0, &s, &u)
9294

9395
u[31] |= 0x80
94-
ScalarMult(&hi1, &s, &u)
96+
curve25519.ScalarMult(&hi1, &s, &u)
9597

9698
if !bytes.Equal(hi0[:], hi1[:]) {
9799
t.Errorf("high bit of group point should not affect result")
@@ -101,14 +103,14 @@ func TestHighBitIgnored(t *testing.T) {
101103
var benchmarkSink byte
102104

103105
func BenchmarkX25519Basepoint(b *testing.B) {
104-
scalar := make([]byte, ScalarSize)
106+
scalar := make([]byte, curve25519.ScalarSize)
105107
if _, err := rand.Read(scalar); err != nil {
106108
b.Fatal(err)
107109
}
108110

109111
b.ResetTimer()
110112
for i := 0; i < b.N; i++ {
111-
out, err := X25519(scalar, Basepoint)
113+
out, err := curve25519.X25519(scalar, curve25519.Basepoint)
112114
if err != nil {
113115
b.Fatal(err)
114116
}
@@ -117,11 +119,11 @@ func BenchmarkX25519Basepoint(b *testing.B) {
117119
}
118120

119121
func BenchmarkX25519(b *testing.B) {
120-
scalar := make([]byte, ScalarSize)
122+
scalar := make([]byte, curve25519.ScalarSize)
121123
if _, err := rand.Read(scalar); err != nil {
122124
b.Fatal(err)
123125
}
124-
point, err := X25519(scalar, Basepoint)
126+
point, err := curve25519.X25519(scalar, curve25519.Basepoint)
125127
if err != nil {
126128
b.Fatal(err)
127129
}
@@ -131,7 +133,7 @@ func BenchmarkX25519(b *testing.B) {
131133

132134
b.ResetTimer()
133135
for i := 0; i < b.N; i++ {
134-
out, err := X25519(scalar, point)
136+
out, err := curve25519.X25519(scalar, point)
135137
if err != nil {
136138
b.Fatal(err)
137139
}

0 commit comments

Comments
 (0)