Skip to content

Commit b499985

Browse files
authored
feat: Add waiter for object (#777)
**What problem does this PR solve?**: Implements a wait for a check to pass against a typed object. We'll use this in some lifecycle handers, e.g., in a future change for deploying ServiceLoadBalancer configuration to the remote cluster. (This is a copy of #762. I had to close that after #755 added required checks that can't be run from PRs from public forks. ) **Which issue(s) this PR fixes**: Fixes # **How Has This Been Tested?**: <!-- Please describe the tests that you ran to verify your changes. Provide output from the tests and any manual steps needed to replicate the tests. --> **Special notes for your reviewer**: <!-- Use this to provide any additional information to the reviewers. This may include: - Best way to review the PR. - Where the author wants the most review attention on. - etc. -->
1 parent a1470d4 commit b499985

File tree

2 files changed

+254
-0
lines changed

2 files changed

+254
-0
lines changed

pkg/wait/wait.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright 2024 Nutanix. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package wait
4+
5+
import (
6+
"context"
7+
"fmt"
8+
"time"
9+
10+
apierrors "k8s.io/apimachinery/pkg/api/errors"
11+
"k8s.io/apimachinery/pkg/util/wait"
12+
"sigs.k8s.io/controller-runtime/pkg/client"
13+
)
14+
15+
// CheckFailedError is used to determine whether the wait failed because wraps an error returned by a failed check.
16+
type CheckFailedError struct {
17+
cause error
18+
}
19+
20+
func (e *CheckFailedError) Error() string {
21+
return fmt.Sprintf("check failed: %s", e.cause)
22+
}
23+
24+
func (e *CheckFailedError) Is(target error) bool {
25+
_, ok := target.(*CheckFailedError)
26+
return ok
27+
}
28+
29+
func (e *CheckFailedError) Unwrap() error {
30+
return e.cause
31+
}
32+
33+
type ForObjectInput[T client.Object] struct {
34+
Reader client.Reader
35+
Target T
36+
Check func(ctx context.Context, obj T) (bool, error)
37+
Interval time.Duration
38+
Timeout time.Duration
39+
}
40+
41+
func ForObject[T client.Object](
42+
ctx context.Context,
43+
input ForObjectInput[T],
44+
) error {
45+
key := client.ObjectKeyFromObject(input.Target)
46+
47+
var getErr error
48+
waitErr := wait.PollUntilContextTimeout(
49+
ctx,
50+
input.Interval,
51+
input.Timeout,
52+
true,
53+
func(checkCtx context.Context) (bool, error) {
54+
if getErr = input.Reader.Get(checkCtx, key, input.Target); getErr != nil {
55+
if apierrors.IsNotFound(getErr) {
56+
return false, nil
57+
}
58+
return false, getErr
59+
}
60+
61+
if ok, err := input.Check(checkCtx, input.Target); err != nil {
62+
return false, &CheckFailedError{cause: err}
63+
} else {
64+
// Retry if check fails.
65+
return ok, nil
66+
}
67+
})
68+
69+
if wait.Interrupted(waitErr) {
70+
if getErr != nil {
71+
return fmt.Errorf("%w; last get error: %w", waitErr, getErr)
72+
}
73+
return fmt.Errorf("%w: check never passed", waitErr)
74+
}
75+
// waitErr is a CheckFailedError
76+
return waitErr
77+
}

