Skip to content

Commit 7562a10

Browse files
authored
Merge pull request #403 from apelisse/thread-safe-cached
Thread safe cached
2 parents 9b4dcd3 + 62c762f commit 7562a10

File tree

2 files changed

+126
-48
lines changed

2 files changed

+126
-48
lines changed

Diff for: pkg/cached/cache.go

+50-46
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ limitations under the License.
1919
// operations are not repeated unnecessarily. The operations can be
2020
// created as a tree, and replaced dynamically as needed.
2121
//
22+
// All the operations in this module are thread-safe.
23+
//
2224
// # Dependencies and types of caches
2325
//
2426
// This package uses a source/transform/sink model of caches to build
@@ -34,15 +36,6 @@ limitations under the License.
3436
// replaced with a new one, and saves the previous results in case an
3537
// error pops-up.
3638
//
37-
// # Atomicity
38-
//
39-
// Most of the operations are not atomic/thread-safe, except for
40-
// [Replaceable.Replace] which can be performed while the objects are
41-
// being read. Specifically, `Get` methods are NOT thread-safe. Never
42-
// call `Get()` without a lock on a multi-threaded environment, since
43-
// it's usually performing updates to caches that will require write
44-
// operations.
45-
//
4639
// # Etags
4740
//
4841
// Etags in this library is a cache version identifier. It doesn't
@@ -57,6 +50,7 @@ package cached
5750

5851
import (
5952
"fmt"
53+
"sync"
6054
"sync/atomic"
6155
)
6256

