Skip to content

Commit 4ce7bbb

Browse files
quartzmogopherbot
authored andcommitted
google: add Credentials.GetUniverseDomain with GCE MDS support
* Deprecate Credentials.UniverseDomain Change-Id: I1cbc842fbfce35540c8dff99fec09e036b9e2cdf Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/554215 TryBot-Result: Gopher Robot <[email protected]> Run-TryBot: Cody Oss <[email protected]> Auto-Submit: Cody Oss <[email protected]> Reviewed-by: Cody Oss <[email protected]> Reviewed-by: Viacheslav Rostovtsev <[email protected]>
1 parent 1e6999b commit 4ce7bbb

File tree

3 files changed

+173
-0
lines changed

3 files changed

+173
-0
lines changed

google/default.go

+58
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"os"
1313
"path/filepath"
1414
"runtime"
15+
"sync"
1516
"time"
1617

1718
"cloud.google.com/go/compute/metadata"
@@ -41,19 +42,76 @@ type Credentials struct {
4142
// running on Google Cloud Platform.
4243
JSON []byte
4344

45+
udMu sync.Mutex // guards universeDomain
4446
// universeDomain is the default service domain for a given Cloud universe.
4547
universeDomain string
4648
}
4749

4850
// UniverseDomain returns the default service domain for a given Cloud universe.
51+
//
4952
// The default value is "googleapis.com".
53+
//
54+
// Deprecated: Use instead (*Credentials).GetUniverseDomain(), which supports
55+
// obtaining the universe domain when authenticating via the GCE metadata server.
56+
// Unlike GetUniverseDomain, this method, UniverseDomain, will always return the
57+
// default value when authenticating via the GCE metadata server.
58+
// See also [The attached service account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa).
5059
func (c *Credentials) UniverseDomain() string {
5160
if c.universeDomain == "" {
5261
return universeDomainDefault
5362
}
5463
return c.universeDomain
5564
}
5665

66+
// GetUniverseDomain returns the default service domain for a given Cloud
67+
// universe.
68+
//
69+
// The default value is "googleapis.com".
70+
//
71+
// It obtains the universe domain from the attached service account on GCE when
72+
// authenticating via the GCE metadata server. See also [The attached service
73+
// account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa).
74+
// If the GCE metadata server returns a 404 error, the default value is
75+
// returned. If the GCE metadata server returns an error other than 404, the
76+
// error is returned.
77+
func (c *Credentials) GetUniverseDomain() (string, error) {
78+
c.udMu.Lock()
79+
defer c.udMu.Unlock()
80+
if c.universeDomain == "" && metadata.OnGCE() {
81+
// If we're on Google Compute Engine, an App Engine standard second
82+
// generation runtime, or App Engine flexible, use the metadata server.
83+
err := c.computeUniverseDomain()
84+
if err != nil {
85+
return "", err
86+
}
87+
}
88+
// If not on Google Compute Engine, or in case of any non-error path in
89+
// computeUniverseDomain that did not set universeDomain, set the default
90+
// universe domain.
91+
if c.universeDomain == "" {
92+
c.universeDomain = universeDomainDefault
93+
}
94+
return c.universeDomain, nil
95+
}
96+
97+
// computeUniverseDomain fetches the default service domain for a given Cloud
98+
// universe from Google Compute Engine (GCE)'s metadata server. It's only valid
99+
// to use this method if your program is running on a GCE instance.
100+
func (c *Credentials) computeUniverseDomain() error {
101+
var err error
102+
c.universeDomain, err = metadata.Get("universe/universe_domain")
103+
if err != nil {
104+
if _, ok := err.(metadata.NotDefinedError); ok {
105+
// http.StatusNotFound (404)
106+
c.universeDomain = universeDomainDefault
107+
return nil
108+
} else {
109+
return err
110+
}
111+
}
112+
return nil
113+
}
114+
57115
// DefaultCredentials is the old name of Credentials.
58116
//
59117
// Deprecated: use Credentials instead.

google/default_test.go

+95
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ package google
66

77
import (
88
"context"
9+
"net/http"
10+
"net/http/httptest"
11+
"strings"
912
"testing"
1013
)
1114

@@ -74,6 +77,9 @@ func TestCredentialsFromJSONWithParams_SA(t *testing.T) {
7477
if want := "googleapis.com"; creds.UniverseDomain() != want {
7578
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
7679
}
80+
if want := "googleapis.com"; creds.UniverseDomain() != want {
81+
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
82+
}
7783
}
7884

7985
func TestCredentialsFromJSONWithParams_SA_Params_UniverseDomain(t *testing.T) {
@@ -94,6 +100,9 @@ func TestCredentialsFromJSONWithParams_SA_Params_UniverseDomain(t *testing.T) {
94100
if creds.UniverseDomain() != universeDomain2 {
95101
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain2)
96102
}
103+
if creds.UniverseDomain() != universeDomain2 {
104+
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain2)
105+
}
97106
}
98107

99108
func TestCredentialsFromJSONWithParams_SA_UniverseDomain(t *testing.T) {
@@ -113,6 +122,13 @@ func TestCredentialsFromJSONWithParams_SA_UniverseDomain(t *testing.T) {
113122
if creds.UniverseDomain() != universeDomain {
114123
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain)
115124
}
125+
got, err := creds.GetUniverseDomain()
126+
if err != nil {
127+
t.Fatal(err)
128+
}
129+
if got != universeDomain {
130+
t.Fatalf("got %q, want %q", got, universeDomain)
131+
}
116132
}
117133

