13
13
"""This module stores types related to SageMaker JumpStart HubAPI requests and responses."""
14
14
from __future__ import absolute_import
15
15
16
+ from enum import Enum
16
17
import re
17
18
import json
18
19
import datetime
19
20
20
21
from typing import Any , Dict , List , Union , Optional
22
+ from sagemaker .jumpstart .enums import JumpStartScriptScope
21
23
from sagemaker .jumpstart .types import (
22
24
HubContentType ,
23
25
HubArnExtractedInfo ,
26
+ JumpStartConfigComponent ,
27
+ JumpStartConfigRanking ,
28
+ JumpStartMetadataConfig ,
29
+ JumpStartMetadataConfigs ,
24
30
JumpStartPredictorSpecs ,
25
31
JumpStartHyperparameter ,
26
32
JumpStartDataHolderType ,
34
40
)
35
41
36
42
43
+ class _ComponentType (str , Enum ):
44
+ """Enum for different component types."""
45
+
46
+ INFERENCE = "Inference"
47
+ TRAINING = "Training"
48
+
49
+
37
50
class HubDataHolderType (JumpStartDataHolderType ):
38
51
"""Base class for many Hub API interfaces."""
39
52
@@ -456,6 +469,9 @@ class HubModelDocument(HubDataHolderType):
456
469
"hosting_use_script_uri" ,
457
470
"hosting_eula_uri" ,
458
471
"hosting_model_package_arn" ,
472
+ "inference_configs" ,
473
+ "inference_config_components" ,
474
+ "inference_config_rankings" ,
459
475
"training_artifact_s3_data_type" ,
460
476
"training_artifact_compression_type" ,
461
477
"training_model_package_artifact_uri" ,
@@ -467,6 +483,9 @@ class HubModelDocument(HubDataHolderType):
467
483
"training_ecr_uri" ,
468
484
"training_metrics" ,
469
485
"training_artifact_uri" ,
486
+ "training_configs" ,
487
+ "training_config_components" ,
488
+ "training_config_rankings" ,
470
489
"inference_dependencies" ,
471
490
"training_dependencies" ,
472
491
"default_inference_instance_type" ,
@@ -566,6 +585,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
566
585
)
567
586
self .hosting_eula_uri : Optional [str ] = json_obj .get ("HostingEulaUri" )
568
587
self .hosting_model_package_arn : Optional [str ] = json_obj .get ("HostingModelPackageArn" )
588
+
589
+ self .inference_config_rankings = self ._get_config_rankings (json_obj )
590
+ self .inference_config_components = self ._get_config_components (json_obj )
591
+ self .inference_configs = self ._get_configs (json_obj )
592
+
569
593
self .default_inference_instance_type : Optional [str ] = json_obj .get (
570
594
"DefaultInferenceInstanceType"
571
595
)
@@ -667,6 +691,15 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
667
691
"TrainingMetrics" , None
668
692
)
669
693
self .training_artifact_uri : Optional [str ] = json_obj .get ("TrainingArtifactUri" )
694
+
695
+ self .training_config_rankings = self ._get_config_rankings (
696
+ json_obj , _ComponentType .TRAINING
697
+ )
698
+ self .training_config_components = self ._get_config_components (
699
+ json_obj , _ComponentType .TRAINING
700
+ )
701
+ self .training_configs = self ._get_configs (json_obj , _ComponentType .TRAINING )
702
+
670
703
self .training_dependencies : Optional [str ] = json_obj .get ("TrainingDependencies" )
671
704
self .default_training_instance_type : Optional [str ] = json_obj .get (
672
705
"DefaultTrainingInstanceType"
@@ -707,6 +740,64 @@ def get_region(self) -> str:
707
740
"""Returns hub region."""
708
741
return self ._region
709
742
743
+ def _get_config_rankings (
744
+ self , json_obj : Dict [str , Any ], component_type = _ComponentType .INFERENCE
745
+ ) -> Optional [Dict [str , JumpStartConfigRanking ]]:
746
+ """Returns config rankings."""
747
+ config_rankings = json_obj .get (f"{ component_type .value } ConfigRankings" )
748
+ return (
749
+ {
750
+ alias : JumpStartConfigRanking (ranking , is_hub_content = True )
751
+ for alias , ranking in config_rankings .items ()
752
+ }
753
+ if config_rankings
754
+ else None
755
+ )
756
+
757
+ def _get_config_components (
758
+ self , json_obj : Dict [str , Any ], component_type = _ComponentType .INFERENCE
759
+ ) -> Optional [Dict [str , JumpStartConfigComponent ]]:
760
+ """Returns config components."""
761
+ config_components = json_obj .get (f"{ component_type .value } ConfigComponents" )
762
+ return (
763
+ {
764
+ alias : JumpStartConfigComponent (alias , config , is_hub_content = True )
765
+ for alias , config in config_components .items ()
766
+ }
767
+ if config_components
768
+ else None
769
+ )
770
+
771
+ def _get_configs (
772
+ self , json_obj : Dict [str , Any ], component_type = _ComponentType .INFERENCE
773
+ ) -> Optional [JumpStartMetadataConfigs ]:
774
+ """Returns configs."""
775
+ if not (configs := json_obj .get (f"{ component_type .value } Configs" )):
776
+ return None
777
+
778
+ configs_dict = {}
779
+ for alias , config in configs .items ():
780
+ config_components = None
781
+ if isinstance (config , dict ) and (component_names := config .get ("ComponentNames" )):
782
+ config_components = {
783
+ name : getattr (self , f"{ component_type .value .lower ()} _config_components" ).get (
784
+ name
785
+ )
786
+ for name in component_names
787
+ }
788
+ configs_dict [alias ] = JumpStartMetadataConfig (
789
+ alias , config , json_obj , config_components , is_hub_content = True
790
+ )
791
+
792
+ if component_type == _ComponentType .INFERENCE :
793
+ config_rankings = self .inference_config_rankings
794
+ scope = JumpStartScriptScope .INFERENCE
795
+ else :
796
+ config_rankings = self .training_config_rankings
797
+ scope = JumpStartScriptScope .TRAINING
798
+
799
+ return JumpStartMetadataConfigs (configs_dict , config_rankings , scope )
800
+
710
801
711
802
class HubNotebookDocument (HubDataHolderType ):
712
803
"""Data class for notebook type HubContentDocument from session.describe_hub_content()."""
0 commit comments