diff --git a/platform/os_windows.go b/platform/os_windows.go index 0a6e3e78d1..d84e335faa 100644 --- a/platform/os_windows.go +++ b/platform/os_windows.go @@ -16,6 +16,7 @@ import ( "github.com/Azure/azure-container-networking/log" "github.com/Azure/azure-container-networking/platform/windows/adapter" "github.com/Azure/azure-container-networking/platform/windows/adapter/mellanox" + "github.com/avast/retry-go/v4" "github.com/pkg/errors" "go.uber.org/zap" "golang.org/x/sys/windows" @@ -232,32 +233,107 @@ func restartHNS(ctx context.Context) error { } defer service.Close() // Stop the service - _, err = service.Control(svc.Stop) - if err != nil { - return errors.Wrap(err, "could not stop service") + log.Printf("Stopping HNS service") + _ = retry.Do( + tryStopServiceFn(ctx, service), + retry.UntilSucceeded(), + retry.Context(ctx), + retry.DelayType(retry.BackOffDelay), + ) + // Start the service again + log.Printf("Starting HNS service") + _ = retry.Do( + tryStartServiceFn(ctx, service), + retry.UntilSucceeded(), + retry.Context(ctx), + retry.DelayType(retry.BackOffDelay), + ) + log.Printf("HNS service started") + return nil +} + +type managedService interface { + Control(control svc.Cmd) (svc.Status, error) + Query() (svc.Status, error) + Start(args ...string) error +} + +func tryStartServiceFn(ctx context.Context, service managedService) func() error { + shouldStart := func(state svc.State) bool { + return !(state == svc.Running || state == svc.StartPending) } - // Wait for the service to stop - ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms - defer ticker.Stop() - for { // hacky cancellable do-while + return func() error { status, err := service.Query() if err != nil { return errors.Wrap(err, "could not query service status") } - if status.State == svc.Stopped { - break + if shouldStart(status.State) { + err = service.Start() + if err != nil { + return errors.Wrap(err, "could not start service") + } } - select { - case <-ctx.Done(): - return errors.New("context cancelled") - case <-ticker.C: + // Wait for the service to start + deadline, cancel := context.WithTimeout(ctx, 90*time.Second) + defer cancel() + ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms + defer ticker.Stop() + for { + status, err := service.Query() + if err != nil { + return errors.Wrap(err, "could not query service status") + } + if status.State == svc.Running { + log.Printf("service started") + break + } + select { + case <-deadline.Done(): + return deadline.Err() //nolint:wrapcheck // error has sufficient context + case <-ticker.C: + } } + return nil } - // Start the service again - if err := service.Start(); err != nil { - return errors.Wrap(err, "could not start service") +} + +func tryStopServiceFn(ctx context.Context, service managedService) func() error { + shouldStop := func(state svc.State) bool { + return !(state == svc.Stopped || state == svc.StopPending) + } + return func() error { + status, err := service.Query() + if err != nil { + return errors.Wrap(err, "could not query service status") + } + if shouldStop(status.State) { + _, err = service.Control(svc.Stop) + if err != nil { + return errors.Wrap(err, "could not stop service") + } + } + // Wait for the service to stop + deadline, cancel := context.WithTimeout(ctx, 90*time.Second) + defer cancel() + ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms + defer ticker.Stop() + for { + status, err := service.Query() + if err != nil { + return errors.Wrap(err, "could not query service status") + } + if status.State == svc.Stopped { + log.Printf("service stopped") + break + } + select { + case <-deadline.Done(): + return deadline.Err() //nolint:wrapcheck // error has sufficient context + case <-ticker.C: + } + } + return nil } - return nil } func HasMellanoxAdapter() bool { diff --git a/platform/os_windows_test.go b/platform/os_windows_test.go index a6e44d2fa5..435e2ca516 100644 --- a/platform/os_windows_test.go +++ b/platform/os_windows_test.go @@ -1,6 +1,7 @@ package platform import ( + "context" "errors" "os/exec" "testing" @@ -9,6 +10,7 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sys/windows/svc" ) var errTestFailure = errors.New("test failure") @@ -98,3 +100,205 @@ func TestExecuteCommandError(t *testing.T) { assert.ErrorAs(t, err, &xErr) assert.Equal(t, 1, xErr.ExitCode()) } + +type mockManagedService struct { + queryFuncs []func() (svc.Status, error) + controlFunc func(svc.Cmd) (svc.Status, error) + startFunc func(args ...string) error +} + +func (m *mockManagedService) Query() (svc.Status, error) { + queryFunc := m.queryFuncs[0] + m.queryFuncs = m.queryFuncs[1:] + return queryFunc() +} + +func (m *mockManagedService) Control(cmd svc.Cmd) (svc.Status, error) { + return m.controlFunc(cmd) +} + +func (m *mockManagedService) Start(args ...string) error { + return m.startFunc(args...) +} + +func TestTryStopServiceFn(t *testing.T) { + tests := []struct { + name string + queryFuncs []func() (svc.Status, error) + controlFunc func(svc.Cmd) (svc.Status, error) + expectError bool + }{ + { + name: "Service already stopped", + queryFuncs: []func() (svc.Status, error){ + func() (svc.Status, error) { + return svc.Status{State: svc.Stopped}, nil + }, + func() (svc.Status, error) { + return svc.Status{State: svc.Stopped}, nil + }, + }, + controlFunc: nil, + expectError: false, + }, + { + name: "Service running and stops successfully", + queryFuncs: []func() (svc.Status, error){ + func() (svc.Status, error) { + return svc.Status{State: svc.Running}, nil + }, + func() (svc.Status, error) { + return svc.Status{State: svc.Stopped}, nil + }, + }, + controlFunc: func(svc.Cmd) (svc.Status, error) { + return svc.Status{State: svc.Stopped}, nil + }, + expectError: false, + }, + { + name: "Service running and stops after multiple attempts", + queryFuncs: []func() (svc.Status, error){ + func() (svc.Status, error) { + return svc.Status{State: svc.Running}, nil + }, + func() (svc.Status, error) { + return svc.Status{State: svc.Running}, nil + }, + func() (svc.Status, error) { + return svc.Status{State: svc.Running}, nil + }, + func() (svc.Status, error) { + return svc.Status{State: svc.Stopped}, nil + }, + }, + controlFunc: func(svc.Cmd) (svc.Status, error) { + return svc.Status{State: svc.Stopped}, nil + }, + expectError: false, + }, + { + name: "Service running and fails to stop", + queryFuncs: []func() (svc.Status, error){ + func() (svc.Status, error) { + return svc.Status{State: svc.Running}, nil + }, + }, + controlFunc: func(svc.Cmd) (svc.Status, error) { + return svc.Status{State: svc.Running}, errors.New("failed to stop service") //nolint:err113 // test error + }, + expectError: true, + }, + { + name: "Service query fails", + queryFuncs: []func() (svc.Status, error){ + func() (svc.Status, error) { + return svc.Status{}, errors.New("failed to query service status") //nolint:err113 // test error + }, + }, + controlFunc: nil, + expectError: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + service := &mockManagedService{ + queryFuncs: tt.queryFuncs, + controlFunc: tt.controlFunc, + } + err := tryStopServiceFn(context.Background(), service)() + if tt.expectError { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func TestTryStartServiceFn(t *testing.T) { + tests := []struct { + name string + queryFuncs []func() (svc.Status, error) + startFunc func(...string) error + expectError bool + }{ + { + name: "Service already running", + queryFuncs: []func() (svc.Status, error){ + func() (svc.Status, error) { + return svc.Status{State: svc.Running}, nil + }, + func() (svc.Status, error) { + return svc.Status{State: svc.Running}, nil + }, + }, + startFunc: nil, + expectError: false, + }, + { + name: "Service already starting", + queryFuncs: []func() (svc.Status, error){ + func() (svc.Status, error) { + return svc.Status{State: svc.StartPending}, nil + }, + func() (svc.Status, error) { + return svc.Status{State: svc.Running}, nil + }, + }, + startFunc: nil, + expectError: false, + }, + { + name: "Service starts successfully", + queryFuncs: []func() (svc.Status, error){ + func() (svc.Status, error) { + return svc.Status{State: svc.Stopped}, nil + }, + func() (svc.Status, error) { + return svc.Status{State: svc.Running}, nil + }, + }, + startFunc: func(...string) error { + return nil + }, + expectError: false, + }, + { + name: "Service fails to start", + queryFuncs: []func() (svc.Status, error){ + func() (svc.Status, error) { + return svc.Status{State: svc.Stopped}, nil + }, + }, + startFunc: func(...string) error { + return errors.New("failed to start service") //nolint:err113 // test error + }, + expectError: true, + }, + { + name: "Service query fails", + queryFuncs: []func() (svc.Status, error){ + func() (svc.Status, error) { + return svc.Status{}, errors.New("failed to query service status") //nolint:err113 // test error + }, + }, + startFunc: nil, + expectError: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + service := &mockManagedService{ + queryFuncs: tt.queryFuncs, + startFunc: tt.startFunc, + } + err := tryStartServiceFn(context.Background(), service)() + if tt.expectError { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +}