Skip to content

Commit 9d49f16

Browse files
authored
chore(cors): Allow a custom validation function which receives the full gin context (#140)
* Allow a origin validation function with context * Revert "Allow a origin validation function with context" This reverts commit 82827c2. * Allow origin validation function which receives the full request context * fix logic in conditional * add test, fix logic * slightly re-work to pass linter * update comments * restructure to shorten line lengths to pass linter * remove punctuation at the end of error string * Add multi-group preflight test * remove comment
1 parent 7f30a1f commit 9d49f16

File tree

3 files changed

+128
-24
lines changed

3 files changed

+128
-24
lines changed

config.go

+27-17
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@ import (
88
)
99

1010
type cors struct {
11-
allowAllOrigins bool
12-
allowCredentials bool
13-
allowOriginFunc func(string) bool
14-
allowOrigins []string
15-
normalHeaders http.Header
16-
preflightHeaders http.Header
17-
wildcardOrigins [][]string
18-
optionsResponseStatusCode int
11+
allowAllOrigins bool
12+
allowCredentials bool
13+
allowOriginFunc func(string) bool
14+
allowOriginWithContextFunc func(*gin.Context, string) bool
15+
allowOrigins []string
16+
normalHeaders http.Header
17+
preflightHeaders http.Header
18+
wildcardOrigins [][]string
19+
optionsResponseStatusCode int
1920
}
2021

2122
var (
@@ -54,14 +55,15 @@ func newCors(config Config) *cors {
5455
}
5556

5657
return &cors{
57-
allowOriginFunc: config.AllowOriginFunc,
58-
allowAllOrigins: config.AllowAllOrigins,
59-
allowCredentials: config.AllowCredentials,
60-
allowOrigins: normalize(config.AllowOrigins),
61-
normalHeaders: generateNormalHeaders(config),
62-
preflightHeaders: generatePreflightHeaders(config),
63-
wildcardOrigins: config.parseWildcardRules(),
64-
optionsResponseStatusCode: config.OptionsResponseStatusCode,
58+
allowOriginFunc: config.AllowOriginFunc,
59+
allowOriginWithContextFunc: config.AllowOriginWithContextFunc,
60+
allowAllOrigins: config.AllowAllOrigins,
61+
allowCredentials: config.AllowCredentials,
62+
allowOrigins: normalize(config.AllowOrigins),
63+
normalHeaders: generateNormalHeaders(config),
64+
preflightHeaders: generatePreflightHeaders(config),
65+
wildcardOrigins: config.parseWildcardRules(),
66+
optionsResponseStatusCode: config.OptionsResponseStatusCode,
6567
}
6668
}
6769

@@ -79,7 +81,7 @@ func (cors *cors) applyCors(c *gin.Context) {
7981
return
8082
}
8183

82-
if !cors.validateOrigin(origin) {
84+
if !cors.isOriginValid(c, origin) {
8385
c.AbortWithStatus(http.StatusForbidden)
8486
return
8587
}
@@ -112,6 +114,14 @@ func (cors *cors) validateWildcardOrigin(origin string) bool {
112114
return false
113115
}
114116

117+
func (cors *cors) isOriginValid(c *gin.Context, origin string) bool {
118+
valid := cors.validateOrigin(origin)
119+
if !valid && cors.allowOriginWithContextFunc != nil {
120+
valid = cors.allowOriginWithContextFunc(c, origin)
121+
}
122+
return valid
123+
}
124+
115125
func (cors *cors) validateOrigin(origin string) bool {
116126
if cors.allowAllOrigins {
117127
return true

cors.go

+21-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cors
22

33
import (
44
"errors"
5+
"fmt"
56
"strings"
67
"time"
78

@@ -22,6 +23,12 @@ type Config struct {
2223
// set, the content of AllowOrigins is ignored.
2324
AllowOriginFunc func(origin string) bool
2425

26+
// Same as AllowOriginFunc except also receives the full request context.
27+
// This function should use the context as a read only source and not
28+
// have any side effects on the request, such as aborting or injecting
29+
// values on the request.
30+
AllowOriginWithContextFunc func(c *gin.Context, origin string) bool
31+
2532
// AllowMethods is a list of methods the client is allowed to use with
2633
// cross-domain requests. Default value is simple methods (GET, POST, PUT, PATCH, DELETE, HEAD, and OPTIONS)
2734
AllowMethods []string
@@ -108,10 +115,21 @@ func (c Config) validateAllowedSchemas(origin string) bool {
108115

109116
// Validate is check configuration of user defined.
110117
func (c Config) Validate() error {
111-
if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
112-
return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOrigins is not needed")
118+
hasOriginFn := c.AllowOriginFunc != nil
119+
hasOriginFn = hasOriginFn || c.AllowOriginWithContextFunc != nil
120+
121+
if c.AllowAllOrigins && (hasOriginFn || len(c.AllowOrigins) > 0) {
122+
originFields := strings.Join([]string{
123+
"AllowOriginFunc",
124+
"AllowOriginFuncWithContext",
125+
"AllowOrigins",
126+
}, " or ")
127+
return fmt.Errorf(
128+
"conflict settings: all origins enabled. %s is not needed",
129+
originFields,
130+
)
113131
}
114-
if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 {
132+
if !c.AllowAllOrigins && !hasOriginFn && len(c.AllowOrigins) == 0 {
115133
return errors.New("conflict settings: all origins disabled")
116134
}
117135
for _, origin := range c.AllowOrigins {

cors_test.go

+80-4
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,34 @@ func newTestRouter(config Config) *gin.Engine {
2828
return router
2929
}
3030

31+
func multiGroupRouter(config Config) *gin.Engine {
32+
router := gin.New()
33+
router.Use(New(config))
34+
35+
app1 := router.Group("/app1")
36+
app1.GET("", func(c *gin.Context) {
37+
c.String(http.StatusOK, "app1")
38+
})
39+
40+
app2 := router.Group("/app2")
41+
app2.GET("", func(c *gin.Context) {
42+
c.String(http.StatusOK, "app2")
43+
})
44+
45+
app3 := router.Group("/app3")
46+
app3.GET("", func(c *gin.Context) {
47+
c.String(http.StatusOK, "app3")
48+
})
49+
50+
return router
51+
}
52+
3153
func performRequest(r http.Handler, method, origin string) *httptest.ResponseRecorder {
32-
return performRequestWithHeaders(r, method, origin, http.Header{})
54+
return performRequestWithHeaders(r, method, "/", origin, http.Header{})
3355
}
3456

35-
func performRequestWithHeaders(r http.Handler, method, origin string, header http.Header) *httptest.ResponseRecorder {
36-
req, _ := http.NewRequestWithContext(context.Background(), method, "/", nil)
57+
func performRequestWithHeaders(r http.Handler, method, path, origin string, header http.Header) *httptest.ResponseRecorder {
58+
req, _ := http.NewRequestWithContext(context.Background(), method, path, nil)
3759
// From go/net/http/request.go:
3860
// For incoming requests, the Host header is promoted to the
3961
// Request.Host field and removed from the Header map.
@@ -299,6 +321,9 @@ func TestPassesAllowOrigins(t *testing.T) {
299321
AllowOriginFunc: func(origin string) bool {
300322
return origin == "http://github.com"
301323
},
324+
AllowOriginWithContextFunc: func(c *gin.Context, origin string) bool {
325+
return origin == "http://sample.com"
326+
},
302327
})
303328

304329
// no CORS request, origin == ""
@@ -311,7 +336,7 @@ func TestPassesAllowOrigins(t *testing.T) {
311336
// no CORS request, origin == host
312337
h := http.Header{}
313338
h.Set("Host", "facebook.com")
314-
w = performRequestWithHeaders(router, "GET", "http://facebook.com", h)
339+
w = performRequestWithHeaders(router, "GET", "/", "http://facebook.com", h)
315340
assert.Equal(t, "get", w.Body.String())
316341
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
317342
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
@@ -346,6 +371,15 @@ func TestPassesAllowOrigins(t *testing.T) {
346371
assert.Equal(t, "Content-Type,Timestamp", w.Header().Get("Access-Control-Allow-Headers"))
347372
assert.Equal(t, "43200", w.Header().Get("Access-Control-Max-Age"))
348373

374+
// allowed CORS prefligh request: allowed via AllowOriginWithContextFunc
375+
w = performRequest(router, "OPTIONS", "http://sample.com")
376+
assert.Equal(t, http.StatusNoContent, w.Code)
377+
assert.Equal(t, "http://sample.com", w.Header().Get("Access-Control-Allow-Origin"))
378+
assert.Equal(t, "", w.Header().Get("Access-Control-Allow-Credentials"))
379+
assert.Equal(t, "GET,POST,PUT,HEAD", w.Header().Get("Access-Control-Allow-Methods"))
380+
assert.Equal(t, "Content-Type,Timestamp", w.Header().Get("Access-Control-Allow-Headers"))
381+
assert.Equal(t, "43200", w.Header().Get("Access-Control-Max-Age"))
382+
349383
// deny CORS prefligh request
350384
w = performRequest(router, "OPTIONS", "http://example.com")
351385
assert.Equal(t, http.StatusForbidden, w.Code)
@@ -432,6 +466,48 @@ func TestWildcard(t *testing.T) {
432466
assert.Equal(t, 200, w.Code)
433467
}
434468

469+
func TestMultiGroupRouter(t *testing.T) {
470+
router := multiGroupRouter(Config{
471+
AllowMethods: []string{"GET"},
472+
AllowOriginWithContextFunc: func(c *gin.Context, origin string) bool {
473+
path := c.Request.URL.Path
474+
if strings.HasPrefix(path, "/app1") {
475+
return "http://app1.example.com" == origin
476+
}
477+
478+
if strings.HasPrefix(path, "/app2") {
479+
return "http://app2.example.com" == origin
480+
}
481+
482+
// app 3 allows all origins
483+
return true
484+
},
485+
})
486+
487+
// allowed CORS prefligh request
488+
emptyHeaders := http.Header{}
489+
app1Origin := "http://app1.example.com"
490+
app2Origin := "http://app2.example.com"
491+
randomOrgin := "http://random.com"
492+
493+
// allowed CORS preflight
494+
w := performRequestWithHeaders(router, "OPTIONS", "/app1", app1Origin, emptyHeaders)
495+
assert.Equal(t, http.StatusNoContent, w.Code)
496+
497+
w = performRequestWithHeaders(router, "OPTIONS", "/app2", app2Origin, emptyHeaders)
498+
assert.Equal(t, http.StatusNoContent, w.Code)
499+
500+
w = performRequestWithHeaders(router, "OPTIONS", "/app3", randomOrgin, emptyHeaders)
501+
assert.Equal(t, http.StatusNoContent, w.Code)
502+
503+
// disallowed CORS preflight
504+
w = performRequestWithHeaders(router, "OPTIONS", "/app1", randomOrgin, emptyHeaders)
505+
assert.Equal(t, http.StatusForbidden, w.Code)
506+
507+
w = performRequestWithHeaders(router, "OPTIONS", "/app2", randomOrgin, emptyHeaders)
508+
assert.Equal(t, http.StatusForbidden, w.Code)
509+
}
510+
435511
func TestParseWildcardRules_NoWildcard(t *testing.T) {
436512
config := Config{
437513
AllowOrigins: []string{

0 commit comments

Comments
 (0)