Skip to content

Commit 9e6b4d7

Browse files
dandersonMaisem Alibradfitz
committed
types/lazy: helpers for lazily computed values
Co-authored-by: Maisem Ali <[email protected]> Co-authored-by: Brad Fitzpatrick <[email protected]> Signed-off-by: David Anderson <[email protected]>
1 parent 5bca44d commit 9e6b4d7

File tree

4 files changed

+477
-0
lines changed

4 files changed

+477
-0
lines changed

types/lazy/lazy.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright (c) Tailscale Inc & AUTHORS
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
// Package lazy provides types for lazily initialized values.
5+
package lazy
6+
7+
import "sync"
8+
9+
// SyncValue is a lazily computed value.
10+
//
11+
// Use either Get or GetErr, depending on whether your fill function returns an
12+
// error.
13+
//
14+
// Recursive use of a SyncValue from its own fill function will deadlock.
15+
//
16+
// SyncValue is safe for concurrent use.
17+
type SyncValue[T any] struct {
18+
once sync.Once
19+
v T
20+
err error
21+
}
22+
23+
// Set attempts to set z's value to val, and reports whether it succeeded.
24+
// Set only succeeds if none of Get/GetErr/Set have been called before.
25+
func (z *SyncValue[T]) Set(val T) bool {
26+
var wasSet bool
27+
z.once.Do(func() {
28+
z.v = val
29+
wasSet = true
30+
})
31+
return wasSet
32+
}
33+
34+
// MustSet sets z's value to val, or panics if z already has a value.
35+
func (z *SyncValue[T]) MustSet(val T) {
36+
if !z.Set(val) {
37+
panic("Set after already filled")
38+
}
39+
}
40+
41+
// Get returns z's value, calling fill to compute it if necessary.
42+
// f is called at most once.
43+
func (z *SyncValue[T]) Get(fill func() T) T {
44+
z.once.Do(func() { z.v = fill() })
45+
return z.v
46+
}
47+
48+
// GetErr returns z's value, calling fill to compute it if necessary.
49+
// f is called at most once, and z remembers both of fill's outputs.
50+
func (z *SyncValue[T]) GetErr(fill func() (T, error)) (T, error) {
51+
z.once.Do(func() { z.v, z.err = fill() })
52+
return z.v, z.err
53+
}
54+
55+
// SyncFunc wraps a function to make it lazy.
56+
//
57+
// The returned function calls fill the first time it's called, and returns
58+
// fill's result on every subsequent call.
59+
//
60+
// The returned function is safe for concurrent use.
61+
func SyncFunc[T any](fill func() T) func() T {
62+
var (
63+
once sync.Once
64+
v T
65+
)
66+
return func() T {
67+
once.Do(func() { v = fill() })
68+
return v
69+
}
70+
}
71+
72+
// SyncFuncErr wraps a function to make it lazy.
73+
//
74+
// The returned function calls fill the first time it's called, and returns
75+
// fill's results on every subsequent call.
76+
//
77+
// The returned function is safe for concurrent use.
78+
func SyncFuncErr[T any](fill func() (T, error)) func() (T, error) {
79+
var (
80+
once sync.Once
81+
v T
82+
err error
83+
)
84+
return func() (T, error) {
85+
once.Do(func() { v, err = fill() })
86+
return v, err
87+
}
88+
}

