diff --git a/pkg/gce-cloud-provider/compute/gce.go b/pkg/gce-cloud-provider/compute/gce.go index 57a2b569b..1d9cb051d 100644 --- a/pkg/gce-cloud-provider/compute/gce.go +++ b/pkg/gce-cloud-provider/compute/gce.go @@ -16,6 +16,7 @@ package gcecloudprovider import ( "context" + "errors" "fmt" "net/http" "os" @@ -248,8 +249,8 @@ func getProjectAndZone(config *ConfigFile) (string, string, error) { // isGCEError returns true if given error is a googleapi.Error with given // reason (e.g. "resourceInUseByAnotherResource") func IsGCEError(err error, reason string) bool { - apiErr, ok := err.(*googleapi.Error) - if !ok { + var apiErr *googleapi.Error + if !errors.As(err, &apiErr) { return false } diff --git a/pkg/gce-cloud-provider/compute/gce_test.go b/pkg/gce-cloud-provider/compute/gce_test.go new file mode 100644 index 000000000..5bb2aed89 --- /dev/null +++ b/pkg/gce-cloud-provider/compute/gce_test.go @@ -0,0 +1,86 @@ +/* +Copyright 2023 The Kubernetes Authors. + + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gcecloudprovider + +import ( + "errors" + "fmt" + "net/http" + "testing" + + "google.golang.org/api/googleapi" +) + +func TestIsGCEError(t *testing.T) { + testCases := []struct { + name string + inputErr error + reason string + expIsGCEError bool + }{ + { + name: "Not googleapi.Error", + inputErr: errors.New("I am not a googleapi.Error"), + reason: "notFound", + expIsGCEError: false, + }, + { + name: "googleapi.Error not found error", + inputErr: &googleapi.Error{ + Code: http.StatusNotFound, + Errors: []googleapi.ErrorItem{ + { + Reason: "notFound", + }, + }, + Message: "Not found", + }, + reason: "notFound", + expIsGCEError: true, + }, + { + name: "wrapped googleapi.Error", + inputErr: fmt.Errorf("encountered not found: %w", &googleapi.Error{ + Code: http.StatusNotFound, + Errors: []googleapi.ErrorItem{ + { + Reason: "notFound", + }, + }, + Message: "Not found", + }, + ), + reason: "notFound", + expIsGCEError: true, + }, + { + name: "nil error", + inputErr: nil, + reason: "notFound", + expIsGCEError: false, + }, + } + + for _, tc := range testCases { + t.Logf("Running test: %v", tc.name) + isGCEError := IsGCEError(tc.inputErr, tc.reason) + if tc.expIsGCEError != isGCEError { + t.Fatalf("Got isGCEError '%t', expected '%t'", isGCEError, tc.expIsGCEError) + } + } +} diff --git a/pkg/gce-pd-csi-driver/utils.go b/pkg/gce-pd-csi-driver/utils.go index e2f4f70ce..cfea8e0b4 100644 --- a/pkg/gce-pd-csi-driver/utils.go +++ b/pkg/gce-pd-csi-driver/utils.go @@ -231,6 +231,13 @@ func containsZone(zones []string, zone string) bool { // (1) "context deadline exceeded", returns grpc DeadlineExceeded, // (2) "context canceled", returns grpc Canceled func CodeForError(err error) *codes.Code { + if err == nil { + return nil + } + + if errCode := existingErrorCode(err); errCode != nil { + return errCode + } if code := isContextError(err); code != nil { return code } @@ -255,6 +262,16 @@ func CodeForError(err error) *codes.Code { return &internalErrorCode } +func existingErrorCode(err error) *codes.Code { + if err == nil { + return nil + } + if status, ok := status.FromError(err); ok { + return errCodePtr(status.Code()) + } + return nil +} + // isContextError returns a pointer to the grpc error code DeadlineExceeded // if the passed in error contains the "context deadline exceeded" string and returns // the grpc error code Canceled if the error contains the "context canceled" string. diff --git a/pkg/gce-pd-csi-driver/utils_test.go b/pkg/gce-pd-csi-driver/utils_test.go index 406012c5d..048810f2f 100644 --- a/pkg/gce-pd-csi-driver/utils_test.go +++ b/pkg/gce-pd-csi-driver/utils_test.go @@ -27,6 +27,7 @@ import ( csi "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/api/googleapi" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) var ( @@ -331,13 +332,26 @@ func TestCodeForError(t *testing.T) { inputErr: context.DeadlineExceeded, expCode: errCodePtr(codes.DeadlineExceeded), }, + { + name: "status error with Aborted error code", + inputErr: status.Error(codes.Aborted, "aborted error"), + expCode: errCodePtr(codes.Aborted), + }, + { + name: "nil error", + inputErr: nil, + expCode: nil, + }, } for _, tc := range testCases { t.Logf("Running test: %v", tc.name) - actualCode := *CodeForError(tc.inputErr) - if *tc.expCode != actualCode { - t.Fatalf("Expected error code '%v' but got '%v'", tc.expCode, actualCode) + errCode := CodeForError(tc.inputErr) + if (tc.expCode == nil) != (errCode == nil) { + t.Errorf("test %v failed: got %v, expected %v", tc.name, errCode, tc.expCode) + } + if tc.expCode != nil && *errCode != *tc.expCode { + t.Errorf("test %v failed: got %v, expected %v", tc.name, errCode, tc.expCode) } } } diff --git a/test/remote/instance.go b/test/remote/instance.go index c716dcef3..8aa5b2fd2 100644 --- a/test/remote/instance.go +++ b/test/remote/instance.go @@ -415,8 +415,8 @@ func generateMetadataWithPublicKey(pubKeyFile string) (*compute.Metadata, error) // isGCEError returns true if given error is a googleapi.Error with given // reason (e.g. "resourceInUseByAnotherResource") func isGCEError(err error, reason string) bool { - apiErr, ok := err.(*googleapi.Error) - if !ok { + var apiErr *googleapi.Error + if !errors.As(err, &apiErr) { return false }