Skip to content

Commit bb1fc2e

Browse files
Bence VidositsBence Vidosits
Bence Vidosits
and
Bence Vidosits
authored
fix Request.Context() checks (#3512)
Co-authored-by: Bence Vidosits <[email protected]>
1 parent 2d4bbec commit bb1fc2e

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

context.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,25 +1180,32 @@ func (c *Context) SetAccepted(formats ...string) {
11801180
/***** GOLANG.ORG/X/NET/CONTEXT *****/
11811181
/************************************/
11821182

1183+
// hasRequestContext returns whether c.Request has Context and fallback.
1184+
func (c *Context) hasRequestContext() bool {
1185+
hasFallback := c.engine != nil && c.engine.ContextWithFallback
1186+
hasRequestContext := c.Request != nil && c.Request.Context() != nil
1187+
return hasFallback && hasRequestContext
1188+
}
1189+
11831190
// Deadline returns that there is no deadline (ok==false) when c.Request has no Context.
11841191
func (c *Context) Deadline() (deadline time.Time, ok bool) {
1185-
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
1192+
if !c.hasRequestContext() {
11861193
return
11871194
}
11881195
return c.Request.Context().Deadline()
11891196
}
11901197

11911198
// Done returns nil (chan which will wait forever) when c.Request has no Context.
11921199
func (c *Context) Done() <-chan struct{} {
1193-
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
1200+
if !c.hasRequestContext() {
11941201
return nil
11951202
}
11961203
return c.Request.Context().Done()
11971204
}
11981205

11991206
// Err returns nil when c.Request has no Context.
12001207
func (c *Context) Err() error {
1201-
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
1208+
if !c.hasRequestContext() {
12021209
return nil
12031210
}
12041211
return c.Request.Context().Err()
@@ -1219,7 +1226,7 @@ func (c *Context) Value(key any) any {
12191226
return val
12201227
}
12211228
}
1222-
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
1229+
if !c.hasRequestContext() {
12231230
return nil
12241231
}
12251232
return c.Request.Context().Value(key)

context_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2176,6 +2176,24 @@ func TestRemoteIPFail(t *testing.T) {
21762176
assert.False(t, trust)
21772177
}
21782178

2179+
func TestHasRequestContext(t *testing.T) {
2180+
c, _ := CreateTestContext(httptest.NewRecorder())
2181+
assert.False(t, c.hasRequestContext(), "no request, no fallback")
2182+
c.engine.ContextWithFallback = true
2183+
assert.False(t, c.hasRequestContext(), "no request, has fallback")
2184+
c.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
2185+
assert.True(t, c.hasRequestContext(), "has request, has fallback")
2186+
c.Request, _ = http.NewRequestWithContext(nil, "", "", nil) //nolint:staticcheck
2187+
assert.False(t, c.hasRequestContext(), "has request with nil ctx, has fallback")
2188+
c.engine.ContextWithFallback = false
2189+
assert.False(t, c.hasRequestContext(), "has request, no fallback")
2190+
2191+
c = &Context{}
2192+
assert.False(t, c.hasRequestContext(), "no request, no engine")
2193+
c.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
2194+
assert.False(t, c.hasRequestContext(), "has request, no engine")
2195+
}
2196+
21792197
func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
21802198
c, _ := CreateTestContext(httptest.NewRecorder())
21812199
// enable ContextWithFallback feature flag

0 commit comments

Comments
 (0)