@@ -18,7 +18,9 @@ limitations under the License.
18
18
package gceGCEDriver
19
19
20
20
import (
21
+ "context"
21
22
"errors"
23
+ "fmt"
22
24
"net/http"
23
25
"testing"
24
26
@@ -319,6 +321,16 @@ func TestCodeForError(t *testing.T) {
319
321
inputErr : & googleapi.Error {Code : http .StatusInternalServerError , Message : "Internal error" },
320
322
expCode : & internalErrorCode ,
321
323
},
324
+ {
325
+ name : "context canceled error" ,
326
+ inputErr : context .Canceled ,
327
+ expCode : errCodePtr (codes .Canceled ),
328
+ },
329
+ {
330
+ name : "context deadline exceeded error" ,
331
+ inputErr : context .DeadlineExceeded ,
332
+ expCode : errCodePtr (codes .DeadlineExceeded ),
333
+ },
322
334
}
323
335
324
336
for _ , tc := range testCases {
@@ -329,3 +341,52 @@ func TestCodeForError(t *testing.T) {
329
341
}
330
342
}
331
343
}
344
+
345
+ func TestIsContextError (t * testing.T ) {
346
+ cases := []struct {
347
+ name string
348
+ err error
349
+ expectedErrCode * codes.Code
350
+ }{
351
+ {
352
+ name : "deadline exceeded error" ,
353
+ err : context .DeadlineExceeded ,
354
+ expectedErrCode : errCodePtr (codes .DeadlineExceeded ),
355
+ },
356
+ {
357
+ name : "contains 'context deadline exceeded'" ,
358
+ err : fmt .Errorf ("got error: %w" , context .DeadlineExceeded ),
359
+ expectedErrCode : errCodePtr (codes .DeadlineExceeded ),
360
+ },
361
+ {
362
+ name : "context canceled error" ,
363
+ err : context .Canceled ,
364
+ expectedErrCode : errCodePtr (codes .Canceled ),
365
+ },
366
+ {
367
+ name : "contains 'context canceled'" ,
368
+ err : fmt .Errorf ("got error: %w" , context .Canceled ),
369
+ expectedErrCode : errCodePtr (codes .Canceled ),
370
+ },
371
+ {
372
+ name : "does not contain 'context canceled' or 'context deadline exceeded'" ,
373
+ err : fmt .Errorf ("unknown error" ),
374
+ expectedErrCode : nil ,
375
+ },
376
+ {
377
+ name : "nil error" ,
378
+ err : nil ,
379
+ expectedErrCode : nil ,
380
+ },
381
+ }
382
+
383
+ for _ , test := range cases {
384
+ errCode := isContextError (test .err )
385
+ if (test .expectedErrCode == nil ) != (errCode == nil ) {
386
+ t .Errorf ("test %v failed: got %v, expected %v" , test .name , errCode , test .expectedErrCode )
387
+ }
388
+ if test .expectedErrCode != nil && * errCode != * test .expectedErrCode {
389
+ t .Errorf ("test %v failed: got %v, expected %v" , test .name , errCode , test .expectedErrCode )
390
+ }
391
+ }
392
+ }
0 commit comments