diff --git a/cmd/gce-pd-csi-driver/main.go b/cmd/gce-pd-csi-driver/main.go index 06c1ab7ca..75ee28e88 100644 --- a/cmd/gce-pd-csi-driver/main.go +++ b/cmd/gce-pd-csi-driver/main.go @@ -245,11 +245,15 @@ func handle() { if err != nil { klog.Fatalf("Failed to set up metadata service: %v", err.Error()) } + isDataCacheEnabledNodePool, err := isDataCacheEnabledNodePool(ctx, *nodeName) + if err != nil { + klog.Fatalf("Failed to get node info from API server: %v", err.Error()) + } nsArgs := driver.NodeServerArgs{ EnableDeviceInUseCheck: *enableDeviceInUseCheck, DeviceInUseTimeout: *deviceInUseTimeout, EnableDataCache: *enableDataCacheFlag, - DataCacheEnabledNodePool: isDataCacheEnabledNodePool(ctx, *nodeName), + DataCacheEnabledNodePool: isDataCacheEnabledNodePool, } nodeServer = driver.NewNodeServer(gceDriver, mounter, deviceUtils, meta, statter, nsArgs) if *maxConcurrentFormatAndMount > 0 { @@ -347,14 +351,15 @@ func urlFlag(target **url.URL, name string, usage string) { }) } -func isDataCacheEnabledNodePool(ctx context.Context, nodeName string) bool { - if nodeName != common.TestNode { // disregard logic below when E2E testing. +func isDataCacheEnabledNodePool(ctx context.Context, nodeName string) (bool, error) { + if !*enableDataCacheFlag { + return false, nil + } + if len(nodeName) > 0 && nodeName != common.TestNode { // disregard logic below when E2E testing. dataCacheLSSDCount, err := driver.GetDataCacheCountFromNodeLabel(ctx, nodeName) - if err != nil || dataCacheLSSDCount == 0 { - return false - } + return dataCacheLSSDCount != 0, err } - return true + return true, nil } func fetchLssdsForRaiding(lssdCount int) ([]string, error) { diff --git a/pkg/gce-pd-csi-driver/cache.go b/pkg/gce-pd-csi-driver/cache.go index 691eb4f07..03b5dfe0f 100644 --- a/pkg/gce-pd-csi-driver/cache.go +++ b/pkg/gce-pd-csi-driver/cache.go @@ -7,10 +7,13 @@ import ( "regexp" "strconv" "strings" + "time" csi "github.com/container-storage-interface/spec/lib/go/csi" fsnotify "github.com/fsnotify/fsnotify" + v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/klog/v2" @@ -242,8 +245,6 @@ func ValidateDataCacheConfig(dataCacheMode string, dataCacheSize string, ctx con func GetDataCacheCountFromNodeLabel(ctx context.Context, nodeName string) (int, error) { cfg, err := rest.InClusterConfig() - // We want to capture API errors with node label fetching, so return -1 - // in those cases instead of 0. if err != nil { return 0, err } @@ -251,9 +252,8 @@ func GetDataCacheCountFromNodeLabel(ctx context.Context, nodeName string) (int, if err != nil { return 0, err } - node, err := kubeClient.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{}) + node, err := getNodeWithRetry(ctx, kubeClient, nodeName) if err != nil { - // We could retry, but this error will also crashloop the driver which may be as good a way to retry as any. return 0, err } if val, found := node.GetLabels()[fmt.Sprintf(common.NodeLabelPrefix, common.DataCacheLssdCountLabel)]; found { @@ -264,10 +264,33 @@ func GetDataCacheCountFromNodeLabel(ctx context.Context, nodeName string) (int, klog.V(4).Infof("Number of local SSDs requested for Data Cache: %v", dataCacheCount) return dataCacheCount, nil } - // This will be returned for a non-Data-Cache node pool return 0, nil } +func getNodeWithRetry(ctx context.Context, kubeClient *kubernetes.Clientset, nodeName string) (*v1.Node, error) { + var nodeObj *v1.Node + backoff := wait.Backoff{ + Duration: 1 * time.Second, + Factor: 2.0, + Steps: 5, + } + err := wait.ExponentialBackoffWithContext(ctx, backoff, func() (bool, error) { + node, err := kubeClient.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{}) + if err != nil { + klog.Warningf("Error getting node %s: %v, retrying...\n", nodeName, err) + return false, nil + } + nodeObj = node + klog.V(4).Infof("Successfully retrieved node info %s\n", nodeName) + return true, nil + }) + + if err != nil { + klog.Errorf("Failed to get node %s after retries: %v\n", nodeName, err) + } + return nodeObj, err +} + func FetchRaidedLssdCountForDatacache() (int, error) { raidedPath, err := fetchRAIDedLocalSsdPath() if err != nil { diff --git a/test/e2e/utils/utils.go b/test/e2e/utils/utils.go index 949c62e40..0f9f7dd86 100644 --- a/test/e2e/utils/utils.go +++ b/test/e2e/utils/utils.go @@ -73,9 +73,9 @@ func GCEClientAndDriverSetup(instance *remote.InstanceInfo, driverConfig DriverC fmt.Sprintf("--fallback-requisite-zones=%s", strings.Join(driverConfig.Zones, ",")), } + extra_flags = append(extra_flags, fmt.Sprintf("--node-name=%s", utilcommon.TestNode)) if instance.GetLocalSSD() > 0 { extra_flags = append(extra_flags, "--enable-data-cache") - extra_flags = append(extra_flags, fmt.Sprintf("--node-name=%s", utilcommon.TestNode)) } extra_flags = append(extra_flags, fmt.Sprintf("--compute-endpoint=%s", driverConfig.ComputeEndpoint)) extra_flags = append(extra_flags, driverConfig.ExtraFlags...)