@@ -48,6 +48,7 @@ def read_media(file_name: str) -> bytes:
48
48
49
49
50
50
LOAD_GW_EVENT = load_event ("apiGatewayProxyEvent.json" )
51
+ LOAD_GW_EVENT_NO_ORIGIN = load_event ("apiGatewayProxyEventNoOrigin.json" )
51
52
LOAD_GW_EVENT_TRAILING_SLASH = load_event ("apiGatewayProxyEventPathTrailingSlash.json" )
52
53
53
54
@@ -324,7 +325,7 @@ def handler(event, context):
324
325
def test_cors ():
325
326
# GIVEN a function with cors=True
326
327
# AND http method set to GET
327
- app = ApiGatewayResolver ()
328
+ app = ApiGatewayResolver (cors = CORSConfig ( "https://aws.amazon.com" , allow_credentials = True ) )
328
329
329
330
@app .get ("/my/path" , cors = True )
330
331
def with_cors () -> Response :
@@ -345,6 +346,69 @@ def handler(event, context):
345
346
headers = result ["multiValueHeaders" ]
346
347
assert headers ["Content-Type" ] == [content_types .TEXT_HTML ]
347
348
assert headers ["Access-Control-Allow-Origin" ] == ["https://aws.amazon.com" ]
349
+ assert "Access-Control-Allow-Credentials" in headers
350
+ assert headers ["Access-Control-Allow-Headers" ] == ["," .join (sorted (CORSConfig ._REQUIRED_HEADERS ))]
351
+
352
+ # THEN for routes without cors flag return no cors headers
353
+ mock_event = {"path" : "/my/request" , "httpMethod" : "GET" }
354
+ result = handler (mock_event , None )
355
+ assert "Access-Control-Allow-Origin" not in result ["multiValueHeaders" ]
356
+
357
+
358
+ def test_cors_no_request_origin ():
359
+ # GIVEN a function with cors=True
360
+ # AND http method set to GET
361
+ app = ApiGatewayResolver ()
362
+
363
+ @app .get ("/my/path" , cors = True )
364
+ def with_cors () -> Response :
365
+ return Response (200 , content_types .TEXT_HTML , "test" )
366
+
367
+ def handler (event , context ):
368
+ return app .resolve (event , context )
369
+
370
+ event = LOAD_GW_EVENT_NO_ORIGIN
371
+
372
+ # WHEN calling the event handler
373
+ result = handler (event , None )
374
+
375
+ # THEN the headers should include cors headers
376
+ assert "multiValueHeaders" in result
377
+ headers = result ["multiValueHeaders" ]
378
+ assert headers ["Content-Type" ] == [content_types .TEXT_HTML ]
379
+ assert "Access-Control-Allow-Credentials" not in headers
380
+ assert "Access-Control-Allow-Origin" not in result ["multiValueHeaders" ]
381
+
382
+
383
+ def test_cors_allow_all_request_origins ():
384
+ # GIVEN a function with cors=True
385
+ # AND http method set to GET
386
+ app = ApiGatewayResolver (
387
+ cors = CORSConfig (
388
+ allow_origin = "*" ,
389
+ allow_credentials = True ,
390
+ ),
391
+ )
392
+
393
+ @app .get ("/my/path" , cors = True )
394
+ def with_cors () -> Response :
395
+ return Response (200 , content_types .TEXT_HTML , "test" )
396
+
397
+ @app .get ("/without-cors" )
398
+ def without_cors () -> Response :
399
+ return Response (200 , content_types .TEXT_HTML , "test" )
400
+
401
+ def handler (event , context ):
402
+ return app .resolve (event , context )
403
+
404
+ # WHEN calling the event handler
405
+ result = handler (LOAD_GW_EVENT , None )
406
+
407
+ # THEN the headers should include cors headers
408
+ assert "multiValueHeaders" in result
409
+ headers = result ["multiValueHeaders" ]
410
+ assert headers ["Content-Type" ] == [content_types .TEXT_HTML ]
411
+ assert headers ["Access-Control-Allow-Origin" ] == ["*" ]
348
412
assert "Access-Control-Allow-Credentials" not in headers
349
413
assert headers ["Access-Control-Allow-Headers" ] == ["," .join (sorted (CORSConfig ._REQUIRED_HEADERS ))]
350
414
0 commit comments