Skip to content

Commit 94bb74b

Browse files
authored
Merge pull request #2291 from vincepri/tls-opts-getcertificate
🌱 Handle TLSOpts.GetCertificate in webhook
2 parents 0ef0753 + bd12701 commit 94bb74b

File tree

2 files changed

+80
-24
lines changed

2 files changed

+80
-24
lines changed

pkg/webhook/server.go

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,13 @@ type Server struct {
6060
CertDir string
6161

6262
// CertName is the server certificate name. Defaults to tls.crt.
63+
//
64+
// Note: This option should only be set when TLSOpts does not override GetCertificate.
6365
CertName string
6466

6567
// KeyName is the server key name. Defaults to tls.key.
68+
//
69+
// Note: This option should only be set when TLSOpts does not override GetCertificate.
6670
KeyName string
6771

6872
// ClientCAName is the CA certificate name which server used to verify remote(client)'s certificate.
@@ -169,32 +173,40 @@ func (s *Server) Start(ctx context.Context) error {
169173
baseHookLog := log.WithName("webhooks")
170174
baseHookLog.Info("Starting webhook server")
171175

172-
certPath := filepath.Join(s.CertDir, s.CertName)
173-
keyPath := filepath.Join(s.CertDir, s.KeyName)
174-
175-
certWatcher, err := certwatcher.New(certPath, keyPath)
176-
if err != nil {
177-
return err
178-
}
179-
180-
go func() {
181-
if err := certWatcher.Start(ctx); err != nil {
182-
log.Error(err, "certificate watcher error")
183-
}
184-
}()
185-
186176
tlsMinVersion, err := tlsVersion(s.TLSMinVersion)
187177
if err != nil {
188178
return err
189179
}
190180

191181
cfg := &tls.Config{ //nolint:gosec
192-
NextProtos: []string{"h2"},
193-
GetCertificate: certWatcher.GetCertificate,
194-
MinVersion: tlsMinVersion,
182+
NextProtos: []string{"h2"},
183+
MinVersion: tlsMinVersion,
184+
}
185+
// fallback TLS config ready, will now mutate if passer wants full control over it
186+
for _, op := range s.TLSOpts {
187+
op(cfg)
188+
}
189+
190+
if cfg.GetCertificate == nil {
191+
certPath := filepath.Join(s.CertDir, s.CertName)
192+
keyPath := filepath.Join(s.CertDir, s.KeyName)
193+
194+
// Create the certificate watcher and
195+
// set the config's GetCertificate on the TLSConfig
196+
certWatcher, err := certwatcher.New(certPath, keyPath)
197+
if err != nil {
198+
return err
199+
}
200+
cfg.GetCertificate = certWatcher.GetCertificate
201+
202+
go func() {
203+
if err := certWatcher.Start(ctx); err != nil {
204+
log.Error(err, "certificate watcher error")
205+
}
206+
}()
195207
}
196208

197-
// load CA to verify client certificate
209+
// Load CA to verify client certificate, if configured.
198210
if s.ClientCAName != "" {
199211
certPool := x509.NewCertPool()
200212
clientCABytes, err := os.ReadFile(filepath.Join(s.CertDir, s.ClientCAName))
@@ -211,11 +223,6 @@ func (s *Server) Start(ctx context.Context) error {
211223
cfg.ClientAuth = tls.RequireAndVerifyClientCert
212224
}
213225

214-
// fallback TLS config ready, will now mutate if passer wants full control over it
215-
for _, op := range s.TLSOpts {
216-
op(cfg)
217-
}
218-
219226
listener, err := tls.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port)), cfg)
220227
if err != nil {
221228
return err

pkg/webhook/server_test.go

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import (
2323
"io"
2424
"net"
2525
"net/http"
26+
"path"
27+
"reflect"
2628

2729
. "github.com/onsi/ginkgo/v2"
2830
. "github.com/onsi/gomega"
@@ -181,7 +183,7 @@ var _ = Describe("Webhook Server", func() {
181183
}
182184
server.Register("/somepath", &testHandler{})
183185
doneCh := genericStartServer(func(ctx context.Context) {
184-
Expect(server.Start(ctx))
186+
Expect(server.Start(ctx)).To(Succeed())
185187
})
186188

187189
Eventually(func() ([]byte, error) {
@@ -199,6 +201,53 @@ var _ = Describe("Webhook Server", func() {
199201
ctxCancel()
200202
Eventually(doneCh, "4s").Should(BeClosed())
201203
})
204+
205+
It("should prefer GetCertificate through TLSOpts", func() {
206+
var finalCfg *tls.Config
207+
finalCert, err := tls.LoadX509KeyPair(
208+
path.Join(servingOpts.LocalServingCertDir, "tls.crt"),
209+
path.Join(servingOpts.LocalServingCertDir, "tls.key"),
210+
)
211+
Expect(err).NotTo(HaveOccurred())
212+
finalGetCertificate := func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam
213+
return &finalCert, nil
214+
}
215+
server = &webhook.Server{
216+
Host: servingOpts.LocalServingHost,
217+
Port: servingOpts.LocalServingPort,
218+
CertDir: servingOpts.LocalServingCertDir,
219+
TLSMinVersion: "1.2",
220+
TLSOpts: []func(*tls.Config){
221+
func(cfg *tls.Config) {
222+
cfg.GetCertificate = finalGetCertificate
223+
// save cfg after changes to test against
224+
finalCfg = cfg
225+
},
226+
},
227+
}
228+
server.Register("/somepath", &testHandler{})
229+
doneCh := genericStartServer(func(ctx context.Context) {
230+
Expect(server.Start(ctx)).To(Succeed())
231+
})
232+
233+
Eventually(func() ([]byte, error) {
234+
resp, err := client.Get(fmt.Sprintf("https://%s/somepath", testHostPort))
235+
Expect(err).NotTo(HaveOccurred())
236+
defer resp.Body.Close()
237+
return io.ReadAll(resp.Body)
238+
}).Should(Equal([]byte("gadzooks!")))
239+
Expect(finalCfg.MinVersion).To(Equal(uint16(tls.VersionTLS12)))
240+
// We can't compare the functions directly, but we can compare their pointers
241+
if reflect.ValueOf(finalCfg.GetCertificate).Pointer() != reflect.ValueOf(finalGetCertificate).Pointer() {
242+
Fail("GetCertificate was not set properly, or overwritten")
243+
}
244+
cert, err := finalCfg.GetCertificate(nil)
245+
Expect(err).NotTo(HaveOccurred())
246+
Expect(cert).To(BeEquivalentTo(&finalCert))
247+
248+
ctxCancel()
249+
Eventually(doneCh, "4s").Should(BeClosed())
250+
})
202251
})
203252

204253
type testHandler struct {

0 commit comments

Comments
 (0)