23
23
from botocore .exceptions import ClientError
24
24
25
25
from sagemaker .enums import Tag
26
- from sagemaker .jumpstart import enums
27
- from sagemaker .jumpstart .utils import verify_model_region_and_return_specs , get_eula_message
28
26
from sagemaker .model import Model
29
27
from sagemaker import model_uris
30
28
from sagemaker .serve .model_server .djl_serving .prepare import prepare_djl_js_resources
40
38
SkipTuningComboException ,
41
39
)
42
40
from sagemaker .serve .utils .optimize_utils import (
43
- _extract_supported_deployment_config ,
44
- _is_speculation_enabled ,
45
41
_is_compatible_with_optimization_job ,
42
+ _extract_model_source ,
43
+ _update_environment_variables ,
46
44
)
47
45
from sagemaker .serve .utils .predictors import (
48
46
DjlLocalModePredictor ,
@@ -643,7 +641,7 @@ def _optimize_for_jumpstart(
643
641
vpc_config : Optional [Dict ] = None ,
644
642
kms_key : Optional [str ] = None ,
645
643
max_runtime_in_sec : Optional [int ] = None ,
646
- ) -> None :
644
+ ) -> Dict [ str , Any ] :
647
645
"""Runs a model optimization job.
648
646
649
647
Args:
@@ -669,79 +667,60 @@ def _optimize_for_jumpstart(
669
667
to S3. Defaults to ``None``.
670
668
max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to
671
669
``None``.
672
- """
673
- model_specs = verify_model_region_and_return_specs (
674
- region = self .sagemaker_session .boto_region_name ,
675
- model_id = self .pysdk_model .model_id ,
676
- version = self .pysdk_model .model_version ,
677
- sagemaker_session = self .sagemaker_session ,
678
- scope = enums .JumpStartScriptScope .INFERENCE ,
679
- model_type = self .pysdk_model .model_type ,
680
- )
681
670
682
- if model_specs .is_gated_model () and accept_eula is not True :
683
- raise ValueError (get_eula_message (model_specs , self .sagemaker_session .boto_region_name ))
684
-
685
- if not (self .pysdk_model .model_data and self .pysdk_model .model_data .get ("S3DataSource" )):
686
- raise ValueError ("Model Optimization Job only supports model backed by S3." )
671
+ Returns:
672
+ Dict[str, Any]: Model optimization job input arguments.
673
+ """
674
+ if self ._is_gated_model () and accept_eula is not True :
675
+ raise ValueError (
676
+ f"ValueError: Model '{ self .model } ' "
677
+ f"requires accepting end-user license agreement (EULA)."
678
+ )
687
679
688
- has_alternative_config = self .pysdk_model .deployment_config is not None
689
- merged_env_vars = None
690
- # TODO: Match Optimization Input Schema
691
- model_source = {
692
- "S3" : {"S3Uri" : self .pysdk_model .model_data .get ("S3DataSource" ).get ("S3Uri" )},
693
- "SageMakerModel" : {"ModelName" : self .model },
694
- }
680
+ optimization_env_vars = None
681
+ pysdk_model_env_vars = None
682
+ model_source = _extract_model_source (self .pysdk_model .model_data , accept_eula )
695
683
696
- if has_alternative_config :
697
- image_uri = self .pysdk_model .deployment_config .get ("DeploymentArgs" ).get ("ImageUri" )
698
- instance_type = self .pysdk_model .deployment_config .get ("InstanceType" )
684
+ if speculative_decoding_config :
685
+ self ._set_additional_model_source (speculative_decoding_config )
686
+ optimization_env_vars = self .pysdk_model .deployment_config .get ("DeploymentArgs" ).get (
687
+ "Environment"
688
+ )
699
689
else :
700
- image_uri = self .pysdk_model .image_uri
701
-
702
- if not _is_compatible_with_optimization_job (instance_type , image_uri ) or (
703
- speculative_decoding_config
704
- and not _is_speculation_enabled (self .pysdk_model .deployment_config )
705
- ):
706
- deployment_config = _extract_supported_deployment_config (
707
- self .pysdk_model .list_deployment_configs (), speculative_decoding_config is None
690
+ image_uri = None
691
+ if quantization_config and quantization_config .get ("Image" ):
692
+ image_uri = quantization_config .get ("Image" )
693
+ elif compilation_config and compilation_config .get ("Image" ):
694
+ image_uri = compilation_config .get ("Image" )
695
+ instance_type = (
696
+ instance_type
697
+ or self .pysdk_model .deployment_config .get ("DeploymentArgs" ).get ("InstanceType" )
698
+ or _get_nb_instance ()
708
699
)
700
+ if not _is_compatible_with_optimization_job (instance_type , image_uri ):
701
+ deployment_config = self ._find_compatible_deployment_config (None )
702
+ if deployment_config :
703
+ optimization_env_vars = deployment_config .get ("DeploymentArgs" ).get (
704
+ "Environment"
705
+ )
706
+ self .pysdk_model .set_deployment_config (
707
+ config_name = deployment_config .get ("DeploymentConfigName" ),
708
+ instance_type = deployment_config .get ("InstanceType" ),
709
+ )
709
710
710
- if deployment_config :
711
- self .pysdk_model .set_deployment_config (
712
- config_name = deployment_config .get ("DeploymentConfigName" ),
713
- instance_type = deployment_config .get ("InstanceType" ),
714
- )
715
- merged_env_vars = self .pysdk_model .deployment_config .get ("Environment" )
716
-
717
- if speculative_decoding_config :
718
- # TODO: Match Optimization Input Schema
719
- s3 = {
720
- "S3Uri" : self .pysdk_model .additional_model_data_sources [
721
- "SpeculativeDecoding"
722
- ][0 ]["S3DataSource" ]["S3Uri" ]
723
- }
724
- model_source ["S3" ].update (s3 )
725
- elif speculative_decoding_config :
726
- raise ValueError ("Can't find deployment config for model optimization job." )
711
+ optimization_env_vars = _update_environment_variables (optimization_env_vars , env_vars )
727
712
728
713
optimization_config = {}
729
- if env_vars :
730
- if merged_env_vars :
731
- merged_env_vars .update (env_vars )
732
- else :
733
- merged_env_vars = env_vars
734
714
if quantization_config :
735
715
optimization_config ["ModelQuantizationConfig" ] = quantization_config
716
+ pysdk_model_env_vars = _update_environment_variables (
717
+ pysdk_model_env_vars , quantization_config ["OverrideEnvironment" ]
718
+ )
736
719
if compilation_config :
737
720
optimization_config ["ModelCompilationConfig" ] = compilation_config
738
-
739
- if accept_eula :
740
- self .pysdk_model .accept_eula = accept_eula
741
- self .pysdk_model .model_data ["S3DataSource" ].update (
742
- {"ModelAccessConfig" : {"AcceptEula" : accept_eula }}
721
+ pysdk_model_env_vars = _update_environment_variables (
722
+ pysdk_model_env_vars , compilation_config ["OverrideEnvironment" ]
743
723
)
744
- model_source ["S3" ].update ({"ModelAccessConfig" : {"AcceptEula" : accept_eula }})
745
724
746
725
output_config = {"S3OutputLocation" : output_path }
747
726
if kms_key :
@@ -751,12 +730,13 @@ def _optimize_for_jumpstart(
751
730
"OptimizationJobName" : job_name ,
752
731
"ModelSource" : model_source ,
753
732
"DeploymentInstanceType" : instance_type ,
754
- "Environment" : merged_env_vars ,
755
733
"OptimizationConfigs" : [optimization_config ],
756
734
"OutputConfig" : output_config ,
757
735
"RoleArn" : role ,
758
736
}
759
737
738
+ if optimization_env_vars :
739
+ create_optimization_job_args ["Environment" ] = optimization_env_vars
760
740
if max_runtime_in_sec :
761
741
create_optimization_job_args ["StoppingCondition" ] = {
762
742
"MaxRuntimeInSeconds" : max_runtime_in_sec
@@ -766,22 +746,106 @@ def _optimize_for_jumpstart(
766
746
if vpc_config :
767
747
create_optimization_job_args ["VpcConfig" ] = vpc_config
768
748
769
- self .sagemaker_session .sagemaker_client .create_optimization_job (
770
- ** create_optimization_job_args
771
- )
749
+ self .pysdk_model .env .update (pysdk_model_env_vars )
750
+ return create_optimization_job_args
772
751
773
- def _is_gated_model (self , model : Model ) -> bool :
752
+ def _is_gated_model (self , model = None ) -> bool :
774
753
"""Determine if ``this`` Model is Gated
775
754
776
755
Args:
777
756
model (Model): Jumpstart Model
778
757
Returns:
779
758
bool: ``True`` if ``this`` Model is Gated
780
759
"""
781
- s3_uri = model .model_data
760
+ s3_uri = model .model_data if model else self . pysdk_model . model_data
782
761
if isinstance (s3_uri , dict ):
783
762
s3_uri = s3_uri .get ("S3DataSource" ).get ("S3Uri" )
784
763
785
764
if s3_uri is None :
786
765
return False
787
766
return "private" in s3_uri
767
+
768
+ def _set_additional_model_source (
769
+ self , speculative_decoding_config : Optional [Dict [str , Any ]] = None
770
+ ) -> None :
771
+ """Set Additional Model Source to ``this`` model.
772
+
773
+ Args:
774
+ speculative_decoding_config (Optional[Dict[str, Any]]): Speculative decoding config.
775
+ """
776
+ if speculative_decoding_config :
777
+ model_provider : str = speculative_decoding_config ["ModelProvider" ]
778
+
779
+ if model_provider .lower () == "sagemaker" :
780
+ if not self ._is_speculation_enabled (self .pysdk_model .deployment_config ):
781
+ deployment_config = self ._find_compatible_deployment_config (
782
+ speculative_decoding_config
783
+ )
784
+ if deployment_config :
785
+ self .pysdk_model .set_deployment_config (
786
+ config_name = deployment_config .get ("DeploymentConfigName" ),
787
+ instance_type = deployment_config .get ("InstanceType" ),
788
+ )
789
+ self .pysdk_model .add_tags (
790
+ {"key" : Tag .SPECULATIVE_DRAFT_MODL_PROVIDER , "value" : "sagemaker" },
791
+ )
792
+ else :
793
+ raise ValueError (
794
+ "Cannot find deployment config compatible for optimization job."
795
+ )
796
+ else :
797
+ s3_uri = speculative_decoding_config .get ("ModelSource" )
798
+ if not s3_uri :
799
+ raise ValueError ("Custom S3 Uri cannot be none." )
800
+
801
+ self .pysdk_model .additional_model_data_sources ["speculative_decoding" ][0 ][
802
+ "s3_data_source"
803
+ ]["s3_uri" ] = s3_uri
804
+ self .pysdk_model .add_tags (
805
+ {"key" : Tag .SPECULATIVE_DRAFT_MODL_PROVIDER , "value" : "customer" },
806
+ )
807
+
808
+ def _find_compatible_deployment_config (
809
+ self , speculative_decoding_config : Optional [Dict ] = None
810
+ ) -> Optional [Dict [str , Any ]]:
811
+ """Finds compatible model deployment config for optimization job.
812
+
813
+ Args:
814
+ speculative_decoding_config (Optional[Dict]): Speculative decoding config.
815
+
816
+ Returns:
817
+ Optional[Dict[str, Any]]: A compatible model deployment config for optimization job.
818
+ """
819
+ for deployment_config in self .pysdk_model .list_deployment_configs ():
820
+ instance_type = deployment_config .get ("deployment_config" ).get ("InstanceType" )
821
+ image_uri = deployment_config .get ("deployment_config" ).get ("ImageUri" )
822
+
823
+ if _is_compatible_with_optimization_job (instance_type , image_uri ):
824
+ if not speculative_decoding_config :
825
+ return deployment_config
826
+
827
+ if self ._is_speculation_enabled (deployment_config ):
828
+ return deployment_config
829
+
830
+ return None
831
+
832
+ def _is_speculation_enabled (self , deployment_config : Optional [Dict [str , Any ]]) -> bool :
833
+ """Checks whether speculative is enabled for the given deployment config.
834
+
835
+ Args:
836
+ deployment_config (Dict[str, Any]): A deployment config.
837
+
838
+ Returns:
839
+ bool: Whether speculative is enabled for this deployment config.
840
+ """
841
+ if deployment_config is None :
842
+ return False
843
+
844
+ acceleration_configs = deployment_config .get ("AccelerationConfigs" )
845
+ if acceleration_configs :
846
+ for acceleration_config in acceleration_configs :
847
+ if acceleration_config .get (
848
+ "type" , "default"
849
+ ).lower () == "speculative" and acceleration_config .get ("enabled" ):
850
+ return True
851
+ return False
0 commit comments