4
4
import re
5
5
import traceback
6
6
import warnings
7
- import zlib
8
7
from abc import ABC , abstractmethod
9
8
from enum import Enum
10
9
from functools import partial
17
16
Match ,
18
17
Optional ,
19
18
Pattern ,
19
+ Sequence ,
20
20
Set ,
21
21
Tuple ,
22
22
Type ,
23
23
Union ,
24
24
)
25
25
26
+ import zlib
27
+ from pydantic .fields import ModelField
28
+ from pydantic .schema import get_flat_models_from_fields , get_model_name_map , model_process_schema
29
+
26
30
from aws_lambda_powertools .event_handler import content_types
27
31
from aws_lambda_powertools .event_handler .exceptions import NotFoundError , ServiceError
32
+ from aws_lambda_powertools .event_handler .openapi .dependant import get_dependant
33
+ from aws_lambda_powertools .event_handler .openapi .models import Contact , License , OpenAPI , Server , Tag
34
+ from aws_lambda_powertools .event_handler .openapi .utils import get_flat_params
35
+ from aws_lambda_powertools .event_handler .route import Route
28
36
from aws_lambda_powertools .shared .cookies import Cookie
29
37
from aws_lambda_powertools .shared .functions import powertools_dev_is_set
30
38
from aws_lambda_powertools .shared .json_encoder import Encoder
@@ -207,26 +215,6 @@ def __init__(
207
215
self .headers .setdefault ("Content-Type" , content_type )
208
216
209
217
210
- class Route :
211
- """Internally used Route Configuration"""
212
-
213
- def __init__ (
214
- self ,
215
- method : str ,
216
- rule : Pattern ,
217
- func : Callable ,
218
- cors : bool ,
219
- compress : bool ,
220
- cache_control : Optional [str ],
221
- ):
222
- self .method = method .upper ()
223
- self .rule = rule
224
- self .func = func
225
- self .cors = cors
226
- self .compress = compress
227
- self .cache_control = cache_control
228
-
229
-
230
218
class ResponseBuilder :
231
219
"""Internally used Response builder"""
232
220
@@ -554,6 +542,119 @@ def __init__(
554
542
# Allow for a custom serializer or a concise json serialization
555
543
self ._serializer = serializer or partial (json .dumps , separators = ("," , ":" ), cls = Encoder )
556
544
545
+ def get_openapi_schema (
546
+ self ,
547
+ * ,
548
+ title : str ,
549
+ version : str ,
550
+ openapi_version : str = "3.1.0" ,
551
+ summary : Optional [str ] = None ,
552
+ description : Optional [str ] = None ,
553
+ tags : Optional [List [Tag ]] = None ,
554
+ servers : Optional [List [Server ]] = None ,
555
+ terms_of_service : Optional [str ] = None ,
556
+ contact : Optional [Contact ] = None ,
557
+ license_info : Optional [License ] = None ,
558
+ ) -> OpenAPI :
559
+ info : Dict [str , Any ] = {"title" : title , "version" : version }
560
+ if summary :
561
+ info ["summary" ] = summary
562
+ if description :
563
+ info ["description" ] = description
564
+ if terms_of_service :
565
+ info ["termsOfService" ] = terms_of_service
566
+ if contact :
567
+ info ["contact" ] = contact
568
+ if license_info :
569
+ info ["license" ] = license_info
570
+
571
+ output : Dict [str , Any ] = {"openapi" : openapi_version , "info" : info }
572
+ if servers :
573
+ output ["servers" ] = servers
574
+ else :
575
+ # If the servers property is not provided, or is an empty array, the default value would be a Server Object
576
+ # with a url value of /.
577
+ output ["servers" ] = [Server (url = "/" )]
578
+
579
+ components : Dict [str , Dict [str , Any ]] = {}
580
+ paths : Dict [str , Dict [str , Any ]] = {}
581
+ operation_ids : Set [str ] = set ()
582
+
583
+ all_routes = self ._dynamic_routes + self ._static_routes
584
+ all_fields = self ._get_fields_from_routes (all_routes )
585
+ models = get_flat_models_from_fields (all_fields , known_models = set ())
586
+ model_name_map = get_model_name_map (models )
587
+
588
+ definitions : Dict [str , Dict [str , Any ]] = {}
589
+ for model in models :
590
+ m_schema , m_definitions , _ = model_process_schema (
591
+ model ,
592
+ model_name_map = model_name_map ,
593
+ ref_prefix = "#/components/schemas/" ,
594
+ )
595
+ definitions .update (m_definitions )
596
+ model_name = model_name_map [model ]
597
+ if "description" in m_schema :
598
+ m_schema ["description" ] = m_schema ["description" ].split ("\f " )[0 ]
599
+ definitions [model_name ] = m_schema
600
+
601
+ for route in all_routes :
602
+ dependant = get_dependant (
603
+ path = route .func .__name__ ,
604
+ call = route .func ,
605
+ )
606
+
607
+ result = route ._openapi_path (
608
+ dependant = dependant ,
609
+ operation_ids = operation_ids ,
610
+ model_name_map = model_name_map ,
611
+ )
612
+ if result :
613
+ path , path_definitions = result
614
+ if path :
615
+ paths .setdefault (route .path , {}).update (path )
616
+ if path_definitions :
617
+ definitions .update (path_definitions )
618
+
619
+ if definitions :
620
+ components ["schemas" ] = {k : definitions [k ] for k in sorted (definitions )}
621
+ if components :
622
+ output ["components" ] = components
623
+ if tags :
624
+ output ["tags" ] = tags
625
+
626
+ output ["paths" ] = paths
627
+
628
+ return OpenAPI (** output ) # .dict(by_alias=True, exclude_none=True)
629
+
630
+ def get_openapi_json_schema (
631
+ self ,
632
+ * ,
633
+ title : str ,
634
+ version : str ,
635
+ openapi_version : str = "3.1.0" ,
636
+ summary : Optional [str ] = None ,
637
+ description : Optional [str ] = None ,
638
+ tags : Optional [List [Tag ]] = None ,
639
+ servers : Optional [List [Server ]] = None ,
640
+ terms_of_service : Optional [str ] = None ,
641
+ contact : Optional [Contact ] = None ,
642
+ license_info : Optional [License ] = None ,
643
+ ) -> str :
644
+ """Returns the OpenAPI schema as a JSON serializable dict"""
645
+ return self .get_openapi_schema (
646
+ title = title ,
647
+ version = version ,
648
+ openapi_version = openapi_version ,
649
+ summary = summary ,
650
+ description = description ,
651
+ tags = tags ,
652
+ servers = servers ,
653
+ terms_of_service = terms_of_service ,
654
+ contact = contact ,
655
+ license_info = license_info ,
656
+ ).json (by_alias = True , exclude_none = True , indent = 2 )
657
+
557
658
def route (
558
659
self ,
559
660
rule : str ,
@@ -573,7 +674,7 @@ def register_resolver(func: Callable):
573
674
cors_enabled = cors
574
675
575
676
for item in methods :
576
- _route = Route (item , self ._compile_regex (rule ), func , cors_enabled , compress , cache_control )
677
+ _route = Route (item , rule , self ._compile_regex (rule ), func , cors_enabled , compress , cache_control )
577
678
578
679
# The more specific route wins.
579
680
# We store dynamic (/studies/{studyid}) and static routes (/studies/fetch) separately.
@@ -889,6 +990,22 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None
889
990
890
991
self .route (* new_route )(func )
891
992
993
+ @staticmethod
994
+ def _get_fields_from_routes (routes : Sequence [Route ]) -> List [ModelField ]:
995
+ responses_from_routes : List [ModelField ] = []
996
+ request_fields_from_routes : List [ModelField ] = []
997
+
998
+ for route in routes :
999
+ dependant = get_dependant (path = route .path , call = route .func )
1000
+ params = get_flat_params (dependant )
1001
+ request_fields_from_routes .extend (params )
1002
+
1003
+ if dependant .return_param :
1004
+ responses_from_routes .append (dependant .return_param )
1005
+
1006
+ flat_models = list (responses_from_routes + request_fields_from_routes )
1007
+ return flat_models
1008
+
892
1009
893
1010
class Router (BaseRouter ):
894
1011
"""Router helper class to allow splitting ApiGatewayResolver into multiple files"""
0 commit comments