Skip to content

Commit f197a8b

Browse files
authored
feat(context): add ContextWithFallback feature flag (#3166) (#3172)
Enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value()
1 parent 92ba8e1 commit f197a8b

File tree

3 files changed

+110
-16
lines changed

3 files changed

+110
-16
lines changed

context.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -1158,23 +1158,23 @@ func (c *Context) SetAccepted(formats ...string) {
11581158

11591159
// Deadline returns that there is no deadline (ok==false) when c.Request has no Context.
11601160
func (c *Context) Deadline() (deadline time.Time, ok bool) {
1161-
if c.Request == nil || c.Request.Context() == nil {
1161+
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
11621162
return
11631163
}
11641164
return c.Request.Context().Deadline()
11651165
}
11661166

11671167
// Done returns nil (chan which will wait forever) when c.Request has no Context.
11681168
func (c *Context) Done() <-chan struct{} {
1169-
if c.Request == nil || c.Request.Context() == nil {
1169+
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
11701170
return nil
11711171
}
11721172
return c.Request.Context().Done()
11731173
}
11741174

11751175
// Err returns nil when c.Request has no Context.
11761176
func (c *Context) Err() error {
1177-
if c.Request == nil || c.Request.Context() == nil {
1177+
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
11781178
return nil
11791179
}
11801180
return c.Request.Context().Err()
@@ -1195,7 +1195,7 @@ func (c *Context) Value(key any) any {
11951195
return val
11961196
}
11971197
}
1198-
if c.Request == nil || c.Request.Context() == nil {
1198+
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
11991199
return nil
12001200
}
12011201
return c.Request.Context().Value(key)

context_test.go

+103-12
Original file line numberDiff line numberDiff line change
@@ -2097,12 +2097,18 @@ func TestRemoteIPFail(t *testing.T) {
20972097
}
20982098

20992099
func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
2100-
c := &Context{}
2100+
c, _ := CreateTestContext(httptest.NewRecorder())
2101+
// enable ContextWithFallback feature flag
2102+
c.engine.ContextWithFallback = true
2103+
21012104
deadline, ok := c.Deadline()
21022105
assert.Zero(t, deadline)
21032106
assert.False(t, ok)
21042107

2105-
c2 := &Context{}
2108+
c2, _ := CreateTestContext(httptest.NewRecorder())
2109+
// enable ContextWithFallback feature flag
2110+
c2.engine.ContextWithFallback = true
2111+
21062112
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
21072113
d := time.Now().Add(time.Second)
21082114
ctx, cancel := context.WithDeadline(context.Background(), d)
@@ -2114,10 +2120,16 @@ func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
21142120
}
21152121

21162122
func TestContextWithFallbackDoneFromRequestContext(t *testing.T) {
2117-
c := &Context{}
2123+
c, _ := CreateTestContext(httptest.NewRecorder())
2124+
// enable ContextWithFallback feature flag
2125+
c.engine.ContextWithFallback = true
2126+
21182127
assert.Nil(t, c.Done())
21192128

2120-
c2 := &Context{}
2129+
c2, _ := CreateTestContext(httptest.NewRecorder())
2130+
// enable ContextWithFallback feature flag
2131+
c2.engine.ContextWithFallback = true
2132+
21212133
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
21222134
ctx, cancel := context.WithCancel(context.Background())
21232135
c2.Request = c2.Request.WithContext(ctx)
@@ -2126,10 +2138,16 @@ func TestContextWithFallbackDoneFromRequestContext(t *testing.T) {
21262138
}
21272139

21282140
func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
2129-
c := &Context{}
2141+
c, _ := CreateTestContext(httptest.NewRecorder())
2142+
// enable ContextWithFallback feature flag
2143+
c.engine.ContextWithFallback = true
2144+
21302145
assert.Nil(t, c.Err())
21312146

2132-
c2 := &Context{}
2147+
c2, _ := CreateTestContext(httptest.NewRecorder())
2148+
// enable ContextWithFallback feature flag
2149+
c2.engine.ContextWithFallback = true
2150+
21332151
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
21342152
ctx, cancel := context.WithCancel(context.Background())
21352153
c2.Request = c2.Request.WithContext(ctx)
@@ -2138,9 +2156,9 @@ func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
21382156
assert.EqualError(t, c2.Err(), context.Canceled.Error())
21392157
}
21402158

2141-
type contextKey string
2142-
21432159
func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
2160+
type contextKey string
2161+
21442162
tests := []struct {
21452163
name string
21462164
getContextAndKey func() (*Context, any)
@@ -2150,7 +2168,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
21502168
name: "c with struct context key",
21512169
getContextAndKey: func() (*Context, any) {
21522170
var key struct{}
2153-
c := &Context{}
2171+
c, _ := CreateTestContext(httptest.NewRecorder())
2172+
// enable ContextWithFallback feature flag
2173+
c.engine.ContextWithFallback = true
21542174
c.Request, _ = http.NewRequest("POST", "/", nil)
21552175
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value"))
21562176
return c, key
@@ -2160,7 +2180,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
21602180
{
21612181
name: "c with string context key",
21622182
getContextAndKey: func() (*Context, any) {
2163-
c := &Context{}
2183+
c, _ := CreateTestContext(httptest.NewRecorder())
2184+
// enable ContextWithFallback feature flag
2185+
c.engine.ContextWithFallback = true
21642186
c.Request, _ = http.NewRequest("POST", "/", nil)
21652187
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), contextKey("key"), "value"))
21662188
return c, contextKey("key")
@@ -2170,15 +2192,20 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
21702192
{
21712193
name: "c with nil http.Request",
21722194
getContextAndKey: func() (*Context, any) {
2173-
c := &Context{}
2195+
c, _ := CreateTestContext(httptest.NewRecorder())
2196+
// enable ContextWithFallback feature flag
2197+
c.engine.ContextWithFallback = true
2198+
c.Request = nil
21742199
return c, "key"
21752200
},
21762201
value: nil,
21772202
},
21782203
{
21792204
name: "c with nil http.Request.Context()",
21802205
getContextAndKey: func() (*Context, any) {
2181-
c := &Context{}
2206+
c, _ := CreateTestContext(httptest.NewRecorder())
2207+
// enable ContextWithFallback feature flag
2208+
c.engine.ContextWithFallback = true
21822209
c.Request, _ = http.NewRequest("POST", "/", nil)
21832210
return c, "key"
21842211
},
@@ -2193,6 +2220,70 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
21932220
}
21942221
}
21952222

2223+
func TestContextCopyShouldNotCancel(t *testing.T) {
2224+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
2225+
w.WriteHeader(http.StatusOK)
2226+
}))
2227+
defer srv.Close()
2228+
2229+
ensureRequestIsOver := make(chan struct{})
2230+
2231+
wg := &sync.WaitGroup{}
2232+
2233+
r := New()
2234+
r.GET("/", func(ginctx *Context) {
2235+
wg.Add(1)
2236+
2237+
ginctx = ginctx.Copy()
2238+
2239+
// start async goroutine for calling srv
2240+
go func() {
2241+
defer wg.Done()
2242+
2243+
<-ensureRequestIsOver // ensure request is done
2244+
2245+
req, err := http.NewRequestWithContext(ginctx, http.MethodGet, srv.URL, nil)
2246+
must(err)
2247+
2248+
res, err := http.DefaultClient.Do(req)
2249+
if err != nil {
2250+
t.Error(fmt.Errorf("request error: %w", err))
2251+
return
2252+
}
2253+
2254+
if res.StatusCode != http.StatusOK {
2255+
t.Error(fmt.Errorf("unexpected status code: %s", res.Status))
2256+
}
2257+
}()
2258+
})
2259+
2260+
l, err := net.Listen("tcp", ":0")
2261+
must(err)
2262+
go func() {
2263+
s := &http.Server{
2264+
Handler: r,
2265+
}
2266+
2267+
must(s.Serve(l))
2268+
}()
2269+
2270+
addr := strings.Split(l.Addr().String(), ":")
2271+
res, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/", addr[len(addr)-1]))
2272+
if err != nil {
2273+
t.Error(fmt.Errorf("request error: %w", err))
2274+
return
2275+
}
2276+
2277+
close(ensureRequestIsOver)
2278+
2279+
if res.StatusCode != http.StatusOK {
2280+
t.Error(fmt.Errorf("unexpected status code: %s", res.Status))
2281+
return
2282+
}
2283+
2284+
wg.Wait()
2285+
}
2286+
21962287
func TestContextAddParam(t *testing.T) {
21972288
c := &Context{}
21982289
id := "id"

gin.go

+3
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ type Engine struct {
147147
// UseH2C enable h2c support.
148148
UseH2C bool
149149

150+
// ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() when Context.Request.Context() is not nil.
151+
ContextWithFallback bool
152+
150153
delims render.Delims
151154
secureJSONPrefix string
152155
HTMLRender render.HTMLRender

0 commit comments

Comments
 (0)