Skip to content

Commit bc7914c

Browse files
authored
Merge pull request #2301 from vincepri/certwatcher-callback
🌱 Add certwatcher callback
2 parents 62e6867 + aeedfbf commit bc7914c

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

pkg/certwatcher/certwatcher.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ type CertWatcher struct {
4444

4545
certPath string
4646
keyPath string
47+
48+
// callback is a function to be invoked when the certificate changes.
49+
callback func(tls.Certificate)
4750
}
4851

4952
// New returns a new CertWatcher watching the given certificate and key.
@@ -68,6 +71,17 @@ func New(certPath, keyPath string) (*CertWatcher, error) {
6871
return cw, nil
6972
}
7073

74+
// RegisterCallback registers a callback to be invoked when the certificate changes.
75+
func (cw *CertWatcher) RegisterCallback(callback func(tls.Certificate)) {
76+
cw.Lock()
77+
defer cw.Unlock()
78+
// If the current certificate is not nil, invoke the callback immediately.
79+
if cw.currentCert != nil {
80+
callback(*cw.currentCert)
81+
}
82+
cw.callback = callback
83+
}
84+
7185
// GetCertificate fetches the currently loaded certificate, which may be nil.
7286
func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
7387
cw.RLock()
@@ -146,6 +160,14 @@ func (cw *CertWatcher) ReadCertificate() error {
146160

147161
log.Info("Updated current TLS certificate")
148162

163+
// If a callback is registered, invoke it with the new certificate.
164+
cw.RLock()
165+
defer cw.RUnlock()
166+
if cw.callback != nil {
167+
go func() {
168+
cw.callback(cert)
169+
}()
170+
}
149171
return nil
150172
}
151173

pkg/certwatcher/certwatcher_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ import (
2020
"context"
2121
"crypto/rand"
2222
"crypto/rsa"
23+
"crypto/tls"
2324
"crypto/x509"
2425
"crypto/x509/pkix"
2526
"encoding/pem"
2627
"fmt"
2728
"math/big"
2829
"net"
2930
"os"
31+
"sync/atomic"
3032
"time"
3133

3234
. "github.com/onsi/ginkgo/v2"
@@ -97,6 +99,11 @@ var _ = Describe("CertWatcher", func() {
9799

98100
It("should reload currentCert when changed", func() {
99101
doneCh := startWatcher()
102+
called := atomic.Int64{}
103+
watcher.RegisterCallback(func(crt tls.Certificate) {
104+
called.Add(1)
105+
Expect(crt.Certificate).ToNot(BeEmpty())
106+
})
100107

101108
firstcert, _ := watcher.GetCertificate(nil)
102109

@@ -111,6 +118,7 @@ var _ = Describe("CertWatcher", func() {
111118

112119
ctxCancel()
113120
Eventually(doneCh, "4s").Should(BeClosed())
121+
Expect(called.Load()).To(BeNumerically(">=", 1))
114122
})
115123

116124
Context("prometheus metric read_certificate_total", func() {

0 commit comments

Comments
 (0)