Skip to content

Commit 3c78f08

Browse files
committed
feat: Add waiter for object
Implements a wait for a check to pass against a typed object.
1 parent e106072 commit 3c78f08

File tree

2 files changed

+211
-0
lines changed

2 files changed

+211
-0
lines changed

pkg/wait/wait.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright 2024 Nutanix. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package wait
4+
5+
import (
6+
"context"
7+
"fmt"
8+
"time"
9+
10+
"k8s.io/apimachinery/pkg/util/wait"
11+
"sigs.k8s.io/controller-runtime/pkg/client"
12+
)
13+
14+
// CheckFailedError is used to determine whether the wait failed because wraps an error returned by a failed check.
15+
type CheckFailedError struct {
16+
cause error
17+
}
18+
19+
func (e *CheckFailedError) Error() string {
20+
return fmt.Sprintf("check failed: %s", e.cause)
21+
}
22+
23+
func (e *CheckFailedError) Is(target error) bool {
24+
_, ok := target.(*CheckFailedError)
25+
return ok
26+
}
27+
28+
func (e *CheckFailedError) Unwrap() error {
29+
return e.cause
30+
}
31+
32+
type ForObjectInput[T client.Object] struct {
33+
Reader client.Reader
34+
Target T
35+
Check func(ctx context.Context, obj T) (bool, error)
36+
Interval time.Duration
37+
Timeout time.Duration
38+
}
39+
40+
func ForObject[T client.Object](
41+
ctx context.Context,
42+
input ForObjectInput[T],
43+
) error {
44+
key := client.ObjectKeyFromObject(input.Target)
45+
46+
var getErr error
47+
waitErr := wait.PollUntilContextTimeout(
48+
ctx,
49+
input.Interval,
50+
input.Timeout,
51+
true,
52+
func(checkCtx context.Context) (bool, error) {
53+
if getErr = input.Reader.Get(checkCtx, key, input.Target); getErr != nil {
54+
// Retry if get fails.
55+
return false, nil
56+
}
57+
58+
if ok, err := input.Check(checkCtx, input.Target); err != nil {
59+
return false, &CheckFailedError{cause: err}
60+
} else {
61+
// Retry if check fails.
62+
return ok, nil
63+
}
64+
})
65+
66+
if wait.Interrupted(waitErr) {
67+
if getErr != nil {
68+
return fmt.Errorf("%w; last get error: %w", waitErr, getErr)
69+
}
70+
return fmt.Errorf("%w: check never passed", waitErr)
71+
}
72+
// waitErr is a CheckFailedError
73+
return waitErr
74+
}

pkg/wait/wait_test.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// Copyright 2024 Nutanix. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package wait
4+
5+
import (
6+
"context"
7+
"errors"
8+
"fmt"
9+
"testing"
10+
"time"
11+
12+
corev1 "k8s.io/api/core/v1"
13+
apierrors "k8s.io/apimachinery/pkg/api/errors"
14+
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
15+
"k8s.io/apimachinery/pkg/util/wait"
16+
"sigs.k8s.io/controller-runtime/pkg/client/fake"
17+
)
18+
19+
func TestWait(t *testing.T) {
20+
// We use the corev1.Namespace concrete type for the test, because we want to
21+
// verify behavior for a concrete type, and because the Wait function is
22+
// generic, and will behave identically for all concrete types.
23+
type args struct {
24+
input ForObjectInput[*corev1.Namespace]
25+
}
26+
tests := []struct {
27+
name string
28+
args args
29+
errCheck func(error) bool
30+
}{
31+
{
32+
name: "time out while get fails; report get error",
33+
args: args{
34+
input: ForObjectInput[*corev1.Namespace]{
35+
Reader: fake.NewFakeClient(),
36+
Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) {
37+
return true, nil
38+
},
39+
Interval: time.Nanosecond,
40+
Timeout: time.Millisecond,
41+
Target: &corev1.Namespace{
42+
TypeMeta: v1.TypeMeta{
43+
Kind: "Namespace",
44+
APIVersion: "v1",
45+
},
46+
ObjectMeta: v1.ObjectMeta{
47+
Name: "example",
48+
},
49+
},
50+
},
51+
},
52+
errCheck: func(err error) bool {
53+
return wait.Interrupted(err) &&
54+
apierrors.IsNotFound(err)
55+
},
56+
},
57+
{
58+
name: "time out while check returns false; no check error to report",
59+
args: args{
60+
input: ForObjectInput[*corev1.Namespace]{
61+
Reader: fake.NewFakeClient(
62+
&corev1.Namespace{
63+
TypeMeta: v1.TypeMeta{
64+
Kind: "Namespace",
65+
APIVersion: "v1",
66+
},
67+
ObjectMeta: v1.ObjectMeta{
68+
Name: "example",
69+
},
70+
},
71+
),
72+
Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) {
73+
return false, nil
74+
},
75+
Interval: time.Nanosecond,
76+
Timeout: time.Millisecond,
77+
Target: &corev1.Namespace{
78+
TypeMeta: v1.TypeMeta{
79+
Kind: "Namespace",
80+
APIVersion: "v1",
81+
},
82+
ObjectMeta: v1.ObjectMeta{
83+
Name: "example",
84+
},
85+
},
86+
},
87+
},
88+
errCheck: wait.Interrupted,
89+
},
90+
{
91+
name: "return immediately when check returns an error; report the error",
92+
args: args{
93+
input: ForObjectInput[*corev1.Namespace]{
94+
Reader: fake.NewFakeClient(
95+
&corev1.Namespace{
96+
TypeMeta: v1.TypeMeta{
97+
Kind: "Namespace",
98+
APIVersion: "v1",
99+
},
100+
ObjectMeta: v1.ObjectMeta{
101+
Name: "example",
102+
},
103+
},
104+
),
105+
Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) {
106+
return false, fmt.Errorf("condition failed")
107+
},
108+
Interval: time.Nanosecond,
109+
Timeout: time.Millisecond, Target: &corev1.Namespace{
110+
TypeMeta: v1.TypeMeta{
111+
Kind: "Namespace",
112+
APIVersion: "v1",
113+
},
114+
ObjectMeta: v1.ObjectMeta{
115+
Name: "example",
116+
},
117+
},
118+
},
119+
},
120+
errCheck: func(err error) bool {
121+
return errors.Is(err, &CheckFailedError{}) &&
122+
!wait.Interrupted(err)
123+
},
124+
},
125+
}
126+
for _, tt := range tests {
127+
t.Run(tt.name, func(t *testing.T) {
128+
err := ForObject(
129+
context.Background(),
130+
tt.args.input,
131+
)
132+
if !tt.errCheck(err) {
133+
t.Errorf("error did not pass check: %s", err)
134+
}
135+
})
136+
}
137+
}

0 commit comments

Comments
 (0)