Skip to content

Commit 99a1aca

Browse files
committed
add User-Agent header to oauth2 requests
Signed-off-by: clayton-gonsalves <[email protected]>
1 parent 26d4974 commit 99a1aca

File tree

3 files changed

+50
-2
lines changed

3 files changed

+50
-2
lines changed

config/http_config.go

+28
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ type OAuth2 struct {
224224
ProxyURL URL `yaml:"proxy_url,omitempty" json:"proxy_url,omitempty"`
225225
// TLSConfig is used to connect to the token URL.
226226
TLSConfig TLSConfig `yaml:"tls_config,omitempty"`
227+
// UserAgent is used to set a custom User-Agent http header while making the oauth request.
228+
UserAgent string `yaml:"user_agent,omitempty" json:"user_agent,omitempty"`
227229
}
228230

229231
// SetDirectory joins any relative file paths with dir.
@@ -681,6 +683,10 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
681683
}
682684
}
683685

686+
if rt.config.UserAgent != "" {
687+
t = NewUserAgentRoundTripper(rt.config.UserAgent, t)
688+
}
689+
684690
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: t})
685691
tokenSource := config.TokenSource(ctx)
686692

@@ -911,6 +917,28 @@ func (t *tlsRoundTripper) CloseIdleConnections() {
911917
}
912918
}
913919

920+
type userAgentRoundTripper struct {
921+
userAgent string
922+
rt http.RoundTripper
923+
}
924+
925+
// NewUserAgentRoundTripper adds the user agent every request header.
926+
func NewUserAgentRoundTripper(userAgent string, rt http.RoundTripper) http.RoundTripper {
927+
return &userAgentRoundTripper{userAgent, rt}
928+
}
929+
930+
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
931+
req = cloneRequest(req)
932+
req.Header.Set("User-Agent", rt.userAgent)
933+
return rt.rt.RoundTrip(req)
934+
}
935+
936+
func (rt *userAgentRoundTripper) CloseIdleConnections() {
937+
if ci, ok := rt.rt.(closeIdler); ok {
938+
ci.CloseIdleConnections()
939+
}
940+
}
941+
914942
func (c HTTPClientConfig) String() string {
915943
b, err := yaml.Marshal(c)
916944
if err != nil {

config/http_config_test.go

+17-2
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,12 @@ type oauth2TestServerResponse struct {
11831183

11841184
func TestOAuth2(t *testing.T) {
11851185
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1186+
if r.URL.Path == "/token" {
1187+
if r.Header.Get("User-Agent") != "myuseragent" {
1188+
t.Fatalf("Expected User-Agent header in oauth request to be 'myuseragent', got '%s'", r.Header.Get("User-Agent"))
1189+
}
1190+
}
1191+
11861192
res, _ := json.Marshal(oauth2TestServerResponse{
11871193
AccessToken: "12345",
11881194
TokenType: "Bearer",
@@ -1198,7 +1204,8 @@ client_secret: 2
11981204
scopes:
11991205
- A
12001206
- B
1201-
token_url: %s
1207+
token_url: %s/token
1208+
user_agent: myuseragent
12021209
endpoint_params:
12031210
hi: hello
12041211
`, ts.URL)
@@ -1207,7 +1214,8 @@ endpoint_params:
12071214
ClientSecret: "2",
12081215
Scopes: []string{"A", "B"},
12091216
EndpointParams: map[string]string{"hi": "hello"},
1210-
TokenURL: ts.URL,
1217+
TokenURL: fmt.Sprintf("%s/token", ts.URL),
1218+
UserAgent: "myuseragent",
12111219
}
12121220

12131221
var unmarshalledConfig OAuth2
@@ -1488,3 +1496,10 @@ func TestOAuth2Proxy(t *testing.T) {
14881496
t.Errorf("Error loading OAuth2 client config: %v", err)
14891497
}
14901498
}
1499+
1500+
func TestOAuth2UserAgent(t *testing.T) {
1501+
_, _, err := LoadHTTPConfigFile("testdata/http.conf.oauth2-user-agent.good.yml")
1502+
if err != nil {
1503+
t.Errorf("Error loading OAuth2 client config: %v", err)
1504+
}
1505+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
oauth2:
2+
client_id: "myclient"
3+
client_secret: "mysecret"
4+
token_url: "http://auth"
5+
user_agent: "myuseragent"

0 commit comments

Comments
 (0)