From 41db77fdda6a1eed826faae37b5251da6c4288ec Mon Sep 17 00:00:00 2001 From: Alexis MacAskill Date: Wed, 24 May 2023 18:37:38 +0000 Subject: [PATCH] Use errors.As so we can detect wrapped errors, and check for existing error codes in CodesForError --- pkg/gce-pd-csi-driver/utils.go | 85 +++++++++++++++++++++ pkg/gce-pd-csi-driver/utils_test.go | 113 ++++++++++++++++++++++++++++ 2 files changed, 198 insertions(+) diff --git a/pkg/gce-pd-csi-driver/utils.go b/pkg/gce-pd-csi-driver/utils.go index d8ef8ce0b..d1802fd82 100644 --- a/pkg/gce-pd-csi-driver/utils.go +++ b/pkg/gce-pd-csi-driver/utils.go @@ -20,9 +20,14 @@ import ( "context" "errors" "fmt" + "net/http" + "strings" csi "github.com/container-storage-interface/spec/lib/go/csi" + "google.golang.org/api/googleapi" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "k8s.io/klog/v2" ) @@ -212,3 +217,83 @@ func containsZone(zones []string, zone string) bool { return false } + +// CodeForError returns a pointer to the grpc error code that maps to the http +// error code for the passed in user googleapi error or context error. Returns +// codes.Internal if the given error is not a googleapi error caused by the user. +// The following http error codes are considered user errors: +// (1) http 400 Bad Request, returns grpc InvalidArgument, +// (2) http 403 Forbidden, returns grpc PermissionDenied, +// (3) http 404 Not Found, returns grpc NotFound +// (4) http 429 Too Many Requests, returns grpc ResourceExhausted +// The following errors are considered context errors: +// (1) "context deadline exceeded", returns grpc DeadlineExceeded, +// (2) "context canceled", returns grpc Canceled +func CodeForError(err error) *codes.Code { + if err == nil { + return nil + } + + if errCode := existingErrorCode(err); errCode != nil { + return errCode + } + if code := isContextError(err); code != nil { + return code + } + + internalErrorCode := codes.Internal + // Upwrap the error + var apiErr *googleapi.Error + if !errors.As(err, &apiErr) { + return &internalErrorCode + } + + userErrors := map[int]codes.Code{ + http.StatusForbidden: codes.PermissionDenied, + http.StatusBadRequest: codes.InvalidArgument, + http.StatusTooManyRequests: codes.ResourceExhausted, + http.StatusNotFound: codes.NotFound, + } + if code, ok := userErrors[apiErr.Code]; ok { + return &code + } + + return &internalErrorCode +} + +func existingErrorCode(err error) *codes.Code { + if err == nil { + return nil + } + if status, ok := status.FromError(err); ok { + return errCodePtr(status.Code()) + } + return nil +} + +// isContextError returns a pointer to the grpc error code DeadlineExceeded +// if the passed in error contains the "context deadline exceeded" string and returns +// the grpc error code Canceled if the error contains the "context canceled" string. +func isContextError(err error) *codes.Code { + if err == nil { + return nil + } + + errStr := err.Error() + if strings.Contains(errStr, context.DeadlineExceeded.Error()) { + return errCodePtr(codes.DeadlineExceeded) + } + if strings.Contains(errStr, context.Canceled.Error()) { + return errCodePtr(codes.Canceled) + } + return nil +} + +func errCodePtr(code codes.Code) *codes.Code { + return &code +} + +func LoggedError(msg string, err error) error { + klog.Errorf(msg+"%v", err.Error()) + return status.Errorf(*CodeForError(err), msg+"%v", err.Error()) +} diff --git a/pkg/gce-pd-csi-driver/utils_test.go b/pkg/gce-pd-csi-driver/utils_test.go index 2a0fd5f45..048810f2f 100644 --- a/pkg/gce-pd-csi-driver/utils_test.go +++ b/pkg/gce-pd-csi-driver/utils_test.go @@ -18,9 +18,16 @@ limitations under the License. package gceGCEDriver import ( + "context" + "errors" + "fmt" + "net/http" "testing" csi "github.com/container-storage-interface/spec/lib/go/csi" + "google.golang.org/api/googleapi" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) var ( @@ -291,3 +298,109 @@ func TestGetReadOnlyFromCapabilities(t *testing.T) { } } } + +func TestCodeForError(t *testing.T) { + internalErrorCode := codes.Internal + userErrorCode := codes.InvalidArgument + testCases := []struct { + name string + inputErr error + expCode *codes.Code + }{ + { + name: "Not googleapi.Error", + inputErr: errors.New("I am not a googleapi.Error"), + expCode: &internalErrorCode, + }, + { + name: "User error", + inputErr: &googleapi.Error{Code: http.StatusBadRequest, Message: "User error with bad request"}, + expCode: &userErrorCode, + }, + { + name: "googleapi.Error but not a user error", + inputErr: &googleapi.Error{Code: http.StatusInternalServerError, Message: "Internal error"}, + expCode: &internalErrorCode, + }, + { + name: "context canceled error", + inputErr: context.Canceled, + expCode: errCodePtr(codes.Canceled), + }, + { + name: "context deadline exceeded error", + inputErr: context.DeadlineExceeded, + expCode: errCodePtr(codes.DeadlineExceeded), + }, + { + name: "status error with Aborted error code", + inputErr: status.Error(codes.Aborted, "aborted error"), + expCode: errCodePtr(codes.Aborted), + }, + { + name: "nil error", + inputErr: nil, + expCode: nil, + }, + } + + for _, tc := range testCases { + t.Logf("Running test: %v", tc.name) + errCode := CodeForError(tc.inputErr) + if (tc.expCode == nil) != (errCode == nil) { + t.Errorf("test %v failed: got %v, expected %v", tc.name, errCode, tc.expCode) + } + if tc.expCode != nil && *errCode != *tc.expCode { + t.Errorf("test %v failed: got %v, expected %v", tc.name, errCode, tc.expCode) + } + } +} + +func TestIsContextError(t *testing.T) { + cases := []struct { + name string + err error + expectedErrCode *codes.Code + }{ + { + name: "deadline exceeded error", + err: context.DeadlineExceeded, + expectedErrCode: errCodePtr(codes.DeadlineExceeded), + }, + { + name: "contains 'context deadline exceeded'", + err: fmt.Errorf("got error: %w", context.DeadlineExceeded), + expectedErrCode: errCodePtr(codes.DeadlineExceeded), + }, + { + name: "context canceled error", + err: context.Canceled, + expectedErrCode: errCodePtr(codes.Canceled), + }, + { + name: "contains 'context canceled'", + err: fmt.Errorf("got error: %w", context.Canceled), + expectedErrCode: errCodePtr(codes.Canceled), + }, + { + name: "does not contain 'context canceled' or 'context deadline exceeded'", + err: fmt.Errorf("unknown error"), + expectedErrCode: nil, + }, + { + name: "nil error", + err: nil, + expectedErrCode: nil, + }, + } + + for _, test := range cases { + errCode := isContextError(test.err) + if (test.expectedErrCode == nil) != (errCode == nil) { + t.Errorf("test %v failed: got %v, expected %v", test.name, errCode, test.expectedErrCode) + } + if test.expectedErrCode != nil && *errCode != *test.expectedErrCode { + t.Errorf("test %v failed: got %v, expected %v", test.name, errCode, test.expectedErrCode) + } + } +}