Skip to content

Commit 5284066

Browse files
authored
fix(auth): default gRPC token type to Bearer if not set (#9800)
As documented on auth.Token.Type, if the value of Type is "" it should be treated as a Bearer token. Added a similar helper method as we have in the httptransport package to default this.
1 parent da245fa commit 5284066

File tree

3 files changed

+58
-10
lines changed

3 files changed

+58
-10
lines changed

auth/grpctransport/dial_socketopt_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func TestDialWithDirectPathEnabled(t *testing.T) {
109109

110110
pool, err := Dial(ctx, true, &Options{
111111
Credentials: auth.NewCredentials(&auth.CredentialsOptions{
112-
TokenProvider: staticTP("hey"),
112+
TokenProvider: &staticTP{tok: &auth.Token{Value: "hey"}},
113113
}),
114114
GRPCDialOpts: []grpc.DialOption{userDialer},
115115
Endpoint: "example.google.com:443",

auth/grpctransport/grpctransport.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,15 +287,24 @@ func (c *grpcCredentialsProvider) GetRequestMetadata(ctx context.Context, uri ..
287287
return nil, fmt.Errorf("unable to transfer credentials PerRPCCredentials: %v", err)
288288
}
289289
}
290-
metadata := map[string]string{
291-
"authorization": token.Type + " " + token.Value,
292-
}
290+
metadata := make(map[string]string, len(c.metadata)+1)
291+
setAuthMetadata(token, metadata)
293292
for k, v := range c.metadata {
294293
metadata[k] = v
295294
}
296295
return metadata, nil
297296
}
298297

298+
// setAuthMetadata uses the provided token to set the Authorization metadata.
299+
// If the token.Type is empty, the type is assumed to be Bearer.
300+
func setAuthMetadata(token *auth.Token, m map[string]string) {
301+
typ := token.Type
302+
if typ == "" {
303+
typ = internal.TokenTypeBearer
304+
}
305+
m["authorization"] = typ + " " + token.Value
306+
}
307+
299308
func (c *grpcCredentialsProvider) RequireTransportSecurity() bool {
300309
return c.secure
301310
}

auth/grpctransport/grpctransport_test.go

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package grpctransport
1717
import (
1818
"context"
1919
"errors"
20+
"log"
2021
"net"
2122
"testing"
2223

@@ -83,7 +84,7 @@ func TestDial_FailsValidation(t *testing.T) {
8384
opts: &Options{
8485
DisableAuthentication: true,
8586
Credentials: auth.NewCredentials(&auth.CredentialsOptions{
86-
TokenProvider: staticTP("fakeToken"),
87+
TokenProvider: &staticTP{tok: &auth.Token{Value: "fakeToken"}},
8788
}),
8889
},
8990
},
@@ -272,6 +273,44 @@ func TestGrpcCredentialsProvider_GetClientUniverseDomain(t *testing.T) {
272273
}
273274
}
274275

276+
func TestGrpcCredentialsProvider_TokenType(t *testing.T) {
277+
tests := []struct {
278+
name string
279+
tok *auth.Token
280+
want string
281+
}{
282+
{
283+
name: "type set",
284+
tok: &auth.Token{
285+
Value: "token",
286+
Type: "Basic",
287+
},
288+
want: "Basic token",
289+
},
290+
{
291+
name: "type set",
292+
tok: &auth.Token{
293+
Value: "token",
294+
},
295+
want: "Bearer token",
296+
},
297+
}
298+
for _, tc := range tests {
299+
cp := grpcCredentialsProvider{
300+
creds: &auth.Credentials{
301+
TokenProvider: &staticTP{tok: tc.tok},
302+
},
303+
}
304+
m, err := cp.GetRequestMetadata(context.Background(), "")
305+
if err != nil {
306+
log.Fatalf("cp.GetRequestMetadata() = %v, want nil", err)
307+
}
308+
if got := m["authorization"]; got != tc.want {
309+
t.Fatalf("got %q, want %q", got, tc.want)
310+
}
311+
}
312+
}
313+
275314
func TestNewClient_DetectedServiceAccount(t *testing.T) {
276315
testQuota := "testquota"
277316
wantHeader := "bar"
@@ -329,12 +368,12 @@ func TestNewClient_DetectedServiceAccount(t *testing.T) {
329368
}
330369
}
331370

332-
type staticTP string
371+
type staticTP struct {
372+
tok *auth.Token
373+
}
333374

334-
func (tp staticTP) Token(context.Context) (*auth.Token, error) {
335-
return &auth.Token{
336-
Value: string(tp),
337-
}, nil
375+
func (tp *staticTP) Token(context.Context) (*auth.Token, error) {
376+
return tp.tok, nil
338377
}
339378

340379
type fakeEchoService struct {

0 commit comments

Comments
 (0)