@@ -11,6 +11,7 @@ import (
11
11
"github.com/golang/mock/gomock"
12
12
"github.com/stretchr/testify/assert"
13
13
"github.com/stretchr/testify/require"
14
+ "golang.org/x/sys/windows/svc"
14
15
)
15
16
16
17
var errTestFailure = errors .New ("test failure" )
@@ -146,3 +147,205 @@ func TestExecuteCommandTimeout(t *testing.T) {
146
147
_ , err := client .ExecuteCommand (context .Background (), "ping" , "-t" , "localhost" )
147
148
require .Error (t , err )
148
149
}
150
+
151
+ type mockManagedService struct {
152
+ queryFuncs []func () (svc.Status , error )
153
+ controlFunc func (svc.Cmd ) (svc.Status , error )
154
+ startFunc func (args ... string ) error
155
+ }
156
+
157
+ func (m * mockManagedService ) Query () (svc.Status , error ) {
158
+ queryFunc := m .queryFuncs [0 ]
159
+ m .queryFuncs = m .queryFuncs [1 :]
160
+ return queryFunc ()
161
+ }
162
+
163
+ func (m * mockManagedService ) Control (cmd svc.Cmd ) (svc.Status , error ) {
164
+ return m .controlFunc (cmd )
165
+ }
166
+
167
+ func (m * mockManagedService ) Start (args ... string ) error {
168
+ return m .startFunc (args ... )
169
+ }
170
+
171
+ func TestTryStopServiceFn (t * testing.T ) {
172
+ tests := []struct {
173
+ name string
174
+ queryFuncs []func () (svc.Status , error )
175
+ controlFunc func (svc.Cmd ) (svc.Status , error )
176
+ expectError bool
177
+ }{
178
+ {
179
+ name : "Service already stopped" ,
180
+ queryFuncs : []func () (svc.Status , error ){
181
+ func () (svc.Status , error ) {
182
+ return svc.Status {State : svc .Stopped }, nil
183
+ },
184
+ func () (svc.Status , error ) {
185
+ return svc.Status {State : svc .Stopped }, nil
186
+ },
187
+ },
188
+ controlFunc : nil ,
189
+ expectError : false ,
190
+ },
191
+ {
192
+ name : "Service running and stops successfully" ,
193
+ queryFuncs : []func () (svc.Status , error ){
194
+ func () (svc.Status , error ) {
195
+ return svc.Status {State : svc .Running }, nil
196
+ },
197
+ func () (svc.Status , error ) {
198
+ return svc.Status {State : svc .Stopped }, nil
199
+ },
200
+ },
201
+ controlFunc : func (svc.Cmd ) (svc.Status , error ) {
202
+ return svc.Status {State : svc .Stopped }, nil
203
+ },
204
+ expectError : false ,
205
+ },
206
+ {
207
+ name : "Service running and stops after multiple attempts" ,
208
+ queryFuncs : []func () (svc.Status , error ){
209
+ func () (svc.Status , error ) {
210
+ return svc.Status {State : svc .Running }, nil
211
+ },
212
+ func () (svc.Status , error ) {
213
+ return svc.Status {State : svc .Running }, nil
214
+ },
215
+ func () (svc.Status , error ) {
216
+ return svc.Status {State : svc .Running }, nil
217
+ },
218
+ func () (svc.Status , error ) {
219
+ return svc.Status {State : svc .Stopped }, nil
220
+ },
221
+ },
222
+ controlFunc : func (svc.Cmd ) (svc.Status , error ) {
223
+ return svc.Status {State : svc .Stopped }, nil
224
+ },
225
+ expectError : false ,
226
+ },
227
+ {
228
+ name : "Service running and fails to stop" ,
229
+ queryFuncs : []func () (svc.Status , error ){
230
+ func () (svc.Status , error ) {
231
+ return svc.Status {State : svc .Running }, nil
232
+ },
233
+ },
234
+ controlFunc : func (svc.Cmd ) (svc.Status , error ) {
235
+ return svc.Status {State : svc .Running }, errors .New ("failed to stop service" ) //nolint:err113 // test error
236
+ },
237
+ expectError : true ,
238
+ },
239
+ {
240
+ name : "Service query fails" ,
241
+ queryFuncs : []func () (svc.Status , error ){
242
+ func () (svc.Status , error ) {
243
+ return svc.Status {}, errors .New ("failed to query service status" ) //nolint:err113 // test error
244
+ },
245
+ },
246
+ controlFunc : nil ,
247
+ expectError : true ,
248
+ },
249
+ }
250
+ for _ , tt := range tests {
251
+ t .Run (tt .name , func (t * testing.T ) {
252
+ service := & mockManagedService {
253
+ queryFuncs : tt .queryFuncs ,
254
+ controlFunc : tt .controlFunc ,
255
+ }
256
+ err := tryStopServiceFn (context .Background (), service )()
257
+ if tt .expectError {
258
+ assert .Error (t , err )
259
+ return
260
+ }
261
+ assert .NoError (t , err )
262
+ })
263
+ }
264
+ }
265
+
266
+ func TestTryStartServiceFn (t * testing.T ) {
267
+ tests := []struct {
268
+ name string
269
+ queryFuncs []func () (svc.Status , error )
270
+ startFunc func (... string ) error
271
+ expectError bool
272
+ }{
273
+ {
274
+ name : "Service already running" ,
275
+ queryFuncs : []func () (svc.Status , error ){
276
+ func () (svc.Status , error ) {
277
+ return svc.Status {State : svc .Running }, nil
278
+ },
279
+ func () (svc.Status , error ) {
280
+ return svc.Status {State : svc .Running }, nil
281
+ },
282
+ },
283
+ startFunc : nil ,
284
+ expectError : false ,
285
+ },
286
+ {
287
+ name : "Service already starting" ,
288
+ queryFuncs : []func () (svc.Status , error ){
289
+ func () (svc.Status , error ) {
290
+ return svc.Status {State : svc .StartPending }, nil
291
+ },
292
+ func () (svc.Status , error ) {
293
+ return svc.Status {State : svc .Running }, nil
294
+ },
295
+ },
296
+ startFunc : nil ,
297
+ expectError : false ,
298
+ },
299
+ {
300
+ name : "Service starts successfully" ,
301
+ queryFuncs : []func () (svc.Status , error ){
302
+ func () (svc.Status , error ) {
303
+ return svc.Status {State : svc .Stopped }, nil
304
+ },
305
+ func () (svc.Status , error ) {
306
+ return svc.Status {State : svc .Running }, nil
307
+ },
308
+ },
309
+ startFunc : func (... string ) error {
310
+ return nil
311
+ },
312
+ expectError : false ,
313
+ },
314
+ {
315
+ name : "Service fails to start" ,
316
+ queryFuncs : []func () (svc.Status , error ){
317
+ func () (svc.Status , error ) {
318
+ return svc.Status {State : svc .Stopped }, nil
319
+ },
320
+ },
321
+ startFunc : func (... string ) error {
322
+ return errors .New ("failed to start service" ) //nolint:err113 // test error
323
+ },
324
+ expectError : true ,
325
+ },
326
+ {
327
+ name : "Service query fails" ,
328
+ queryFuncs : []func () (svc.Status , error ){
329
+ func () (svc.Status , error ) {
330
+ return svc.Status {}, errors .New ("failed to query service status" ) //nolint:err113 // test error
331
+ },
332
+ },
333
+ startFunc : nil ,
334
+ expectError : true ,
335
+ },
336
+ }
337
+ for _ , tt := range tests {
338
+ t .Run (tt .name , func (t * testing.T ) {
339
+ service := & mockManagedService {
340
+ queryFuncs : tt .queryFuncs ,
341
+ startFunc : tt .startFunc ,
342
+ }
343
+ err := tryStartServiceFn (context .Background (), service )()
344
+ if tt .expectError {
345
+ assert .Error (t , err )
346
+ return
347
+ }
348
+ assert .NoError (t , err )
349
+ })
350
+ }
351
+ }
0 commit comments