pkg/wait/wait_test.go

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
// Copyright 2024 Nutanix. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package wait
4+
5+
import (
6+
"context"
7+
"errors"
8+
"fmt"
9+
"testing"
10+
"time"
11+
12+
corev1 "k8s.io/api/core/v1"
13+
apierrors "k8s.io/apimachinery/pkg/api/errors"
14+
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
15+
"k8s.io/apimachinery/pkg/util/wait"
16+
"sigs.k8s.io/controller-runtime/pkg/client"
17+
"sigs.k8s.io/controller-runtime/pkg/client/fake"
18+
)
19+
20+
var errBrokenReader = errors.New("broken")
21+
22+
type brokenReader struct{}
23+
24+
func (r *brokenReader) Get(
25+
ctx context.Context,
26+
key client.ObjectKey,
27+
obj client.Object,
28+
opts ...client.GetOption,
29+
) error {
30+
return errBrokenReader
31+
}
32+
33+
func (r *brokenReader) List(
34+
ctx context.Context,
35+
list client.ObjectList,
36+
opts ...client.ListOption,
37+
) error {
38+
return errBrokenReader
39+
}
40+
41+
var _ client.Reader = &brokenReader{}
42+
43+
func TestWait(t *testing.T) {
44+
tests := []struct {
45+
name string
46+
// We use the corev1.Namespace concrete type for the test, because we want to
47+
// verify behavior for a concrete type, and because the Wait function is
48+
// generic, and will behave identically for all concrete types.
49+
input ForObjectInput[*corev1.Namespace]
50+
errCheck func(error) bool
51+
}{
52+
{
53+
name: "time out while get does not find object; report get error",
54+
input: ForObjectInput[*corev1.Namespace]{
55+
Reader: fake.NewFakeClient(),
56+
Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) {
57+
return true, nil
58+
},
59+
Interval: time.Nanosecond,
60+
Timeout: time.Millisecond,
61+
Target: &corev1.Namespace{
62+
TypeMeta: v1.TypeMeta{
63+
Kind: "Namespace",
64+
APIVersion: "v1",
65+
},
66+
ObjectMeta: v1.ObjectMeta{
67+
Name: "example",
68+
},
69+
},
70+
},
71+
errCheck: func(err error) bool {
72+
return wait.Interrupted(err) &&
73+
apierrors.IsNotFound(err)
74+
},
75+
},
76+
{
77+
name: "return immediately when get fails; report get error",
78+
input: ForObjectInput[*corev1.Namespace]{
79+
Reader: &brokenReader{},
80+
Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) {
81+
return true, nil
82+
},
83+
Interval: time.Nanosecond,
84+
Timeout: time.Millisecond,
85+
Target: &corev1.Namespace{
86+
TypeMeta: v1.TypeMeta{
87+
Kind: "Namespace",
88+
APIVersion: "v1",
89+
},
90+
ObjectMeta: v1.ObjectMeta{
91+
Name: "example",
92+
},
93+
},
94+
},
95+
errCheck: func(err error) bool {
96+
return !wait.Interrupted(err) &&
97+
!apierrors.IsNotFound(err) &&
98+
errors.Is(err, errBrokenReader)
99+
},
100+
},
101+
{
102+
name: "time out while check returns false; no check error to report",
103+
input: ForObjectInput[*corev1.Namespace]{
104+
Reader: fake.NewFakeClient(
105+
&corev1.Namespace{
106+
TypeMeta: v1.TypeMeta{
107+
Kind: "Namespace",
108+
APIVersion: "v1",
109+
},
110+
ObjectMeta: v1.ObjectMeta{
111+
Name: "example",
112+
},
113+
},
114+
),
115+
Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) {
116+
return false, nil
117+
},
118+
Interval: time.Nanosecond,
119+
Timeout: time.Millisecond,
120+
Target: &corev1.Namespace{
121+
TypeMeta: v1.TypeMeta{
122+
Kind: "Namespace",
123+
APIVersion: "v1",
124+
},
125+
ObjectMeta: v1.ObjectMeta{
126+
Name: "example",
127+
},
128+
},
129+
},
130+
errCheck: wait.Interrupted,
131+
},
132+
{
133+
name: "return immediately when check returns an error; report the error",
134+
input: ForObjectInput[*corev1.Namespace]{
135+
Reader: fake.NewFakeClient(
136+
&corev1.Namespace{
137+
TypeMeta: v1.TypeMeta{
138+
Kind: "Namespace",
139+
APIVersion: "v1",
140+
},
141+
ObjectMeta: v1.ObjectMeta{
142+
Name: "example",
143+
},
144+
},
145+
),
146+
Check: func(_ context.Context, _ *corev1.Namespace) (bool, error) {
147+
return false, fmt.Errorf("condition failed")
148+
},
149+
Interval: time.Nanosecond,
150+
Timeout: time.Millisecond, Target: &corev1.Namespace{
151+
TypeMeta: v1.TypeMeta{
152+
Kind: "Namespace",
153+
APIVersion: "v1",
154+
},
155+
ObjectMeta: v1.ObjectMeta{
156+
Name: "example",
157+
},
158+
},
159+
},
160+
errCheck: func(err error) bool {
161+
return errors.Is(err, &CheckFailedError{}) &&
162+
!wait.Interrupted(err)
163+
},
164+
},
165+
}
166+
for _, tt := range tests {
167+
t.Run(tt.name, func(t *testing.T) {
168+
err := ForObject(
169+
context.Background(),
170+
tt.input,
171+
)
172+
if !tt.errCheck(err) {
173+
t.Errorf("error did not pass check: %s", err)
174+
}
175+
})
176+
}
177+
}

0 commit comments

Comments
 (0)