Skip to content

Commit 316097c

Browse files
author
Julien Pivotto
committed
Use WithUserAgent
Signed-off-by: Julien Pivotto <[email protected]>
1 parent 99a1aca commit 316097c

File tree

3 files changed

+64
-28
lines changed

3 files changed

+64
-28
lines changed

config/http_config.go

+18-6
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,6 @@ type OAuth2 struct {
224224
ProxyURL URL `yaml:"proxy_url,omitempty" json:"proxy_url,omitempty"`
225225
// TLSConfig is used to connect to the token URL.
226226
TLSConfig TLSConfig `yaml:"tls_config,omitempty"`
227-
// UserAgent is used to set a custom User-Agent http header while making the oauth request.
228-
UserAgent string `yaml:"user_agent,omitempty" json:"user_agent,omitempty"`
229227
}
230228

231229
// SetDirectory joins any relative file paths with dir.
@@ -374,6 +372,7 @@ type httpClientOptions struct {
374372
keepAlivesEnabled bool
375373
http2Enabled bool
376374
idleConnTimeout time.Duration
375+
userAgent string
377376
}
378377

379378
// HTTPClientOption defines an option that can be applied to the HTTP client.
@@ -407,6 +406,13 @@ func WithIdleConnTimeout(timeout time.Duration) HTTPClientOption {
407406
}
408407
}
409408

409+
// WithIdleConnTimeout allows setting the user agent.
410+
func WithUserAgent(ua string) HTTPClientOption {
411+
return func(opts *httpClientOptions) {
412+
opts.userAgent = ua
413+
}
414+
}
415+
410416
// NewClient returns a http.Client using the specified http.RoundTripper.
411417
func newClient(rt http.RoundTripper) *http.Client {
412418
return &http.Client{Transport: rt}
@@ -499,8 +505,12 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
499505
rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile, rt)
500506
}
501507

508+
if opts.userAgent != "" {
509+
rt = NewUserAgentRoundTripper(opts.userAgent, rt)
510+
}
511+
502512
if cfg.OAuth2 != nil {
503-
rt = NewOAuth2RoundTripper(cfg.OAuth2, rt)
513+
rt = NewOAuth2RoundTripper(cfg.OAuth2, rt, &opts)
504514
}
505515
// Return a new configured RoundTripper.
506516
return rt, nil
@@ -621,12 +631,14 @@ type oauth2RoundTripper struct {
621631
next http.RoundTripper
622632
secret string
623633
mtx sync.RWMutex
634+
opts *httpClientOptions
624635
}
625636

626-
func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper) http.RoundTripper {
637+
func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper {
627638
return &oauth2RoundTripper{
628639
config: config,
629640
next: next,
641+
opts: opts,
630642
}
631643
}
632644

@@ -683,8 +695,8 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
683695
}
684696
}
685697

686-
if rt.config.UserAgent != "" {
687-
t = NewUserAgentRoundTripper(rt.config.UserAgent, t)
698+
if rt.opts.userAgent != "" {
699+
t = NewUserAgentRoundTripper(rt.opts.userAgent, t)
688700
}
689701

690702
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: t})

config/http_config_test.go

+46-17
Original file line numberDiff line numberDiff line change
@@ -1183,12 +1183,6 @@ type oauth2TestServerResponse struct {
11831183

11841184
func TestOAuth2(t *testing.T) {
11851185
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1186-
if r.URL.Path == "/token" {
1187-
if r.Header.Get("User-Agent") != "myuseragent" {
1188-
t.Fatalf("Expected User-Agent header in oauth request to be 'myuseragent', got '%s'", r.Header.Get("User-Agent"))
1189-
}
1190-
}
1191-
11921186
res, _ := json.Marshal(oauth2TestServerResponse{
11931187
AccessToken: "12345",
11941188
TokenType: "Bearer",
@@ -1205,7 +1199,6 @@ scopes:
12051199
- A
12061200
- B
12071201
token_url: %s/token
1208-
user_agent: myuseragent
12091202
endpoint_params:
12101203
hi: hello
12111204
`, ts.URL)
@@ -1215,7 +1208,6 @@ endpoint_params:
12151208
Scopes: []string{"A", "B"},
12161209
EndpointParams: map[string]string{"hi": "hello"},
12171210
TokenURL: fmt.Sprintf("%s/token", ts.URL),
1218-
UserAgent: "myuseragent",
12191211
}
12201212

12211213
var unmarshalledConfig OAuth2
@@ -1227,7 +1219,7 @@ endpoint_params:
12271219
t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig)
12281220
}
12291221

1230-
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)
1222+
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions)
12311223

12321224
client := http.Client{
12331225
Transport: rt,
@@ -1240,6 +1232,50 @@ endpoint_params:
12401232
}
12411233
}
12421234

1235+
func TestOAuth2UserAgent(t *testing.T) {
1236+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1237+
if r.URL.Path != "/" {
1238+
if r.Header.Get("User-Agent") != "myuseragent" {
1239+
t.Fatalf("Expected User-Agent header in oauth request to be 'myuseragent', got '%s'", r.Header.Get("User-Agent"))
1240+
}
1241+
}
1242+
1243+
res, _ := json.Marshal(oauth2TestServerResponse{
1244+
AccessToken: "12345",
1245+
TokenType: "Bearer",
1246+
})
1247+
w.Header().Add("Content-Type", "application/json")
1248+
_, _ = w.Write(res)
1249+
}))
1250+
defer ts.Close()
1251+
1252+
config := &OAuth2{
1253+
ClientID: "1",
1254+
ClientSecret: "2",
1255+
Scopes: []string{"A", "B"},
1256+
EndpointParams: map[string]string{"hi": "hello"},
1257+
TokenURL: fmt.Sprintf("%s/token", ts.URL),
1258+
}
1259+
1260+
opts := defaultHTTPClientOptions
1261+
WithUserAgent("myuseragent")(&opts)
1262+
1263+
rt := NewOAuth2RoundTripper(config, http.DefaultTransport, &opts)
1264+
1265+
client := http.Client{
1266+
Transport: rt,
1267+
}
1268+
resp, err := client.Get(ts.URL)
1269+
if err != nil {
1270+
t.Fatal(err)
1271+
}
1272+
1273+
authorization := resp.Request.Header.Get("Authorization")
1274+
if authorization != "Bearer 12345" {
1275+
t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
1276+
}
1277+
}
1278+
12431279
func TestOAuth2WithFile(t *testing.T) {
12441280
var expectedAuth *string
12451281
var previousAuth string
@@ -1302,7 +1338,7 @@ endpoint_params:
13021338
t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig)
13031339
}
13041340

1305-
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)
1341+
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions)
13061342

13071343
client := http.Client{
13081344
Transport: rt,
@@ -1496,10 +1532,3 @@ func TestOAuth2Proxy(t *testing.T) {
14961532
t.Errorf("Error loading OAuth2 client config: %v", err)
14971533
}
14981534
}
1499-
1500-
func TestOAuth2UserAgent(t *testing.T) {
1501-
_, _, err := LoadHTTPConfigFile("testdata/http.conf.oauth2-user-agent.good.yml")
1502-
if err != nil {
1503-
t.Errorf("Error loading OAuth2 client config: %v", err)
1504-
}
1505-
}

config/testdata/http.conf.oauth2-user-agent.good.yml

-5
This file was deleted.

0 commit comments

Comments
 (0)