@@ -100,13 +94,6 @@ func (r Result[T]) Get() Result[T] {
10094
type Data[T any] interface {
10195
// Returns the cached data, as well as an "etag" to identify the
10296
// version of the cache, or an error if something happened.
103-
//
104-
// # Important note
105-
//
106-
// This method is NEVER thread-safe, never assume it is OK to
107-
// call `Get()` without holding a proper mutex in a
108-
// multi-threaded environment, especially since `Get()` will
109-
// usually update the cache and perform write operations.
11097
Get() Result[T]
11198
}
11299

@@ -155,6 +142,7 @@ func NewMerger[K comparable, T, V any](mergeFn func(results map[K]Result[T]) Res
155142
}
156143

157144
type listMerger[T, V any] struct {
145+
lock sync.Mutex
158146
mergeFn func([]Result[T]) Result[V]
159147
caches []Data[T]
160148
cacheResults []Result[T]
@@ -183,15 +171,32 @@ func NewListMerger[T, V any](mergeFn func(results []Result[T]) Result[V], caches
183171
caches: caches,
184172
}
185173
}
186-
func (c *listMerger[T, V]) prepareResults() []Result[T] {
187-
cacheResults := make([]Result[T], 0, len(c.caches))
188-
for _, cache := range c.caches {
189-
cacheResults = append(cacheResults, cache.Get())
174+
175+
func (c *listMerger[T, V]) prepareResultsLocked() []Result[T] {
176+
cacheResults := make([]Result[T], len(c.caches))
177+
ch := make(chan struct {
178+
int
179+
Result[T]
180+
}, len(c.caches))
181+
for i := range c.caches {
182+
go func(index int) {
183+
ch <- struct {
184+
int
185+
Result[T]
186+
}{
187+
index,
188+
c.caches[index].Get(),
189+
}
190+
}(i)
191+
}
192+
for i := 0; i < len(c.caches); i++ {
193+
res := <-ch
194+
cacheResults[res.int] = res.Result
190195
}
191196
return cacheResults
192197
}
193198

194-
func (c *listMerger[T, V]) needsRunning(results []Result[T]) bool {
199+
func (c *listMerger[T, V]) needsRunningLocked(results []Result[T]) bool {
195200
if c.cacheResults == nil {
196201
return true
197202
}
@@ -211,8 +216,10 @@ func (c *listMerger[T, V]) needsRunning(results []Result[T]) bool {
211216
}
212217

213218
func (c *listMerger[T, V]) Get() Result[V] {
214-
cacheResults := c.prepareResults()
215-
if c.needsRunning(cacheResults) {
219+
c.lock.Lock()
220+
defer c.lock.Unlock()
221+
cacheResults := c.prepareResultsLocked()
222+
if c.needsRunningLocked(cacheResults) {
216223
c.cacheResults = cacheResults
217224
c.result = c.mergeFn(c.cacheResults)
218225
}
@@ -238,7 +245,7 @@ func NewTransformer[T, V any](transformerFn func(Result[T]) Result[V], source Da
238245

239246
// NewSource creates a new cache that generates some data. This
240247
// will always be called since we don't know the origin of the data and
241-
// if it needs to be updated or not.
248+
// if it needs to be updated or not. sourceFn MUST be thread-safe.
242249
func NewSource[T any](sourceFn func() Result[T]) Data[T] {
243250
c := source[T](sourceFn)
244251
return &c
@@ -259,25 +266,24 @@ func NewStaticSource[T any](staticFn func() Result[T]) Data[T] {
259266
}
260267

261268
type static[T any] struct {
269+
once sync.Once
262270
fn func() Result[T]
263-
result *Result[T]
271+
result Result[T]
264272
}
265273

266274
func (c *static[T]) Get() Result[T] {
267-
if c.result == nil {
268-
result := c.fn()
269-
c.result = &result
270-
}
271-
return *c.result
275+
c.once.Do(func() {
276+
c.result = c.fn()
277+
})
278+
return c.result
272279
}
273280

274-
// Replaceable is a cache that carries the result even when the
275-
// cache is replaced. The cache can be replaced atomically (without any
276-
// lock held). This is the type that should typically be stored in
281+
// Replaceable is a cache that carries the result even when the cache is
282+
// replaced. This is the type that should typically be stored in
277283
// structs.
278284
type Replaceable[T any] struct {
279285
cache atomic.Pointer[Data[T]]
280-
result *Result[T]
286+
result atomic.Pointer[Result[T]]
281287
}
282288

283289
// Get retrieves the data from the underlying source. [Replaceable]
@@ -286,23 +292,21 @@ type Replaceable[T any] struct {
286292
// previously had returned a success, that success will be returned
287293
// instead. If the cache fails but we never returned a success, that
288294
// failure is returned.
289-
//
290-
// # Important note
291-
//
292-
// As all implementations of Get, this implementation is NOT
293-
// thread-safe. Please properly lock a mutex before calling this method
294-
// if you are in a multi-threaded environment, since this method will
295-
// update the cache and perform write operations.
296295
func (c *Replaceable[T]) Get() Result[T] {
297296
result := (*c.cache.Load()).Get()
298-
if result.Err != nil && c.result != nil && c.result.Err == nil {
299-
return *c.result
297+
298+
for {
299+
cResult := c.result.Load()
300+
if result.Err != nil && cResult != nil && cResult.Err == nil {
301+
return *cResult
302+
}
303+
if c.result.CompareAndSwap(cResult, &result) {
304+
return result
305+
}
300306
}
301-
c.result = &result
302-
return *c.result
303307
}
304308

305-
// Replace changes the cache in a thread-safe way.
309+
// Replace changes the cache.
306310
func (c *Replaceable[T]) Replace(cache Data[T]) {
307311
c.cache.Swap(&cache)
308312
}

Diff for: pkg/cached/cache_test.go

+76-2
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@ import (
2121
"encoding/json"
2222
"errors"
2323
"fmt"
24+
"math/rand"
2425
"sort"
2526
"strings"
27+
"sync"
2628
"testing"
29+
"time"
2730

2831
"k8s.io/kube-openapi/pkg/cached"
2932
)
@@ -1012,9 +1015,7 @@ func TestListMergerAlternateSourceError(t *testing.T) {
10121015
}
10131016

10141017
func TestListDAG(t *testing.T) {
1015-
count := 0
10161018
source := cached.NewSource(func() cached.Result[[]byte] {
1017-
count += 1
10181019
return cached.NewResultOK([]byte("source"), "source")
10191020
})
10201021
transformer1 := cached.NewTransformer(func(result cached.Result[[]byte]) cached.Result[[]byte] {
@@ -1054,3 +1055,76 @@ func TestListDAG(t *testing.T) {
10541055
t.Fatalf("expected etag = %v, got %v", want, result.Etag)
10551056
}
10561057
}
1058+
1059+
func randomString(length uint) string {
1060+
bytes := make([]byte, 6)
1061+
rand.Read(bytes)
1062+
return string(bytes)
1063+
1064+
}
1065+
1066+
func NewRandomSource() cached.Data[int64] {
1067+
return cached.NewStaticSource(func() cached.Result[int64] {
1068+
bytes := make([]byte, 6)
1069+
rand.Read(bytes)
1070+
return cached.NewResultOK(rand.Int63(), randomString(10))
1071+
})
1072+
}
1073+
1074+
func repeatedGet(data cached.Data[int64], end time.Time, wg *sync.WaitGroup) {
1075+
for time.Now().Before(end) {
1076+
_ = data.Get()
1077+
}
1078+
wg.Done()
1079+
}
1080+
1081+
func TestThreadSafe(t *testing.T) {
1082+
end := time.Now().Add(time.Second)
1083+
wg := sync.WaitGroup{}
1084+
static := NewRandomSource()
1085+
wg.Add(1)
1086+
go repeatedGet(static, end, &wg)
1087+
result := cached.NewResultOK(rand.Int63(), randomString(10))
1088+
wg.Add(1)
1089+
go repeatedGet(result, end, &wg)
1090+
replaceable := cached.Replaceable[int64]{}
1091+
replaceable.Replace(NewRandomSource())
1092+
wg.Add(1)
1093+
go repeatedGet(&replaceable, end, &wg)
1094+
wg.Add(1)
1095+
go func(r *cached.Replaceable[int64], end time.Time, wg *sync.WaitGroup) {
1096+
for time.Now().Before(end) {
1097+
r.Replace(NewRandomSource())
1098+
}
1099+
wg.Done()
1100+
}(&replaceable, end, &wg)
1101+
merger := cached.NewMerger(func(results map[string]cached.Result[int64]) cached.Result[int64] {
1102+
sum := int64(0)
1103+
for _, result := range results {
1104+
sum += result.Data
1105+
}
1106+
return cached.NewResultOK(sum, randomString(10))
1107+
}, map[string]cached.Data[int64]{
1108+
"one": NewRandomSource(),
1109+
"two": NewRandomSource(),
1110+
})
1111+
wg.Add(1)
1112+
go repeatedGet(merger, end, &wg)
1113+
transformer := cached.NewTransformer(func(result cached.Result[int64]) cached.Result[int64] {
1114+
return cached.NewResultOK(result.Data+5, randomString(10))
1115+
}, NewRandomSource())
1116+
wg.Add(1)
1117+
go repeatedGet(transformer, end, &wg)
1118+
1119+
listmerger := cached.NewListMerger(func(results []cached.Result[int64]) cached.Result[int64] {
1120+
sum := int64(0)
1121+
for i := range results {
1122+
sum += results[i].Data
1123+
}
1124+
return cached.NewResultOK(sum, randomString(10))
1125+
}, []cached.Data[int64]{static, result, &replaceable, merger, transformer})
1126+
wg.Add(1)
1127+
go repeatedGet(listmerger, end, &wg)
1128+
1129+
wg.Wait()
1130+
}

0 commit comments

Comments
 (0)