From 070bf70597abb4df0af2e6b13125981853f44f90 Mon Sep 17 00:00:00 2001 From: Alexis MacAskill Date: Mon, 22 May 2023 21:02:24 +0000 Subject: [PATCH] filter out context errors --- pkg/gce-pd-csi-driver/utils.go | 37 +++++++++++++++-- pkg/gce-pd-csi-driver/utils_test.go | 61 +++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/pkg/gce-pd-csi-driver/utils.go b/pkg/gce-pd-csi-driver/utils.go index c81fc5be3..e2f4f70ce 100644 --- a/pkg/gce-pd-csi-driver/utils.go +++ b/pkg/gce-pd-csi-driver/utils.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "net/http" + "strings" "context" @@ -219,14 +220,21 @@ func containsZone(zones []string, zone string) bool { } // CodeForError returns a pointer to the grpc error code that maps to the http -// error code for the passed in user googleapi error. Returns codes.Internal if -// the given error is not a googleapi error caused by the user. The following -// http error codes are considered user errors: +// error code for the passed in user googleapi error or context error. Returns +// codes.Internal if the given error is not a googleapi error caused by the user. +// The following http error codes are considered user errors: // (1) http 400 Bad Request, returns grpc InvalidArgument, // (2) http 403 Forbidden, returns grpc PermissionDenied, // (3) http 404 Not Found, returns grpc NotFound // (4) http 429 Too Many Requests, returns grpc ResourceExhausted +// The following errors are considered context errors: +// (1) "context deadline exceeded", returns grpc DeadlineExceeded, +// (2) "context canceled", returns grpc Canceled func CodeForError(err error) *codes.Code { + if code := isContextError(err); code != nil { + return code + } + internalErrorCode := codes.Internal // Upwrap the error var apiErr *googleapi.Error @@ -243,9 +251,32 @@ func CodeForError(err error) *codes.Code { if code, ok := userErrors[apiErr.Code]; ok { return &code } + return &internalErrorCode } +// isContextError returns a pointer to the grpc error code DeadlineExceeded +// if the passed in error contains the "context deadline exceeded" string and returns +// the grpc error code Canceled if the error contains the "context canceled" string. +func isContextError(err error) *codes.Code { + if err == nil { + return nil + } + + errStr := err.Error() + if strings.Contains(errStr, context.DeadlineExceeded.Error()) { + return errCodePtr(codes.DeadlineExceeded) + } + if strings.Contains(errStr, context.Canceled.Error()) { + return errCodePtr(codes.Canceled) + } + return nil +} + +func errCodePtr(code codes.Code) *codes.Code { + return &code +} + func LoggedError(msg string, err error) error { klog.Errorf(msg+"%v", err.Error()) return status.Errorf(*CodeForError(err), msg+"%v", err.Error()) diff --git a/pkg/gce-pd-csi-driver/utils_test.go b/pkg/gce-pd-csi-driver/utils_test.go index ab92e45ee..406012c5d 100644 --- a/pkg/gce-pd-csi-driver/utils_test.go +++ b/pkg/gce-pd-csi-driver/utils_test.go @@ -18,7 +18,9 @@ limitations under the License. package gceGCEDriver import ( + "context" "errors" + "fmt" "net/http" "testing" @@ -319,6 +321,16 @@ func TestCodeForError(t *testing.T) { inputErr: &googleapi.Error{Code: http.StatusInternalServerError, Message: "Internal error"}, expCode: &internalErrorCode, }, + { + name: "context canceled error", + inputErr: context.Canceled, + expCode: errCodePtr(codes.Canceled), + }, + { + name: "context deadline exceeded error", + inputErr: context.DeadlineExceeded, + expCode: errCodePtr(codes.DeadlineExceeded), + }, } for _, tc := range testCases { @@ -329,3 +341,52 @@ func TestCodeForError(t *testing.T) { } } } + +func TestIsContextError(t *testing.T) { + cases := []struct { + name string + err error + expectedErrCode *codes.Code + }{ + { + name: "deadline exceeded error", + err: context.DeadlineExceeded, + expectedErrCode: errCodePtr(codes.DeadlineExceeded), + }, + { + name: "contains 'context deadline exceeded'", + err: fmt.Errorf("got error: %w", context.DeadlineExceeded), + expectedErrCode: errCodePtr(codes.DeadlineExceeded), + }, + { + name: "context canceled error", + err: context.Canceled, + expectedErrCode: errCodePtr(codes.Canceled), + }, + { + name: "contains 'context canceled'", + err: fmt.Errorf("got error: %w", context.Canceled), + expectedErrCode: errCodePtr(codes.Canceled), + }, + { + name: "does not contain 'context canceled' or 'context deadline exceeded'", + err: fmt.Errorf("unknown error"), + expectedErrCode: nil, + }, + { + name: "nil error", + err: nil, + expectedErrCode: nil, + }, + } + + for _, test := range cases { + errCode := isContextError(test.err) + if (test.expectedErrCode == nil) != (errCode == nil) { + t.Errorf("test %v failed: got %v, expected %v", test.name, errCode, test.expectedErrCode) + } + if test.expectedErrCode != nil && *errCode != *test.expectedErrCode { + t.Errorf("test %v failed: got %v, expected %v", test.name, errCode, test.expectedErrCode) + } + } +}