types/lazy/sync_test.go

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
// Copyright (c) Tailscale Inc & AUTHORS
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
package lazy
5+
6+
import (
7+
"errors"
8+
"sync"
9+
"testing"
10+
)
11+
12+
func TestSyncValue(t *testing.T) {
13+
var lt SyncValue[int]
14+
n := int(testing.AllocsPerRun(1000, func() {
15+
got := lt.Get(fortyTwo)
16+
if got != 42 {
17+
t.Fatalf("got %v; want 42", got)
18+
}
19+
}))
20+
if n != 0 {
21+
t.Errorf("allocs = %v; want 0", n)
22+
}
23+
}
24+
25+
func TestSyncValueErr(t *testing.T) {
26+
var lt SyncValue[int]
27+
n := int(testing.AllocsPerRun(1000, func() {
28+
got, err := lt.GetErr(func() (int, error) {
29+
return 42, nil
30+
})
31+
if got != 42 || err != nil {
32+
t.Fatalf("got %v, %v; want 42, nil", got, err)
33+
}
34+
}))
35+
if n != 0 {
36+
t.Errorf("allocs = %v; want 0", n)
37+
}
38+
39+
var lterr SyncValue[int]
40+
wantErr := errors.New("test error")
41+
n = int(testing.AllocsPerRun(1000, func() {
42+
got, err := lterr.GetErr(func() (int, error) {
43+
return 0, wantErr
44+
})
45+
if got != 0 || err != wantErr {
46+
t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr)
47+
}
48+
}))
49+
if n != 0 {
50+
t.Errorf("allocs = %v; want 0", n)
51+
}
52+
}
53+
54+
func TestSyncValueSet(t *testing.T) {
55+
var lt SyncValue[int]
56+
if !lt.Set(42) {
57+
t.Fatalf("Set failed")
58+
}
59+
if lt.Set(43) {
60+
t.Fatalf("Set succeeded after first Set")
61+
}
62+
n := int(testing.AllocsPerRun(1000, func() {
63+
got := lt.Get(fortyTwo)
64+
if got != 42 {
65+
t.Fatalf("got %v; want 42", got)
66+
}
67+
}))
68+
if n != 0 {
69+
t.Errorf("allocs = %v; want 0", n)
70+
}
71+
}
72+
73+
func TestSyncValueMustSet(t *testing.T) {
74+
var lt SyncValue[int]
75+
lt.MustSet(42)
76+
defer func() {
77+
if e := recover(); e == nil {
78+
t.Errorf("unexpected success; want panic")
79+
}
80+
}()
81+
lt.MustSet(43)
82+
}
83+
84+
func TestSyncValueConcurrent(t *testing.T) {
85+
var (
86+
lt SyncValue[int]
87+
wg sync.WaitGroup
88+
start = make(chan struct{})
89+
routines = 10000
90+
)
91+
wg.Add(routines)
92+
for i := 0; i < routines; i++ {
93+
go func() {
94+
defer wg.Done()
95+
// Every goroutine waits for the go signal, so that more of them
96+
// have a chance to race on the initial Get than with sequential
97+
// goroutine starts.
98+
<-start
99+
got := lt.Get(fortyTwo)
100+
if got != 42 {
101+
t.Errorf("got %v; want 42", got)
102+
}
103+
}()
104+
}
105+
close(start)
106+
wg.Wait()
107+
}
108+
109+
func TestSyncFunc(t *testing.T) {
110+
f := SyncFunc(fortyTwo)
111+
112+
n := int(testing.AllocsPerRun(1000, func() {
113+
got := f()
114+
if got != 42 {
115+
t.Fatalf("got %v; want 42", got)
116+
}
117+
}))
118+
if n != 0 {
119+
t.Errorf("allocs = %v; want 0", n)
120+
}
121+
}
122+
123+
func TestSyncFuncErr(t *testing.T) {
124+
f := SyncFuncErr(func() (int, error) {
125+
return 42, nil
126+
})
127+
n := int(testing.AllocsPerRun(1000, func() {
128+
got, err := f()
129+
if got != 42 || err != nil {
130+
t.Fatalf("got %v, %v; want 42, nil", got, err)
131+
}
132+
}))
133+
if n != 0 {
134+
t.Errorf("allocs = %v; want 0", n)
135+
}
136+
137+
wantErr := errors.New("test error")
138+
f = SyncFuncErr(func() (int, error) {
139+
return 0, wantErr
140+
})
141+
n = int(testing.AllocsPerRun(1000, func() {
142+
got, err := f()
143+
if got != 0 || err != wantErr {
144+
t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr)
145+
}
146+
}))
147+
if n != 0 {
148+
t.Errorf("allocs = %v; want 0", n)
149+
}
150+
}

types/lazy/unsync.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// Copyright (c) Tailscale Inc & AUTHORS
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
package lazy
5+
6+
// GValue is a lazily computed value.
7+
//
8+
// Use either Get or GetErr, depending on whether your fill function returns an
9+
// error.
10+
//
11+
// Recursive use of a GValue from its own fill function will panic.
12+
//
13+
// GValue is not safe for concurrent use. (Mnemonic: G is for one Goroutine,
14+
// which isn't strictly true if you provide your own synchronization between
15+
// goroutines, but in practice most of our callers have been using it within
16+
// a single goroutine.)
17+
type GValue[T any] struct {
18+
done bool
19+
calling bool
20+
V T
21+
err error
22+
}
23+
24+
// Set attempts to set z's value to val, and reports whether it succeeded.
25+
// Set only succeeds if none of Get/GetErr/Set have been called before.
26+
func (z *GValue[T]) Set(v T) bool {
27+
if z.done {
28+
return false
29+
}
30+
if z.calling {
31+
panic("Set while Get fill is running")
32+
}
33+
z.V = v
34+
z.done = true
35+
return true
36+
}
37+
38+
// MustSet sets z's value to val, or panics if z already has a value.
39+
func (z *GValue[T]) MustSet(val T) {
40+
if !z.Set(val) {
41+
panic("Set after already filled")
42+
}
43+
}
44+
45+
// Get returns z's value, calling fill to compute it if necessary.
46+
// f is called at most once.
47+
func (z *GValue[T]) Get(fill func() T) T {
48+
if !z.done {
49+
if z.calling {
50+
panic("recursive lazy fill")
51+
}
52+
z.calling = true
53+
z.V = fill()
54+
z.done = true
55+
z.calling = false
56+
}
57+
return z.V
58+
}
59+
60+
// GetErr returns z's value, calling fill to compute it if necessary.
61+
// f is called at most once, and z remembers both of fill's outputs.
62+
func (z *GValue[T]) GetErr(fill func() (T, error)) (T, error) {
63+
if !z.done {
64+
if z.calling {
65+
panic("recursive lazy fill")
66+
}
67+
z.calling = true
68+
z.V, z.err = fill()
69+
z.done = true
70+
z.calling = false
71+
}
72+
return z.V, z.err
73+
}
74+
75+
// GFunc wraps a function to make it lazy.
76+
//
77+
// The returned function calls fill the first time it's called, and returns
78+
// fill's result on every subsequent call.
79+
//
80+
// The returned function is not safe for concurrent use.
81+
func GFunc[T any](fill func() T) func() T {
82+
var v GValue[T]
83+
return func() T {
84+
return v.Get(fill)
85+
}
86+
}
87+
88+
// SyncFuncErr wraps a function to make it lazy.
89+
//
90+
// The returned function calls fill the first time it's called, and returns
91+
// fill's results on every subsequent call.
92+
//
93+
// The returned function is not safe for concurrent use.
94+
func GFuncErr[T any](fill func() (T, error)) func() (T, error) {
95+
var v GValue[T]
96+
return func() (T, error) {
97+
return v.GetErr(fill)
98+
}
99+
}

0 commit comments

Comments
 (0)