Skip to content

Commit 8296c6b

Browse files
committed
cached: Make package entirely thread-safe
1 parent 9b4dcd3 commit 8296c6b

File tree

2 files changed

+106
-45
lines changed

2 files changed

+106
-45
lines changed

Diff for: pkg/cached/cache.go

+30-43
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ limitations under the License.
1919
// operations are not repeated unnecessarily. The operations can be
2020
// created as a tree, and replaced dynamically as needed.
2121
//
22+
// All the operations in this module are thread-safe.
23+
//
2224
// # Dependencies and types of caches
2325
//
2426
// This package uses a source/transform/sink model of caches to build
@@ -34,15 +36,6 @@ limitations under the License.
3436
// replaced with a new one, and saves the previous results in case an
3537
// error pops-up.
3638
//
37-
// # Atomicity
38-
//
39-
// Most of the operations are not atomic/thread-safe, except for
40-
// [Replaceable.Replace] which can be performed while the objects are
41-
// being read. Specifically, `Get` methods are NOT thread-safe. Never
42-
// call `Get()` without a lock on a multi-threaded environment, since
43-
// it's usually performing updates to caches that will require write
44-
// operations.
45-
//
4639
// # Etags
4740
//
4841
// Etags in this library is a cache version identifier. It doesn't
@@ -57,6 +50,7 @@ package cached
5750

5851
import (
5952
"fmt"
53+
"sync"
6054
"sync/atomic"
6155
)
6256

