Skip to content

Commit a79e693

Browse files
authored
feat(auth): add idtoken package (#8580)
1 parent 5feb3ea commit a79e693

12 files changed

+1543
-58
lines changed

auth/idtoken/cache.go

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package idtoken
16+
17+
import (
18+
"context"
19+
"encoding/json"
20+
"fmt"
21+
"net/http"
22+
"strconv"
23+
"strings"
24+
"sync"
25+
"time"
26+
)
27+
28+
type cachingClient struct {
29+
client *http.Client
30+
31+
// clock optionally specifies a func to return the current time.
32+
// If nil, time.Now is used.
33+
clock func() time.Time
34+
35+
mu sync.Mutex
36+
certs map[string]*cachedResponse
37+
}
38+
39+
func newCachingClient(client *http.Client) *cachingClient {
40+
return &cachingClient{
41+
client: client,
42+
certs: make(map[string]*cachedResponse, 2),
43+
}
44+
}
45+
46+
type cachedResponse struct {
47+
resp *certResponse
48+
exp time.Time
49+
}
50+
51+
func (c *cachingClient) getCert(ctx context.Context, url string) (*certResponse, error) {
52+
if response, ok := c.get(url); ok {
53+
return response, nil
54+
}
55+
req, err := http.NewRequest(http.MethodGet, url, nil)
56+
if err != nil {
57+
return nil, err
58+
}
59+
req = req.WithContext(ctx)
60+
resp, err := c.client.Do(req)
61+
if err != nil {
62+
return nil, err
63+
}
64+
defer resp.Body.Close()
65+
if resp.StatusCode != http.StatusOK {
66+
return nil, fmt.Errorf("idtoken: unable to retrieve cert, got status code %d", resp.StatusCode)
67+
}
68+
69+
certResp := &certResponse{}
70+
if err := json.NewDecoder(resp.Body).Decode(certResp); err != nil {
71+
return nil, err
72+
73+
}
74+
c.set(url, certResp, resp.Header)
75+
return certResp, nil
76+
}
77+
78+
func (c *cachingClient) now() time.Time {
79+
if c.clock != nil {
80+
return c.clock()
81+
}
82+
return time.Now()
83+
}
84+
85+
func (c *cachingClient) get(url string) (*certResponse, bool) {
86+
c.mu.Lock()
87+
defer c.mu.Unlock()
88+
cachedResp, ok := c.certs[url]
89+
if !ok {
90+
return nil, false
91+
}
92+
if c.now().After(cachedResp.exp) {
93+
return nil, false
94+
}
95+
return cachedResp.resp, true
96+
}
97+
98+
func (c *cachingClient) set(url string, resp *certResponse, headers http.Header) {
99+
exp := c.calculateExpireTime(headers)
100+
c.mu.Lock()
101+
c.certs[url] = &cachedResponse{resp: resp, exp: exp}
102+
c.mu.Unlock()
103+
}
104+
105+
// calculateExpireTime will determine the expire time for the cache based on
106+
// HTTP headers. If there is any difficulty reading the headers the fallback is
107+
// to set the cache to expire now.
108+
func (c *cachingClient) calculateExpireTime(headers http.Header) time.Time {
109+
var maxAge int
110+
cc := strings.Split(headers.Get("cache-control"), ",")
111+
for _, v := range cc {
112+
if strings.Contains(v, "max-age") {
113+
ss := strings.Split(v, "=")
114+
if len(ss) < 2 {
115+
return c.now()
116+
}
117+
ma, err := strconv.Atoi(ss[1])
118+
if err != nil {
119+
return c.now()
120+
}
121+
maxAge = ma
122+
}
123+
}
124+
a := headers.Get("age")
125+
if a == "" {
126+
return c.now().Add(time.Duration(maxAge) * time.Second)
127+
}
128+
age, err := strconv.Atoi(a)
129+
if err != nil {
130+
return c.now()
131+
}
132+
return c.now().Add(time.Duration(maxAge-age) * time.Second)
133+
}

auth/idtoken/cache_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package idtoken
16+
17+
import (
18+
"net/http"
19+
"sync"
20+
"testing"
21+
"time"
22+
)
23+
24+
type fakeClock struct {
25+
mu sync.Mutex
26+
t time.Time
27+
}
28+
29+
func (c *fakeClock) Now() time.Time {
30+
c.mu.Lock()
31+
defer c.mu.Unlock()
32+
return c.t
33+
}
34+
35+
func (c *fakeClock) Sleep(d time.Duration) {
36+
c.mu.Lock()
37+
defer c.mu.Unlock()
38+
c.t = c.t.Add(d)
39+
}
40+
41+
func TestCacheHit(t *testing.T) {
42+
clock := &fakeClock{t: time.Now()}
43+
fakeResp := &certResponse{
44+
Keys: []jwk{
45+
{
46+
Kid: "123",
47+
},
48+
},
49+
}
50+
cache := newCachingClient(nil)
51+
cache.clock = clock.Now
52+
53+
// Cache should be empty
54+
cert, ok := cache.get(googleSACertsURL)
55+
if ok || cert != nil {
56+
t.Fatal("cache for SA certs should be empty")
57+
}
58+
59+
// Add an item, but make it expire now
60+
cache.set(googleSACertsURL, fakeResp, make(http.Header))
61+
clock.Sleep(time.Nanosecond) // it expires when current time is > expiration, not >=
62+
cert, ok = cache.get(googleSACertsURL)
63+
if ok || cert != nil {
64+
t.Fatal("cache for SA certs should be expired")
65+
}
66+
67+
// Add an item that expires in 1 seconds
68+
h := make(http.Header)
69+
h.Set("age", "0")
70+
h.Set("cache-control", "public, max-age=1, must-revalidate, no-transform")
71+
cache.set(googleSACertsURL, fakeResp, h)
72+
cert, ok = cache.get(googleSACertsURL)
73+
if !ok || cert == nil || cert.Keys[0].Kid != "123" {
74+
t.Fatal("cache for SA certs have a resp")
75+
}
76+
// Wait
77+
clock.Sleep(2 * time.Second)
78+
cert, ok = cache.get(googleSACertsURL)
79+
if ok || cert != nil {
80+
t.Fatal("cache for SA certs should be expired")
81+
}
82+
}

auth/idtoken/compute.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package idtoken
16+
17+
import (
18+
"context"
19+
"fmt"
20+
"net/url"
21+
"time"
22+
23+
"cloud.google.com/go/auth"
24+
"cloud.google.com/go/auth/internal"
25+
"cloud.google.com/go/compute/metadata"
26+
)
27+
28+
const identitySuffix = "instance/service-accounts/default/identity"
29+
30+
// computeTokenProvider checks if this code is being run on GCE. If it is, it
31+
// will use the metadata service to build a TokenProvider that fetches ID
32+
// tokens.
33+
func computeTokenProvider(opts *Options) (auth.TokenProvider, error) {
34+
if opts.CustomClaims != nil {
35+
return nil, fmt.Errorf("idtoken: Options.CustomClaims can't be used with the metadata service, please provide a service account if you would like to use this feature")
36+
}
37+
tp := computeIDTokenProvider{
38+
audience: opts.Audience,
39+
format: opts.ComputeTokenFormat,
40+
client: *metadata.NewClient(opts.client()),
41+
}
42+
return auth.NewCachedTokenProvider(tp, &auth.CachedTokenProviderOptions{
43+
ExpireEarly: 5 * time.Minute,
44+
}), nil
45+
}
46+
47+
type computeIDTokenProvider struct {
48+
audience string
49+
format ComputeTokenFormat
50+
client metadata.Client
51+
}
52+
53+
func (c computeIDTokenProvider) Token(ctx context.Context) (*auth.Token, error) {
54+
v := url.Values{}
55+
v.Set("audience", c.audience)
56+
if c.format != ComputeTokenFormatStandard {
57+
v.Set("format", "full")
58+
}
59+
if c.format == ComputeTokenFormatFullWithLicense {
60+
v.Set("licenses", "TRUE")
61+
}
62+
urlSuffix := identitySuffix + "?" + v.Encode()
63+
res, err := c.client.Get(urlSuffix)
64+
if err != nil {
65+
return nil, err
66+
}
67+
if res == "" {
68+
return nil, fmt.Errorf("idtoken: invalid empty response from metadata service")
69+
}
70+
return &auth.Token{
71+
Value: res,
72+
Type: internal.TokenTypeBearer,
73+
// Compute tokens are valid for one hour:
74+
// https://cloud.google.com/iam/docs/create-short-lived-credentials-direct#create-id
75+
Expiry: time.Now().Add(1 * time.Hour),
76+
}, nil
77+
}

auth/idtoken/compute_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package idtoken
16+
17+
import (
18+
"context"
19+
"net/http"
20+
"net/http/httptest"
21+
"strings"
22+
"testing"
23+
)
24+
25+
const metadataHostEnv = "GCE_METADATA_HOST"
26+
27+
func TestComputeTokenSource(t *testing.T) {
28+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
29+
if !strings.Contains(r.URL.Path, identitySuffix) {
30+
t.Errorf("got %q, want contains %q", r.URL.Path, identitySuffix)
31+
}
32+
if got, want := r.URL.Query().Get("audience"), "aud"; got != want {
33+
t.Errorf("got %q, want %q", got, want)
34+
}
35+
if got, want := r.URL.Query().Get("format"), "full"; got != want {
36+
t.Errorf("got %q, want %q", got, want)
37+
}
38+
if got, want := r.URL.Query().Get("licenses"), "TRUE"; got != want {
39+
t.Errorf("got %q, want %q", got, want)
40+
}
41+
w.Write([]byte(`fake_token`))
42+
}))
43+
defer ts.Close()
44+
t.Setenv(metadataHostEnv, strings.TrimPrefix(ts.URL, "http://"))
45+
tp, err := computeTokenProvider(&Options{
46+
Audience: "aud",
47+
ComputeTokenFormat: ComputeTokenFormatFullWithLicense,
48+
})
49+
if err != nil {
50+
t.Fatalf("computeTokenProvider() = %v", err)
51+
}
52+
tok, err := tp.Token(context.Background())
53+
if err != nil {
54+
t.Fatalf("tp.Token() = %v", err)
55+
}
56+
if want := "fake_token"; tok.Value != want {
57+
t.Errorf("got %q, want %q", tok.Value, want)
58+
}
59+
}
60+
61+
func TestComputeTokenSource_Standard(t *testing.T) {
62+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
63+
if !strings.Contains(r.URL.Path, identitySuffix) {
64+
t.Errorf("got %q, want contains %q", r.URL.Path, identitySuffix)
65+
}
66+
if got, want := r.URL.Query().Get("audience"), "aud"; got != want {
67+
t.Errorf("got %q, want %q", got, want)
68+
}
69+
if got, want := r.URL.Query().Get("format"), ""; got != want {
70+
t.Errorf("got %q, want %q", got, want)
71+
}
72+
if got, want := r.URL.Query().Get("licenses"), ""; got != want {
73+
t.Errorf("got %q, want %q", got, want)
74+
}
75+
w.Write([]byte(`fake_token`))
76+
}))
77+
defer ts.Close()
78+
t.Setenv(metadataHostEnv, strings.TrimPrefix(ts.URL, "http://"))
79+
tp, err := computeTokenProvider(&Options{
80+
Audience: "aud",
81+
ComputeTokenFormat: ComputeTokenFormatStandard,
82+
})
83+
if err != nil {
84+
t.Fatalf("computeTokenProvider() = %v", err)
85+
}
86+
tok, err := tp.Token(context.Background())
87+
if err != nil {
88+
t.Fatalf("tp.Token() = %v", err)
89+
}
90+
if want := "fake_token"; tok.Value != want {
91+
t.Errorf("got %q, want %q", tok.Value, want)
92+
}
93+
}
94+
95+
func TestComputeTokenSource_Invalid(t *testing.T) {
96+
if _, err := computeTokenProvider(&Options{
97+
Audience: "aud",
98+
CustomClaims: map[string]interface{}{"foo": "bar"},
99+
}); err == nil {
100+
t.Fatal("computeTokenProvider() = nil, expected non-nil error", err)
101+
}
102+
}

0 commit comments

Comments
 (0)