Skip to content

Commit b3c7fbf

Browse files
authored
update express cache key (#2414)
1 parent 9b90af4 commit b3c7fbf

File tree

6 files changed

+248
-30
lines changed

6 files changed

+248
-30
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"id": "8e6a0119-7da8-48c8-8aaa-b5adb296abc1",
3+
"type": "bugfix",
4+
"description": "Improve uniqueness of default S3Express sesssion credentials cache keying to prevent collision in multi-credential scenarios.",
5+
"modules": [
6+
"service/s3"
7+
]
8+
}

codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/auth/S3ExpressAuthScheme.java

+10
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate;
4545
import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;
46+
import static software.amazon.smithy.go.codegen.SymbolUtils.buildPackageSymbol;
4647

4748
public class S3ExpressAuthScheme implements GoIntegration {
4849
private static final ConfigField s3ExpressCredentials =
@@ -67,6 +68,14 @@ public class S3ExpressAuthScheme implements GoIntegration {
6768
.withClientInput(true)
6869
.build();
6970

71+
private static final ConfigFieldResolver s3ExpressCredentialsOperationFinalizer =
72+
ConfigFieldResolver.builder()
73+
.location(ConfigFieldResolver.Location.OPERATION)
74+
.target(ConfigFieldResolver.Target.FINALIZATION)
75+
.resolver(buildPackageSymbol("finalizeOperationExpressCredentials"))
76+
.withClientInput(true)
77+
.build();
78+
7079
@Override
7180
public void writeAdditionalFiles(
7281
GoSettings settings, Model model, SymbolProvider symbolProvider, GoDelegator goDelegator
@@ -84,6 +93,7 @@ public List<RuntimeClientPlugin> getClientPlugins() {
8493
.addConfigField(s3ExpressCredentials)
8594
.addConfigFieldResolver(s3ExpressCredentialsResolver)
8695
.addConfigFieldResolver(s3ExpressCredentialsClientFinalizer)
96+
.addConfigFieldResolver(s3ExpressCredentialsOperationFinalizer)
8797
.addAuthSchemeDefinition(SigV4S3ExpressTrait.ID, new SigV4S3Express())
8898
.build()
8999
);

service/s3/api_client.go

+2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

service/s3/express_default.go

+75-28
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ package s3
22

33
import (
44
"context"
5+
"crypto/hmac"
6+
"crypto/sha256"
57
"errors"
8+
"fmt"
69
"sync"
710
"time"
811

@@ -17,18 +20,49 @@ const s3ExpressCacheCap = 100
1720

1821
const s3ExpressRefreshWindow = 1 * time.Minute
1922

23+
type cacheKey struct {
24+
CredentialsHash string // hmac(sigv4 akid, sigv4 secret)
25+
Bucket string
26+
}
27+
28+
func (c cacheKey) Slug() string {
29+
return fmt.Sprintf("%s%s", c.CredentialsHash, c.Bucket)
30+
}
31+
32+
type sessionCredsCache struct {
33+
mu sync.Mutex
34+
cache cache.Cache
35+
}
36+
37+
func (c *sessionCredsCache) Get(key cacheKey) (*aws.Credentials, bool) {
38+
c.mu.Lock()
39+
defer c.mu.Unlock()
40+
41+
if v, ok := c.cache.Get(key); ok {
42+
return v.(*aws.Credentials), true
43+
}
44+
return nil, false
45+
}
46+
47+
func (c *sessionCredsCache) Put(key cacheKey, creds *aws.Credentials) {
48+
c.mu.Lock()
49+
defer c.mu.Unlock()
50+
51+
c.cache.Put(key, creds)
52+
}
53+
2054
// The default S3Express provider uses an LRU cache with a capacity of 100.
2155
//
2256
// Credentials will be refreshed asynchronously when a Retrieve() call is made
2357
// for cached credentials within an expiry window (1 minute, currently
2458
// non-configurable).
2559
type defaultS3ExpressCredentialsProvider struct {
26-
mu sync.Mutex
2760
sf singleflight.Group
2861

2962
client createSessionAPIClient
30-
credsCache cache.Cache
63+
cache *sessionCredsCache
3164
refreshWindow time.Duration
65+
v4creds aws.CredentialsProvider // underlying credentials used for CreateSession
3266
}
3367

3468
type createSessionAPIClient interface {
@@ -37,35 +71,54 @@ type createSessionAPIClient interface {
3771

3872
func newDefaultS3ExpressCredentialsProvider() *defaultS3ExpressCredentialsProvider {
3973
return &defaultS3ExpressCredentialsProvider{
40-
credsCache: lru.New(s3ExpressCacheCap),
74+
cache: &sessionCredsCache{
75+
cache: lru.New(s3ExpressCacheCap),
76+
},
4177
refreshWindow: s3ExpressRefreshWindow,
4278
}
4379
}
4480

81+
// returns a cloned provider using new base credentials, used when per-op
82+
// config mutations change the credentials provider
83+
func (p *defaultS3ExpressCredentialsProvider) CloneWithBaseCredentials(v4creds aws.CredentialsProvider) *defaultS3ExpressCredentialsProvider {
84+
return &defaultS3ExpressCredentialsProvider{
85+
client: p.client,
86+
cache: p.cache,
87+
refreshWindow: p.refreshWindow,
88+
v4creds: v4creds,
89+
}
90+
}
91+
4592
func (p *defaultS3ExpressCredentialsProvider) Retrieve(ctx context.Context, bucket string) (aws.Credentials, error) {
46-
p.mu.Lock()
47-
defer p.mu.Unlock()
93+
v4creds, err := p.v4creds.Retrieve(ctx)
94+
if err != nil {
95+
return aws.Credentials{}, fmt.Errorf("get sigv4 creds: %w", err)
96+
}
4897

49-
creds, ok := p.getCacheCredentials(bucket)
98+
key := cacheKey{
99+
CredentialsHash: gethmac(v4creds.AccessKeyID, v4creds.SecretAccessKey),
100+
Bucket: bucket,
101+
}
102+
creds, ok := p.cache.Get(key)
50103
if !ok || creds.Expired() {
51-
return p.awaitDoChanRetrieve(ctx, bucket)
104+
return p.awaitDoChanRetrieve(ctx, key)
52105
}
53106

54107
if creds.Expires.Sub(sdk.NowTime()) <= p.refreshWindow {
55-
p.doChanRetrieve(ctx, bucket)
108+
p.doChanRetrieve(ctx, key)
56109
}
57110

58111
return *creds, nil
59112
}
60113

61-
func (p *defaultS3ExpressCredentialsProvider) doChanRetrieve(ctx context.Context, bucket string) <-chan singleflight.Result {
62-
return p.sf.DoChan(bucket, func() (interface{}, error) {
63-
return p.retrieve(ctx, bucket)
114+
func (p *defaultS3ExpressCredentialsProvider) doChanRetrieve(ctx context.Context, key cacheKey) <-chan singleflight.Result {
115+
return p.sf.DoChan(key.Slug(), func() (interface{}, error) {
116+
return p.retrieve(ctx, key)
64117
})
65118
}
66119

67-
func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Context, bucket string) (aws.Credentials, error) {
68-
ch := p.doChanRetrieve(ctx, bucket)
120+
func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) {
121+
ch := p.doChanRetrieve(ctx, key)
69122

70123
select {
71124
case r := <-ch:
@@ -75,9 +128,9 @@ func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Co
75128
}
76129
}
77130

78-
func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, bucket string) (aws.Credentials, error) {
131+
func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) {
79132
resp, err := p.client.CreateSession(ctx, &CreateSessionInput{
80-
Bucket: aws.String(bucket),
133+
Bucket: aws.String(key.Bucket),
81134
})
82135
if err != nil {
83136
return aws.Credentials{}, err
@@ -88,22 +141,10 @@ func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, buck
88141
return aws.Credentials{}, err
89142
}
90143

91-
p.putCacheCredentials(bucket, creds)
144+
p.cache.Put(key, creds)
92145
return *creds, nil
93146
}
94147

95-
func (p *defaultS3ExpressCredentialsProvider) getCacheCredentials(bucket string) (*aws.Credentials, bool) {
96-
if v, ok := p.credsCache.Get(bucket); ok {
97-
return v.(*aws.Credentials), true
98-
}
99-
100-
return nil, false
101-
}
102-
103-
func (p *defaultS3ExpressCredentialsProvider) putCacheCredentials(bucket string, creds *aws.Credentials) {
104-
p.credsCache.Put(bucket, creds)
105-
}
106-
107148
func credentialsFromResponse(o *CreateSessionOutput) (*aws.Credentials, error) {
108149
if o.Credentials == nil {
109150
return nil, errors.New("s3express session credentials unset")
@@ -121,3 +162,9 @@ func credentialsFromResponse(o *CreateSessionOutput) (*aws.Credentials, error) {
121162
Expires: *o.Credentials.Expiration,
122163
}, nil
123164
}
165+
166+
func gethmac(p, key string) string {
167+
hash := hmac.New(sha256.New, []byte(key))
168+
hash.Write([]byte(p))
169+
return string(hash.Sum(nil))
170+
}

service/s3/express_resolve.go

+17-2
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,26 @@ func resolveExpressCredentials(o *Options) {
1313
}
1414
}
1515

16-
// Config finalizer: if we're using the default S3Express implementation,
17-
// grab a reference to the client for its CreateSession API.
16+
// Config finalizer: if we're using the default S3Express implementation, grab
17+
// a reference to the client for its CreateSession API, and the underlying
18+
// sigv4 credentials provider for cache keying.
1819
func finalizeExpressCredentials(o *Options, c *Client) {
1920
if p, ok := o.ExpressCredentials.(*defaultS3ExpressCredentialsProvider); ok {
2021
p.client = c
22+
p.v4creds = o.Credentials
23+
}
24+
}
25+
26+
// Operation config finalizer: update the sigv4 credentials on the default
27+
// express provider if it changed to ensure different cache keys
28+
func finalizeOperationExpressCredentials(o *Options, c Client) {
29+
p, ok := o.ExpressCredentials.(*defaultS3ExpressCredentialsProvider)
30+
if !ok {
31+
return
32+
}
33+
34+
if c.options.Credentials != o.Credentials {
35+
o.ExpressCredentials = p.CloneWithBaseCredentials(o.Credentials)
2136
}
2237
}
2338

0 commit comments

Comments
 (0)