diff --git a/pkg/common/constants.go b/pkg/common/constants.go index eb64d8930..f0b5f510c 100644 --- a/pkg/common/constants.go +++ b/pkg/common/constants.go @@ -56,6 +56,9 @@ const ( // Node label for Data Cache (only applicable to GKE nodes) NodeLabelPrefix = "cloud.google.com/%s" DataCacheLssdCountLabel = "gke-data-cache-disk" + // Node label for attach limit override + NodeRestrictionLabelPrefix = "node-restriction.kubernetes.io/%s" + AttachLimitOverrideLabel = "gke-volume-attach-limit-override" ) // doc https://cloud.google.com/compute/docs/disks/hyperdisks#max-total-disks-per-vm diff --git a/pkg/common/utils.go b/pkg/common/utils.go index e2bd766a4..b58657e7c 100644 --- a/pkg/common/utils.go +++ b/pkg/common/utils.go @@ -24,6 +24,7 @@ import ( "net/http" "regexp" "slices" + "strconv" "strings" "time" @@ -772,6 +773,25 @@ func MapNumber(num int64) int64 { return 0 } +func ExtractCPUFromMachineType(input string) (int64, error) { + // Regex to find the number at the end of the string, + // it allows optional -lssd suffix. + re := regexp.MustCompile(`(\d+)(?:-lssd|-metal)?$`) + + match := re.FindStringSubmatch(input) + if len(match) < 2 { + return 0, fmt.Errorf("no number found at the end of the input string: %s", input) + } + + numberStr := match[1] + number, err := strconv.ParseInt(numberStr, 10, 64) + if err != nil { + return 0, fmt.Errorf("failed to convert string '%s' to integer: %w", numberStr, err) + } + + return number, nil +} + // IsUpdateIopsThroughputValuesAllowed checks if a disk type is hyperdisk, // which implies that IOPS and throughput values can be updated. func IsUpdateIopsThroughputValuesAllowed(disk *computev1.Disk) bool { diff --git a/pkg/gce-pd-csi-driver/node.go b/pkg/gce-pd-csi-driver/node.go index 0d1ffc268..ef30d1c62 100644 --- a/pkg/gce-pd-csi-driver/node.go +++ b/pkg/gce-pd-csi-driver/node.go @@ -31,6 +31,8 @@ import ( csi "github.com/container-storage-interface/spec/lib/go/csi" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" "k8s.io/klog/v2" "k8s.io/mount-utils" @@ -571,7 +573,7 @@ func (ns *GCENodeServer) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRe nodeID := common.CreateNodeID(ns.MetadataService.GetProject(), ns.MetadataService.GetZone(), ns.MetadataService.GetName()) - volumeLimits, err := ns.GetVolumeLimits() + volumeLimits, err := ns.GetVolumeLimits(ctx) if err != nil { klog.Errorf("GetVolumeLimits failed: %v. The error is ignored so that the driver can register", err.Error()) // No error should be returned from NodeGetInfo, otherwise the driver will not register @@ -733,7 +735,7 @@ func (ns *GCENodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpa }, nil } -func (ns *GCENodeServer) GetVolumeLimits() (int64, error) { +func (ns *GCENodeServer) GetVolumeLimits(ctx context.Context) (int64, error) { // Machine-type format: n1-type-CPUS or custom-CPUS-RAM or f1/g1-type machineType := ns.MetadataService.GetMachineType() @@ -743,6 +745,22 @@ func (ns *GCENodeServer) GetVolumeLimits() (int64, error) { return volumeLimitSmall, nil } } + + // Get attach limit override from label + attachLimitOverride, err := GetAttachLimitsOverrideFromNodeLabel(ctx, ns.MetadataService.GetName()) + if err == nil && attachLimitOverride > 0 && attachLimitOverride < 128 { + return attachLimitOverride, nil + } else { + // If there is an error or the range is not valid, still proceed to get defaults for the machine type + if err != nil { + klog.Warningf("using default value due to err getting node-restriction.kubernetes.io/gke-volume-attach-limit-override: %v", err) + } + if attachLimitOverride != 0 { + klog.Warningf("using default value due to invalid node-restriction.kubernetes.io/gke-volume-attach-limit-override: %d", attachLimitOverride) + } + } + + // Process gen4 machine attach limits gen4MachineTypesPrefix := []string{"c4a-", "c4-", "n4-"} for _, gen4Prefix := range gen4MachineTypesPrefix { if strings.HasPrefix(machineType, gen4Prefix) { @@ -768,3 +786,27 @@ func (ns *GCENodeServer) GetVolumeLimits() (int64, error) { return volumeLimitBig, nil } + +func GetAttachLimitsOverrideFromNodeLabel(ctx context.Context, nodeName string) (int64, error) { + cfg, err := rest.InClusterConfig() + if err != nil { + return 0, err + } + kubeClient, err := kubernetes.NewForConfig(cfg) + if err != nil { + return 0, err + } + node, err := getNodeWithRetry(ctx, kubeClient, nodeName) + if err != nil { + return 0, err + } + if val, found := node.GetLabels()[fmt.Sprintf(common.NodeRestrictionLabelPrefix, common.AttachLimitOverrideLabel)]; found { + attachLimitOverrideForNode, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return 0, fmt.Errorf("error getting attach limit override from node label: %v", err) + } + klog.V(4).Infof("attach limit override for the node: %v", attachLimitOverrideForNode) + return attachLimitOverrideForNode, nil + } + return 0, nil +} diff --git a/pkg/gce-pd-csi-driver/node_test.go b/pkg/gce-pd-csi-driver/node_test.go index ddfd50aa5..8d6449de6 100644 --- a/pkg/gce-pd-csi-driver/node_test.go +++ b/pkg/gce-pd-csi-driver/node_test.go @@ -313,6 +313,21 @@ func TestNodeGetVolumeLimits(t *testing.T) { machineType: "a4-highgpu-8g", expVolumeLimit: a4HyperdiskLimit, }, + { + name: "c3-standard-4", + machineType: "c3-standard-4", + expVolumeLimit: volumeLimitBig, + }, + { + name: "c3d-highmem-8-lssd", + machineType: "c3d-highmem-8-lssd", + expVolumeLimit: volumeLimitBig, + }, + { + name: "c4a-standard-32-lssd", + machineType: "c4a-standard-32-lssd", + expVolumeLimit: 49, + }, } for _, tc := range testCases {