42
42
_update_environment_variables ,
43
43
_extract_speculative_draft_model_provider ,
44
44
_is_image_compatible_with_optimization_job ,
45
- _extracts_and_validates_speculative_model_source ,
46
45
_generate_channel_name ,
47
- _generate_additional_model_data_sources ,
48
- _is_s3_uri ,
46
+ _extract_optimization_config_and_env ,
47
+ _is_optimized ,
48
+ _custom_speculative_decoding ,
49
+ SPECULATIVE_DRAFT_MODEL ,
49
50
)
50
51
from sagemaker .serve .utils .predictors import (
51
52
DjlLocalModePredictor ,
@@ -121,7 +122,7 @@ def __init__(self):
121
122
self .speculative_decoding_draft_model_source = None
122
123
123
124
@abstractmethod
124
- def _prepare_for_mode (self ):
125
+ def _prepare_for_mode (self , ** kwargs ):
125
126
"""Placeholder docstring"""
126
127
127
128
@abstractmethod
@@ -130,6 +131,9 @@ def _get_client_translators(self):
130
131
131
132
def _is_jumpstart_model_id (self ) -> bool :
132
133
"""Placeholder docstring"""
134
+ if self .model is None :
135
+ return False
136
+
133
137
try :
134
138
model_uris .retrieve (model_id = self .model , model_version = "*" , model_scope = _JS_SCOPE )
135
139
except KeyError :
@@ -141,8 +145,9 @@ def _is_jumpstart_model_id(self) -> bool:
141
145
142
146
def _create_pre_trained_js_model (self ) -> Type [Model ]:
143
147
"""Placeholder docstring"""
144
- pysdk_model = JumpStartModel (self .model , vpc_config = self .vpc_config )
145
- pysdk_model .sagemaker_session = self .sagemaker_session
148
+ pysdk_model = JumpStartModel (
149
+ self .model , vpc_config = self .vpc_config , sagemaker_session = self .sagemaker_session
150
+ )
146
151
147
152
self ._original_deploy = pysdk_model .deploy
148
153
pysdk_model .deploy = self ._js_builder_deploy_wrapper
@@ -151,6 +156,7 @@ def _create_pre_trained_js_model(self) -> Type[Model]:
151
156
@_capture_telemetry ("jumpstart.deploy" )
152
157
def _js_builder_deploy_wrapper (self , * args , ** kwargs ) -> Type [PredictorBase ]:
153
158
"""Placeholder docstring"""
159
+ env = {}
154
160
if "mode" in kwargs and kwargs .get ("mode" ) != self .mode :
155
161
overwrite_mode = kwargs .get ("mode" )
156
162
# mode overwritten by customer during model.deploy()
@@ -167,7 +173,8 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
167
173
or not hasattr (self , "prepared_for_tgi" )
168
174
or not hasattr (self , "prepared_for_mms" )
169
175
):
170
- self .pysdk_model .model_data , env = self ._prepare_for_mode ()
176
+ if not _is_optimized (self .pysdk_model ):
177
+ self .pysdk_model .model_data , env = self ._prepare_for_mode ()
171
178
elif overwrite_mode == Mode .LOCAL_CONTAINER :
172
179
self .mode = self .pysdk_model .mode = Mode .LOCAL_CONTAINER
173
180
@@ -198,7 +205,6 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
198
205
)
199
206
200
207
self ._prepare_for_mode ()
201
- env = {}
202
208
else :
203
209
raise ValueError ("Mode %s is not supported!" % overwrite_mode )
204
210
@@ -726,25 +732,17 @@ def _optimize_for_jumpstart(
726
732
)
727
733
728
734
model_source = _generate_model_source (self .pysdk_model .model_data , accept_eula )
729
-
730
- optimization_config = {}
731
- if quantization_config :
732
- optimization_config ["ModelQuantizationConfig" ] = quantization_config
733
- pysdk_model_env_vars = _update_environment_variables (
734
- pysdk_model_env_vars , quantization_config ["OverrideEnvironment" ]
735
- )
736
- if compilation_config :
737
- optimization_config ["ModelCompilationConfig" ] = compilation_config
738
- pysdk_model_env_vars = _update_environment_variables (
739
- pysdk_model_env_vars , compilation_config ["OverrideEnvironment" ]
740
- )
735
+ optimization_config , env = _extract_optimization_config_and_env (
736
+ quantization_config , compilation_config
737
+ )
738
+ pysdk_model_env_vars = _update_environment_variables (pysdk_model_env_vars , env )
741
739
742
740
output_config = {"S3OutputLocation" : output_path }
743
741
if kms_key :
744
742
output_config ["KmsKeyId" ] = kms_key
745
743
if not instance_type :
746
- instance_type = self .pysdk_model .deployment_config .get ("DeploymentArgs" ).get (
747
- "InstanceType"
744
+ instance_type = self .pysdk_model .deployment_config .get ("DeploymentArgs" , {} ).get (
745
+ "InstanceType" , _get_nb_instance ()
748
746
)
749
747
750
748
create_optimization_job_args = {
@@ -771,6 +769,10 @@ def _optimize_for_jumpstart(
771
769
self .pysdk_model .env .update (pysdk_model_env_vars )
772
770
if accept_eula :
773
771
self .pysdk_model .accept_eula = accept_eula
772
+ if isinstance (self .pysdk_model .model_data , dict ):
773
+ self .pysdk_model .model_data ["S3DataSource" ]["ModelAccessConfig" ] = {
774
+ "AcceptEula" : True
775
+ }
774
776
775
777
if quantization_config or compilation_config :
776
778
return create_optimization_job_args
@@ -806,7 +808,6 @@ def _set_additional_model_source(
806
808
if speculative_decoding_config :
807
809
model_provider = _extract_speculative_draft_model_provider (speculative_decoding_config )
808
810
channel_name = _generate_channel_name (self .pysdk_model .additional_model_data_sources )
809
- speculative_draft_model = f"/opt/ml/additional-model-data-sources/{ channel_name } "
810
811
811
812
if model_provider == "sagemaker" :
812
813
additional_model_data_sources = self .pysdk_model .deployment_config .get (
@@ -825,32 +826,18 @@ def _set_additional_model_source(
825
826
raise ValueError (
826
827
"Cannot find deployment config compatible for optimization job."
827
828
)
829
+
830
+ self .pysdk_model .env .update (
831
+ {"OPTION_SPECULATIVE_DRAFT_MODEL" : f"{ SPECULATIVE_DRAFT_MODEL } /{ channel_name } " }
832
+ )
833
+ self .pysdk_model .add_tags (
834
+ {"Key" : Tag .SPECULATIVE_DRAFT_MODEL_PROVIDER , "Value" : "sagemaker" },
835
+ )
828
836
else :
829
- model_source = _extracts_and_validates_speculative_model_source (
830
- speculative_decoding_config
837
+ self . pysdk_model = _custom_speculative_decoding (
838
+ self . pysdk_model , speculative_decoding_config , accept_eula
831
839
)
832
840
833
- if _is_s3_uri (model_source ):
834
- self .pysdk_model .additional_model_data_sources = (
835
- _generate_additional_model_data_sources (
836
- model_source , channel_name , accept_eula
837
- )
838
- )
839
- else :
840
- speculative_draft_model = model_source
841
-
842
- self .pysdk_model .env = _update_environment_variables (
843
- self .pysdk_model .env ,
844
- {"OPTION_SPECULATIVE_DRAFT_MODEL" : speculative_draft_model },
845
- )
846
- self .pysdk_model .add_tags (
847
- {"Key" : Tag .SPECULATIVE_DRAFT_MODEL_PROVIDER , "Value" : model_provider },
848
- )
849
- if accept_eula and isinstance (self .pysdk_model .model_data , dict ):
850
- self .pysdk_model .model_data ["S3DataSource" ]["ModelAccessConfig" ] = {
851
- "AcceptEula" : True
852
- }
853
-
854
841
def _find_compatible_deployment_config (
855
842
self , speculative_decoding_config : Optional [Dict ] = None
856
843
) -> Optional [Dict [str , Any ]]:
0 commit comments