@@ -100,13 +94,6 @@ func (r Result[T]) Get() Result[T] {
10094
type Data[T any] interface {
10195
// Returns the cached data, as well as an "etag" to identify the
10296
// version of the cache, or an error if something happened.
103-
//
104-
// # Important note
105-
//
106-
// This method is NEVER thread-safe, never assume it is OK to
107-
// call `Get()` without holding a proper mutex in a
108-
// multi-threaded environment, especially since `Get()` will
109-
// usually update the cache and perform write operations.
11097
Get() Result[T]
11198
}
11299

@@ -155,6 +142,7 @@ func NewMerger[K comparable, T, V any](mergeFn func(results map[K]Result[T]) Res
155142
}
156143

157144
type listMerger[T, V any] struct {
145+
lock sync.Mutex
158146
mergeFn func([]Result[T]) Result[V]
159147
caches []Data[T]
160148
cacheResults []Result[T]
@@ -183,15 +171,15 @@ func NewListMerger[T, V any](mergeFn func(results []Result[T]) Result[V], caches
183171
caches: caches,
184172
}
185173
}
186-
func (c *listMerger[T, V]) prepareResults() []Result[T] {
174+
func (c *listMerger[T, V]) prepareResultsLocked() []Result[T] {
187175
cacheResults := make([]Result[T], 0, len(c.caches))
188176
for _, cache := range c.caches {
189177
cacheResults = append(cacheResults, cache.Get())
190178
}
191179
return cacheResults
192180
}
193181

194-
func (c *listMerger[T, V]) needsRunning(results []Result[T]) bool {
182+
func (c *listMerger[T, V]) needsRunningLocked(results []Result[T]) bool {
195183
if c.cacheResults == nil {
196184
return true
197185
}
@@ -211,8 +199,10 @@ func (c *listMerger[T, V]) needsRunning(results []Result[T]) bool {
211199
}
212200

213201
func (c *listMerger[T, V]) Get() Result[V] {
214-
cacheResults := c.prepareResults()
215-
if c.needsRunning(cacheResults) {
202+
c.lock.Lock()
203+
defer c.lock.Unlock()
204+
cacheResults := c.prepareResultsLocked()
205+
if c.needsRunningLocked(cacheResults) {
216206
c.cacheResults = cacheResults
217207
c.result = c.mergeFn(c.cacheResults)
218208
}
@@ -238,7 +228,7 @@ func NewTransformer[T, V any](transformerFn func(Result[T]) Result[V], source Da
238228

239229
// NewSource creates a new cache that generates some data. This
240230
// will always be called since we don't know the origin of the data and
241-
// if it needs to be updated or not.
231+
// if it needs to be updated or not. sourceFn MUST be thread-safe.
242232
func NewSource[T any](sourceFn func() Result[T]) Data[T] {
243233
c := source[T](sourceFn)
244234
return &c
@@ -259,25 +249,24 @@ func NewStaticSource[T any](staticFn func() Result[T]) Data[T] {
259249
}
260250

261251
type static[T any] struct {
252+
once sync.Once
262253
fn func() Result[T]
263-
result *Result[T]
254+
result Result[T]
264255
}
265256

266257
func (c *static[T]) Get() Result[T] {
267-
if c.result == nil {
268-
result := c.fn()
269-
c.result = &result
270-
}
271-
return *c.result
258+
c.once.Do(func() {
259+
c.result = c.fn()
260+
})
261+
return c.result
272262
}
273263

274-
// Replaceable is a cache that carries the result even when the
275-
// cache is replaced. The cache can be replaced atomically (without any
276-
// lock held). This is the type that should typically be stored in
264+
// Replaceable is a cache that carries the result even when the cache is
265+
// replaced. This is the type that should typically be stored in
277266
// structs.
278267
type Replaceable[T any] struct {
279268
cache atomic.Pointer[Data[T]]
280-
result *Result[T]
269+
result atomic.Pointer[Result[T]]
281270
}
282271

283272
// Get retrieves the data from the underlying source. [Replaceable]
@@ -286,23 +275,21 @@ type Replaceable[T any] struct {
286275
// previously had returned a success, that success will be returned
287276
// instead. If the cache fails but we never returned a success, that
288277
// failure is returned.
289-
//
290-
// # Important note
291-
//
292-
// As all implementations of Get, this implementation is NOT
293-
// thread-safe. Please properly lock a mutex before calling this method
294-
// if you are in a multi-threaded environment, since this method will
295-
// update the cache and perform write operations.
296278
func (c *Replaceable[T]) Get() Result[T] {
297279
result := (*c.cache.Load()).Get()
298-
if result.Err != nil && c.result != nil && c.result.Err == nil {
299-
return *c.result
280+
281+
for {
282+
cResult := c.result.Load()
283+
if result.Err != nil && cResult != nil && cResult.Err == nil {
284+
return *cResult
285+
}
286+
if c.result.CompareAndSwap(cResult, &result) {
287+
return result
288+
}
300289
}
301-
c.result = &result
302-
return *c.result
303290
}
304291

305-
// Replace changes the cache in a thread-safe way.
292+
// Replace changes the cache.
306293
func (c *Replaceable[T]) Replace(cache Data[T]) {
307294
c.cache.Swap(&cache)
308295
}

Diff for: pkg/cached/cache_test.go

+76-2
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@ import (
2121
"encoding/json"
2222
"errors"
2323
"fmt"
24+
"math/rand"
2425
"sort"
2526
"strings"
27+
"sync"
2628
"testing"
29+
"time"
2730

2831
"k8s.io/kube-openapi/pkg/cached"
2932
)
@@ -1012,9 +1015,7 @@ func TestListMergerAlternateSourceError(t *testing.T) {
10121015
}
10131016

10141017
func TestListDAG(t *testing.T) {
1015-
count := 0
10161018
source := cached.NewSource(func() cached.Result[[]byte] {
1017-
count += 1
10181019
return cached.NewResultOK([]byte("source"), "source")
10191020
})
10201021
transformer1 := cached.NewTransformer(func(result cached.Result[[]byte]) cached.Result[[]byte] {
@@ -1054,3 +1055,76 @@ func TestListDAG(t *testing.T) {
10541055
t.Fatalf("expected etag = %v, got %v", want, result.Etag)
10551056
}
10561057
}
1058+
1059+
func randomString(length uint) string {
1060+
bytes := make([]byte, 6)
1061+
rand.Read(bytes)
1062+
return string(bytes)
1063+
1064+
}
1065+
1066+
func NewRandomSource() cached.Data[int64] {
1067+
return cached.NewStaticSource(func() cached.Result[int64] {
1068+
bytes := make([]byte, 6)
1069+
rand.Read(bytes)
1070+
return cached.NewResultOK(rand.Int63(), randomString(10))
1071+
})
1072+
}
1073+
1074+
func repeatedGet(data cached.Data[int64], end time.Time, wg *sync.WaitGroup) {
1075+
for time.Now().Before(end) {
1076+
_ = data.Get()
1077+
}
1078+
wg.Done()
1079+
}
1080+
1081+
func TestThreadSafe(t *testing.T) {
1082+
end := time.Now().Add(time.Second)
1083+
wg := sync.WaitGroup{}
1084+
static := NewRandomSource()
1085+
wg.Add(1)
1086+
go repeatedGet(static, end, &wg)
1087+
result := cached.NewResultOK(rand.Int63(), randomString(10))
1088+
wg.Add(1)
1089+
go repeatedGet(result, end, &wg)
1090+
replaceable := cached.Replaceable[int64]{}
1091+
replaceable.Replace(NewRandomSource())
1092+
wg.Add(1)
1093+
go repeatedGet(&replaceable, end, &wg)
1094+
wg.Add(1)
1095+
go func(r *cached.Replaceable[int64], end time.Time, wg *sync.WaitGroup) {
1096+
for time.Now().Before(end) {
1097+
r.Replace(NewRandomSource())
1098+
}
1099+
wg.Done()
1100+
}(&replaceable, end, &wg)
1101+
merger := cached.NewMerger(func(results map[string]cached.Result[int64]) cached.Result[int64] {
1102+
sum := int64(0)
1103+
for _, result := range results {
1104+
sum += result.Data
1105+
}
1106+
return cached.NewResultOK(sum, randomString(10))
1107+
}, map[string]cached.Data[int64]{
1108+
"one": NewRandomSource(),
1109+
"two": NewRandomSource(),
1110+
})
1111+
wg.Add(1)
1112+
go repeatedGet(merger, end, &wg)
1113+
transformer := cached.NewTransformer(func(result cached.Result[int64]) cached.Result[int64] {
1114+
return cached.NewResultOK(result.Data+5, randomString(10))
1115+
}, NewRandomSource())
1116+
wg.Add(1)
1117+
go repeatedGet(transformer, end, &wg)
1118+
1119+
listmerger := cached.NewListMerger(func(results []cached.Result[int64]) cached.Result[int64] {
1120+
sum := int64(0)
1121+
for i := range results {
1122+
sum += results[i].Data
1123+
}
1124+
return cached.NewResultOK(sum, randomString(10))
1125+
}, []cached.Data[int64]{static, result, &replaceable, merger, transformer})
1126+
wg.Add(1)
1127+
go repeatedGet(listmerger, end, &wg)
1128+
1129+
wg.Wait()
1130+
}

0 commit comments

Comments
 (0)