Skip to content

Commit ef873ba

Browse files
authored
fix: carefully retry restarting HNS if it hangs (#3529)
* fix: carefully retry restarting HNS if it hangs Signed-off-by: Evan Baker <[email protected]> * retry start, check stop pending Signed-off-by: Evan Baker <[email protected]> --------- Signed-off-by: Evan Baker <[email protected]>
1 parent a88584e commit ef873ba

File tree

2 files changed

+290
-17
lines changed

2 files changed

+290
-17
lines changed

Diff for: platform/os_windows.go

+87-17
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/Azure/azure-container-networking/log"
1818
"github.com/Azure/azure-container-networking/platform/windows/adapter"
1919
"github.com/Azure/azure-container-networking/platform/windows/adapter/mellanox"
20+
"github.com/avast/retry-go/v4"
2021
"github.com/pkg/errors"
2122
"go.uber.org/zap"
2223
"golang.org/x/sys/windows"
@@ -302,32 +303,101 @@ func restartHNS(ctx context.Context) error {
302303
}
303304
defer service.Close()
304305
// Stop the service
305-
_, err = service.Control(svc.Stop)
306-
if err != nil {
307-
return errors.Wrap(err, "could not stop service")
306+
log.Printf("Stopping HNS service")
307+
_ = retry.Do(
308+
tryStopServiceFn(ctx, service),
309+
retry.UntilSucceeded(),
310+
retry.Context(ctx),
311+
)
312+
// Start the service again
313+
log.Printf("Starting HNS service")
314+
_ = retry.Do(
315+
tryStartServiceFn(ctx, service),
316+
retry.UntilSucceeded(),
317+
retry.Context(ctx),
318+
)
319+
log.Printf("HNS service started")
320+
return nil
321+
}
322+
323+
type managedService interface {
324+
Control(control svc.Cmd) (svc.Status, error)
325+
Query() (svc.Status, error)
326+
Start(args ...string) error
327+
}
328+
329+
func tryStartServiceFn(ctx context.Context, service managedService) func() error {
330+
shouldStart := func(state svc.State) bool {
331+
return !(state == svc.Running || state == svc.StartPending)
308332
}
309-
// Wait for the service to stop
310-
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
311-
defer ticker.Stop()
312-
for { // hacky cancellable do-while
333+
return func() error {
313334
status, err := service.Query()
314335
if err != nil {
315336
return errors.Wrap(err, "could not query service status")
316337
}
317-
if status.State == svc.Stopped {
318-
break
338+
if shouldStart(status.State) {
339+
err = service.Start()
340+
if err != nil {
341+
return errors.Wrap(err, "could not start service")
342+
}
319343
}
320-
select {
321-
case <-ctx.Done():
322-
return errors.New("context cancelled")
323-
case <-ticker.C:
344+
// Wait for the service to start
345+
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
346+
defer ticker.Stop()
347+
for {
348+
status, err := service.Query()
349+
if err != nil {
350+
return errors.Wrap(err, "could not query service status")
351+
}
352+
if status.State == svc.Running {
353+
log.Printf("service started")
354+
break
355+
}
356+
select {
357+
case <-ctx.Done():
358+
return errors.Wrap(ctx.Err(), "context cancelled")
359+
case <-ticker.C:
360+
}
324361
}
362+
return nil
325363
}
326-
// Start the service again
327-
if err := service.Start(); err != nil {
328-
return errors.Wrap(err, "could not start service")
364+
}
365+
366+
func tryStopServiceFn(ctx context.Context, service managedService) func() error {
367+
shouldStop := func(state svc.State) bool {
368+
return !(state == svc.Stopped || state == svc.StopPending)
369+
}
370+
return func() error {
371+
status, err := service.Query()
372+
if err != nil {
373+
return errors.Wrap(err, "could not query service status")
374+
}
375+
if shouldStop(status.State) {
376+
_, err = service.Control(svc.Stop)
377+
if err != nil {
378+
return errors.Wrap(err, "could not stop service")
379+
}
380+
}
381+
// Wait for the service to stop
382+
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
383+
defer ticker.Stop()
384+
for {
385+
status, err := service.Query()
386+
if err != nil {
387+
return errors.Wrap(err, "could not query service status")
388+
}
389+
if status.State == svc.Stopped {
390+
log.Printf("service stopped")
391+
break
392+
}
393+
select {
394+
case <-ctx.Done():
395+
return errors.Wrap(ctx.Err(), "context cancelled")
396+
case <-ticker.C:
397+
}
398+
}
399+
return nil
329400
}
330-
return nil
331401
}
332402

333403
func HasMellanoxAdapter() bool {

Diff for: platform/os_windows_test.go

+203
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/golang/mock/gomock"
1212
"github.com/stretchr/testify/assert"
1313
"github.com/stretchr/testify/require"
14+
"golang.org/x/sys/windows/svc"
1415
)
1516

1617
var errTestFailure = errors.New("test failure")
@@ -146,3 +147,205 @@ func TestExecuteCommandTimeout(t *testing.T) {
146147
_, err := client.ExecuteCommand(context.Background(), "ping", "-t", "localhost")
147148
require.Error(t, err)
148149
}
150+
151+
type mockManagedService struct {
152+
queryFuncs []func() (svc.Status, error)
153+
controlFunc func(svc.Cmd) (svc.Status, error)
154+
startFunc func(args ...string) error
155+
}
156+
157+
func (m *mockManagedService) Query() (svc.Status, error) {
158+
queryFunc := m.queryFuncs[0]
159+
m.queryFuncs = m.queryFuncs[1:]
160+
return queryFunc()
161+
}
162+
163+
func (m *mockManagedService) Control(cmd svc.Cmd) (svc.Status, error) {
164+
return m.controlFunc(cmd)
165+
}
166+
167+
func (m *mockManagedService) Start(args ...string) error {
168+
return m.startFunc(args...)
169+
}
170+
171+
func TestTryStopServiceFn(t *testing.T) {
172+
tests := []struct {
173+
name string
174+
queryFuncs []func() (svc.Status, error)
175+
controlFunc func(svc.Cmd) (svc.Status, error)
176+
expectError bool
177+
}{
178+
{
179+
name: "Service already stopped",
180+
queryFuncs: []func() (svc.Status, error){
181+
func() (svc.Status, error) {
182+
return svc.Status{State: svc.Stopped}, nil
183+
},
184+
func() (svc.Status, error) {
185+
return svc.Status{State: svc.Stopped}, nil
186+
},
187+
},
188+
controlFunc: nil,
189+
expectError: false,
190+
},
191+
{
192+
name: "Service running and stops successfully",
193+
queryFuncs: []func() (svc.Status, error){
194+
func() (svc.Status, error) {
195+
return svc.Status{State: svc.Running}, nil
196+
},
197+
func() (svc.Status, error) {
198+
return svc.Status{State: svc.Stopped}, nil
199+
},
200+
},
201+
controlFunc: func(svc.Cmd) (svc.Status, error) {
202+
return svc.Status{State: svc.Stopped}, nil
203+
},
204+
expectError: false,
205+
},
206+
{
207+
name: "Service running and stops after multiple attempts",
208+
queryFuncs: []func() (svc.Status, error){
209+
func() (svc.Status, error) {
210+
return svc.Status{State: svc.Running}, nil
211+
},
212+
func() (svc.Status, error) {
213+
return svc.Status{State: svc.Running}, nil
214+
},
215+
func() (svc.Status, error) {
216+
return svc.Status{State: svc.Running}, nil
217+
},
218+
func() (svc.Status, error) {
219+
return svc.Status{State: svc.Stopped}, nil
220+
},
221+
},
222+
controlFunc: func(svc.Cmd) (svc.Status, error) {
223+
return svc.Status{State: svc.Stopped}, nil
224+
},
225+
expectError: false,
226+
},
227+
{
228+
name: "Service running and fails to stop",
229+
queryFuncs: []func() (svc.Status, error){
230+
func() (svc.Status, error) {
231+
return svc.Status{State: svc.Running}, nil
232+
},
233+
},
234+
controlFunc: func(svc.Cmd) (svc.Status, error) {
235+
return svc.Status{State: svc.Running}, errors.New("failed to stop service") //nolint:err113 // test error
236+
},
237+
expectError: true,
238+
},
239+
{
240+
name: "Service query fails",
241+
queryFuncs: []func() (svc.Status, error){
242+
func() (svc.Status, error) {
243+
return svc.Status{}, errors.New("failed to query service status") //nolint:err113 // test error
244+
},
245+
},
246+
controlFunc: nil,
247+
expectError: true,
248+
},
249+
}
250+
for _, tt := range tests {
251+
t.Run(tt.name, func(t *testing.T) {
252+
service := &mockManagedService{
253+
queryFuncs: tt.queryFuncs,
254+
controlFunc: tt.controlFunc,
255+
}
256+
err := tryStopServiceFn(context.Background(), service)()
257+
if tt.expectError {
258+
assert.Error(t, err)
259+
return
260+
}
261+
assert.NoError(t, err)
262+
})
263+
}
264+
}
265+
266+
func TestTryStartServiceFn(t *testing.T) {
267+
tests := []struct {
268+
name string
269+
queryFuncs []func() (svc.Status, error)
270+
startFunc func(...string) error
271+
expectError bool
272+
}{
273+
{
274+
name: "Service already running",
275+
queryFuncs: []func() (svc.Status, error){
276+
func() (svc.Status, error) {
277+
return svc.Status{State: svc.Running}, nil
278+
},
279+
func() (svc.Status, error) {
280+
return svc.Status{State: svc.Running}, nil
281+
},
282+
},
283+
startFunc: nil,
284+
expectError: false,
285+
},
286+
{
287+
name: "Service already starting",
288+
queryFuncs: []func() (svc.Status, error){
289+
func() (svc.Status, error) {
290+
return svc.Status{State: svc.StartPending}, nil
291+
},
292+
func() (svc.Status, error) {
293+
return svc.Status{State: svc.Running}, nil
294+
},
295+
},
296+
startFunc: nil,
297+
expectError: false,
298+
},
299+
{
300+
name: "Service starts successfully",
301+
queryFuncs: []func() (svc.Status, error){
302+
func() (svc.Status, error) {
303+
return svc.Status{State: svc.Stopped}, nil
304+
},
305+
func() (svc.Status, error) {
306+
return svc.Status{State: svc.Running}, nil
307+
},
308+
},
309+
startFunc: func(...string) error {
310+
return nil
311+
},
312+
expectError: false,
313+
},
314+
{
315+
name: "Service fails to start",
316+
queryFuncs: []func() (svc.Status, error){
317+
func() (svc.Status, error) {
318+
return svc.Status{State: svc.Stopped}, nil
319+
},
320+
},
321+
startFunc: func(...string) error {
322+
return errors.New("failed to start service") //nolint:err113 // test error
323+
},
324+
expectError: true,
325+
},
326+
{
327+
name: "Service query fails",
328+
queryFuncs: []func() (svc.Status, error){
329+
func() (svc.Status, error) {
330+
return svc.Status{}, errors.New("failed to query service status") //nolint:err113 // test error
331+
},
332+
},
333+
startFunc: nil,
334+
expectError: true,
335+
},
336+
}
337+
for _, tt := range tests {
338+
t.Run(tt.name, func(t *testing.T) {
339+
service := &mockManagedService{
340+
queryFuncs: tt.queryFuncs,
341+
startFunc: tt.startFunc,
342+
}
343+
err := tryStartServiceFn(context.Background(), service)()
344+
if tt.expectError {
345+
assert.Error(t, err)
346+
return
347+
}
348+
assert.NoError(t, err)
349+
})
350+
}
351+
}

0 commit comments

Comments
 (0)