Skip to content

Commit 04e403a

Browse files
authored
Merge pull request #465 from davidz627/fix/context
Thread contexts into all appropriate functions in non-test code
2 parents e26e538 + 814def5 commit 04e403a

File tree

3 files changed

+23
-19
lines changed

3 files changed

+23
-19
lines changed

cmd/main.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515
package main
1616

1717
import (
18+
"context"
1819
"flag"
1920
"math/rand"
2021
"os"
@@ -63,8 +64,11 @@ func handle() {
6364

6465
gceDriver := driver.GetGCEDriver()
6566

66-
//Initialize GCE Driver (Move setup to main?)
67-
cloudProvider, err := gce.CreateCloudProvider(vendorVersion, *gceConfigFilePath)
67+
//Initialize GCE Driver
68+
ctx, cancel := context.WithCancel(context.Background())
69+
defer cancel()
70+
71+
cloudProvider, err := gce.CreateCloudProvider(ctx, vendorVersion, *gceConfigFilePath)
6872
if err != nil {
6973
klog.Fatalf("Failed to get cloud provider: %v", err)
7074
}

pkg/gce-cloud-provider/compute/gce.go

+11-11
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ type ConfigGlobal struct {
6868
ProjectId string `gcfg:"project-id"`
6969
}
7070

71-
func CreateCloudProvider(vendorVersion string, configPath string) (*CloudProvider, error) {
71+
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string) (*CloudProvider, error) {
7272
configFile, err := readConfig(configPath)
7373
if err != nil {
7474
return nil, err
@@ -78,12 +78,12 @@ func CreateCloudProvider(vendorVersion string, configPath string) (*CloudProvide
7878

7979
klog.V(2).Infof("Using GCE provider config %+v", configFile)
8080

81-
tokenSource, err := generateTokenSource(configFile)
81+
tokenSource, err := generateTokenSource(ctx, configFile)
8282
if err != nil {
8383
return nil, err
8484
}
8585

86-
svc, err := createCloudService(vendorVersion, tokenSource)
86+
svc, err := createCloudService(ctx, vendorVersion, tokenSource)
8787
if err != nil {
8888
return nil, err
8989
}
@@ -102,7 +102,7 @@ func CreateCloudProvider(vendorVersion string, configPath string) (*CloudProvide
102102

103103
}
104104

105-
func generateTokenSource(configFile *ConfigFile) (oauth2.TokenSource, error) {
105+
func generateTokenSource(ctx context.Context, configFile *ConfigFile) (oauth2.TokenSource, error) {
106106

107107
if configFile != nil && configFile.Global.TokenURL != "" && configFile.Global.TokenURL != "nil" {
108108
// configFile.Global.TokenURL is defined
@@ -116,7 +116,7 @@ func generateTokenSource(configFile *ConfigFile) (oauth2.TokenSource, error) {
116116
// Use DefaultTokenSource
117117

118118
tokenSource, err := google.DefaultTokenSource(
119-
context.Background(),
119+
ctx,
120120
compute.CloudPlatformScope,
121121
compute.ComputeScope)
122122

@@ -149,13 +149,13 @@ func readConfig(configPath string) (*ConfigFile, error) {
149149
return cfg, nil
150150
}
151151

152-
func createCloudService(vendorVersion string, tokenSource oauth2.TokenSource) (*compute.Service, error) {
153-
svc, err := createCloudServiceWithDefaultServiceAccount(vendorVersion, tokenSource)
152+
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource) (*compute.Service, error) {
153+
svc, err := createCloudServiceWithDefaultServiceAccount(ctx, vendorVersion, tokenSource)
154154
return svc, err
155155
}
156156

157-
func createCloudServiceWithDefaultServiceAccount(vendorVersion string, tokenSource oauth2.TokenSource) (*compute.Service, error) {
158-
client, err := newOauthClient(tokenSource)
157+
func createCloudServiceWithDefaultServiceAccount(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource) (*compute.Service, error) {
158+
client, err := newOauthClient(ctx, tokenSource)
159159
if err != nil {
160160
return nil, err
161161
}
@@ -167,7 +167,7 @@ func createCloudServiceWithDefaultServiceAccount(vendorVersion string, tokenSour
167167
return service, nil
168168
}
169169

170-
func newOauthClient(tokenSource oauth2.TokenSource) (*http.Client, error) {
170+
func newOauthClient(ctx context.Context, tokenSource oauth2.TokenSource) (*http.Client, error) {
171171
if err := wait.PollImmediate(5*time.Second, 30*time.Second, func() (bool, error) {
172172
if _, err := tokenSource.Token(); err != nil {
173173
klog.Errorf("error fetching initial token: %v", err)
@@ -178,7 +178,7 @@ func newOauthClient(tokenSource oauth2.TokenSource) (*http.Client, error) {
178178
return nil, err
179179
}
180180

181-
return oauth2.NewClient(context.Background(), tokenSource), nil
181+
return oauth2.NewClient(ctx, tokenSource), nil
182182
}
183183

184184
func getProjectAndZone(config *ConfigFile) (string, string, error) {

pkg/gce-pd-csi-driver/controller.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ func (gceCS *GCEControllerServer) CreateVolume(ctx context.Context, req *csi.Cre
118118
var volKey *meta.Key
119119
switch replicationType {
120120
case replicationTypeNone:
121-
zones, err = pickZones(gceCS, req.GetAccessibilityRequirements(), 1)
121+
zones, err = pickZones(ctx, gceCS, req.GetAccessibilityRequirements(), 1)
122122
if err != nil {
123123
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("CreateVolume failed to pick zones for disk: %v", err))
124124
}
@@ -128,7 +128,7 @@ func (gceCS *GCEControllerServer) CreateVolume(ctx context.Context, req *csi.Cre
128128
volKey = meta.ZonalKey(name, zones[0])
129129

130130
case replicationTypeRegionalPD:
131-
zones, err = pickZones(gceCS, req.GetAccessibilityRequirements(), 2)
131+
zones, err = pickZones(ctx, gceCS, req.GetAccessibilityRequirements(), 2)
132132
if err != nil {
133133
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("CreateVolume failed to pick zones for disk: %v", err))
134134
}
@@ -941,7 +941,7 @@ func getZoneFromSegment(seg map[string]string) (string, error) {
941941
return zone, nil
942942
}
943943

944-
func pickZones(gceCS *GCEControllerServer, top *csi.TopologyRequirement, numZones int) ([]string, error) {
944+
func pickZones(ctx context.Context, gceCS *GCEControllerServer, top *csi.TopologyRequirement, numZones int) ([]string, error) {
945945
var zones []string
946946
var err error
947947
if top != nil {
@@ -950,7 +950,7 @@ func pickZones(gceCS *GCEControllerServer, top *csi.TopologyRequirement, numZone
950950
return nil, fmt.Errorf("failed to pick zones from topology: %v", err)
951951
}
952952
} else {
953-
zones, err = getDefaultZonesInRegion(gceCS, []string{gceCS.MetadataService.GetZone()}, numZones)
953+
zones, err = getDefaultZonesInRegion(ctx, gceCS, []string{gceCS.MetadataService.GetZone()}, numZones)
954954
if err != nil {
955955
return nil, fmt.Errorf("failed to get default %v zones in region: %v", numZones, err)
956956
}
@@ -960,13 +960,13 @@ func pickZones(gceCS *GCEControllerServer, top *csi.TopologyRequirement, numZone
960960
return zones, nil
961961
}
962962

963-
func getDefaultZonesInRegion(gceCS *GCEControllerServer, existingZones []string, numZones int) ([]string, error) {
963+
func getDefaultZonesInRegion(ctx context.Context, gceCS *GCEControllerServer, existingZones []string, numZones int) ([]string, error) {
964964
region, err := common.GetRegionFromZones(existingZones)
965965
if err != nil {
966966
return nil, fmt.Errorf("failed to get region from zones: %v", err)
967967
}
968968
needToGet := numZones - len(existingZones)
969-
totZones, err := gceCS.CloudProvider.ListZones(context.Background(), region)
969+
totZones, err := gceCS.CloudProvider.ListZones(ctx, region)
970970
if err != nil {
971971
return nil, fmt.Errorf("failed to list zones from cloud provider: %v", err)
972972
}

0 commit comments

Comments
 (0)