Skip to content

Commit 183147b

Browse files
authored
Deny context values from flowing across disjoint requests (#21863)
This prevents things like custom HTTP headers from showing up in an authentication request. Removed duplicate import from a test.
1 parent 9ef98fe commit 183147b

8 files changed

+52
-7
lines changed

sdk/azcore/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* Passing a `nil` credential value will no longer cause a panic. Instead, the authentication is skipped.
1818
* Calling `Error` on a zero-value `azcore.ResponseError` will no longer panic.
1919
* Fixed an issue in `fake.PagerResponder[T]` that would cause a trailing error to be omitted when iterating over pages.
20+
* Context values created by `azcore` will no longer flow across disjoint HTTP requests.
2021

2122
### Other Changes
2223

sdk/azcore/arm/runtime/policy_bearer_token_test.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
1616
armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy"
1717
"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared"
18-
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
1918
azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
2019
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
2120
"github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo"
@@ -292,7 +291,7 @@ func TestBearerTokenPolicyRequiresHTTPS(t *testing.T) {
292291
srv, close := mock.NewServer()
293292
defer close()
294293
b := NewBearerTokenPolicy(mockCredential{}, nil)
295-
pl := newTestPipeline(&policy.ClientOptions{Transport: srv, PerRetryPolicies: []policy.Policy{b}})
294+
pl := newTestPipeline(&azpolicy.ClientOptions{Transport: srv, PerRetryPolicies: []azpolicy.Policy{b}})
296295
req, err := runtime.NewRequest(context.Background(), "GET", srv.URL())
297296
require.NoError(t, err)
298297
_, err = pl.Do(req)

sdk/azcore/arm/runtime/policy_register_rp.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,13 @@ func (r *rpRegistrationPolicy) Do(req *azpolicy.Request) (*http.Response, error)
126126
u: r.endpoint,
127127
subID: subID,
128128
}
129-
if _, err = rpOps.Register(req.Raw().Context(), rp); err != nil {
129+
if _, err = rpOps.Register(&shared.ContextWithDeniedValues{Context: req.Raw().Context()}, rp); err != nil {
130130
logRegistrationExit(err)
131131
return resp, err
132132
}
133+
133134
// RP was registered, however we need to wait for the registration to complete
134-
pollCtx, pollCancel := context.WithTimeout(req.Raw().Context(), r.options.PollingDuration)
135+
pollCtx, pollCancel := context.WithTimeout(&shared.ContextWithDeniedValues{Context: req.Raw().Context()}, r.options.PollingDuration)
135136
var lastRegState string
136137
for {
137138
// get the current registration state

sdk/azcore/internal/shared/shared.go

+22
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ import (
1818
"github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo"
1919
)
2020

21+
// NOTE: when adding a new context key type, it likely needs to be
22+
// added to the deny-list of key types in ContextWithDeniedValues
23+
2124
// CtxWithHTTPHeaderKey is used as a context key for adding/retrieving http.Header.
2225
type CtxWithHTTPHeaderKey struct{}
2326

@@ -110,6 +113,25 @@ func ExtractModuleName(clientName string) (string, string, error) {
110113
return matches[3], matches[2], nil
111114
}
112115

116+
// ContextWithDeniedValues wraps an existing [context.Context], denying access to certain context values.
117+
// Pipeline policies that create new requests to be sent down their own pipeline MUST wrap the caller's
118+
// context with an instance of this type. This is to prevent context values from flowing across disjoint
119+
// requests which can have unintended side-effects.
120+
type ContextWithDeniedValues struct {
121+
context.Context
122+
}
123+
124+
// Value implements part of the [context.Context] interface.
125+
// It acts as a deny-list for certain context keys.
126+
func (c *ContextWithDeniedValues) Value(key any) any {
127+
switch key.(type) {
128+
case CtxAPINameKey, CtxWithCaptureResponse, CtxWithHTTPHeaderKey, CtxWithRetryOptionsKey, CtxWithTracingTracer:
129+
return nil
130+
default:
131+
return c.Context.Value(key)
132+
}
133+
}
134+
113135
// NonRetriableError marks the specified error as non-retriable.
114136
func NonRetriableError(err error) error {
115137
return &nonRetriableError{err}

sdk/azcore/internal/shared/shared_test.go

+20
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,23 @@ func TestExtractModuleName(t *testing.T) {
136136
require.Empty(t, mod)
137137
require.Empty(t, client)
138138
}
139+
140+
func TestContextWithDeniedValues(t *testing.T) {
141+
type testKey struct{}
142+
const value = "value"
143+
144+
ctx := context.WithValue(context.Background(), testKey{}, value)
145+
ctx = context.WithValue(ctx, CtxAPINameKey{}, value)
146+
ctx = context.WithValue(ctx, CtxWithCaptureResponse{}, value)
147+
ctx = context.WithValue(ctx, CtxWithHTTPHeaderKey{}, value)
148+
ctx = context.WithValue(ctx, CtxWithRetryOptionsKey{}, value)
149+
ctx = context.WithValue(ctx, CtxWithTracingTracer{}, value)
150+
ctx = &ContextWithDeniedValues{Context: ctx}
151+
152+
require.Nil(t, ctx.Value(CtxAPINameKey{}))
153+
require.Nil(t, ctx.Value(CtxWithCaptureResponse{}))
154+
require.Nil(t, ctx.Value(CtxWithHTTPHeaderKey{}))
155+
require.Nil(t, ctx.Value(CtxWithRetryOptionsKey{}))
156+
require.Nil(t, ctx.Value(CtxWithTracingTracer{}))
157+
require.NotNil(t, ctx.Value(testKey{}))
158+
}

sdk/azcore/runtime/policy_bearer_token.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ type acquiringResourceState struct {
3434
// acquire acquires or updates the resource; only one
3535
// thread/goroutine at a time ever calls this function
3636
func acquire(state acquiringResourceState) (newResource exported.AccessToken, newExpiration time.Time, err error) {
37-
tk, err := state.p.cred.GetToken(state.req.Raw().Context(), state.tro)
37+
tk, err := state.p.cred.GetToken(&shared.ContextWithDeniedValues{Context: state.req.Raw().Context()}, state.tro)
3838
if err != nil {
3939
return exported.AccessToken{}, time.Time{}, err
4040
}

sdk/azcore/runtime/policy_http_trace.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ func StartSpan(ctx context.Context, name string, tracer tracing.Tracer, options
110110
if !tracer.Enabled() {
111111
return ctx, func(err error) {}
112112
}
113+
114+
// we MUST propagate the active tracer before returning so that the trace policy can access it
115+
ctx = context.WithValue(ctx, shared.CtxWithTracingTracer{}, tracer)
116+
113117
const newSpanKind = tracing.SpanKindInternal
114118
if activeSpan := ctx.Value(ctxActiveSpan{}); activeSpan != nil {
115119
// per the design guidelines, if a SDK method Foo() calls SDK method Bar(),
@@ -125,7 +129,6 @@ func StartSpan(ctx context.Context, name string, tracer tracing.Tracer, options
125129
ctx, span := tracer.Start(ctx, name, &tracing.SpanOptions{
126130
Kind: newSpanKind,
127131
})
128-
ctx = context.WithValue(ctx, shared.CtxWithTracingTracer{}, tracer)
129132
ctx = context.WithValue(ctx, ctxActiveSpan{}, newSpanKind)
130133
return ctx, func(err error) {
131134
if err != nil {

sdk/azcore/runtime/policy_http_trace_test.go

-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,6 @@ func TestStartSpansDontNest(t *testing.T) {
205205

206206
barMethod := func(ctx context.Context) {
207207
ourCtx, endSpan := StartSpan(ctx, "BarMethod", tr, nil)
208-
require.Same(t, ctx, ourCtx)
209208
defer endSpan(nil)
210209
req, err := exported.NewRequest(ourCtx, http.MethodGet, srv.URL()+"/bar")
211210
require.NoError(t, err)

0 commit comments

Comments
 (0)