diff --git a/pkg/common/constants.go b/pkg/common/constants.go index a9f070f82..905c8fcda 100644 --- a/pkg/common/constants.go +++ b/pkg/common/constants.go @@ -33,6 +33,11 @@ const ( // Label that is set on a disk when it is used by a 'multi-zone' VolumeHandle MultiZoneLabel = "goog-gke-multi-zone" + // GCE Access Modes that are valid for hyperdisks only. + GCEReadOnlyManyAccessMode = "READ_ONLY_MANY" + GCEReadWriteManyAccessMode = "READ_WRITE_MANY" + GCEReadWriteOnceAccessMode = "READ_WRITE_SINGLE" + // Data cache mode DataCacheModeWriteBack = "writeback" DataCacheModeWriteThrough = "writethrough" diff --git a/pkg/common/parameters.go b/pkg/common/parameters.go index 7f59e69d4..0572a0515 100644 --- a/pkg/common/parameters.go +++ b/pkg/common/parameters.go @@ -42,7 +42,6 @@ const ( ParameterKeyDataCacheMode = "data-cache-mode" ParameterKeyResourceTags = "resource-tags" ParameterKeyEnableMultiZoneProvisioning = "enable-multi-zone-provisioning" - ParameterHdHADiskType = "hyperdisk-balanced-high-availability" // Parameters for VolumeSnapshotClass ParameterKeyStorageLocations = "storage-locations" @@ -76,6 +75,11 @@ const ( tagKeyCreatedForSnapshotName = "kubernetes.io/created-for/volumesnapshot/name" tagKeyCreatedForSnapshotNamespace = "kubernetes.io/created-for/volumesnapshot/namespace" tagKeyCreatedForSnapshotContentName = "kubernetes.io/created-for/volumesnapshotcontent/name" + + // Hyperdisk disk types + DiskTypeHdHA = "hyperdisk-balanced-high-availability" + DiskTypeHdT = "hyperdisk-throughput" + DiskTypeHdE = "hyperdisk-extreme" ) type DataCacheParameters struct { @@ -130,7 +134,7 @@ type DiskParameters struct { } func (dp *DiskParameters) IsRegional() bool { - return dp.ReplicationType == "regional-pd" || dp.DiskType == ParameterHdHADiskType + return dp.ReplicationType == "regional-pd" || dp.DiskType == DiskTypeHdHA } // SnapshotParameters contains normalized and defaulted parameters for snapshots @@ -200,8 +204,8 @@ func (pp *ParameterProcessor) ExtractAndDefaultParameters(parameters map[string] case ParameterKeyType: if v != "" { p.DiskType = strings.ToLower(v) - if !pp.EnableHdHA && p.DiskType == ParameterHdHADiskType { - return p, d, fmt.Errorf("parameters contain invalid disk type %s", ParameterHdHADiskType) + if !pp.EnableHdHA && p.DiskType == DiskTypeHdHA { + return p, d, fmt.Errorf("parameters contain invalid disk type %s", DiskTypeHdHA) } } case ParameterKeyReplicationType: diff --git a/pkg/common/utils.go b/pkg/common/utils.go index 222207456..44f736d7f 100644 --- a/pkg/common/utils.go +++ b/pkg/common/utils.go @@ -726,6 +726,10 @@ func NewLimiter(limit, burst int, emptyBucket bool) *rate.Limiter { return limiter } +func IsHyperdisk(diskType string) bool { + return strings.HasPrefix(diskType, "hyperdisk-") +} + // shortString is inspired by k8s.io/apimachinery/pkg/util/rand.SafeEncodeString, but takes data from a hash. func ShortString(s string) string { hasher := fnv.New128a() diff --git a/pkg/gce-cloud-provider/compute/fake-gce.go b/pkg/gce-cloud-provider/compute/fake-gce.go index ad517fca0..c2b931670 100644 --- a/pkg/gce-cloud-provider/compute/fake-gce.go +++ b/pkg/gce-cloud-provider/compute/fake-gce.go @@ -23,6 +23,7 @@ import ( "github.com/GoogleCloudPlatform/k8s-cloud-provider/pkg/cloud/meta" csi "github.com/container-storage-interface/spec/lib/go/csi" + computebeta "google.golang.org/api/compute/v0.beta" computev1 "google.golang.org/api/compute/v1" "google.golang.org/api/googleapi" "google.golang.org/grpc/codes" @@ -196,55 +197,30 @@ func (cloud *FakeCloudProvider) GetDisk(ctx context.Context, project string, vol return disk, nil } -func (cloud *FakeCloudProvider) ValidateExistingDisk(ctx context.Context, resp *CloudDisk, params common.DiskParameters, reqBytes, limBytes int64, multiWriter bool) error { - if resp == nil { - return fmt.Errorf("disk does not exist") - } - requestValid := common.GbToBytes(resp.GetSizeGb()) >= reqBytes || reqBytes == 0 - responseValid := common.GbToBytes(resp.GetSizeGb()) <= limBytes || limBytes == 0 - if !requestValid || !responseValid { - return fmt.Errorf( - "disk already exists with incompatible capacity. Need %v (Required) < %v (Existing) < %v (Limit)", - reqBytes, common.GbToBytes(resp.GetSizeGb()), limBytes) - } - - respType := strings.Split(resp.GetPDType(), "/") - typeMatch := strings.TrimSpace(respType[len(respType)-1]) == strings.TrimSpace(params.DiskType) - typeDefault := params.DiskType == "" && strings.TrimSpace(respType[len(respType)-1]) == "pd-standard" - if !typeMatch && !typeDefault { - return fmt.Errorf("disk already exists with incompatible type. Need %v. Got %v", - params.DiskType, respType[len(respType)-1]) - } - - // We are assuming here that a multiWriter disk could be used as non-multiWriter - if multiWriter && !resp.GetMultiWriter() { - return fmt.Errorf("disk already exists with incompatible capability. Need MultiWriter. Got non-MultiWriter") - } - klog.V(4).Infof("Compatible disk already exists") - return ValidateDiskParameters(resp, params) -} - func (cloud *FakeCloudProvider) InsertDisk(ctx context.Context, project string, volKey *meta.Key, params common.DiskParameters, capBytes int64, capacityRange *csi.CapacityRange, replicaZones []string, snapshotID string, volumeContentSourceVolumeID string, multiWriter bool, accessMode string) error { if disk, ok := cloud.disks[volKey.String()]; ok { - err := cloud.ValidateExistingDisk(ctx, disk, params, + err := ValidateExistingDisk(ctx, disk, params, int64(capacityRange.GetRequiredBytes()), int64(capacityRange.GetLimitBytes()), - multiWriter) + multiWriter, accessMode) if err != nil { return err } } - computeDisk := &computev1.Disk{ - Name: volKey.Name, - SizeGb: common.BytesToGbRoundUp(capBytes), - Description: "Disk created by GCE-PD CSI Driver", - Type: cloud.GetDiskTypeURI(project, volKey, params.DiskType), - SourceDiskId: volumeContentSourceVolumeID, - Status: cloud.mockDiskStatus, - Labels: params.Labels, - ProvisionedIops: params.ProvisionedIOPSOnCreate, - ProvisionedThroughput: params.ProvisionedThroughputOnCreate, + computeDisk := &computebeta.Disk{ + Name: volKey.Name, + SizeGb: common.BytesToGbRoundUp(capBytes), + Description: "Disk created by GCE-PD CSI Driver", + Type: cloud.GetDiskTypeURI(project, volKey, params.DiskType), + SourceDiskId: volumeContentSourceVolumeID, + Status: cloud.mockDiskStatus, + Labels: params.Labels, + ProvisionedIops: params.ProvisionedIOPSOnCreate, + ProvisionedThroughput: params.ProvisionedThroughputOnCreate, + AccessMode: accessMode, + MultiWriter: multiWriter, + EnableConfidentialCompute: params.EnableConfidentialCompute, } if snapshotID != "" { @@ -263,7 +239,7 @@ func (cloud *FakeCloudProvider) InsertDisk(ctx context.Context, project string, } if params.DiskEncryptionKMSKey != "" { - computeDisk.DiskEncryptionKey = &computev1.CustomerEncryptionKey{ + computeDisk.DiskEncryptionKey = &computebeta.CustomerEncryptionKey{ KmsKeyName: params.DiskEncryptionKMSKey, } } @@ -278,13 +254,7 @@ func (cloud *FakeCloudProvider) InsertDisk(ctx context.Context, project string, return fmt.Errorf("could not create disk, key was neither zonal nor regional, instead got: %v", volKey.String()) } - if containsBetaDiskType(hyperdiskTypes, params.DiskType) { - betaDisk := convertV1DiskToBetaDisk(computeDisk) - betaDisk.EnableConfidentialCompute = params.EnableConfidentialCompute - cloud.disks[volKey.String()] = CloudDiskFromBeta(betaDisk) - } else { - cloud.disks[volKey.String()] = CloudDiskFromV1(computeDisk) - } + cloud.disks[volKey.String()] = CloudDiskFromBeta(computeDisk) return nil } diff --git a/pkg/gce-cloud-provider/compute/gce-compute.go b/pkg/gce-cloud-provider/compute/gce-compute.go index 16b2997ce..947ba72ac 100644 --- a/pkg/gce-cloud-provider/compute/gce-compute.go +++ b/pkg/gce-cloud-provider/compute/gce-compute.go @@ -53,7 +53,6 @@ const ( pdDiskTypeUnsupportedPattern = `\[([a-z-]+)\] features are not compatible for creating instance` ) -var hyperdiskTypes = []string{"hyperdisk-extreme", "hyperdisk-throughput", "hyperdisk-balanced"} var pdDiskTypeUnsupportedRegex = regexp.MustCompile(pdDiskTypeUnsupportedPattern) type GCEAPIVersion string @@ -101,7 +100,6 @@ type GCECompute interface { // Disk Methods GetDisk(ctx context.Context, project string, volumeKey *meta.Key) (*CloudDisk, error) RepairUnderspecifiedVolumeKey(ctx context.Context, project string, volumeKey *meta.Key) (string, *meta.Key, error) - ValidateExistingDisk(ctx context.Context, disk *CloudDisk, params common.DiskParameters, reqBytes, limBytes int64, multiWriter bool) error InsertDisk(ctx context.Context, project string, volKey *meta.Key, params common.DiskParameters, capBytes int64, capacityRange *csi.CapacityRange, replicaZones []string, snapshotID string, volumeContentSourceVolumeID string, multiWriter bool, accessMode string) error DeleteDisk(ctx context.Context, project string, volumeKey *meta.Key) error UpdateDisk(ctx context.Context, project string, volKey *meta.Key, existingDisk *CloudDisk, params common.ModifyVolumeParameters) error @@ -382,7 +380,7 @@ func (cloud *CloudProvider) getRegionURI(project, region string) string { region) } -func (cloud *CloudProvider) ValidateExistingDisk(ctx context.Context, resp *CloudDisk, params common.DiskParameters, reqBytes, limBytes int64, multiWriter bool) error { +func ValidateExistingDisk(ctx context.Context, resp *CloudDisk, params common.DiskParameters, reqBytes, limBytes int64, multiWriter bool, accessMode string) error { klog.V(5).Infof("Validating existing disk %v with diskType: %s, reqested bytes: %v, limit bytes: %v", resp, params.DiskType, reqBytes, limBytes) if resp == nil { return fmt.Errorf("disk does not exist") @@ -395,14 +393,31 @@ func (cloud *CloudProvider) ValidateExistingDisk(ctx context.Context, resp *Clou reqBytes, common.GbToBytes(resp.GetSizeGb()), limBytes) } - // We are assuming here that a multiWriter disk could be used as non-multiWriter - if multiWriter && !resp.GetMultiWriter() { + if common.IsHyperdisk(params.DiskType) { + if !validAccessMode(accessMode, resp.GetAccessMode()) { + return fmt.Errorf("disk already exists with incompatible capability. Need %s. Got %s", accessMode, resp.GetAccessMode()) + } + } else if multiWriter && !resp.GetMultiWriter() { + // We are assuming here that a multiWriter PD could be used as non-multiWriter return fmt.Errorf("disk already exists with incompatible capability. Need MultiWriter. Got non-MultiWriter") } return ValidateDiskParameters(resp, params) } +func validAccessMode(want, got string) bool { + if want == got { + return true + } + switch want { + case common.GCEReadOnlyManyAccessMode, common.GCEReadWriteOnceAccessMode: + return got == common.GCEReadWriteManyAccessMode + // For RWX, no other access mode is valid. + default: + return false + } +} + // ValidateDiskParameters takes a CloudDisk and returns true if the parameters // specified validly describe the disk provided, and false otherwise. func ValidateDiskParameters(disk *CloudDisk, params common.DiskParameters) error { @@ -442,7 +457,7 @@ func (cloud *CloudProvider) InsertDisk(ctx context.Context, project string, volK if description == "" { description = "Regional disk created by GCE-PD CSI Driver" } - return cloud.insertRegionalDisk(ctx, project, volKey, params, capBytes, capacityRange, replicaZones, snapshotID, volumeContentSourceVolumeID, description, multiWriter) + return cloud.insertRegionalDisk(ctx, project, volKey, params, capBytes, capacityRange, replicaZones, snapshotID, volumeContentSourceVolumeID, description, multiWriter, accessMode) default: return fmt.Errorf("could not insert disk, key was neither zonal nor regional, instead got: %v", volKey.String()) } @@ -626,7 +641,8 @@ func (cloud *CloudProvider) insertRegionalDisk( snapshotID string, volumeContentSourceVolumeID string, description string, - multiWriter bool) error { + multiWriter bool, + accessMode string) error { var ( err error opName string @@ -676,8 +692,13 @@ func (cloud *CloudProvider) insertRegionalDisk( } } + if common.IsHyperdisk(params.DiskType) { + diskToCreate.AccessMode = accessMode + } else { + diskToCreate.MultiWriter = multiWriter + } + var insertOp *computebeta.Operation - diskToCreate.MultiWriter = multiWriter insertOp, err = cloud.betaService.RegionDisks.Insert(project, volKey.Region, diskToCreate).Context(ctx).Do() if insertOp != nil { opName = insertOp.Name @@ -691,10 +712,10 @@ func (cloud *CloudProvider) insertRegionalDisk( // the error code should be non-Final return common.NewTemporaryError(codes.Unavailable, fmt.Errorf("error when getting disk: %w", err)) } - err = cloud.ValidateExistingDisk(ctx, disk, params, + err = ValidateExistingDisk(ctx, disk, params, int64(capacityRange.GetRequiredBytes()), int64(capacityRange.GetLimitBytes()), - multiWriter) + multiWriter, accessMode) if err != nil { return err } @@ -715,10 +736,10 @@ func (cloud *CloudProvider) insertRegionalDisk( if err != nil { return common.NewTemporaryError(codes.Unavailable, fmt.Errorf("error when getting disk: %w", err)) } - err = cloud.ValidateExistingDisk(ctx, disk, params, + err = ValidateExistingDisk(ctx, disk, params, int64(capacityRange.GetRequiredBytes()), int64(capacityRange.GetLimitBytes()), - multiWriter) + multiWriter, accessMode) if err != nil { return err } @@ -806,7 +827,12 @@ func (cloud *CloudProvider) insertZonalDisk( } } - diskToCreate.AccessMode = accessMode + if common.IsHyperdisk(params.DiskType) { + diskToCreate.AccessMode = accessMode + } else { + diskToCreate.MultiWriter = multiWriter + } + var insertOp *computebeta.Operation insertOp, err = cloud.betaService.Disks.Insert(project, volKey.Zone, diskToCreate).Context(ctx).Do() if insertOp != nil { @@ -821,10 +847,10 @@ func (cloud *CloudProvider) insertZonalDisk( // the error code should be non-Final return common.NewTemporaryError(codes.Unavailable, fmt.Errorf("error when getting disk: %w", err)) } - err = cloud.ValidateExistingDisk(ctx, disk, params, + err = ValidateExistingDisk(ctx, disk, params, int64(capacityRange.GetRequiredBytes()), int64(capacityRange.GetLimitBytes()), - multiWriter) + multiWriter, accessMode) if err != nil { return err } @@ -846,10 +872,10 @@ func (cloud *CloudProvider) insertZonalDisk( if err != nil { return common.NewTemporaryError(codes.Unavailable, fmt.Errorf("error when getting disk: %w", err)) } - err = cloud.ValidateExistingDisk(ctx, disk, params, + err = ValidateExistingDisk(ctx, disk, params, int64(capacityRange.GetRequiredBytes()), int64(capacityRange.GetLimitBytes()), - multiWriter) + multiWriter, accessMode) if err != nil { return err } @@ -1687,13 +1713,3 @@ func encodeTags(tags map[string]string) (string, error) { } return string(enc), nil } - -func containsBetaDiskType(betaDiskTypes []string, diskType string) bool { - for _, betaDiskType := range betaDiskTypes { - if betaDiskType == diskType { - return true - } - } - - return false -} diff --git a/pkg/gce-cloud-provider/compute/gce-compute_test.go b/pkg/gce-cloud-provider/compute/gce-compute_test.go index cd5fc9c23..eb2661e40 100644 --- a/pkg/gce-cloud-provider/compute/gce-compute_test.go +++ b/pkg/gce-cloud-provider/compute/gce-compute_test.go @@ -15,8 +15,10 @@ limitations under the License. package gcecloudprovider import ( + "context" "testing" + computebeta "google.golang.org/api/compute/v0.beta" computev1 "google.golang.org/api/compute/v1" "google.golang.org/grpc/codes" "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/common" @@ -103,6 +105,154 @@ func TestValidateDiskParameters(t *testing.T) { } } +func TestValidateExistingDisk(t *testing.T) { + hyperdisk := "hyperdisk-balanced" + pd := "pd-balanced" + for _, tc := range []struct { + name string + reqBytes int64 + limBytes int64 + multiWriter bool + accessMode string + disk *computebeta.Disk + diskType string + wantErr bool + }{ + { + name: "invalid reqbytes - too big", + reqBytes: common.GbToBytes(10), + disk: &computebeta.Disk{ + SizeGb: 5, + }, + wantErr: true, + }, + { + name: "valid reqbytes", + reqBytes: common.GbToBytes(5), + disk: &computebeta.Disk{ + SizeGb: 8, + }, + }, + { + name: "invalid limbytes", + limBytes: common.GbToBytes(5), + disk: &computebeta.Disk{ + SizeGb: 10, + }, + wantErr: true, + }, + { + name: "valid limbytes", + limBytes: common.GbToBytes(5), + disk: &computebeta.Disk{ + SizeGb: 3, + }, + }, + { + name: "valid pd with same multi-writer config", + multiWriter: true, + disk: &computebeta.Disk{ + MultiWriter: true, + }, + diskType: pd, + }, + { + name: "valid pd with compatible multi-writer config", + multiWriter: false, + disk: &computebeta.Disk{ + MultiWriter: true, + }, + diskType: pd, + }, + { + name: "invalid pd with incompatible multi-writer config", + multiWriter: true, + disk: &computebeta.Disk{ + MultiWriter: false, + }, + diskType: pd, + wantErr: true, + }, + { + name: "valid hyperdisk with same access mode config", + accessMode: common.GCEReadWriteManyAccessMode, + disk: &computebeta.Disk{ + AccessMode: common.GCEReadWriteManyAccessMode, + }, + diskType: hyperdisk, + }, + { + name: "valid hyperdisk with compatible access mode config - ROX can use RWX", + accessMode: common.GCEReadOnlyManyAccessMode, + disk: &computebeta.Disk{ + AccessMode: common.GCEReadWriteManyAccessMode, + }, + diskType: hyperdisk, + }, + { + name: "valid hyperdisk with compatible access mode config - RWO can use RWX", + accessMode: common.GCEReadWriteOnceAccessMode, + disk: &computebeta.Disk{ + AccessMode: common.GCEReadWriteManyAccessMode, + }, + diskType: hyperdisk, + }, + { + name: "invalid hyperdisk with incompatible access mode config - ROX cannot use RWO", + accessMode: common.GCEReadOnlyManyAccessMode, + disk: &computebeta.Disk{ + AccessMode: common.GCEReadWriteOnceAccessMode, + }, + diskType: hyperdisk, + wantErr: true, + }, + { + name: "invalid hyperdisk with incompatible access mode config - RWO cannot use ROX", + accessMode: common.GCEReadWriteOnceAccessMode, + disk: &computebeta.Disk{ + AccessMode: common.GCEReadOnlyManyAccessMode, + }, + diskType: hyperdisk, + wantErr: true, + }, + { + name: "invalid hyperdisk with incompatible access mode config - RWX cannot use ROX", + accessMode: common.GCEReadWriteManyAccessMode, + disk: &computebeta.Disk{ + AccessMode: common.GCEReadOnlyManyAccessMode, + }, + diskType: hyperdisk, + wantErr: true, + }, + { + name: "invalid access mode", + accessMode: "RANDOM_ERROR", + disk: &computebeta.Disk{ + AccessMode: common.GCEReadOnlyManyAccessMode, + }, + diskType: hyperdisk, + wantErr: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + // Bootstrap correct disk + d := tc.disk + d.Type = tc.diskType + d.Zone = "zone" + + // Bootstrap params. We don't care about these as they are already tested in previous unit test. + params := common.DiskParameters{ + DiskType: tc.diskType, + } + + err := ValidateExistingDisk(context.Background(), CloudDiskFromBeta(tc.disk), params, tc.reqBytes, tc.limBytes, tc.multiWriter, tc.accessMode) + if gotErr := err != nil; gotErr != tc.wantErr { + t.Errorf("want error: %v, got error: %v", tc.wantErr, err) + } + }) + } +} + func TestCodeForGCEOpError(t *testing.T) { testCases := []struct { name string diff --git a/pkg/gce-pd-csi-driver/controller.go b/pkg/gce-pd-csi-driver/controller.go index 2688efa4d..a5a19d0bb 100644 --- a/pkg/gce-pd-csi-driver/controller.go +++ b/pkg/gce-pd-csi-driver/controller.go @@ -212,9 +212,6 @@ const ( resourceProject = "projects" listDisksUsersField = googleapi.Field("items/users") - - gceReadOnlyManyAccessMode = "READ_ONLY_MANY" - gceReadWriteManyAccessMode = "READ_WRITE_MANY" ) var ( @@ -236,6 +233,10 @@ var ( } listDisksFieldsWithUsers = append(listDisksFieldsWithoutUsers, "items/users") disksWithModifiableAccessMode = []string{"hyperdisk-ml"} + disksWithUnsettableAccessMode = map[string]bool{ + common.DiskTypeHdE: true, + common.DiskTypeHdT: true, + } ) func isDiskReady(disk *gce.CloudDisk) (bool, error) { @@ -377,8 +378,17 @@ func (gceCS *GCEControllerServer) createVolumeInternal(ctx context.Context, req return nil, status.Error(codes.InvalidArgument, "VolumeContentSource must be provided when AccessMode is set to read only") } - if readonly && params.DiskType == common.ParameterHdHADiskType { - return nil, status.Errorf(codes.InvalidArgument, "Invalid access mode for disk type %s", common.ParameterHdHADiskType) + if readonly && params.DiskType == common.DiskTypeHdHA { + return nil, status.Errorf(codes.InvalidArgument, "Invalid access mode for disk type %s", common.DiskTypeHdHA) + } + + // Hyperdisk-throughput and hyperdisk-extreme do not support attaching to multiple VMs. + isMultiAttach, err := getMultiAttachementFromCapabilities(volumeCapabilities) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "CreateVolume failed to parse volume capabilities: %v", err) + } + if isMultiAttach && disksWithUnsettableAccessMode[params.DiskType] { + return nil, status.Errorf(codes.InvalidArgument, "Invalid access mode for disk type %s", params.DiskType) } // Validate multi-zone provisioning configuration @@ -535,12 +545,12 @@ func (gceCS *GCEControllerServer) updateAccessModeIfNecessary(ctx context.Contex return nil } project := gceCS.CloudProvider.GetDefaultProject() - if disk.GetAccessMode() == gceReadOnlyManyAccessMode { + if disk.GetAccessMode() == common.GCEReadOnlyManyAccessMode { // If the access mode is already readonly, return return nil } - return gceCS.CloudProvider.SetDiskAccessMode(ctx, project, volKey, gceReadOnlyManyAccessMode) + return gceCS.CloudProvider.SetDiskAccessMode(ctx, project, volKey, common.GCEReadOnlyManyAccessMode) } func (gceCS *GCEControllerServer) createSingleDeviceDisk(ctx context.Context, req *csi.CreateVolumeRequest, params common.DiskParameters, dataCacheParams common.DataCacheParameters, enableDataCache bool) (*csi.CreateVolumeResponse, error) { @@ -599,11 +609,27 @@ func (gceCS *GCEControllerServer) createSingleDeviceDisk(ctx context.Context, re func (gceCS *GCEControllerServer) createSingleDisk(ctx context.Context, req *csi.CreateVolumeRequest, params common.DiskParameters, volKey *meta.Key, zones []string) (*gce.CloudDisk, error) { capacityRange := req.GetCapacityRange() capBytes, _ := getRequestCapacity(capacityRange) - multiWriter, _ := getMultiWriterFromCapabilities(req.GetVolumeCapabilities()) readonly, _ := getReadOnlyFromCapabilities(req.GetVolumeCapabilities()) - accessMode := params.AccessMode + accessMode := "" + multiWriter := false + if common.IsHyperdisk(params.DiskType) { + if am, err := getHyperdiskAccessModeFromCapabilities(req.GetVolumeCapabilities()); err != nil { + return nil, err + } else if disksWithUnsettableAccessMode[params.DiskType] { + // Disallow multi-attach for HdT and HdE. These checks were done in `createVolumeInternal`, + // but repeating them here future-proves us from possible refactors. + if am != common.GCEReadWriteOnceAccessMode { + return nil, status.Errorf(codes.Internal, "") + } + } else { + accessMode = am + } + } else { + multiWriter, _ = getMultiWriterFromCapabilities(req.GetVolumeCapabilities()) + } + if readonly && slices.Contains(disksWithModifiableAccessMode, params.DiskType) { - accessMode = gceReadOnlyManyAccessMode + accessMode = common.GCEReadOnlyManyAccessMode } // Validate if disk already exists @@ -616,10 +642,10 @@ func (gceCS *GCEControllerServer) createSingleDisk(ctx context.Context, req *csi } if err == nil { // There was no error so we want to validate the disk that we find - err = gceCS.CloudProvider.ValidateExistingDisk(ctx, existingDisk, params, + err = gce.ValidateExistingDisk(ctx, existingDisk, params, int64(capacityRange.GetRequiredBytes()), int64(capacityRange.GetLimitBytes()), - multiWriter) + multiWriter, accessMode) if err != nil { return nil, status.Errorf(codes.AlreadyExists, "CreateVolume disk already exists with same name and is incompatible: %v", err.Error()) } @@ -1562,9 +1588,8 @@ func (gceCS *GCEControllerServer) CreateSnapshot(ctx context.Context, req *csi.C } return nil, common.LoggedError("CreateSnapshot, failed to getDisk: ", err) } - isHyperdisk := strings.HasPrefix(disk.GetPDType(), "hyperdisk-") - if isHyperdisk && disk.GetAccessMode() == gceReadWriteManyAccessMode { - return nil, status.Errorf(codes.InvalidArgument, "Cannot create snapshot for disk type %s with access mode %s", common.ParameterHdHADiskType, gceReadWriteManyAccessMode) + if common.IsHyperdisk(disk.GetPDType()) && disk.GetAccessMode() == common.GCEReadWriteManyAccessMode { + return nil, status.Errorf(codes.InvalidArgument, "Cannot create snapshot for disk type %s with access mode %s", common.DiskTypeHdHA, common.GCEReadWriteManyAccessMode) } snapshotParams, err := common.ExtractAndDefaultSnapshotParameters(req.GetParameters(), gceCS.Driver.name, gceCS.Driver.extraTags) diff --git a/pkg/gce-pd-csi-driver/controller_test.go b/pkg/gce-pd-csi-driver/controller_test.go index 80871c77e..30a508ae8 100644 --- a/pkg/gce-pd-csi-driver/controller_test.go +++ b/pkg/gce-pd-csi-driver/controller_test.go @@ -30,6 +30,7 @@ import ( "github.com/GoogleCloudPlatform/k8s-cloud-provider/pkg/cloud/meta" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/timestamppb" compute "google.golang.org/api/compute/v1" @@ -235,7 +236,7 @@ func TestCreateSnapshotArguments(t *testing.T) { gce.CloudDiskFromV1(&compute.Disk{ Name: name, SelfLink: fmt.Sprintf("https://www.googleapis.com/compute/v1/projects/project/regions/country-region/name/%s", name), - Type: common.ParameterHdHADiskType, + Type: common.DiskTypeHdHA, Region: "country-region", }), }, @@ -252,7 +253,7 @@ func TestCreateSnapshotArguments(t *testing.T) { gce.CloudDiskFromV1(&compute.Disk{ Name: name, SelfLink: fmt.Sprintf("https://www.googleapis.com/compute/v1/projects/project/regions/country-region/name/%s", name), - Type: common.ParameterHdHADiskType, + Type: common.DiskTypeHdHA, Region: "country-region", }), }, @@ -275,8 +276,8 @@ func TestCreateSnapshotArguments(t *testing.T) { gce.CloudDiskFromV1(&compute.Disk{ Name: name, SelfLink: fmt.Sprintf("https://www.googleapis.com/compute/v1/projects/project/regions/country-region/name/%s", name), - Type: common.ParameterHdHADiskType, - AccessMode: gceReadWriteManyAccessMode, + Type: common.DiskTypeHdHA, + AccessMode: common.GCEReadWriteManyAccessMode, Region: "country-region", }), }, @@ -920,7 +921,7 @@ func TestCreateVolumeArguments(t *testing.T) { Name: name, CapacityRange: stdCapRange, VolumeCapabilities: stdVolCaps, - Parameters: map[string]string{common.ParameterKeyType: common.ParameterHdHADiskType}, + Parameters: map[string]string{common.ParameterKeyType: common.DiskTypeHdHA}, AccessibilityRequirements: &csi.TopologyRequirement{ Preferred: []*csi.Topology{ { @@ -953,7 +954,7 @@ func TestCreateVolumeArguments(t *testing.T) { CapacityRange: stdCapRange, VolumeCapabilities: stdVolCaps, Parameters: map[string]string{ - common.ParameterKeyType: common.ParameterHdHADiskType, + common.ParameterKeyType: common.DiskTypeHdHA, }, AccessibilityRequirements: &csi.TopologyRequirement{ Requisite: []*csi.Topology{ @@ -977,7 +978,7 @@ func TestCreateVolumeArguments(t *testing.T) { CapacityRange: stdCapRange, VolumeCapabilities: stdVolCaps, Parameters: map[string]string{ - common.ParameterKeyType: common.ParameterHdHADiskType, + common.ParameterKeyType: common.DiskTypeHdHA, }, }, expVol: &csi.Volume{ @@ -1229,6 +1230,74 @@ func TestCreateVolumeArguments(t *testing.T) { }, expErrCode: codes.InvalidArgument, }, + { + name: "fail with invalid hyperdisk access mode ROO", + req: &csi.CreateVolumeRequest{ + Name: name, + Parameters: map[string]string{"type": "hyperdisk-balanced"}, + VolumeCapabilities: []*csi.VolumeCapability{ + { + AccessType: &csi.VolumeCapability_Block{ + Block: &csi.VolumeCapability_BlockVolume{}, + }, + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY, + }, + }, + }, + }, + expErrCode: codes.InvalidArgument, + }, + { + name: "success with hyperdisk access mode RWO", + req: &csi.CreateVolumeRequest{ + Name: name, + Parameters: map[string]string{"type": "hyperdisk-balanced"}, + VolumeCapabilities: stdVolCaps, + }, + expVol: &csi.Volume{ + VolumeId: "projects/test-project/zones/country-region-zone/disks/test-name", + VolumeContext: nil, + AccessibleTopology: stdTopology, + CapacityBytes: MinimumVolumeSizeInBytes, + }, + }, + { + name: "fail with hyperdisk-throughput access mode ROX", + req: &csi.CreateVolumeRequest{ + Name: name, + Parameters: map[string]string{"type": "hyperdisk-throughput"}, + VolumeCapabilities: []*csi.VolumeCapability{ + { + AccessType: &csi.VolumeCapability_Block{ + Block: &csi.VolumeCapability_BlockVolume{}, + }, + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY, + }, + }, + }, + }, + expErrCode: codes.InvalidArgument, + }, + { + name: "fail with hyperdisk-extreme access mode RWX", + req: &csi.CreateVolumeRequest{ + Name: name, + Parameters: map[string]string{"type": "hyperdisk-extreme"}, + VolumeCapabilities: []*csi.VolumeCapability{ + { + AccessType: &csi.VolumeCapability_Block{ + Block: &csi.VolumeCapability_BlockVolume{}, + }, + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER, + }, + }, + }, + }, + expErrCode: codes.InvalidArgument, + }, } // Run test cases @@ -1260,16 +1329,8 @@ func TestCreateVolumeArguments(t *testing.T) { t.Fatalf("Expected volume %v, got nil volume", tc.expVol) } - if !reflect.DeepEqual(vol, tc.expVol) { - errStr := fmt.Sprintf("Expected volume: %#v\nTopology %#v\n\n to equal volume: %#v\nTopology %#v\n\n", - vol, vol.GetAccessibleTopology()[0], tc.expVol, tc.expVol.GetAccessibleTopology()[0]) - if len(vol.GetAccessibleTopology()) != len(tc.expVol.GetAccessibleTopology()) { - t.Errorf("Accessible topologies are not the same length, got %v, expected %v", len(vol.GetAccessibleTopology()), len(tc.expVol.GetAccessibleTopology())) - } - for i := 0; i < len(vol.GetAccessibleTopology()); i++ { - errStr += fmt.Sprintf("Got topology %#v\nExpected toplogy %#v\n\n", vol.GetAccessibleTopology()[i], tc.expVol.GetAccessibleTopology()[i]) - } - t.Error(errStr) + if diff := cmp.Diff(vol, tc.expVol, protocmp.Transform()); diff != "" { + t.Errorf("unexpected diff (-vol, +expVol): \n%s", diff) } } } @@ -1712,6 +1773,181 @@ func TestMultiZoneVolumeCreation(t *testing.T) { } } } +func TestCreateVolumeMultiWriterOrAccessMode(t *testing.T) { + testCases := []struct { + name string + req *csi.CreateVolumeRequest + existingDisk *gce.CloudDisk + expAccessMode string + expMultiWriter bool + expErrCode codes.Code + }{ + { + name: "success non-multi-writer PD", + req: &csi.CreateVolumeRequest{ + Name: name, + VolumeCapabilities: []*csi.VolumeCapability{ + { + AccessType: &csi.VolumeCapability_Mount{ + Mount: &csi.VolumeCapability_MountVolume{}, + }, + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, + }, + }, + }, + Parameters: map[string]string{ + common.ParameterKeyType: "pd-balanced", + }, + }, + expMultiWriter: false, + }, + { + name: "success multi-writer PD", + req: &csi.CreateVolumeRequest{ + Name: name, + VolumeCapabilities: []*csi.VolumeCapability{ + { + AccessType: &csi.VolumeCapability_Block{ + Block: &csi.VolumeCapability_BlockVolume{}, + }, + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER, + }, + }, + }, + Parameters: map[string]string{ + common.ParameterKeyType: "pd-balanced", + }, + }, + expMultiWriter: true, + }, + { + name: "success multi-writer Hyperdisk", + req: &csi.CreateVolumeRequest{ + Name: name, + VolumeCapabilities: []*csi.VolumeCapability{ + { + AccessType: &csi.VolumeCapability_Block{ + Block: &csi.VolumeCapability_BlockVolume{}, + }, + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER, + }, + }, + }, + Parameters: map[string]string{ + common.ParameterKeyType: "hyperdisk-balanced", + }, + }, + expAccessMode: common.GCEReadWriteManyAccessMode, + }, + { + name: "success non-multi-writer Hyperdisk", + req: &csi.CreateVolumeRequest{ + Name: name, + VolumeCapabilities: []*csi.VolumeCapability{ + { + AccessType: &csi.VolumeCapability_Block{ + Block: &csi.VolumeCapability_BlockVolume{}, + }, + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, + }, + }, + }, + Parameters: map[string]string{ + common.ParameterKeyType: "hyperdisk-balanced", + }, + }, + expAccessMode: common.GCEReadWriteOnceAccessMode, + }, + { + name: "failure unsupported access mode for Hyperdisk", + req: &csi.CreateVolumeRequest{ + Name: name, + VolumeCapabilities: []*csi.VolumeCapability{ + { + AccessType: &csi.VolumeCapability_Block{ + Block: &csi.VolumeCapability_BlockVolume{}, + }, + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY, + }, + }, + }, + Parameters: map[string]string{ + common.ParameterKeyType: "hyperdisk-balanced", + }, + }, + expErrCode: codes.InvalidArgument, + }, + } + + // Run test cases + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + fcp, err := gce.CreateFakeCloudProvider(project, zone, nil) + if err != nil { + t.Fatalf("Failed to create fake cloud provider: %v", err) + } + // Setup new driver each time so no interference + gceDriver := initGCEDriverWithCloudProvider(t, fcp) + + // Start Test + resp, err := gceDriver.cs.CreateVolume(context.Background(), tc.req) + if err != nil { + serverError, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from err: %v", serverError) + } + if serverError.Code() != tc.expErrCode { + t.Fatalf("Expected error code: %v, got: %v. err : %v", tc.expErrCode, serverError.Code(), err) + } + return + } + if tc.expErrCode != codes.OK { + t.Fatalf("Expected error: %v, got no error", tc.expErrCode) + } + + expVol := &csi.Volume{ + CapacityBytes: MinimumVolumeSizeInBytes, + VolumeId: testVolumeID, + VolumeContext: nil, + AccessibleTopology: []*csi.Topology{ + { + Segments: map[string]string{ + "topology.gke.io/zone": zone, + }, + }, + }, + } + + // Make sure responses match + vol := resp.GetVolume() + if diff := cmp.Diff(expVol, vol, protocmp.Transform()); diff != "" { + t.Errorf("Accessible topologies mismatch (-want +got):\n%s", diff) + } + + // Now check the fake "disk" for multi-writer or access mode, depending on disk type. + _, volKey, err := common.VolumeIDToKey(vol.GetVolumeId()) + if err != nil { + t.Fatalf("unexptected error while parsing volume id %v", vol.GetVolumeId()) + } + disk, err := fcp.GetDisk(context.Background(), project, volKey) + if err != nil { + t.Fatalf("unexpected error while getting disk from fake cloud provider: %v", err) + } + if disk.GetAccessMode() != tc.expAccessMode { + t.Errorf("want access mode %q, got access mode %q", tc.expAccessMode, disk.GetAccessMode()) + } + if disk.GetMultiWriter() != tc.expMultiWriter { + t.Errorf("want multi writer = %v, got multi writer = %v.", tc.expMultiWriter, disk.GetMultiWriter()) + } + }) + } +} type FakeCloudProviderInsertDiskErr struct { *gce.FakeCloudProvider @@ -2254,8 +2490,9 @@ func TestVolumeModifyErrorHandling(t *testing.T) { resp, err := gceDriver.cs.CreateVolume(context.Background(), tc.createReq) if err != nil { t.Errorf("Expected no error, got %v", err) + } else { + volId = resp.GetVolume().VolumeId } - volId = resp.GetVolume().VolumeId } tc.modifyReq.VolumeId = volId @@ -2847,7 +3084,7 @@ func TestCloningLocationRequirements(t *testing.T) { sourceVolumeID: testZonalVolumeSourceID, requestCapacityRange: stdCapRange, reqParameters: map[string]string{ - common.ParameterKeyType: common.ParameterHdHADiskType, + common.ParameterKeyType: common.DiskTypeHdHA, }, cloneIsRegional: true, expectedLocationRequirements: &locationRequirements{srcVolRegion: region, srcVolZone: zone, srcIsRegional: false, cloneIsRegional: true}, @@ -4849,8 +5086,8 @@ func TestCreateVolumeDiskReady(t *testing.T) { } vol := resp.GetVolume() - if !reflect.DeepEqual(vol, tc.expVol) { - t.Fatalf("Mismatch in expected vol %v, current volume: %v\n", tc.expVol, vol) + if diff := cmp.Diff(vol, tc.expVol, protocmp.Transform()); diff != "" { + t.Errorf("unexpected diff (-vol, +expVol): \n%s", diff) } }) } @@ -5345,7 +5582,6 @@ func TestCreateConfidentialVolume(t *testing.T) { }, } for _, tc := range testCases { - t.Logf("test case: %s", tc.name) t.Run(tc.name, func(t *testing.T) { fcp, err := gce.CreateFakeCloudProvider(project, zone, nil) if err != nil { diff --git a/pkg/gce-pd-csi-driver/utils.go b/pkg/gce-pd-csi-driver/utils.go index 77ac640cb..9a94995bd 100644 --- a/pkg/gce-pd-csi-driver/utils.go +++ b/pkg/gce-pd-csi-driver/utils.go @@ -21,7 +21,7 @@ import ( "errors" "fmt" - csi "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/klog/v2" @@ -33,7 +33,20 @@ const ( fsTypeXFS = "xfs" ) -var ProbeCSIFullMethod = "/csi.v1.Identity/Probe" +var ( + ProbeCSIFullMethod = "/csi.v1.Identity/Probe" + + csiAccessModeToHyperdiskMode = map[csi.VolumeCapability_AccessMode_Mode]string{ + csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER: common.GCEReadWriteOnceAccessMode, + csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY: common.GCEReadOnlyManyAccessMode, + csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER: common.GCEReadWriteManyAccessMode, + } + + supportedMultiAttachAccessModes = map[csi.VolumeCapability_AccessMode_Mode]bool{ + csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY: true, + csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER: true, + } +) func NewVolumeCapabilityAccessMode(mode csi.VolumeCapability_AccessMode_Mode) *csi.VolumeCapability_AccessMode { return &csi.VolumeCapability_AccessMode{Mode: mode} @@ -137,7 +150,7 @@ func validateAccessMode(am *csi.VolumeCapability_AccessMode) error { case csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY: case csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER: default: - return fmt.Errorf("%v access mode is not supported for for PD", am.GetMode()) + return fmt.Errorf("%v access mode is not supported for GCE disks", am.GetMode()) } return nil } @@ -253,6 +266,42 @@ func getReadOnlyFromCapabilities(vcs []*csi.VolumeCapability) (bool, error) { return false, nil } +func getMultiAttachementFromCapabilities(vcs []*csi.VolumeCapability) (bool, error) { + if vcs == nil { + return false, errors.New("volume capabilities is nil") + } + for _, vc := range vcs { + if vc.GetAccessMode() == nil { + return false, errors.New("access mode is nil") + } + mode := vc.GetAccessMode().GetMode() + if isMultiAttach, ok := supportedMultiAttachAccessModes[mode]; !ok { + return false, nil + } else { + return isMultiAttach, nil + } + } + return false, nil +} + +func getHyperdiskAccessModeFromCapabilities(vcs []*csi.VolumeCapability) (string, error) { + if vcs == nil { + return "", errors.New("volume capabilities is nil") + } + for _, vc := range vcs { + if vc.GetAccessMode() == nil { + return "", errors.New("access mode is nil") + } + mode := vc.GetAccessMode().GetMode() + if am, ok := csiAccessModeToHyperdiskMode[mode]; !ok { + return "", errors.New("found unsupported access mode for hyperdisk") + } else { + return am, nil + } + } + return "", errors.New("volume capabilities is nil") +} + func collectMountOptions(fsType string, mntFlags []string) []string { var options []string diff --git a/pkg/gce-pd-csi-driver/utils_test.go b/pkg/gce-pd-csi-driver/utils_test.go index 9d65c355b..f187cc9d8 100644 --- a/pkg/gce-pd-csi-driver/utils_test.go +++ b/pkg/gce-pd-csi-driver/utils_test.go @@ -775,3 +775,85 @@ func TestValidateStoragePoolZones(t *testing.T) { } } } + +func TestGetHyperdiskAccessModeFromCapabilities(t *testing.T) { + for _, tc := range []struct { + name string + vcs []*csi.VolumeCapability + want string + wantErr bool + }{ + { + name: "error with nil vcs", + wantErr: true, + }, + { + name: "error with no vcs", + vcs: []*csi.VolumeCapability{}, + wantErr: true, + }, + { + name: "error with nil access mode", + vcs: []*csi.VolumeCapability{ + {}, + }, + wantErr: true, + }, + { + name: "error with unsupported CSI access mode", + vcs: []*csi.VolumeCapability{ + { + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_MULTI_NODE_SINGLE_WRITER, + }, + }, + }, + wantErr: true, + }, + { + name: "success getting ROX", + vcs: []*csi.VolumeCapability{ + { + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY, + }, + }, + }, + want: common.GCEReadOnlyManyAccessMode, + }, + { + name: "success getting RWO", + vcs: []*csi.VolumeCapability{ + { + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, + }, + }, + }, + want: common.GCEReadWriteOnceAccessMode, + }, + { + name: "success getting RWX", + vcs: []*csi.VolumeCapability{ + { + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER, + }, + }, + }, + want: common.GCEReadWriteManyAccessMode, + }, + } { + t.Logf("Running test: %v", tc.name) + am, err := getHyperdiskAccessModeFromCapabilities(tc.vcs) + if err != nil { + if !tc.wantErr { + t.Errorf("unexpected error: %v", err) + } + continue + } + if am != tc.want { + t.Errorf("want %s, got %s", tc.want, am) + } + } +} diff --git a/test/e2e/tests/multi_zone_e2e_test.go b/test/e2e/tests/multi_zone_e2e_test.go index 09ad3ef6a..d7887841e 100644 --- a/test/e2e/tests/multi_zone_e2e_test.go +++ b/test/e2e/tests/multi_zone_e2e_test.go @@ -1065,7 +1065,7 @@ var _ = Describe("GCE PD CSI Driver Multi-Zone", func() { // Create Disk volName := testNamePrefix + string(uuid.NewUUID()) volume, err := controllerClient.CreateVolume(volName, map[string]string{ - common.ParameterKeyType: common.ParameterHdHADiskType, + common.ParameterKeyType: common.DiskTypeHdHA, }, defaultRepdSizeGb, &csi.TopologyRequirement{ Requisite: []*csi.Topology{ { @@ -1144,7 +1144,7 @@ var _ = Describe("GCE PD CSI Driver Multi-Zone", func() { // Create Disk volName := testNamePrefix + string(uuid.NewUUID()) volume, err := controllerClient.CreateVolume(volName, map[string]string{ - common.ParameterKeyType: common.ParameterHdHADiskType, + common.ParameterKeyType: common.DiskTypeHdHA, common.ParameterAvailabilityClass: common.ParameterRegionalHardFailoverClass, }, defaultRepdSizeGb, &csi.TopologyRequirement{ Requisite: []*csi.Topology{ diff --git a/vendor/google.golang.org/protobuf/internal/msgfmt/format.go b/vendor/google.golang.org/protobuf/internal/msgfmt/format.go new file mode 100644 index 000000000..a319550f6 --- /dev/null +++ b/vendor/google.golang.org/protobuf/internal/msgfmt/format.go @@ -0,0 +1,261 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package msgfmt implements a text marshaler combining the desirable features +// of both the JSON and proto text formats. +// It is optimized for human readability and has no associated deserializer. +package msgfmt + +import ( + "bytes" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + "time" + + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/internal/detrand" + "google.golang.org/protobuf/internal/genid" + "google.golang.org/protobuf/internal/order" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" +) + +// Format returns a formatted string for the message. +func Format(m proto.Message) string { + return string(appendMessage(nil, m.ProtoReflect())) +} + +// FormatValue returns a formatted string for an arbitrary value. +func FormatValue(v protoreflect.Value, fd protoreflect.FieldDescriptor) string { + return string(appendValue(nil, v, fd)) +} + +func appendValue(b []byte, v protoreflect.Value, fd protoreflect.FieldDescriptor) []byte { + switch v := v.Interface().(type) { + case nil: + return append(b, ""...) + case bool, int32, int64, uint32, uint64, float32, float64: + return append(b, fmt.Sprint(v)...) + case string: + return append(b, strconv.Quote(string(v))...) + case []byte: + return append(b, strconv.Quote(string(v))...) + case protoreflect.EnumNumber: + return appendEnum(b, v, fd) + case protoreflect.Message: + return appendMessage(b, v) + case protoreflect.List: + return appendList(b, v, fd) + case protoreflect.Map: + return appendMap(b, v, fd) + default: + panic(fmt.Sprintf("invalid type: %T", v)) + } +} + +func appendEnum(b []byte, v protoreflect.EnumNumber, fd protoreflect.FieldDescriptor) []byte { + if fd != nil { + if ev := fd.Enum().Values().ByNumber(v); ev != nil { + return append(b, ev.Name()...) + } + } + return strconv.AppendInt(b, int64(v), 10) +} + +func appendMessage(b []byte, m protoreflect.Message) []byte { + if b2 := appendKnownMessage(b, m); b2 != nil { + return b2 + } + + b = append(b, '{') + order.RangeFields(m, order.IndexNameFieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { + b = append(b, fd.TextName()...) + b = append(b, ':') + b = appendValue(b, v, fd) + b = append(b, delim()...) + return true + }) + b = appendUnknown(b, m.GetUnknown()) + b = bytes.TrimRight(b, delim()) + b = append(b, '}') + return b +} + +var protocmpMessageType = reflect.TypeOf(map[string]interface{}(nil)) + +func appendKnownMessage(b []byte, m protoreflect.Message) []byte { + md := m.Descriptor() + fds := md.Fields() + switch md.FullName() { + case genid.Any_message_fullname: + var msgVal protoreflect.Message + url := m.Get(fds.ByNumber(genid.Any_TypeUrl_field_number)).String() + if v := reflect.ValueOf(m); v.Type().ConvertibleTo(protocmpMessageType) { + // For protocmp.Message, directly obtain the sub-message value + // which is stored in structured form, rather than as raw bytes. + m2 := v.Convert(protocmpMessageType).Interface().(map[string]interface{}) + v, ok := m2[string(genid.Any_Value_field_name)].(proto.Message) + if !ok { + return nil + } + msgVal = v.ProtoReflect() + } else { + val := m.Get(fds.ByNumber(genid.Any_Value_field_number)).Bytes() + mt, err := protoregistry.GlobalTypes.FindMessageByURL(url) + if err != nil { + return nil + } + msgVal = mt.New() + err = proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(val, msgVal.Interface()) + if err != nil { + return nil + } + } + + b = append(b, '{') + b = append(b, "["+url+"]"...) + b = append(b, ':') + b = appendMessage(b, msgVal) + b = append(b, '}') + return b + + case genid.Timestamp_message_fullname: + secs := m.Get(fds.ByNumber(genid.Timestamp_Seconds_field_number)).Int() + nanos := m.Get(fds.ByNumber(genid.Timestamp_Nanos_field_number)).Int() + if nanos < 0 || nanos >= 1e9 { + return nil + } + t := time.Unix(secs, nanos).UTC() + x := t.Format("2006-01-02T15:04:05.000000000") // RFC 3339 + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, ".000") + return append(b, x+"Z"...) + + case genid.Duration_message_fullname: + sign := "" + secs := m.Get(fds.ByNumber(genid.Duration_Seconds_field_number)).Int() + nanos := m.Get(fds.ByNumber(genid.Duration_Nanos_field_number)).Int() + if nanos <= -1e9 || nanos >= 1e9 || (secs > 0 && nanos < 0) || (secs < 0 && nanos > 0) { + return nil + } + if secs < 0 || nanos < 0 { + sign, secs, nanos = "-", -1*secs, -1*nanos + } + x := fmt.Sprintf("%s%d.%09d", sign, secs, nanos) + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, ".000") + return append(b, x+"s"...) + + case genid.BoolValue_message_fullname, + genid.Int32Value_message_fullname, + genid.Int64Value_message_fullname, + genid.UInt32Value_message_fullname, + genid.UInt64Value_message_fullname, + genid.FloatValue_message_fullname, + genid.DoubleValue_message_fullname, + genid.StringValue_message_fullname, + genid.BytesValue_message_fullname: + fd := fds.ByNumber(genid.WrapperValue_Value_field_number) + return appendValue(b, m.Get(fd), fd) + } + + return nil +} + +func appendUnknown(b []byte, raw protoreflect.RawFields) []byte { + rs := make(map[protoreflect.FieldNumber][]protoreflect.RawFields) + for len(raw) > 0 { + num, _, n := protowire.ConsumeField(raw) + rs[num] = append(rs[num], raw[:n]) + raw = raw[n:] + } + + var ns []protoreflect.FieldNumber + for n := range rs { + ns = append(ns, n) + } + sort.Slice(ns, func(i, j int) bool { return ns[i] < ns[j] }) + + for _, n := range ns { + var leftBracket, rightBracket string + if len(rs[n]) > 1 { + leftBracket, rightBracket = "[", "]" + } + + b = strconv.AppendInt(b, int64(n), 10) + b = append(b, ':') + b = append(b, leftBracket...) + for _, r := range rs[n] { + num, typ, n := protowire.ConsumeTag(r) + r = r[n:] + switch typ { + case protowire.VarintType: + v, _ := protowire.ConsumeVarint(r) + b = strconv.AppendInt(b, int64(v), 10) + case protowire.Fixed32Type: + v, _ := protowire.ConsumeFixed32(r) + b = append(b, fmt.Sprintf("0x%08x", v)...) + case protowire.Fixed64Type: + v, _ := protowire.ConsumeFixed64(r) + b = append(b, fmt.Sprintf("0x%016x", v)...) + case protowire.BytesType: + v, _ := protowire.ConsumeBytes(r) + b = strconv.AppendQuote(b, string(v)) + case protowire.StartGroupType: + v, _ := protowire.ConsumeGroup(num, r) + b = append(b, '{') + b = appendUnknown(b, v) + b = bytes.TrimRight(b, delim()) + b = append(b, '}') + default: + panic(fmt.Sprintf("invalid type: %v", typ)) + } + b = append(b, delim()...) + } + b = bytes.TrimRight(b, delim()) + b = append(b, rightBracket...) + b = append(b, delim()...) + } + return b +} + +func appendList(b []byte, v protoreflect.List, fd protoreflect.FieldDescriptor) []byte { + b = append(b, '[') + for i := 0; i < v.Len(); i++ { + b = appendValue(b, v.Get(i), fd) + b = append(b, delim()...) + } + b = bytes.TrimRight(b, delim()) + b = append(b, ']') + return b +} + +func appendMap(b []byte, v protoreflect.Map, fd protoreflect.FieldDescriptor) []byte { + b = append(b, '{') + order.RangeEntries(v, order.GenericKeyOrder, func(k protoreflect.MapKey, v protoreflect.Value) bool { + b = appendValue(b, k.Value(), fd.MapKey()) + b = append(b, ':') + b = appendValue(b, v, fd.MapValue()) + b = append(b, delim()...) + return true + }) + b = bytes.TrimRight(b, delim()) + b = append(b, '}') + return b +} + +func delim() string { + // Deliberately introduce instability into the message string to + // discourage users from depending on it. + if detrand.Bool() { + return " " + } + return ", " +} diff --git a/vendor/google.golang.org/protobuf/testing/protocmp/reflect.go b/vendor/google.golang.org/protobuf/testing/protocmp/reflect.go new file mode 100644 index 000000000..0a5e47467 --- /dev/null +++ b/vendor/google.golang.org/protobuf/testing/protocmp/reflect.go @@ -0,0 +1,258 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package protocmp + +import ( + "reflect" + "sort" + "strconv" + "strings" + + "google.golang.org/protobuf/internal/genid" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/runtime/protoiface" +) + +func reflectValueOf(v interface{}) protoreflect.Value { + switch v := v.(type) { + case Enum: + return protoreflect.ValueOfEnum(v.Number()) + case Message: + return protoreflect.ValueOfMessage(v.ProtoReflect()) + case []byte: + return protoreflect.ValueOfBytes(v) // avoid overlap with reflect.Slice check below + default: + switch rv := reflect.ValueOf(v); { + case rv.Kind() == reflect.Slice: + return protoreflect.ValueOfList(reflectList{rv}) + case rv.Kind() == reflect.Map: + return protoreflect.ValueOfMap(reflectMap{rv}) + default: + return protoreflect.ValueOf(v) + } + } +} + +type reflectMessage Message + +func (m reflectMessage) stringKey(fd protoreflect.FieldDescriptor) string { + if m.Descriptor() != fd.ContainingMessage() { + panic("mismatching containing message") + } + return fd.TextName() +} + +func (m reflectMessage) Descriptor() protoreflect.MessageDescriptor { + return (Message)(m).Descriptor() +} +func (m reflectMessage) Type() protoreflect.MessageType { + return reflectMessageType{m.Descriptor()} +} +func (m reflectMessage) New() protoreflect.Message { + return m.Type().New() +} +func (m reflectMessage) Interface() protoreflect.ProtoMessage { + return Message(m) +} +func (m reflectMessage) Range(f func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool) { + // Range over populated known fields. + fds := m.Descriptor().Fields() + for i := 0; i < fds.Len(); i++ { + fd := fds.Get(i) + if m.Has(fd) && !f(fd, m.Get(fd)) { + return + } + } + + // Range over populated extension fields. + for _, xd := range m[messageTypeKey].(messageMeta).xds { + if m.Has(xd) && !f(xd, m.Get(xd)) { + return + } + } +} +func (m reflectMessage) Has(fd protoreflect.FieldDescriptor) bool { + _, ok := m[m.stringKey(fd)] + return ok +} +func (m reflectMessage) Clear(protoreflect.FieldDescriptor) { + panic("invalid mutation of read-only message") +} +func (m reflectMessage) Get(fd protoreflect.FieldDescriptor) protoreflect.Value { + v, ok := m[m.stringKey(fd)] + if !ok { + switch { + case fd.IsList(): + return protoreflect.ValueOfList(reflectList{}) + case fd.IsMap(): + return protoreflect.ValueOfMap(reflectMap{}) + case fd.Message() != nil: + return protoreflect.ValueOfMessage(reflectMessage{ + messageTypeKey: messageMeta{md: fd.Message()}, + }) + default: + return fd.Default() + } + } + + // The transformation may leave Any messages in structured form. + // If so, convert them back to a raw-encoded form. + if fd.FullName() == genid.Any_Value_field_fullname { + if m, ok := v.(Message); ok { + b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m) + if err != nil { + panic("BUG: " + err.Error()) + } + return protoreflect.ValueOfBytes(b) + } + } + + return reflectValueOf(v) +} +func (m reflectMessage) Set(protoreflect.FieldDescriptor, protoreflect.Value) { + panic("invalid mutation of read-only message") +} +func (m reflectMessage) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value { + panic("invalid mutation of read-only message") +} +func (m reflectMessage) NewField(protoreflect.FieldDescriptor) protoreflect.Value { + panic("not implemented") +} +func (m reflectMessage) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor { + if m.Descriptor().Oneofs().ByName(od.Name()) != od { + panic("oneof descriptor does not belong to this message") + } + fds := od.Fields() + for i := 0; i < fds.Len(); i++ { + fd := fds.Get(i) + if _, ok := m[m.stringKey(fd)]; ok { + return fd + } + } + return nil +} +func (m reflectMessage) GetUnknown() protoreflect.RawFields { + var nums []protoreflect.FieldNumber + for k := range m { + if len(strings.Trim(k, "0123456789")) == 0 { + n, _ := strconv.ParseUint(k, 10, 32) + nums = append(nums, protoreflect.FieldNumber(n)) + } + } + sort.Slice(nums, func(i, j int) bool { return nums[i] < nums[j] }) + + var raw protoreflect.RawFields + for _, num := range nums { + b, _ := m[strconv.FormatUint(uint64(num), 10)].(protoreflect.RawFields) + raw = append(raw, b...) + } + return raw +} +func (m reflectMessage) SetUnknown(protoreflect.RawFields) { + panic("invalid mutation of read-only message") +} +func (m reflectMessage) IsValid() bool { + invalid, _ := m[messageInvalidKey].(bool) + return !invalid +} +func (m reflectMessage) ProtoMethods() *protoiface.Methods { + return nil +} + +type reflectMessageType struct{ protoreflect.MessageDescriptor } + +func (t reflectMessageType) New() protoreflect.Message { + panic("not implemented") +} +func (t reflectMessageType) Zero() protoreflect.Message { + panic("not implemented") +} +func (t reflectMessageType) Descriptor() protoreflect.MessageDescriptor { + return t.MessageDescriptor +} + +type reflectList struct{ v reflect.Value } + +func (ls reflectList) Len() int { + if !ls.IsValid() { + return 0 + } + return ls.v.Len() +} +func (ls reflectList) Get(i int) protoreflect.Value { + return reflectValueOf(ls.v.Index(i).Interface()) +} +func (ls reflectList) Set(int, protoreflect.Value) { + panic("invalid mutation of read-only list") +} +func (ls reflectList) Append(protoreflect.Value) { + panic("invalid mutation of read-only list") +} +func (ls reflectList) AppendMutable() protoreflect.Value { + panic("invalid mutation of read-only list") +} +func (ls reflectList) Truncate(int) { + panic("invalid mutation of read-only list") +} +func (ls reflectList) NewElement() protoreflect.Value { + panic("not implemented") +} +func (ls reflectList) IsValid() bool { + return ls.v.IsValid() +} + +type reflectMap struct{ v reflect.Value } + +func (ms reflectMap) Len() int { + if !ms.IsValid() { + return 0 + } + return ms.v.Len() +} +func (ms reflectMap) Range(f func(protoreflect.MapKey, protoreflect.Value) bool) { + if !ms.IsValid() { + return + } + ks := ms.v.MapKeys() + for _, k := range ks { + pk := reflectValueOf(k.Interface()).MapKey() + pv := reflectValueOf(ms.v.MapIndex(k).Interface()) + if !f(pk, pv) { + return + } + } +} +func (ms reflectMap) Has(k protoreflect.MapKey) bool { + if !ms.IsValid() { + return false + } + return ms.v.MapIndex(reflect.ValueOf(k.Interface())).IsValid() +} +func (ms reflectMap) Clear(protoreflect.MapKey) { + panic("invalid mutation of read-only list") +} +func (ms reflectMap) Get(k protoreflect.MapKey) protoreflect.Value { + if !ms.IsValid() { + return protoreflect.Value{} + } + v := ms.v.MapIndex(reflect.ValueOf(k.Interface())) + if !v.IsValid() { + return protoreflect.Value{} + } + return reflectValueOf(v.Interface()) +} +func (ms reflectMap) Set(protoreflect.MapKey, protoreflect.Value) { + panic("invalid mutation of read-only list") +} +func (ms reflectMap) Mutable(k protoreflect.MapKey) protoreflect.Value { + panic("invalid mutation of read-only list") +} +func (ms reflectMap) NewValue() protoreflect.Value { + panic("not implemented") +} +func (ms reflectMap) IsValid() bool { + return ms.v.IsValid() +} diff --git a/vendor/google.golang.org/protobuf/testing/protocmp/util.go b/vendor/google.golang.org/protobuf/testing/protocmp/util.go new file mode 100644 index 000000000..dec34f20c --- /dev/null +++ b/vendor/google.golang.org/protobuf/testing/protocmp/util.go @@ -0,0 +1,684 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package protocmp + +import ( + "bytes" + "fmt" + "math" + "reflect" + "strings" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" +) + +var ( + enumReflectType = reflect.TypeOf(Enum{}) + messageReflectType = reflect.TypeOf(Message{}) +) + +// FilterEnum filters opt to only be applicable on a standalone [Enum], +// singular fields of enums, list fields of enums, or map fields of enum values, +// where the enum is the same type as the specified enum. +// +// The Go type of the last path step may be an: +// - [Enum] for singular fields, elements of a repeated field, +// values of a map field, or standalone [Enum] values +// - [][Enum] for list fields +// - map[K][Enum] for map fields +// - interface{} for a [Message] map entry value +// +// This must be used in conjunction with [Transform]. +func FilterEnum(enum protoreflect.Enum, opt cmp.Option) cmp.Option { + return FilterDescriptor(enum.Descriptor(), opt) +} + +// FilterMessage filters opt to only be applicable on a standalone [Message] values, +// singular fields of messages, list fields of messages, or map fields of +// message values, where the message is the same type as the specified message. +// +// The Go type of the last path step may be an: +// - [Message] for singular fields, elements of a repeated field, +// values of a map field, or standalone [Message] values +// - [][Message] for list fields +// - map[K][Message] for map fields +// - interface{} for a [Message] map entry value +// +// This must be used in conjunction with [Transform]. +func FilterMessage(message proto.Message, opt cmp.Option) cmp.Option { + return FilterDescriptor(message.ProtoReflect().Descriptor(), opt) +} + +// FilterField filters opt to only be applicable on the specified field +// in the message. It panics if a field of the given name does not exist. +// +// The Go type of the last path step may be an: +// - T for singular fields +// - []T for list fields +// - map[K]T for map fields +// - interface{} for a [Message] map entry value +// +// This must be used in conjunction with [Transform]. +func FilterField(message proto.Message, name protoreflect.Name, opt cmp.Option) cmp.Option { + md := message.ProtoReflect().Descriptor() + return FilterDescriptor(mustFindFieldDescriptor(md, name), opt) +} + +// FilterOneof filters opt to only be applicable on all fields within the +// specified oneof in the message. It panics if a oneof of the given name +// does not exist. +// +// The Go type of the last path step may be an: +// - T for singular fields +// - []T for list fields +// - map[K]T for map fields +// - interface{} for a [Message] map entry value +// +// This must be used in conjunction with [Transform]. +func FilterOneof(message proto.Message, name protoreflect.Name, opt cmp.Option) cmp.Option { + md := message.ProtoReflect().Descriptor() + return FilterDescriptor(mustFindOneofDescriptor(md, name), opt) +} + +// FilterDescriptor ignores the specified descriptor. +// +// The following descriptor types may be specified: +// - [protoreflect.EnumDescriptor] +// - [protoreflect.MessageDescriptor] +// - [protoreflect.FieldDescriptor] +// - [protoreflect.OneofDescriptor] +// +// For the behavior of each, see the corresponding filter function. +// Since this filter accepts a [protoreflect.FieldDescriptor], it can be used +// to also filter for extension fields as a [protoreflect.ExtensionDescriptor] +// is just an alias to [protoreflect.FieldDescriptor]. +// +// This must be used in conjunction with [Transform]. +func FilterDescriptor(desc protoreflect.Descriptor, opt cmp.Option) cmp.Option { + f := newNameFilters(desc) + return cmp.FilterPath(f.Filter, opt) +} + +// IgnoreEnums ignores all enums of the specified types. +// It is equivalent to FilterEnum(enum, cmp.Ignore()) for each enum. +// +// This must be used in conjunction with [Transform]. +func IgnoreEnums(enums ...protoreflect.Enum) cmp.Option { + var ds []protoreflect.Descriptor + for _, e := range enums { + ds = append(ds, e.Descriptor()) + } + return IgnoreDescriptors(ds...) +} + +// IgnoreMessages ignores all messages of the specified types. +// It is equivalent to [FilterMessage](message, [cmp.Ignore]()) for each message. +// +// This must be used in conjunction with [Transform]. +func IgnoreMessages(messages ...proto.Message) cmp.Option { + var ds []protoreflect.Descriptor + for _, m := range messages { + ds = append(ds, m.ProtoReflect().Descriptor()) + } + return IgnoreDescriptors(ds...) +} + +// IgnoreFields ignores the specified fields in the specified message. +// It is equivalent to [FilterField](message, name, [cmp.Ignore]()) for each field +// in the message. +// +// This must be used in conjunction with [Transform]. +func IgnoreFields(message proto.Message, names ...protoreflect.Name) cmp.Option { + var ds []protoreflect.Descriptor + md := message.ProtoReflect().Descriptor() + for _, s := range names { + ds = append(ds, mustFindFieldDescriptor(md, s)) + } + return IgnoreDescriptors(ds...) +} + +// IgnoreOneofs ignores fields of the specified oneofs in the specified message. +// It is equivalent to FilterOneof(message, name, cmp.Ignore()) for each oneof +// in the message. +// +// This must be used in conjunction with [Transform]. +func IgnoreOneofs(message proto.Message, names ...protoreflect.Name) cmp.Option { + var ds []protoreflect.Descriptor + md := message.ProtoReflect().Descriptor() + for _, s := range names { + ds = append(ds, mustFindOneofDescriptor(md, s)) + } + return IgnoreDescriptors(ds...) +} + +// IgnoreDescriptors ignores the specified set of descriptors. +// It is equivalent to [FilterDescriptor](desc, [cmp.Ignore]()) for each descriptor. +// +// This must be used in conjunction with [Transform]. +func IgnoreDescriptors(descs ...protoreflect.Descriptor) cmp.Option { + return cmp.FilterPath(newNameFilters(descs...).Filter, cmp.Ignore()) +} + +func mustFindFieldDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.FieldDescriptor { + d := findDescriptor(md, s) + if fd, ok := d.(protoreflect.FieldDescriptor); ok && fd.TextName() == string(s) { + return fd + } + + var suggestion string + switch d := d.(type) { + case protoreflect.FieldDescriptor: + suggestion = fmt.Sprintf("; consider specifying field %q instead", d.TextName()) + case protoreflect.OneofDescriptor: + suggestion = fmt.Sprintf("; consider specifying oneof %q with IgnoreOneofs instead", d.Name()) + } + panic(fmt.Sprintf("message %q has no field %q%s", md.FullName(), s, suggestion)) +} + +func mustFindOneofDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.OneofDescriptor { + d := findDescriptor(md, s) + if od, ok := d.(protoreflect.OneofDescriptor); ok && d.Name() == s { + return od + } + + var suggestion string + switch d := d.(type) { + case protoreflect.OneofDescriptor: + suggestion = fmt.Sprintf("; consider specifying oneof %q instead", d.Name()) + case protoreflect.FieldDescriptor: + suggestion = fmt.Sprintf("; consider specifying field %q with IgnoreFields instead", d.TextName()) + } + panic(fmt.Sprintf("message %q has no oneof %q%s", md.FullName(), s, suggestion)) +} + +func findDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.Descriptor { + // Exact match. + if fd := md.Fields().ByTextName(string(s)); fd != nil { + return fd + } + if od := md.Oneofs().ByName(s); od != nil && !od.IsSynthetic() { + return od + } + + // Best-effort match. + // + // It's a common user mistake to use the CamelCased field name as it appears + // in the generated Go struct. Instead of complaining that it doesn't exist, + // suggest the real protobuf name that the user may have desired. + normalize := func(s protoreflect.Name) string { + return strings.Replace(strings.ToLower(string(s)), "_", "", -1) + } + for i := 0; i < md.Fields().Len(); i++ { + if fd := md.Fields().Get(i); normalize(fd.Name()) == normalize(s) { + return fd + } + } + for i := 0; i < md.Oneofs().Len(); i++ { + if od := md.Oneofs().Get(i); normalize(od.Name()) == normalize(s) { + return od + } + } + return nil +} + +type nameFilters struct { + names map[protoreflect.FullName]bool +} + +func newNameFilters(descs ...protoreflect.Descriptor) *nameFilters { + f := &nameFilters{names: make(map[protoreflect.FullName]bool)} + for _, d := range descs { + switch d := d.(type) { + case protoreflect.EnumDescriptor: + f.names[d.FullName()] = true + case protoreflect.MessageDescriptor: + f.names[d.FullName()] = true + case protoreflect.FieldDescriptor: + f.names[d.FullName()] = true + case protoreflect.OneofDescriptor: + for i := 0; i < d.Fields().Len(); i++ { + f.names[d.Fields().Get(i).FullName()] = true + } + default: + panic("invalid descriptor type") + } + } + return f +} + +func (f *nameFilters) Filter(p cmp.Path) bool { + vx, vy := p.Last().Values() + return (f.filterValue(vx) && f.filterValue(vy)) || f.filterFields(p) +} + +func (f *nameFilters) filterFields(p cmp.Path) bool { + // Trim off trailing type-assertions so that the filter can match on the + // concrete value held within an interface value. + if _, ok := p.Last().(cmp.TypeAssertion); ok { + p = p[:len(p)-1] + } + + // Filter for Message maps. + mi, ok := p.Index(-1).(cmp.MapIndex) + if !ok { + return false + } + ps := p.Index(-2) + if ps.Type() != messageReflectType { + return false + } + + // Check field name. + vx, vy := ps.Values() + mx := vx.Interface().(Message) + my := vy.Interface().(Message) + k := mi.Key().String() + if f.filterFieldName(mx, k) && f.filterFieldName(my, k) { + return true + } + + // Check field value. + vx, vy = mi.Values() + if f.filterFieldValue(vx) && f.filterFieldValue(vy) { + return true + } + + return false +} + +func (f *nameFilters) filterFieldName(m Message, k string) bool { + if _, ok := m[k]; !ok { + return true // treat missing fields as already filtered + } + var fd protoreflect.FieldDescriptor + switch mm := m[messageTypeKey].(messageMeta); { + case protoreflect.Name(k).IsValid(): + fd = mm.md.Fields().ByTextName(k) + default: + fd = mm.xds[k] + } + if fd != nil { + return f.names[fd.FullName()] + } + return false +} + +func (f *nameFilters) filterFieldValue(v reflect.Value) bool { + if !v.IsValid() { + return true // implies missing slice element or map entry + } + v = v.Elem() // map entries are always populated values + switch t := v.Type(); { + case t == enumReflectType || t == messageReflectType: + // Check for singular message or enum field. + return f.filterValue(v) + case t.Kind() == reflect.Slice && (t.Elem() == enumReflectType || t.Elem() == messageReflectType): + // Check for list field of enum or message type. + return f.filterValue(v.Index(0)) + case t.Kind() == reflect.Map && (t.Elem() == enumReflectType || t.Elem() == messageReflectType): + // Check for map field of enum or message type. + return f.filterValue(v.MapIndex(v.MapKeys()[0])) + } + return false +} + +func (f *nameFilters) filterValue(v reflect.Value) bool { + if !v.IsValid() { + return true // implies missing slice element or map entry + } + if !v.CanInterface() { + return false // implies unexported struct field + } + switch v := v.Interface().(type) { + case Enum: + return v.Descriptor() != nil && f.names[v.Descriptor().FullName()] + case Message: + return v.Descriptor() != nil && f.names[v.Descriptor().FullName()] + } + return false +} + +// IgnoreDefaultScalars ignores singular scalars that are unpopulated or +// explicitly set to the default value. +// This option does not effect elements in a list or entries in a map. +// +// This must be used in conjunction with [Transform]. +func IgnoreDefaultScalars() cmp.Option { + return cmp.FilterPath(func(p cmp.Path) bool { + // Filter for Message maps. + mi, ok := p.Index(-1).(cmp.MapIndex) + if !ok { + return false + } + ps := p.Index(-2) + if ps.Type() != messageReflectType { + return false + } + + // Check whether both fields are default or unpopulated scalars. + vx, vy := ps.Values() + mx := vx.Interface().(Message) + my := vy.Interface().(Message) + k := mi.Key().String() + return isDefaultScalar(mx, k) && isDefaultScalar(my, k) + }, cmp.Ignore()) +} + +func isDefaultScalar(m Message, k string) bool { + if _, ok := m[k]; !ok { + return true + } + + var fd protoreflect.FieldDescriptor + switch mm := m[messageTypeKey].(messageMeta); { + case protoreflect.Name(k).IsValid(): + fd = mm.md.Fields().ByTextName(k) + default: + fd = mm.xds[k] + } + if fd == nil || !fd.Default().IsValid() { + return false + } + switch fd.Kind() { + case protoreflect.BytesKind: + v, ok := m[k].([]byte) + return ok && bytes.Equal(fd.Default().Bytes(), v) + case protoreflect.FloatKind: + v, ok := m[k].(float32) + return ok && equalFloat64(fd.Default().Float(), float64(v)) + case protoreflect.DoubleKind: + v, ok := m[k].(float64) + return ok && equalFloat64(fd.Default().Float(), float64(v)) + case protoreflect.EnumKind: + v, ok := m[k].(Enum) + return ok && fd.Default().Enum() == v.Number() + default: + return reflect.DeepEqual(fd.Default().Interface(), m[k]) + } +} + +func equalFloat64(x, y float64) bool { + return x == y || (math.IsNaN(x) && math.IsNaN(y)) +} + +// IgnoreEmptyMessages ignores messages that are empty or unpopulated. +// It applies to standalone [Message] values, singular message fields, +// list fields of messages, and map fields of message values. +// +// This must be used in conjunction with [Transform]. +func IgnoreEmptyMessages() cmp.Option { + return cmp.FilterPath(func(p cmp.Path) bool { + vx, vy := p.Last().Values() + return (isEmptyMessage(vx) && isEmptyMessage(vy)) || isEmptyMessageFields(p) + }, cmp.Ignore()) +} + +func isEmptyMessageFields(p cmp.Path) bool { + // Filter for Message maps. + mi, ok := p.Index(-1).(cmp.MapIndex) + if !ok { + return false + } + ps := p.Index(-2) + if ps.Type() != messageReflectType { + return false + } + + // Check field value. + vx, vy := mi.Values() + if isEmptyMessageFieldValue(vx) && isEmptyMessageFieldValue(vy) { + return true + } + + return false +} + +func isEmptyMessageFieldValue(v reflect.Value) bool { + if !v.IsValid() { + return true // implies missing slice element or map entry + } + v = v.Elem() // map entries are always populated values + switch t := v.Type(); { + case t == messageReflectType: + // Check singular field for empty message. + if !isEmptyMessage(v) { + return false + } + case t.Kind() == reflect.Slice && t.Elem() == messageReflectType: + // Check list field for all empty message elements. + for i := 0; i < v.Len(); i++ { + if !isEmptyMessage(v.Index(i)) { + return false + } + } + case t.Kind() == reflect.Map && t.Elem() == messageReflectType: + // Check map field for all empty message values. + for _, k := range v.MapKeys() { + if !isEmptyMessage(v.MapIndex(k)) { + return false + } + } + default: + return false + } + return true +} + +func isEmptyMessage(v reflect.Value) bool { + if !v.IsValid() { + return true // implies missing slice element or map entry + } + if !v.CanInterface() { + return false // implies unexported struct field + } + if m, ok := v.Interface().(Message); ok { + for k := range m { + if k != messageTypeKey && k != messageInvalidKey { + return false + } + } + return true + } + return false +} + +// IgnoreUnknown ignores unknown fields in all messages. +// +// This must be used in conjunction with [Transform]. +func IgnoreUnknown() cmp.Option { + return cmp.FilterPath(func(p cmp.Path) bool { + // Filter for Message maps. + mi, ok := p.Index(-1).(cmp.MapIndex) + if !ok { + return false + } + ps := p.Index(-2) + if ps.Type() != messageReflectType { + return false + } + + // Filter for unknown fields (which always have a numeric map key). + return strings.Trim(mi.Key().String(), "0123456789") == "" + }, cmp.Ignore()) +} + +// SortRepeated sorts repeated fields of the specified element type. +// The less function must be of the form "func(T, T) bool" where T is the +// Go element type for the repeated field kind. +// +// The element type T can be one of the following: +// - Go type for a protobuf scalar kind except for an enum +// (i.e., bool, int32, int64, uint32, uint64, float32, float64, string, and []byte) +// - E where E is a concrete enum type that implements [protoreflect.Enum] +// - M where M is a concrete message type that implement [proto.Message] +// +// This option only applies to repeated fields within a protobuf message. +// It does not operate on higher-order Go types that seem like a repeated field. +// For example, a []T outside the context of a protobuf message will not be +// handled by this option. To sort Go slices that are not repeated fields, +// consider using [github.com/google/go-cmp/cmp/cmpopts.SortSlices] instead. +// +// This must be used in conjunction with [Transform]. +func SortRepeated(lessFunc interface{}) cmp.Option { + t, ok := checkTTBFunc(lessFunc) + if !ok { + panic(fmt.Sprintf("invalid less function: %T", lessFunc)) + } + + var opt cmp.Option + var sliceType reflect.Type + switch vf := reflect.ValueOf(lessFunc); { + case t.Implements(enumV2Type): + et := reflect.Zero(t).Interface().(protoreflect.Enum).Type() + lessFunc = func(x, y Enum) bool { + vx := reflect.ValueOf(et.New(x.Number())) + vy := reflect.ValueOf(et.New(y.Number())) + return vf.Call([]reflect.Value{vx, vy})[0].Bool() + } + opt = FilterDescriptor(et.Descriptor(), cmpopts.SortSlices(lessFunc)) + sliceType = reflect.SliceOf(enumReflectType) + case t.Implements(messageV2Type): + mt := reflect.Zero(t).Interface().(protoreflect.ProtoMessage).ProtoReflect().Type() + lessFunc = func(x, y Message) bool { + mx := mt.New().Interface() + my := mt.New().Interface() + proto.Merge(mx, x) + proto.Merge(my, y) + vx := reflect.ValueOf(mx) + vy := reflect.ValueOf(my) + return vf.Call([]reflect.Value{vx, vy})[0].Bool() + } + opt = FilterDescriptor(mt.Descriptor(), cmpopts.SortSlices(lessFunc)) + sliceType = reflect.SliceOf(messageReflectType) + default: + switch t { + case reflect.TypeOf(bool(false)): + case reflect.TypeOf(int32(0)): + case reflect.TypeOf(int64(0)): + case reflect.TypeOf(uint32(0)): + case reflect.TypeOf(uint64(0)): + case reflect.TypeOf(float32(0)): + case reflect.TypeOf(float64(0)): + case reflect.TypeOf(string("")): + case reflect.TypeOf([]byte(nil)): + default: + panic(fmt.Sprintf("invalid element type: %v", t)) + } + opt = cmpopts.SortSlices(lessFunc) + sliceType = reflect.SliceOf(t) + } + + return cmp.FilterPath(func(p cmp.Path) bool { + // Filter to only apply to repeated fields within a message. + if t := p.Index(-1).Type(); t == nil || t != sliceType { + return false + } + if t := p.Index(-2).Type(); t == nil || t.Kind() != reflect.Interface { + return false + } + if t := p.Index(-3).Type(); t == nil || t != messageReflectType { + return false + } + return true + }, opt) +} + +func checkTTBFunc(lessFunc interface{}) (reflect.Type, bool) { + switch t := reflect.TypeOf(lessFunc); { + case t == nil: + return nil, false + case t.NumIn() != 2 || t.In(0) != t.In(1) || t.IsVariadic(): + return nil, false + case t.NumOut() != 1 || t.Out(0) != reflect.TypeOf(false): + return nil, false + default: + return t.In(0), true + } +} + +// SortRepeatedFields sorts the specified repeated fields. +// Sorting a repeated field is useful for treating the list as a multiset +// (i.e., a set where each value can appear multiple times). +// It panics if the field does not exist or is not a repeated field. +// +// The sort ordering is as follows: +// - Booleans are sorted where false is sorted before true. +// - Integers are sorted in ascending order. +// - Floating-point numbers are sorted in ascending order according to +// the total ordering defined by IEEE-754 (section 5.10). +// - Strings and bytes are sorted lexicographically in ascending order. +// - [Enum] values are sorted in ascending order based on its numeric value. +// - [Message] values are sorted according to some arbitrary ordering +// which is undefined and may change in future implementations. +// +// The ordering chosen for repeated messages is unlikely to be aesthetically +// preferred by humans. Consider using a custom sort function: +// +// FilterField(m, "foo_field", SortRepeated(func(x, y *foopb.MyMessage) bool { +// ... // user-provided definition for less +// })) +// +// This must be used in conjunction with [Transform]. +func SortRepeatedFields(message proto.Message, names ...protoreflect.Name) cmp.Option { + var opts cmp.Options + md := message.ProtoReflect().Descriptor() + for _, name := range names { + fd := mustFindFieldDescriptor(md, name) + if !fd.IsList() { + panic(fmt.Sprintf("message field %q is not repeated", fd.FullName())) + } + + var lessFunc interface{} + switch fd.Kind() { + case protoreflect.BoolKind: + lessFunc = func(x, y bool) bool { return !x && y } + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: + lessFunc = func(x, y int32) bool { return x < y } + case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: + lessFunc = func(x, y int64) bool { return x < y } + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: + lessFunc = func(x, y uint32) bool { return x < y } + case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: + lessFunc = func(x, y uint64) bool { return x < y } + case protoreflect.FloatKind: + lessFunc = lessF32 + case protoreflect.DoubleKind: + lessFunc = lessF64 + case protoreflect.StringKind: + lessFunc = func(x, y string) bool { return x < y } + case protoreflect.BytesKind: + lessFunc = func(x, y []byte) bool { return bytes.Compare(x, y) < 0 } + case protoreflect.EnumKind: + lessFunc = func(x, y Enum) bool { return x.Number() < y.Number() } + case protoreflect.MessageKind, protoreflect.GroupKind: + lessFunc = func(x, y Message) bool { return x.String() < y.String() } + default: + panic(fmt.Sprintf("invalid kind: %v", fd.Kind())) + } + opts = append(opts, FilterDescriptor(fd, cmpopts.SortSlices(lessFunc))) + } + return opts +} + +func lessF32(x, y float32) bool { + // Bit-wise implementation of IEEE-754, section 5.10. + xi := int32(math.Float32bits(x)) + yi := int32(math.Float32bits(y)) + xi ^= int32(uint32(xi>>31) >> 1) + yi ^= int32(uint32(yi>>31) >> 1) + return xi < yi +} +func lessF64(x, y float64) bool { + // Bit-wise implementation of IEEE-754, section 5.10. + xi := int64(math.Float64bits(x)) + yi := int64(math.Float64bits(y)) + xi ^= int64(uint64(xi>>63) >> 1) + yi ^= int64(uint64(yi>>63) >> 1) + return xi < yi +} diff --git a/vendor/google.golang.org/protobuf/testing/protocmp/xform.go b/vendor/google.golang.org/protobuf/testing/protocmp/xform.go new file mode 100644 index 000000000..0a1aef9b4 --- /dev/null +++ b/vendor/google.golang.org/protobuf/testing/protocmp/xform.go @@ -0,0 +1,377 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package protocmp provides protobuf specific options for the +// [github.com/google/go-cmp/cmp] package. +// +// The primary feature is the [Transform] option, which transform [proto.Message] +// types into a [Message] map that is suitable for cmp to introspect upon. +// All other options in this package must be used in conjunction with [Transform]. +package protocmp + +import ( + "reflect" + "strconv" + + "github.com/google/go-cmp/cmp" + + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/internal/genid" + "google.golang.org/protobuf/internal/msgfmt" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/runtime/protoiface" + "google.golang.org/protobuf/runtime/protoimpl" +) + +var ( + enumV2Type = reflect.TypeOf((*protoreflect.Enum)(nil)).Elem() + messageV1Type = reflect.TypeOf((*protoiface.MessageV1)(nil)).Elem() + messageV2Type = reflect.TypeOf((*proto.Message)(nil)).Elem() +) + +// Enum is a dynamic representation of a protocol buffer enum that is +// suitable for [cmp.Equal] and [cmp.Diff] to compare upon. +type Enum struct { + num protoreflect.EnumNumber + ed protoreflect.EnumDescriptor +} + +// Descriptor returns the enum descriptor. +// It returns nil for a zero Enum value. +func (e Enum) Descriptor() protoreflect.EnumDescriptor { + return e.ed +} + +// Number returns the enum value as an integer. +func (e Enum) Number() protoreflect.EnumNumber { + return e.num +} + +// Equal reports whether e1 and e2 represent the same enum value. +func (e1 Enum) Equal(e2 Enum) bool { + if e1.ed.FullName() != e2.ed.FullName() { + return false + } + return e1.num == e2.num +} + +// String returns the name of the enum value if known (e.g., "ENUM_VALUE"), +// otherwise it returns the formatted decimal enum number (e.g., "14"). +func (e Enum) String() string { + if ev := e.ed.Values().ByNumber(e.num); ev != nil { + return string(ev.Name()) + } + return strconv.Itoa(int(e.num)) +} + +const ( + // messageTypeKey indicates the protobuf message type. + // The value type is always messageMeta. + // From the public API, it presents itself as only the type, but the + // underlying data structure holds arbitrary metadata about the message. + messageTypeKey = "@type" + + // messageInvalidKey indicates that the message is invalid. + // The value is always the boolean "true". + messageInvalidKey = "@invalid" +) + +type messageMeta struct { + m proto.Message + md protoreflect.MessageDescriptor + xds map[string]protoreflect.ExtensionDescriptor +} + +func (t messageMeta) String() string { + return string(t.md.FullName()) +} + +func (t1 messageMeta) Equal(t2 messageMeta) bool { + return t1.md.FullName() == t2.md.FullName() +} + +// Message is a dynamic representation of a protocol buffer message that is +// suitable for [cmp.Equal] and [cmp.Diff] to directly operate upon. +// +// Every populated known field (excluding extension fields) is stored in the map +// with the key being the short name of the field (e.g., "field_name") and +// the value determined by the kind and cardinality of the field. +// +// Singular scalars are represented by the same Go type as [protoreflect.Value], +// singular messages are represented by the [Message] type, +// singular enums are represented by the [Enum] type, +// list fields are represented as a Go slice, and +// map fields are represented as a Go map. +// +// Every populated extension field is stored in the map with the key being the +// full name of the field surrounded by brackets (e.g., "[extension.full.name]") +// and the value determined according to the same rules as known fields. +// +// Every unknown field is stored in the map with the key being the field number +// encoded as a decimal string (e.g., "132") and the value being the raw bytes +// of the encoded field (as the [protoreflect.RawFields] type). +// +// Message values must not be created by or mutated by users. +type Message map[string]interface{} + +// Unwrap returns the original message value. +// It returns nil if this Message was not constructed from another message. +func (m Message) Unwrap() proto.Message { + mm, _ := m[messageTypeKey].(messageMeta) + return mm.m +} + +// Descriptor return the message descriptor. +// It returns nil for a zero Message value. +func (m Message) Descriptor() protoreflect.MessageDescriptor { + mm, _ := m[messageTypeKey].(messageMeta) + return mm.md +} + +// ProtoReflect returns a reflective view of m. +// It only implements the read-only operations of [protoreflect.Message]. +// Calling any mutating operations on m panics. +func (m Message) ProtoReflect() protoreflect.Message { + return (reflectMessage)(m) +} + +// ProtoMessage is a marker method from the legacy message interface. +func (m Message) ProtoMessage() {} + +// Reset is the required Reset method from the legacy message interface. +func (m Message) Reset() { + panic("invalid mutation of a read-only message") +} + +// String returns a formatted string for the message. +// It is intended for human debugging and has no guarantees about its +// exact format or the stability of its output. +func (m Message) String() string { + switch { + case m == nil: + return "" + case !m.ProtoReflect().IsValid(): + return "" + default: + return msgfmt.Format(m) + } +} + +type transformer struct { + resolver protoregistry.MessageTypeResolver +} + +func newTransformer(opts ...option) *transformer { + xf := &transformer{ + resolver: protoregistry.GlobalTypes, + } + for _, opt := range opts { + opt(xf) + } + return xf +} + +type option func(*transformer) + +// MessageTypeResolver overrides the resolver used for messages packed +// inside Any. The default is protoregistry.GlobalTypes, which is +// sufficient for all compiled-in Protobuf messages. Overriding the +// resolver is useful in tests that dynamically create Protobuf +// descriptors and messages, e.g. in proxies using dynamicpb. +func MessageTypeResolver(r protoregistry.MessageTypeResolver) option { + return func(xf *transformer) { + xf.resolver = r + } +} + +// Transform returns a [cmp.Option] that converts each [proto.Message] to a [Message]. +// The transformation does not mutate nor alias any converted messages. +// +// The google.protobuf.Any message is automatically unmarshaled such that the +// "value" field is a [Message] representing the underlying message value +// assuming it could be resolved and properly unmarshaled. +// +// This does not directly transform higher-order composite Go types. +// For example, []*foopb.Message is not transformed into []Message, +// but rather the individual message elements of the slice are transformed. +func Transform(opts ...option) cmp.Option { + xf := newTransformer(opts...) + + // addrType returns a pointer to t if t isn't a pointer or interface. + addrType := func(t reflect.Type) reflect.Type { + if k := t.Kind(); k == reflect.Interface || k == reflect.Ptr { + return t + } + return reflect.PtrTo(t) + } + + // TODO: Should this transform protoreflect.Enum types to Enum as well? + return cmp.FilterPath(func(p cmp.Path) bool { + ps := p.Last() + if isMessageType(addrType(ps.Type())) { + return true + } + + // Check whether the concrete values of an interface both satisfy + // the Message interface. + if ps.Type().Kind() == reflect.Interface { + vx, vy := ps.Values() + if !vx.IsValid() || vx.IsNil() || !vy.IsValid() || vy.IsNil() { + return false + } + return isMessageType(addrType(vx.Elem().Type())) && isMessageType(addrType(vy.Elem().Type())) + } + + return false + }, cmp.Transformer("protocmp.Transform", func(v interface{}) Message { + // For user convenience, shallow copy the message value if necessary + // in order for it to implement the message interface. + if rv := reflect.ValueOf(v); rv.IsValid() && rv.Kind() != reflect.Ptr && !isMessageType(rv.Type()) { + pv := reflect.New(rv.Type()) + pv.Elem().Set(rv) + v = pv.Interface() + } + + m := protoimpl.X.MessageOf(v) + switch { + case m == nil: + return nil + case !m.IsValid(): + return Message{messageTypeKey: messageMeta{m: m.Interface(), md: m.Descriptor()}, messageInvalidKey: true} + default: + return xf.transformMessage(m) + } + })) +} + +func isMessageType(t reflect.Type) bool { + // Avoid transforming the Message itself. + if t == reflect.TypeOf(Message(nil)) || t == reflect.TypeOf((*Message)(nil)) { + return false + } + return t.Implements(messageV1Type) || t.Implements(messageV2Type) +} + +func (xf *transformer) transformMessage(m protoreflect.Message) Message { + mx := Message{} + mt := messageMeta{m: m.Interface(), md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)} + + // Handle known and extension fields. + m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { + s := fd.TextName() + if fd.IsExtension() { + mt.xds[s] = fd + } + switch { + case fd.IsList(): + mx[s] = xf.transformList(fd, v.List()) + case fd.IsMap(): + mx[s] = xf.transformMap(fd, v.Map()) + default: + mx[s] = xf.transformSingular(fd, v) + } + return true + }) + + // Handle unknown fields. + for b := m.GetUnknown(); len(b) > 0; { + num, _, n := protowire.ConsumeField(b) + s := strconv.Itoa(int(num)) + b2, _ := mx[s].(protoreflect.RawFields) + mx[s] = append(b2, b[:n]...) + b = b[n:] + } + + // Expand Any messages. + if mt.md.FullName() == genid.Any_message_fullname { + s, _ := mx[string(genid.Any_TypeUrl_field_name)].(string) + b, _ := mx[string(genid.Any_Value_field_name)].([]byte) + mt, err := xf.resolver.FindMessageByURL(s) + if mt != nil && err == nil { + m2 := mt.New() + err := proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(b, m2.Interface()) + if err == nil { + mx[string(genid.Any_Value_field_name)] = xf.transformMessage(m2) + } + } + } + + mx[messageTypeKey] = mt + return mx +} + +func (xf *transformer) transformList(fd protoreflect.FieldDescriptor, lv protoreflect.List) interface{} { + t := protoKindToGoType(fd.Kind()) + rv := reflect.MakeSlice(reflect.SliceOf(t), lv.Len(), lv.Len()) + for i := 0; i < lv.Len(); i++ { + v := reflect.ValueOf(xf.transformSingular(fd, lv.Get(i))) + rv.Index(i).Set(v) + } + return rv.Interface() +} + +func (xf *transformer) transformMap(fd protoreflect.FieldDescriptor, mv protoreflect.Map) interface{} { + kfd := fd.MapKey() + vfd := fd.MapValue() + kt := protoKindToGoType(kfd.Kind()) + vt := protoKindToGoType(vfd.Kind()) + rv := reflect.MakeMapWithSize(reflect.MapOf(kt, vt), mv.Len()) + mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { + kv := reflect.ValueOf(xf.transformSingular(kfd, k.Value())) + vv := reflect.ValueOf(xf.transformSingular(vfd, v)) + rv.SetMapIndex(kv, vv) + return true + }) + return rv.Interface() +} + +func (xf *transformer) transformSingular(fd protoreflect.FieldDescriptor, v protoreflect.Value) interface{} { + switch fd.Kind() { + case protoreflect.EnumKind: + return Enum{num: v.Enum(), ed: fd.Enum()} + case protoreflect.MessageKind, protoreflect.GroupKind: + return xf.transformMessage(v.Message()) + case protoreflect.BytesKind: + // The protoreflect API does not specify whether an empty bytes is + // guaranteed to be nil or not. Always return non-nil bytes to avoid + // leaking information about the concrete proto.Message implementation. + if len(v.Bytes()) == 0 { + return []byte{} + } + return v.Bytes() + default: + return v.Interface() + } +} + +func protoKindToGoType(k protoreflect.Kind) reflect.Type { + switch k { + case protoreflect.BoolKind: + return reflect.TypeOf(bool(false)) + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: + return reflect.TypeOf(int32(0)) + case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: + return reflect.TypeOf(int64(0)) + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: + return reflect.TypeOf(uint32(0)) + case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: + return reflect.TypeOf(uint64(0)) + case protoreflect.FloatKind: + return reflect.TypeOf(float32(0)) + case protoreflect.DoubleKind: + return reflect.TypeOf(float64(0)) + case protoreflect.StringKind: + return reflect.TypeOf(string("")) + case protoreflect.BytesKind: + return reflect.TypeOf([]byte(nil)) + case protoreflect.EnumKind: + return reflect.TypeOf(Enum{}) + case protoreflect.MessageKind, protoreflect.GroupKind: + return reflect.TypeOf(Message{}) + default: + panic("invalid kind") + } +} diff --git a/vendor/modules.txt b/vendor/modules.txt index b82c44864..88d79a70c 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -614,6 +614,7 @@ google.golang.org/protobuf/internal/filetype google.golang.org/protobuf/internal/flags google.golang.org/protobuf/internal/genid google.golang.org/protobuf/internal/impl +google.golang.org/protobuf/internal/msgfmt google.golang.org/protobuf/internal/order google.golang.org/protobuf/internal/pragma google.golang.org/protobuf/internal/set @@ -626,6 +627,7 @@ google.golang.org/protobuf/reflect/protoreflect google.golang.org/protobuf/reflect/protoregistry google.golang.org/protobuf/runtime/protoiface google.golang.org/protobuf/runtime/protoimpl +google.golang.org/protobuf/testing/protocmp google.golang.org/protobuf/types/descriptorpb google.golang.org/protobuf/types/gofeaturespb google.golang.org/protobuf/types/known/anypb