@@ -2,7 +2,10 @@ package s3
2
2
3
3
import (
4
4
"context"
5
+ "crypto/hmac"
6
+ "crypto/sha256"
5
7
"errors"
8
+ "fmt"
6
9
"sync"
7
10
"time"
8
11
@@ -17,18 +20,49 @@ const s3ExpressCacheCap = 100
17
20
18
21
const s3ExpressRefreshWindow = 1 * time .Minute
19
22
23
+ type cacheKey struct {
24
+ CredentialsHash string // hmac(sigv4 akid, sigv4 secret)
25
+ Bucket string
26
+ }
27
+
28
+ func (c cacheKey ) Slug () string {
29
+ return fmt .Sprintf ("%s%s" , c .CredentialsHash , c .Bucket )
30
+ }
31
+
32
+ type sessionCredsCache struct {
33
+ mu sync.Mutex
34
+ cache cache.Cache
35
+ }
36
+
37
+ func (c * sessionCredsCache ) Get (key cacheKey ) (* aws.Credentials , bool ) {
38
+ c .mu .Lock ()
39
+ defer c .mu .Unlock ()
40
+
41
+ if v , ok := c .cache .Get (key ); ok {
42
+ return v .(* aws.Credentials ), true
43
+ }
44
+ return nil , false
45
+ }
46
+
47
+ func (c * sessionCredsCache ) Put (key cacheKey , creds * aws.Credentials ) {
48
+ c .mu .Lock ()
49
+ defer c .mu .Unlock ()
50
+
51
+ c .cache .Put (key , creds )
52
+ }
53
+
20
54
// The default S3Express provider uses an LRU cache with a capacity of 100.
21
55
//
22
56
// Credentials will be refreshed asynchronously when a Retrieve() call is made
23
57
// for cached credentials within an expiry window (1 minute, currently
24
58
// non-configurable).
25
59
type defaultS3ExpressCredentialsProvider struct {
26
- mu sync.Mutex
27
60
sf singleflight.Group
28
61
29
62
client createSessionAPIClient
30
- credsCache cache. Cache
63
+ cache * sessionCredsCache
31
64
refreshWindow time.Duration
65
+ v4creds aws.CredentialsProvider // underlying credentials used for CreateSession
32
66
}
33
67
34
68
type createSessionAPIClient interface {
@@ -37,35 +71,54 @@ type createSessionAPIClient interface {
37
71
38
72
func newDefaultS3ExpressCredentialsProvider () * defaultS3ExpressCredentialsProvider {
39
73
return & defaultS3ExpressCredentialsProvider {
40
- credsCache : lru .New (s3ExpressCacheCap ),
74
+ cache : & sessionCredsCache {
75
+ cache : lru .New (s3ExpressCacheCap ),
76
+ },
41
77
refreshWindow : s3ExpressRefreshWindow ,
42
78
}
43
79
}
44
80
81
+ // returns a cloned provider using new base credentials, used when per-op
82
+ // config mutations change the credentials provider
83
+ func (p * defaultS3ExpressCredentialsProvider ) CloneWithBaseCredentials (v4creds aws.CredentialsProvider ) * defaultS3ExpressCredentialsProvider {
84
+ return & defaultS3ExpressCredentialsProvider {
85
+ client : p .client ,
86
+ cache : p .cache ,
87
+ refreshWindow : p .refreshWindow ,
88
+ v4creds : v4creds ,
89
+ }
90
+ }
91
+
45
92
func (p * defaultS3ExpressCredentialsProvider ) Retrieve (ctx context.Context , bucket string ) (aws.Credentials , error ) {
46
- p .mu .Lock ()
47
- defer p .mu .Unlock ()
93
+ v4creds , err := p .v4creds .Retrieve (ctx )
94
+ if err != nil {
95
+ return aws.Credentials {}, fmt .Errorf ("get sigv4 creds: %w" , err )
96
+ }
48
97
49
- creds , ok := p .getCacheCredentials (bucket )
98
+ key := cacheKey {
99
+ CredentialsHash : gethmac (v4creds .AccessKeyID , v4creds .SecretAccessKey ),
100
+ Bucket : bucket ,
101
+ }
102
+ creds , ok := p .cache .Get (key )
50
103
if ! ok || creds .Expired () {
51
- return p .awaitDoChanRetrieve (ctx , bucket )
104
+ return p .awaitDoChanRetrieve (ctx , key )
52
105
}
53
106
54
107
if creds .Expires .Sub (sdk .NowTime ()) <= p .refreshWindow {
55
- p .doChanRetrieve (ctx , bucket )
108
+ p .doChanRetrieve (ctx , key )
56
109
}
57
110
58
111
return * creds , nil
59
112
}
60
113
61
- func (p * defaultS3ExpressCredentialsProvider ) doChanRetrieve (ctx context.Context , bucket string ) <- chan singleflight.Result {
62
- return p .sf .DoChan (bucket , func () (interface {}, error ) {
63
- return p .retrieve (ctx , bucket )
114
+ func (p * defaultS3ExpressCredentialsProvider ) doChanRetrieve (ctx context.Context , key cacheKey ) <- chan singleflight.Result {
115
+ return p .sf .DoChan (key . Slug () , func () (interface {}, error ) {
116
+ return p .retrieve (ctx , key )
64
117
})
65
118
}
66
119
67
- func (p * defaultS3ExpressCredentialsProvider ) awaitDoChanRetrieve (ctx context.Context , bucket string ) (aws.Credentials , error ) {
68
- ch := p .doChanRetrieve (ctx , bucket )
120
+ func (p * defaultS3ExpressCredentialsProvider ) awaitDoChanRetrieve (ctx context.Context , key cacheKey ) (aws.Credentials , error ) {
121
+ ch := p .doChanRetrieve (ctx , key )
69
122
70
123
select {
71
124
case r := <- ch :
@@ -75,9 +128,9 @@ func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Co
75
128
}
76
129
}
77
130
78
- func (p * defaultS3ExpressCredentialsProvider ) retrieve (ctx context.Context , bucket string ) (aws.Credentials , error ) {
131
+ func (p * defaultS3ExpressCredentialsProvider ) retrieve (ctx context.Context , key cacheKey ) (aws.Credentials , error ) {
79
132
resp , err := p .client .CreateSession (ctx , & CreateSessionInput {
80
- Bucket : aws .String (bucket ),
133
+ Bucket : aws .String (key . Bucket ),
81
134
})
82
135
if err != nil {
83
136
return aws.Credentials {}, err
@@ -88,22 +141,10 @@ func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, buck
88
141
return aws.Credentials {}, err
89
142
}
90
143
91
- p .putCacheCredentials ( bucket , creds )
144
+ p .cache . Put ( key , creds )
92
145
return * creds , nil
93
146
}
94
147
95
- func (p * defaultS3ExpressCredentialsProvider ) getCacheCredentials (bucket string ) (* aws.Credentials , bool ) {
96
- if v , ok := p .credsCache .Get (bucket ); ok {
97
- return v .(* aws.Credentials ), true
98
- }
99
-
100
- return nil , false
101
- }
102
-
103
- func (p * defaultS3ExpressCredentialsProvider ) putCacheCredentials (bucket string , creds * aws.Credentials ) {
104
- p .credsCache .Put (bucket , creds )
105
- }
106
-
107
148
func credentialsFromResponse (o * CreateSessionOutput ) (* aws.Credentials , error ) {
108
149
if o .Credentials == nil {
109
150
return nil , errors .New ("s3express session credentials unset" )
@@ -121,3 +162,9 @@ func credentialsFromResponse(o *CreateSessionOutput) (*aws.Credentials, error) {
121
162
Expires : * o .Credentials .Expiration ,
122
163
}, nil
123
164
}
165
+
166
+ func gethmac (p , key string ) string {
167
+ hash := hmac .New (sha256 .New , []byte (key ))
168
+ hash .Write ([]byte (p ))
169
+ return string (hash .Sum (nil ))
170
+ }
0 commit comments