Skip to content

Commit cc6957f

Browse files
committed
Add support for multi-zone provisioning
1 parent 4d6be71 commit cc6957f

16 files changed

+2076
-80
lines changed

pkg/common/parameters.go

+27-4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ const (
3333
ParameterKeyEnableConfidentialCompute = "enable-confidential-storage"
3434
ParameterKeyStoragePools = "storage-pools"
3535
ParameterKeyResourceTags = "resource-tags"
36+
ParameterKeyEnableMultiZoneProvisioning = "enable-multi-zone-provisioning"
3637

3738
// Parameters for VolumeSnapshotClass
3839
ParameterKeyStorageLocations = "storage-locations"
@@ -102,6 +103,9 @@ type DiskParameters struct {
102103
// Values: {map[string]string}
103104
// Default: ""
104105
ResourceTags map[string]string
106+
// Values: {bool}
107+
// Default: false
108+
MultiZoneProvisioning bool
105109
}
106110

107111
// SnapshotParameters contains normalized and defaulted parameters for snapshots
@@ -121,11 +125,17 @@ type StoragePool struct {
121125
ResourceName string
122126
}
123127

128+
type ParameterProcessor struct {
129+
DriverName string
130+
EnableStoragePools bool
131+
EnableMultiZone bool
132+
}
133+
124134
// ExtractAndDefaultParameters will take the relevant parameters from a map and
125135
// put them into a well defined struct making sure to default unspecified fields.
126136
// extraVolumeLabels are added as labels; if there are also labels specified in
127137
// parameters, any matching extraVolumeLabels will be overridden.
128-
func ExtractAndDefaultParameters(parameters map[string]string, driverName string, extraVolumeLabels map[string]string, enableStoragePools bool, extraTags map[string]string) (DiskParameters, error) {
138+
func (pp *ParameterProcessor) ExtractAndDefaultParameters(parameters map[string]string, extraVolumeLabels map[string]string, extraTags map[string]string) (DiskParameters, error) {
129139
p := DiskParameters{
130140
DiskType: "pd-standard", // Default
131141
ReplicationType: replicationTypeNone, // Default
@@ -210,24 +220,37 @@ func ExtractAndDefaultParameters(parameters map[string]string, driverName string
210220

211221
p.EnableConfidentialCompute = paramEnableConfidentialCompute
212222
case ParameterKeyStoragePools:
213-
if !enableStoragePools {
223+
if !pp.EnableStoragePools {
214224
return p, fmt.Errorf("parameters contains invalid option %q", ParameterKeyStoragePools)
215225
}
216226
storagePools, err := ParseStoragePools(v)
217227
if err != nil {
218-
return p, fmt.Errorf("parameters contain invalid value for %s parameter: %w", ParameterKeyStoragePools, err)
228+
return p, fmt.Errorf("parameters contains invalid value for %s parameter %q: %w", ParameterKeyStoragePools, v, err)
219229
}
220230
p.StoragePools = storagePools
221231
case ParameterKeyResourceTags:
222232
if err := extractResourceTagsParameter(v, p.ResourceTags); err != nil {
223233
return p, err
224234
}
235+
case ParameterKeyEnableMultiZoneProvisioning:
236+
if !pp.EnableMultiZone {
237+
return p, fmt.Errorf("parameters contains invalid option %q", ParameterKeyEnableMultiZoneProvisioning)
238+
}
239+
paramEnableMultiZoneProvisioning, err := ConvertStringToBool(v)
240+
if err != nil {
241+
return p, fmt.Errorf("parameters contain invalid value for %s parameter: %w", ParameterKeyEnableMultiZoneProvisioning, err)
242+
}
243+
244+
p.MultiZoneProvisioning = paramEnableMultiZoneProvisioning
245+
if paramEnableMultiZoneProvisioning {
246+
p.Labels[MultiZoneLabel] = "true"
247+
}
225248
default:
226249
return p, fmt.Errorf("parameters contains invalid option %q", k)
227250
}
228251
}
229252
if len(p.Tags) > 0 {
230-
p.Tags[tagKeyCreatedBy] = driverName
253+
p.Tags[tagKeyCreatedBy] = pp.DriverName
231254
}
232255
return p, nil
233256
}

pkg/common/parameters_test.go

+45-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ func TestExtractAndDefaultParameters(t *testing.T) {
2929
parameters map[string]string
3030
labels map[string]string
3131
enableStoragePools bool
32+
enableMultiZone bool
3233
extraTags map[string]string
3334
expectParams DiskParameters
3435
expectErr bool
@@ -350,19 +351,61 @@ func TestExtractAndDefaultParameters(t *testing.T) {
350351
labels: map[string]string{},
351352
expectErr: true,
352353
},
354+
{
355+
name: "multi-zone-enable parameters, multi-zone label is set, multi-zone feature enabled",
356+
parameters: map[string]string{ParameterKeyType: "hyperdisk-ml", ParameterKeyEnableMultiZoneProvisioning: "true"},
357+
labels: map[string]string{MultiZoneLabel: "true"},
358+
enableMultiZone: true,
359+
expectParams: DiskParameters{
360+
DiskType: "hyperdisk-ml",
361+
ReplicationType: "none",
362+
Tags: map[string]string{},
363+
Labels: map[string]string{MultiZoneLabel: "true"},
364+
ResourceTags: map[string]string{},
365+
MultiZoneProvisioning: true,
366+
},
367+
},
368+
{
369+
name: "multi-zone-enable parameters, multi-zone label is false, multi-zone feature enabled",
370+
parameters: map[string]string{ParameterKeyType: "hyperdisk-ml", ParameterKeyEnableMultiZoneProvisioning: "false"},
371+
enableMultiZone: true,
372+
expectParams: DiskParameters{
373+
DiskType: "hyperdisk-ml",
374+
ReplicationType: "none",
375+
Tags: map[string]string{},
376+
ResourceTags: map[string]string{},
377+
Labels: map[string]string{},
378+
},
379+
},
380+
{
381+
name: "multi-zone-enable parameters, invalid value, multi-zone feature enabled",
382+
parameters: map[string]string{ParameterKeyType: "hyperdisk-ml", ParameterKeyEnableMultiZoneProvisioning: "unknown"},
383+
enableMultiZone: true,
384+
expectErr: true,
385+
},
386+
{
387+
name: "multi-zone-enable parameters, multi-zone label is set, multi-zone feature disabled",
388+
parameters: map[string]string{ParameterKeyType: "hyperdisk-ml", ParameterKeyEnableMultiZoneProvisioning: "true"},
389+
expectErr: true,
390+
},
353391
}
354392

355393
for _, tc := range tests {
356394
t.Run(tc.name, func(t *testing.T) {
357-
p, err := ExtractAndDefaultParameters(tc.parameters, "testDriver", tc.labels, tc.enableStoragePools, tc.extraTags)
395+
pp := ParameterProcessor{
396+
DriverName: "testDriver",
397+
EnableStoragePools: tc.enableStoragePools,
398+
EnableMultiZone: tc.enableMultiZone,
399+
}
400+
p, err := pp.ExtractAndDefaultParameters(tc.parameters, tc.labels, tc.extraTags)
358401
if gotErr := err != nil; gotErr != tc.expectErr {
359402
t.Fatalf("ExtractAndDefaultParameters(%+v) = %v; expectedErr: %v", tc.parameters, err, tc.expectErr)
360403
}
361404
if err != nil {
362405
return
363406
}
364407

365-
if diff := cmp.Diff(p, tc.expectParams); diff != "" {
408+
if diff := cmp.Diff(tc.expectParams, p); diff != "" {
366409
t.Errorf("ExtractAndDefaultParameters(%+v): -want, +got \n%s", tc.parameters, diff)
367410
}
368411
})

pkg/common/utils.go

+47-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"fmt"
2323
"net/http"
2424
"regexp"
25+
"slices"
2526
"strings"
2627
"time"
2728

@@ -73,6 +74,10 @@ const (
7374
// Full or partial URL of the machine type resource, in the format:
7475
// zones/zone/machineTypes/machine-type
7576
machineTypePattern = "zones/[^/]+/machineTypes/([^/]+)$"
77+
78+
// Full or partial URL of the zone resource, in the format:
79+
// projects/{project}/zones/{zone}
80+
zoneURIPattern = "projects/[^/]+/zones/([^/]+)$"
7681
)
7782

7883
var (
@@ -85,6 +90,8 @@ var (
8590

8691
storagePoolFieldsRegex = regexp.MustCompile(`^projects/([^/]+)/zones/([^/]+)/storagePools/([^/]+)$`)
8792

93+
zoneURIRegex = regexp.MustCompile(zoneURIPattern)
94+
8895
// userErrorCodeMap tells how API error types are translated to error codes.
8996
userErrorCodeMap = map[int]codes.Code{
9097
http.StatusForbidden: codes.PermissionDenied,
@@ -97,6 +104,8 @@ var (
97104
regexParent = regexp.MustCompile(`(^[1-9][0-9]{0,31}$)|(^[a-z][a-z0-9-]{4,28}[a-z0-9]$)`)
98105
regexKey = regexp.MustCompile(`^[a-zA-Z0-9]([0-9A-Za-z_.-]{0,61}[a-zA-Z0-9])?$`)
99106
regexValue = regexp.MustCompile(`^[a-zA-Z0-9]([0-9A-Za-z_.@%=+:,*#&()\[\]{}\-\s]{0,61}[a-zA-Z0-9])?$`)
107+
108+
csiRetryableErrorCodes = []codes.Code{codes.Canceled, codes.DeadlineExceeded, codes.Unavailable, codes.Aborted, codes.ResourceExhausted}
100109
)
101110

102111
func BytesToGbRoundDown(bytes int64) int64 {
@@ -545,9 +554,37 @@ func isGoogleAPIError(err error) (codes.Code, error) {
545554
return codes.Unknown, fmt.Errorf("googleapi.Error %w does not map to any known errors", err)
546555
}
547556

548-
func LoggedError(msg string, err error) error {
557+
func loggedErrorForCode(msg string, code codes.Code, err error) error {
549558
klog.Errorf(msg+"%v", err.Error())
550-
return status.Errorf(CodeForError(err), msg+"%v", err.Error())
559+
return status.Errorf(code, msg+"%v", err.Error())
560+
}
561+
562+
func LoggedError(msg string, err error) error {
563+
return loggedErrorForCode(msg, CodeForError(err), err)
564+
}
565+
566+
// NewCombinedError tries to return an appropriate wrapped error that captures
567+
// useful information as an error code
568+
// If there are multiple errors, it extracts the first "retryable" error
569+
// as interpreted by the CSI sidecar.
570+
func NewCombinedError(msg string, errs []error) error {
571+
// If there is only one error, return it as the single error code
572+
if len(errs) == 1 {
573+
LoggedError(msg, errs[0])
574+
}
575+
576+
for _, err := range errs {
577+
code := CodeForError(err)
578+
if slices.Contains(csiRetryableErrorCodes, code) {
579+
// Return this as a TemporaryError to lock-in the retryable code
580+
// This will invoke the "existing" error code check in CodeForError
581+
return NewTemporaryError(code, fmt.Errorf("%s: %w", msg, err))
582+
}
583+
}
584+
585+
// None of these error codes were retryable. Just return a combined error
586+
// The first matching error (based on our CodeForError) logic will be returned.
587+
return LoggedError(msg, errors.Join(errs...))
551588
}
552589

553590
func isValidDiskEncryptionKmsKey(DiskEncryptionKmsKey string) bool {
@@ -556,6 +593,14 @@ func isValidDiskEncryptionKmsKey(DiskEncryptionKmsKey string) bool {
556593
return kmsKeyPattern.MatchString(DiskEncryptionKmsKey)
557594
}
558595

596+
func ParseZoneFromURI(zoneURI string) (string, error) {
597+
zoneMatch := zoneURIRegex.FindStringSubmatch(zoneURI)
598+
if zoneMatch == nil {
599+
return "", fmt.Errorf("failed to parse zone URI. Expected projects/{project}/zones/{zone}. Got: %s", zoneURI)
600+
}
601+
return zoneMatch[1], nil
602+
}
603+
559604
// ParseStoragePools returns an error if none of the given storagePools
560605
// (delimited by a comma) are in the format
561606
// projects/project/zones/zone/storagePools/storagePool.

pkg/common/utils_test.go

+86
Original file line numberDiff line numberDiff line change
@@ -1648,3 +1648,89 @@ func TestUnorderedSlicesEqual(t *testing.T) {
16481648
})
16491649
}
16501650
}
1651+
1652+
func TestParseZoneFromURI(t *testing.T) {
1653+
testcases := []struct {
1654+
name string
1655+
zoneURI string
1656+
wantZone string
1657+
expectErr bool
1658+
}{
1659+
{
1660+
name: "ParseZoneFromURI_FullURI",
1661+
zoneURI: "https://www.googleapis.com/compute/v1/projects/psch-gke-dev/zones/us-east4-a",
1662+
wantZone: "us-east4-a",
1663+
},
1664+
{
1665+
name: "ParseZoneFromURI_ProjectZoneString",
1666+
zoneURI: "projects/psch-gke-dev/zones/us-east4-a",
1667+
wantZone: "us-east4-a",
1668+
},
1669+
{
1670+
name: "ParseZoneFromURI_Malformed",
1671+
zoneURI: "projects/psch-gke-dev/regions/us-east4",
1672+
expectErr: true,
1673+
},
1674+
}
1675+
for _, tc := range testcases {
1676+
t.Run(tc.name, func(t *testing.T) {
1677+
gotZone, err := ParseZoneFromURI(tc.zoneURI)
1678+
if err != nil && !tc.expectErr {
1679+
t.Fatalf("Unexpected error: %v", err)
1680+
}
1681+
if err == nil && tc.expectErr {
1682+
t.Fatalf("Expected err, but none was returned. Zone result: %v", gotZone)
1683+
}
1684+
if gotZone != tc.wantZone {
1685+
t.Errorf("ParseZoneFromURI(%v): got %v, want %v", tc.zoneURI, gotZone, tc.wantZone)
1686+
}
1687+
})
1688+
}
1689+
}
1690+
1691+
func TestNewCombinedError(t *testing.T) {
1692+
testcases := []struct {
1693+
name string
1694+
errors []error
1695+
wantCode codes.Code
1696+
}{
1697+
{
1698+
name: "single generic error",
1699+
errors: []error{fmt.Errorf("my internal error")},
1700+
wantCode: codes.Internal,
1701+
},
1702+
{
1703+
name: "single retryable error",
1704+
errors: []error{&googleapi.Error{Code: http.StatusTooManyRequests, Message: "Resource Exhausted"}},
1705+
wantCode: codes.ResourceExhausted,
1706+
},
1707+
{
1708+
name: "multi generic error",
1709+
errors: []error{fmt.Errorf("my internal error"), fmt.Errorf("my other internal error")},
1710+
wantCode: codes.Internal,
1711+
},
1712+
{
1713+
name: "multi retryable error",
1714+
errors: []error{fmt.Errorf("my internal error"), &googleapi.Error{Code: http.StatusTooManyRequests, Message: "Resource Exhausted"}},
1715+
wantCode: codes.ResourceExhausted,
1716+
},
1717+
{
1718+
name: "multi retryable error",
1719+
errors: []error{fmt.Errorf("my internal error"), &googleapi.Error{Code: http.StatusGatewayTimeout, Message: "connection reset by peer"}, fmt.Errorf("my other internal error")},
1720+
wantCode: codes.Unavailable,
1721+
},
1722+
{
1723+
name: "multi retryable error",
1724+
errors: []error{fmt.Errorf("The disk resource is already being used"), &googleapi.Error{Code: http.StatusGatewayTimeout, Message: "connection reset by peer"}},
1725+
wantCode: codes.Unavailable,
1726+
},
1727+
}
1728+
for _, tc := range testcases {
1729+
t.Run(tc.name, func(t *testing.T) {
1730+
gotCode := CodeForError(NewCombinedError("message", tc.errors))
1731+
if gotCode != tc.wantCode {
1732+
t.Errorf("NewCombinedError(%v): got %v, want %v", tc.errors, gotCode, tc.wantCode)
1733+
}
1734+
})
1735+
}
1736+
}

pkg/gce-cloud-provider/compute/cloud-disk.go

+11
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,14 @@ func (d *CloudDisk) GetLabels() map[string]string {
256256
return nil
257257
}
258258
}
259+
260+
func (d *CloudDisk) GetAccessMode() string {
261+
switch {
262+
case d.disk != nil:
263+
return d.disk.AccessMode
264+
case d.betaDisk != nil:
265+
return d.betaDisk.AccessMode
266+
default:
267+
return ""
268+
}
269+
}

pkg/gce-cloud-provider/compute/cloud-disk_test.go

+47
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,50 @@ func TestGetLabels(t *testing.T) {
148148
}
149149
}
150150
}
151+
152+
func TestGetAccessMode(t *testing.T) {
153+
testCases := []struct {
154+
name string
155+
cloudDisk *CloudDisk
156+
wantAccessMode string
157+
}{
158+
{
159+
name: "v1 disk accessMode",
160+
cloudDisk: &CloudDisk{
161+
disk: &computev1.Disk{
162+
AccessMode: "READ_WRITE_SINGLE",
163+
},
164+
},
165+
wantAccessMode: "READ_WRITE_SINGLE",
166+
},
167+
{
168+
name: "beta disk accessMode",
169+
cloudDisk: &CloudDisk{
170+
betaDisk: &computebeta.Disk{
171+
AccessMode: "READ_ONLY_MANY",
172+
},
173+
},
174+
wantAccessMode: "READ_ONLY_MANY",
175+
},
176+
{
177+
name: "unset disk accessMode",
178+
cloudDisk: &CloudDisk{
179+
betaDisk: &computebeta.Disk{},
180+
},
181+
wantAccessMode: "",
182+
},
183+
{
184+
name: "unset disk",
185+
cloudDisk: &CloudDisk{},
186+
wantAccessMode: "",
187+
},
188+
}
189+
190+
for _, tc := range testCases {
191+
t.Logf("Running test: %v", tc.name)
192+
gotAccessMode := tc.cloudDisk.GetAccessMode()
193+
if gotAccessMode != tc.wantAccessMode {
194+
t.Errorf("GetAccessMode() got %v, want %v", gotAccessMode, tc.wantAccessMode)
195+
}
196+
}
197+
}

0 commit comments

Comments
 (0)