Skip to content

Commit 72c7f6b

Browse files
authored
feat(auth/httptransport): add ability to customize transport (#10023)
This was a known limitation of the current implementation that had been forgotten to be implemented. See removal of todo in related PR. Updates: #9812 Updates: #9814 Related: googleapis/google-api-go-client#2541
1 parent 50994e7 commit 72c7f6b

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

auth/httptransport/httptransport.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ type Options struct {
4444
// Headers are extra HTTP headers that will be appended to every outgoing
4545
// request.
4646
Headers http.Header
47+
// BaseRoundTripper overrides the base transport used for serving requests.
48+
// If specified ClientCertProvider is ignored.
49+
BaseRoundTripper http.RoundTripper
4750
// Endpoint overrides the default endpoint to be used for a service.
4851
Endpoint string
4952
// APIKey specifies an API key to be used as the basis for authentication.
@@ -181,7 +184,11 @@ func NewClient(opts *Options) (*http.Client, error) {
181184
if err != nil {
182185
return nil, err
183186
}
184-
trans, err := newTransport(defaultBaseTransport(clientCertProvider, dialTLSContext), opts)
187+
baseRoundTripper := opts.BaseRoundTripper
188+
if baseRoundTripper == nil {
189+
baseRoundTripper = defaultBaseTransport(clientCertProvider, dialTLSContext)
190+
}
191+
trans, err := newTransport(baseRoundTripper, opts)
185192
if err != nil {
186193
return nil, err
187194
}

auth/httptransport/httptransport_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,46 @@ func TestNewClient_APIKey(t *testing.T) {
350350
}
351351
}
352352

353+
func TestNewClient_BaseRoundTripper(t *testing.T) {
354+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
355+
got := r.Header.Get("Foo")
356+
if want := "foo"; got != want {
357+
t.Errorf("got %q, want %q", got, want)
358+
}
359+
got = r.Header.Get("Bar")
360+
if want := "bar"; got != want {
361+
t.Errorf("got %q, want %q", got, want)
362+
}
363+
}))
364+
defer ts.Close()
365+
client, err := NewClient(&Options{
366+
BaseRoundTripper: &rt{key: "Bar", value: "bar"},
367+
Headers: http.Header{"Foo": []string{"foo"}},
368+
APIKey: "key",
369+
})
370+
if err != nil {
371+
t.Fatalf("NewClient() = %v", err)
372+
}
373+
if _, err := client.Get(ts.URL); err != nil {
374+
t.Fatalf("client.Get() = %v", err)
375+
}
376+
}
377+
353378
type staticTP string
354379

355380
func (tp staticTP) Token(context.Context) (*auth.Token, error) {
356381
return &auth.Token{
357382
Value: string(tp),
358383
}, nil
359384
}
385+
386+
type rt struct {
387+
key string
388+
value string
389+
}
390+
391+
func (r *rt) RoundTrip(req *http.Request) (*http.Response, error) {
392+
req2 := req.Clone(req.Context())
393+
req2.Header.Add(r.key, r.value)
394+
return http.DefaultTransport.RoundTrip(req2)
395+
}

0 commit comments

Comments
 (0)