Skip to content

Commit a340159

Browse files
committed
feat: Add waiter for object
Implements a wait for a check to pass against a typed object.
1 parent e106072 commit a340159

File tree

2 files changed

+207
-0
lines changed

2 files changed

+207
-0
lines changed

pkg/wait/wait.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package wait
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"time"
7+
8+
"k8s.io/apimachinery/pkg/util/wait"
9+
"sigs.k8s.io/controller-runtime/pkg/client"
10+
)
11+
12+
// CheckFailedError is used to determine whether the wait failed because wraps an error returned by a failed check.
13+
type CheckFailedError struct {
14+
cause error
15+
}
16+
17+
func (e *CheckFailedError) Error() string {
18+
return fmt.Sprintf("check failed: %s", e.cause)
19+
}
20+
21+
func (e *CheckFailedError) Is(target error) bool {
22+
_, ok := target.(*CheckFailedError)
23+
return ok
24+
}
25+
26+
func (e *CheckFailedError) Unwrap() error {
27+
return e.cause
28+
}
29+
30+
type ForObjectInput[T client.Object] struct {
31+
Reader client.Reader
32+
Target T
33+
Check func(ctx context.Context, obj T) (bool, error)
34+
Interval time.Duration
35+
Timeout time.Duration
36+
}
37+
38+
func ForObject[T client.Object](
39+
ctx context.Context,
40+
input ForObjectInput[T],
41+
) error {
42+
key := client.ObjectKeyFromObject(input.Target)
43+
44+
var getErr error
45+
waitErr := wait.PollUntilContextTimeout(
46+
ctx,
47+
input.Interval,
48+
input.Timeout,
49+
true,
50+
func(checkCtx context.Context) (bool, error) {
51+
if getErr = input.Reader.Get(checkCtx, key, input.Target); getErr != nil {
52+
// Retry if get fails.
53+
return false, nil
54+
}
55+
56+
if ok, err := input.Check(checkCtx, input.Target); err != nil {
57+
return false, &CheckFailedError{cause: err}
58+
} else {
59+
// Retry if check fails.
60+
return ok, nil
61+
}
62+
})
63+
64+
if wait.Interrupted(waitErr) {
65+
if getErr != nil {
66+
return fmt.Errorf("%w; last get error: %w", waitErr, getErr)
67+
}
68+
return fmt.Errorf("%w: check never passed", waitErr)
69+
}
70+
// waitErr is a CheckFailedError
71+
return waitErr
72+
}

pkg/wait/wait_test.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package wait
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"testing"
8+
"time"
9+
10+
corev1 "k8s.io/api/core/v1"
11+
apierrors "k8s.io/apimachinery/pkg/api/errors"
12+
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
13+
"k8s.io/apimachinery/pkg/util/wait"
14+
"sigs.k8s.io/controller-runtime/pkg/client/fake"
15+
)
16+
17+
func TestWait(t *testing.T) {
18+
// We use the corev1.Namespace concrete type for the test, because we want to
19+
// verify behavior for a concrete type, and because the Wait function is
20+
// generic, and will behave identically for all concrete types.
21+
type args struct {
22+
input ForObjectInput[*corev1.Namespace]
23+
}
24+
tests := []struct {
25+
name string
26+
args args
27+
errCheck func(error) bool
28+
}{
29+
{
30+
name: "time out while get fails; report get error",
31+
args: args{
32+
input: ForObjectInput[*corev1.Namespace]{
33+
Reader: fake.NewFakeClient(),
34+
Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) {
35+
return true, nil
36+
},
37+
Interval: time.Nanosecond,
38+
Timeout: time.Millisecond,
39+
Target: &corev1.Namespace{
40+
TypeMeta: v1.TypeMeta{
41+
Kind: "Namespace",
42+
APIVersion: "v1",
43+
},
44+
ObjectMeta: v1.ObjectMeta{
45+
Name: "example",
46+
},
47+
},
48+
},
49+
},
50+
errCheck: func(err error) bool {
51+
return wait.Interrupted(err) &&
52+
apierrors.IsNotFound(err)
53+
},
54+
},
55+
{
56+
name: "time out while check returns false; no check error to report",
57+
args: args{
58+
input: ForObjectInput[*corev1.Namespace]{
59+
Reader: fake.NewFakeClient(
60+
&corev1.Namespace{
61+
TypeMeta: v1.TypeMeta{
62+
Kind: "Namespace",
63+
APIVersion: "v1",
64+
},
65+
ObjectMeta: v1.ObjectMeta{
66+
Name: "example",
67+
},
68+
},
69+
),
70+
Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) {
71+
return false, nil
72+
},
73+
Interval: time.Nanosecond,
74+
Timeout: time.Millisecond,
75+
Target: &corev1.Namespace{
76+
TypeMeta: v1.TypeMeta{
77+
Kind: "Namespace",
78+
APIVersion: "v1",
79+
},
80+
ObjectMeta: v1.ObjectMeta{
81+
Name: "example",
82+
},
83+
},
84+
},
85+
},
86+
errCheck: wait.Interrupted,
87+
},
88+
{
89+
name: "return immediately when check returns an error; report the error",
90+
args: args{
91+
input: ForObjectInput[*corev1.Namespace]{
92+
Reader: fake.NewFakeClient(
93+
&corev1.Namespace{
94+
TypeMeta: v1.TypeMeta{
95+
Kind: "Namespace",
96+
APIVersion: "v1",
97+
},
98+
ObjectMeta: v1.ObjectMeta{
99+
Name: "example",
100+
},
101+
},
102+
),
103+
Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) {
104+
return false, fmt.Errorf("condition failed")
105+
},
106+
Interval: time.Nanosecond,
107+
Timeout: time.Millisecond, Target: &corev1.Namespace{
108+
TypeMeta: v1.TypeMeta{
109+
Kind: "Namespace",
110+
APIVersion: "v1",
111+
},
112+
ObjectMeta: v1.ObjectMeta{
113+
Name: "example",
114+
},
115+
},
116+
},
117+
},
118+
errCheck: func(err error) bool {
119+
return errors.Is(err, &CheckFailedError{}) &&
120+
!wait.Interrupted(err)
121+
},
122+
},
123+
}
124+
for _, tt := range tests {
125+
t.Run(tt.name, func(t *testing.T) {
126+
err := ForObject(
127+
context.Background(),
128+
tt.args.input,
129+
)
130+
if !tt.errCheck(err) {
131+
t.Errorf("error did not pass check: %s", err)
132+
}
133+
})
134+
}
135+
}

0 commit comments

Comments
 (0)