Skip to content

Commit b3103f2

Browse files
authored
add support for token file and eks container endpoint in general HTTP provider
1 parent f300f13 commit b3103f2

File tree

4 files changed

+225
-8
lines changed

4 files changed

+225
-8
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"id": "0593bfc1-f008-41fe-bcb7-3c7957839f01",
3+
"type": "feature",
4+
"collapse": true,
5+
"description": "Add support for dynamic auth token from file and EKS container host in absolute/relative URIs in the HTTP credential provider.",
6+
"modules": [
7+
"config",
8+
"credentials"
9+
]
10+
}

config/resolve_credentials.go

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ package config
33
import (
44
"context"
55
"fmt"
6+
"io/ioutil"
7+
"net"
68
"net/url"
9+
"os"
710
"time"
811

912
"github.com/aws/aws-sdk-go-v2/aws"
@@ -21,11 +24,33 @@ import (
2124

2225
const (
2326
// valid credential source values
24-
credSourceEc2Metadata = "Ec2InstanceMetadata"
25-
credSourceEnvironment = "Environment"
26-
credSourceECSContainer = "EcsContainer"
27+
credSourceEc2Metadata = "Ec2InstanceMetadata"
28+
credSourceEnvironment = "Environment"
29+
credSourceECSContainer = "EcsContainer"
30+
httpProviderAuthFileEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE"
2731
)
2832

33+
// direct representation of the IPv4 address for the ECS container
34+
// "169.254.170.2"
35+
var ecsContainerIPv4 net.IP = []byte{
36+
169, 254, 170, 2,
37+
}
38+
39+
// direct representation of the IPv4 address for the EKS container
40+
// "169.254.170.23"
41+
var eksContainerIPv4 net.IP = []byte{
42+
169, 254, 170, 23,
43+
}
44+
45+
// direct representation of the IPv6 address for the EKS container
46+
// "fd00:ec2::23"
47+
var eksContainerIPv6 net.IP = []byte{
48+
0xFD, 0, 0xE, 0xC2,
49+
0, 0, 0, 0,
50+
0, 0, 0, 0,
51+
0, 0, 0, 0x23,
52+
}
53+
2954
var (
3055
ecsContainerEndpoint = "http://169.254.170.2" // not constant to allow for swapping during unit-testing
3156
)
@@ -222,6 +247,36 @@ func processCredentials(ctx context.Context, cfg *aws.Config, sharedConfig *Shar
222247
return nil
223248
}
224249

250+
// isAllowedHost allows host to be loopback or known ECS/EKS container IPs
251+
//
252+
// host can either be an IP address OR an unresolved hostname - resolution will
253+
// be automatically performed in the latter case
254+
func isAllowedHost(host string) (bool, error) {
255+
if ip := net.ParseIP(host); ip != nil {
256+
return isIPAllowed(ip), nil
257+
}
258+
259+
addrs, err := lookupHostFn(host)
260+
if err != nil {
261+
return false, err
262+
}
263+
264+
for _, addr := range addrs {
265+
if ip := net.ParseIP(addr); ip == nil || !isIPAllowed(ip) {
266+
return false, nil
267+
}
268+
}
269+
270+
return true, nil
271+
}
272+
273+
func isIPAllowed(ip net.IP) bool {
274+
return ip.IsLoopback() ||
275+
ip.Equal(ecsContainerIPv4) ||
276+
ip.Equal(eksContainerIPv4) ||
277+
ip.Equal(eksContainerIPv6)
278+
}
279+
225280
func resolveLocalHTTPCredProvider(ctx context.Context, cfg *aws.Config, endpointURL, authToken string, configs configs) error {
226281
var resolveErr error
227282

@@ -232,10 +287,12 @@ func resolveLocalHTTPCredProvider(ctx context.Context, cfg *aws.Config, endpoint
232287
host := parsed.Hostname()
233288
if len(host) == 0 {
234289
resolveErr = fmt.Errorf("unable to parse host from local HTTP cred provider URL")
235-
} else if isLoopback, loopbackErr := isLoopbackHost(host); loopbackErr != nil {
236-
resolveErr = fmt.Errorf("failed to resolve host %q, %v", host, loopbackErr)
237-
} else if !isLoopback {
238-
resolveErr = fmt.Errorf("invalid endpoint host, %q, only loopback hosts are allowed", host)
290+
} else if parsed.Scheme == "http" {
291+
if isAllowedHost, allowHostErr := isAllowedHost(host); allowHostErr != nil {
292+
resolveErr = fmt.Errorf("failed to resolve host %q, %v", host, allowHostErr)
293+
} else if !isAllowedHost {
294+
resolveErr = fmt.Errorf("invalid endpoint host, %q, only loopback/ecs/eks hosts are allowed", host)
295+
}
239296
}
240297
}
241298

@@ -252,6 +309,16 @@ func resolveHTTPCredProvider(ctx context.Context, cfg *aws.Config, url, authToke
252309
if len(authToken) != 0 {
253310
options.AuthorizationToken = authToken
254311
}
312+
if authFilePath := os.Getenv(httpProviderAuthFileEnvVar); authFilePath != "" {
313+
options.AuthorizationTokenProvider = endpointcreds.TokenProviderFunc(func() (string, error) {
314+
var contents []byte
315+
var err error
316+
if contents, err = ioutil.ReadFile(authFilePath); err != nil {
317+
return "", fmt.Errorf("failed to read authorization token from %v: %v", authFilePath, err)
318+
}
319+
return string(contents), nil
320+
})
321+
}
255322
options.APIOptions = cfg.APIOptions
256323
if cfg.Retryer != nil {
257324
options.Retryer = cfg.Retryer()

credentials/endpointcreds/provider.go

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
"context"
3737
"fmt"
3838
"net/http"
39+
"strings"
3940

4041
"github.com/aws/aws-sdk-go-v2/aws"
4142
"github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client"
@@ -81,7 +82,37 @@ type Options struct {
8182

8283
// Optional authorization token value if set will be used as the value of
8384
// the Authorization header of the endpoint credential request.
85+
//
86+
// When constructed from environment, the provider will use the value of
87+
// AWS_CONTAINER_AUTHORIZATION_TOKEN environment variable as the token
88+
//
89+
// Will be overridden if AuthorizationTokenProvider is configured
8490
AuthorizationToken string
91+
92+
// Optional auth provider func to dynamically load the auth token from a file
93+
// everytime a credential is retrieved
94+
//
95+
// When constructed from environment, the provider will read and use the content
96+
// of the file pointed to by AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE environment variable
97+
// as the auth token everytime credentials are retrieved
98+
//
99+
// Will override AuthorizationToken if configured
100+
AuthorizationTokenProvider AuthTokenProvider
101+
}
102+
103+
// AuthTokenProvider defines an interface to dynamically load a value to be passed
104+
// for the Authorization header of a credentials request.
105+
type AuthTokenProvider interface {
106+
GetToken() (string, error)
107+
}
108+
109+
// TokenProviderFunc is a func type implementing AuthTokenProvider interface
110+
// and enables customizing token provider behavior
111+
type TokenProviderFunc func() (string, error)
112+
113+
// GetToken func retrieves auth token according to TokenProviderFunc implementation
114+
func (p TokenProviderFunc) GetToken() (string, error) {
115+
return p()
85116
}
86117

87118
// New returns a credentials Provider for retrieving AWS credentials
@@ -132,5 +163,30 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
132163
}
133164

134165
func (p *Provider) getCredentials(ctx context.Context) (*client.GetCredentialsOutput, error) {
135-
return p.client.GetCredentials(ctx, &client.GetCredentialsInput{AuthorizationToken: p.options.AuthorizationToken})
166+
authToken, err := p.resolveAuthToken()
167+
if err != nil {
168+
return nil, fmt.Errorf("resolve auth token: %v", err)
169+
}
170+
171+
return p.client.GetCredentials(ctx, &client.GetCredentialsInput{
172+
AuthorizationToken: authToken,
173+
})
174+
}
175+
176+
func (p *Provider) resolveAuthToken() (string, error) {
177+
authToken := p.options.AuthorizationToken
178+
179+
var err error
180+
if p.options.AuthorizationTokenProvider != nil {
181+
authToken, err = p.options.AuthorizationTokenProvider.GetToken()
182+
if err != nil {
183+
return "", err
184+
}
185+
}
186+
187+
if strings.ContainsAny(authToken, "\r\n") {
188+
return "", fmt.Errorf("authorization token contains invalid newline sequence")
189+
}
190+
191+
return authToken, nil
136192
}

credentials/endpointcreds/provider_test.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,90 @@ func TestRetrieveStaticCredentials(t *testing.T) {
108108
}
109109
}
110110

111+
func TestAuthTokenProvider(t *testing.T) {
112+
cases := map[string]struct {
113+
AuthToken string
114+
AuthTokenProvider endpointcreds.AuthTokenProvider
115+
ExpectAuthToken string
116+
ExpectError bool
117+
}{
118+
"AuthToken": {
119+
AuthToken: "Basic abc123",
120+
ExpectAuthToken: "Basic abc123",
121+
},
122+
"AuthFileToken": {
123+
AuthToken: "Basic abc123",
124+
AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) {
125+
return "Hello %20world", nil
126+
}),
127+
ExpectAuthToken: "Hello %20world",
128+
},
129+
"RetrieveFileTokenError": {
130+
AuthToken: "Basic abc123",
131+
AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) {
132+
return "", fmt.Errorf("test error")
133+
}),
134+
ExpectAuthToken: "Hello %20world",
135+
ExpectError: true,
136+
},
137+
}
138+
139+
for name, c := range cases {
140+
t.Run(name, func(t *testing.T) {
141+
orig := sdk.NowTime
142+
defer func() { sdk.NowTime = orig }()
143+
144+
var actualToken string
145+
p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
146+
o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) {
147+
actualToken = r.Header["Authorization"][0]
148+
return &http.Response{
149+
StatusCode: 200,
150+
Body: ioutil.NopCloser(bytes.NewReader([]byte(`{
151+
"AccessKeyID": "AKID",
152+
"SecretAccessKey": "SECRET"
153+
}`))),
154+
}, nil
155+
})
156+
o.AuthorizationToken = c.AuthToken
157+
o.AuthorizationTokenProvider = c.AuthTokenProvider
158+
})
159+
creds, err := p.Retrieve(context.Background())
160+
161+
if err != nil && !c.ExpectError {
162+
t.Errorf("expect no error, got %v", err)
163+
} else if err == nil && c.ExpectError {
164+
t.Errorf("expect error, got nil")
165+
}
166+
167+
if c.ExpectError {
168+
return
169+
}
170+
171+
if e, a := "AKID", creds.AccessKeyID; e != a {
172+
t.Errorf("expect %v, got %v", e, a)
173+
}
174+
if e, a := "SECRET", creds.SecretAccessKey; e != a {
175+
t.Errorf("expect %v, got %v", e, a)
176+
}
177+
if v := creds.SessionToken; len(v) != 0 {
178+
t.Errorf("expect empty, got %v", v)
179+
}
180+
if e, a := c.ExpectAuthToken, actualToken; e != a {
181+
t.Errorf("Expect %v, got %v", e, a)
182+
}
183+
184+
sdk.NowTime = func() time.Time {
185+
return time.Date(3000, 12, 16, 1, 30, 37, 0, time.UTC)
186+
}
187+
188+
if creds.Expired() {
189+
t.Errorf("expect not to be expired")
190+
}
191+
})
192+
}
193+
}
194+
111195
func TestFailedRetrieveCredentials(t *testing.T) {
112196
p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
113197
o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) {

0 commit comments

Comments
 (0)