From 3a4debe61687856931402375b91d62d31d24a1df Mon Sep 17 00:00:00 2001 From: Rob Rati Date: Tue, 12 Jan 2021 19:47:09 -0500 Subject: [PATCH] Fix handling of nil parameters in B and BA --- pkg/controller/bucket/bucket_controller.go | 12 ++- .../bucket/bucket_controller_test.go | 53 ++++++++++++- .../bucketaccess/bucket_access_controller.go | 12 ++- .../bucket_access_controller_test.go | 79 +++++++++++++++++++ 4 files changed, 150 insertions(+), 6 deletions(-) diff --git a/pkg/controller/bucket/bucket_controller.go b/pkg/controller/bucket/bucket_controller.go index 21ded69..a55328c 100644 --- a/pkg/controller/bucket/bucket_controller.go +++ b/pkg/controller/bucket/bucket_controller.go @@ -102,7 +102,7 @@ func (bl *bucketListener) Add(ctx context.Context, obj *v1alpha1.Bucket) error { req := osspec.ProvisionerCreateBucketRequest{ BucketName: obj.Name, - BucketContext: obj.Spec.Parameters, + BucketContext: bl.getParams(obj), } req.BucketContext["ProtocolVersion"] = obj.Spec.Protocol.Version @@ -147,7 +147,7 @@ func (bl *bucketListener) Delete(ctx context.Context, obj *v1alpha1.Bucket) erro } req := osspec.ProvisionerDeleteBucketRequest{ - BucketContext: obj.Spec.Parameters, + BucketContext: bl.getParams(obj), } switch obj.Spec.Protocol.Name { @@ -198,3 +198,11 @@ func (bl *bucketListener) updateStatus(ctx context.Context, name, msg string, st }) return err } + +func (bl *bucketListener) getParams(obj *v1alpha1.Bucket) map[string]string { + params := map[string]string{} + if obj.Spec.Parameters != nil { + params = obj.Spec.Parameters + } + return params +} diff --git a/pkg/controller/bucket/bucket_controller_test.go b/pkg/controller/bucket/bucket_controller_test.go index 219ed1b..5df8ca4 100644 --- a/pkg/controller/bucket/bucket_controller_test.go +++ b/pkg/controller/bucket/bucket_controller_test.go @@ -212,6 +212,20 @@ func TestAddValidProtocols(t *testing.T) { "AnonymousAccessMode": anonAccess, }, }, + { + name: "Empty parameters", + protocolName: v1alpha1.ProtocolNameS3, + createFunc: func(ctx context.Context, in *osspec.ProvisionerCreateBucketRequest, opts ...grpc.CallOption) (*osspec.ProvisionerCreateBucketResponse, error) { + if in.BucketName != bucketName { + t.Errorf("expected %s, got %s", bucketName, in.BucketName) + } + if in.BucketContext["ProtocolVersion"] != protocolVersion { + t.Errorf("expected %s, got %s", protocolVersion, in.BucketContext["ProtocolVersion"]) + } + return &osspec.ProvisionerCreateBucketResponse{}, nil + }, + params: nil, + }, } for _, tc := range testCases { @@ -242,7 +256,7 @@ func TestAddValidProtocols(t *testing.T) { kubeClient: kubeClient, } - t.Logf("Testing protocol %s", tc.name) + t.Logf(tc.name) err := bl.Add(ctx, &b) if err != nil { t.Errorf("add returned: %+v", err) @@ -405,6 +419,41 @@ func TestDeleteValidProtocols(t *testing.T) { extraParamName: extraParamValue, }, }, + { + name: "Empty parameters", + setProtocol: func(b *v1alpha1.Bucket) { + b.Spec.Protocol.S3 = &v1alpha1.S3Protocol{ + Region: region, + Version: protocolVersion, + SignatureVersion: sigVersion, + BucketName: bucketName, + Endpoint: endpoint, + } + }, + protocolName: v1alpha1.ProtocolNameS3, + deleteFunc: func(ctx context.Context, in *osspec.ProvisionerDeleteBucketRequest, opts ...grpc.CallOption) (*osspec.ProvisionerDeleteBucketResponse, error) { + if in.BucketName != bucketName { + t.Errorf("expected %s, got %s", bucketName, in.BucketName) + } + if in.BucketContext["Region"] != region { + t.Errorf("expected %s, got %s", region, in.BucketContext["Region"]) + } + if in.BucketContext["ProtocolVersion"] != protocolVersion { + t.Errorf("expected %s, got %s", protocolVersion, in.BucketContext["ProtocolVersion"]) + } + if in.BucketContext["SignatureVersion"] != string(sigVersion) { + t.Errorf("expected %s, got %s", sigVersion, in.BucketContext["SignatureVersion"]) + } + if in.BucketContext["Endpoint"] != endpoint { + t.Errorf("expected %s, got %s", endpoint, in.BucketContext["Endpoint"]) + } + if in.BucketContext["ProtocolVersion"] != protocolVersion { + t.Errorf("expected %s, got %s", protocolVersion, in.BucketContext["ProtocolVersion"]) + } + return &osspec.ProvisionerDeleteBucketResponse{}, nil + }, + params: nil, + }, } for _, tc := range testCases { @@ -434,7 +483,7 @@ func TestDeleteValidProtocols(t *testing.T) { } tc.setProtocol(&b) - t.Logf("Testing protocol %s", tc.name) + t.Logf(tc.name) err := bl.Delete(ctx, &b) if err != nil { t.Errorf("delete returned: %+v", err) diff --git a/pkg/controller/bucketaccess/bucket_access_controller.go b/pkg/controller/bucketaccess/bucket_access_controller.go index 885d092..f0375d3 100644 --- a/pkg/controller/bucketaccess/bucket_access_controller.go +++ b/pkg/controller/bucketaccess/bucket_access_controller.go @@ -116,7 +116,7 @@ func (bal *bucketAccessListener) Add(ctx context.Context, obj *v1alpha1.BucketAc req := osspec.ProvisionerGrantBucketAccessRequest{ Principal: obj.Spec.Principal, AccessPolicy: obj.Spec.PolicyActionsConfigMapData, - BucketContext: obj.Spec.Parameters, + BucketContext: bal.getParams(obj), } switch bucket.Spec.Protocol.Name { @@ -206,7 +206,7 @@ func (bal *bucketAccessListener) Delete(ctx context.Context, obj *v1alpha1.Bucke req := osspec.ProvisionerRevokeBucketAccessRequest{ Principal: obj.Spec.Principal, - BucketContext: obj.Spec.Parameters, + BucketContext: bal.getParams(obj), } switch bucket.Spec.Protocol.Name { @@ -277,3 +277,11 @@ func (bal *bucketAccessListener) updatePrincipal(ctx context.Context, name strin }) return err } + +func (bal *bucketAccessListener) getParams(obj *v1alpha1.BucketAccess) map[string]string { + params := map[string]string{} + if obj.Spec.Parameters != nil { + params = obj.Spec.Parameters + } + return params +} diff --git a/pkg/controller/bucketaccess/bucket_access_controller_test.go b/pkg/controller/bucketaccess/bucket_access_controller_test.go index 4188410..0bb9870 100644 --- a/pkg/controller/bucketaccess/bucket_access_controller_test.go +++ b/pkg/controller/bucketaccess/bucket_access_controller_test.go @@ -303,6 +303,47 @@ func TestAdd(t *testing.T) { extraParamName: extraParamValue, }, }, + { + name: "Empty parameters", + setProtocol: func(b *v1alpha1.Bucket) { + b.Spec.Protocol.S3 = &v1alpha1.S3Protocol{ + Region: region, + Version: protocolVersion, + SignatureVersion: sigVersion, + BucketName: bucketName, + Endpoint: endpoint, + } + }, + protocolName: v1alpha1.ProtocolNameS3, + grantFunc: func(ctx context.Context, in *osspec.ProvisionerGrantBucketAccessRequest, opts ...grpc.CallOption) (*osspec.ProvisionerGrantBucketAccessResponse, error) { + if in.BucketName != bucketName { + t.Errorf("expected %s, got %s", bucketName, in.BucketName) + } + if in.BucketContext["Region"] != region { + t.Errorf("expected %s, got %s", region, in.BucketContext["Region"]) + } + if in.Principal != principal { + t.Errorf("expected %s, got %s", principal, in.Principal) + } + if in.BucketContext["Version"] != protocolVersion { + t.Errorf("expected %s, got %s", protocolVersion, in.BucketContext["Version"]) + } + if in.BucketContext["SignatureVersion"] != string(sigVersion) { + t.Errorf("expected %s, got %s", sigVersion, in.BucketContext["SignatureVersion"]) + } + if in.BucketContext["Endpoint"] != endpoint { + t.Errorf("expected %s, got %s", endpoint, in.BucketContext["Endpoint"]) + } + return &osspec.ProvisionerGrantBucketAccessResponse{ + Principal: principal, + CredentialsFileContents: credsContents, + CredentialsFilePath: credsFile, + }, nil + }, + principal: principal, + serviceAccount: "", + params: nil, + }, } for _, tc := range testCases { @@ -347,6 +388,7 @@ func TestAdd(t *testing.T) { kubeClient: kubeClient, } + t.Logf(tc.name) err := bal.Add(ctx, &ba) if err != nil { t.Errorf("add returned: %+v", err) @@ -576,6 +618,42 @@ func TestDelete(t *testing.T) { extraParamName: extraParamValue, }, }, + { + name: "Empty parameters", + setProtocol: func(b *v1alpha1.Bucket) { + b.Spec.Protocol.S3 = &v1alpha1.S3Protocol{ + Region: region, + Version: protocolVersion, + SignatureVersion: sigVersion, + BucketName: bucketName, + Endpoint: endpoint, + } + }, + protocolName: v1alpha1.ProtocolNameS3, + revokeFunc: func(ctx context.Context, in *osspec.ProvisionerRevokeBucketAccessRequest, opts ...grpc.CallOption) (*osspec.ProvisionerRevokeBucketAccessResponse, error) { + if in.BucketName != bucketName { + t.Errorf("expected %s, got %s", bucketName, in.BucketName) + } + if in.BucketContext["Region"] != region { + t.Errorf("expected %s, got %s", region, in.BucketContext["Region"]) + } + if in.Principal != principal { + t.Errorf("expected %s, got %s", principal, in.Principal) + } + if in.BucketContext["Version"] != protocolVersion { + t.Errorf("expected %s, got %s", protocolVersion, in.BucketContext["Version"]) + } + if in.BucketContext["SignatureVersion"] != string(sigVersion) { + t.Errorf("expected %s, got %s", sigVersion, in.BucketContext["SignatureVersion"]) + } + if in.BucketContext["Endpoint"] != endpoint { + t.Errorf("expected %s, got %s", endpoint, in.BucketContext["Endpoint"]) + } + return &osspec.ProvisionerRevokeBucketAccessResponse{}, nil + }, + serviceAccount: "", + params: nil, + }, } for _, tc := range testCases { @@ -619,6 +697,7 @@ func TestDelete(t *testing.T) { Type: v1.SecretTypeOpaque, } + t.Logf(tc.name) ctx := context.TODO() tc.setProtocol(&b) client := fakebucketclientset.NewSimpleClientset(&ba, &b)