118134
func TestCredentialsFromJSONWithParams_SA_UniverseDomain_Params_UniverseDomain(t *testing.T) {
@@ -133,6 +149,13 @@ func TestCredentialsFromJSONWithParams_SA_UniverseDomain_Params_UniverseDomain(t
133149
if creds.UniverseDomain() != universeDomain2 {
134150
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain2)
135151
}
152+
got, err := creds.GetUniverseDomain()
153+
if err != nil {
154+
t.Fatal(err)
155+
}
156+
if got != universeDomain2 {
157+
t.Fatalf("got %q, want %q", got, universeDomain2)
158+
}
136159
}
137160

138161
func TestCredentialsFromJSONWithParams_User(t *testing.T) {
@@ -149,6 +172,13 @@ func TestCredentialsFromJSONWithParams_User(t *testing.T) {
149172
if want := "googleapis.com"; creds.UniverseDomain() != want {
150173
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
151174
}
175+
got, err := creds.GetUniverseDomain()
176+
if err != nil {
177+
t.Fatal(err)
178+
}
179+
if want := "googleapis.com"; got != want {
180+
t.Fatalf("got %q, want %q", got, want)
181+
}
152182
}
153183

154184
func TestCredentialsFromJSONWithParams_User_Params_UniverseDomain(t *testing.T) {
@@ -166,6 +196,13 @@ func TestCredentialsFromJSONWithParams_User_Params_UniverseDomain(t *testing.T)
166196
if want := "googleapis.com"; creds.UniverseDomain() != want {
167197
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
168198
}
199+
got, err := creds.GetUniverseDomain()
200+
if err != nil {
201+
t.Fatal(err)
202+
}
203+
if want := "googleapis.com"; got != want {
204+
t.Fatalf("got %q, want %q", got, want)
205+
}
169206
}
170207

171208
func TestCredentialsFromJSONWithParams_User_UniverseDomain(t *testing.T) {
@@ -182,6 +219,13 @@ func TestCredentialsFromJSONWithParams_User_UniverseDomain(t *testing.T) {
182219
if want := "googleapis.com"; creds.UniverseDomain() != want {
183220
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
184221
}
222+
got, err := creds.GetUniverseDomain()
223+
if err != nil {
224+
t.Fatal(err)
225+
}
226+
if want := "googleapis.com"; got != want {
227+
t.Fatalf("got %q, want %q", got, want)
228+
}
185229
}
186230

187231
func TestCredentialsFromJSONWithParams_User_UniverseDomain_Params_UniverseDomain(t *testing.T) {
@@ -199,4 +243,55 @@ func TestCredentialsFromJSONWithParams_User_UniverseDomain_Params_UniverseDomain
199243
if want := "googleapis.com"; creds.UniverseDomain() != want {
200244
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
201245
}
246+
got, err := creds.GetUniverseDomain()
247+
if err != nil {
248+
t.Fatal(err)
249+
}
250+
if want := "googleapis.com"; got != want {
251+
t.Fatalf("got %q, want %q", got, want)
252+
}
253+
}
254+
255+
func TestComputeUniverseDomain(t *testing.T) {
256+
universeDomainPath := "/computeMetadata/v1/universe/universe_domain"
257+
universeDomainResponseBody := "example.com"
258+
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
259+
if r.URL.Path != universeDomainPath {
260+
t.Errorf("got %s, want %s", r.URL.Path, universeDomainPath)
261+
}
262+
w.Write([]byte(universeDomainResponseBody))
263+
}))
264+
defer s.Close()
265+
t.Setenv("GCE_METADATA_HOST", strings.TrimPrefix(s.URL, "http://"))
266+
267+
scope := "https://www.googleapis.com/auth/cloud-platform"
268+
params := CredentialsParams{
269+
Scopes: []string{scope},
270+
}
271+
// Copied from FindDefaultCredentialsWithParams, metadata.OnGCE() = true block
272+
creds := &Credentials{
273+
ProjectID: "fake_project",
274+
TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...),
275+
universeDomain: params.UniverseDomain, // empty
276+
}
277+
c := make(chan bool)
278+
go func() {
279+
got, err := creds.GetUniverseDomain() // First conflicting access.
280+
if err != nil {
281+
t.Error(err)
282+
}
283+
if want := universeDomainResponseBody; got != want {
284+
t.Errorf("got %q, want %q", got, want)
285+
}
286+
c <- true
287+
}()
288+
got, err := creds.GetUniverseDomain() // Second conflicting access.
289+
<-c
290+
if err != nil {
291+
t.Error(err)
292+
}
293+
if want := universeDomainResponseBody; got != want {
294+
t.Errorf("got %q, want %q", got, want)
295+
}
296+
202297
}

google/google_test.go

+20
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
package google
66

77
import (
8+
"net/http"
9+
"net/http/httptest"
810
"strings"
911
"testing"
1012
)
@@ -137,3 +139,21 @@ func TestJWTConfigFromJSONNoAudience(t *testing.T) {
137139
t.Errorf("Audience = %q; want %q", got, want)
138140
}
139141
}
142+
143+
func TestComputeTokenSource(t *testing.T) {
144+
tokenPath := "/computeMetadata/v1/instance/service-accounts/default/token"
145+
tokenResponseBody := `{"access_token":"Sample.Access.Token","token_type":"Bearer","expires_in":3600}`
146+
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
147+
if r.URL.Path != tokenPath {
148+
t.Errorf("got %s, want %s", r.URL.Path, tokenPath)
149+
}
150+
w.Write([]byte(tokenResponseBody))
151+
}))
152+
defer s.Close()
153+
t.Setenv("GCE_METADATA_HOST", strings.TrimPrefix(s.URL, "http://"))
154+
ts := ComputeTokenSource("")
155+
_, err := ts.Token()
156+
if err != nil {
157+
t.Errorf("ts.Token() = %v", err)
158+
}
159+
}

0 commit comments

Comments
 (0)