36
36
_set_serve_properties ,
37
37
_get_admissible_tensor_parallel_degrees ,
38
38
_get_admissible_dtypes ,
39
+ _get_default_tensor_parallel_degree ,
40
+ )
41
+ from sagemaker .serve .utils .local_hardware import (
42
+ _get_nb_instance ,
43
+ _get_ram_usage_mb ,
44
+ _get_gpu_info ,
45
+ _get_gpu_info_fallback ,
39
46
)
40
- from sagemaker .serve .utils .local_hardware import _get_nb_instance , _get_ram_usage_mb
41
47
from sagemaker .serve .model_server .djl_serving .prepare import (
42
48
prepare_for_djl_serving ,
43
49
_create_dir_structure ,
@@ -164,13 +170,6 @@ def _create_djl_model(self) -> Type[Model]:
164
170
@_capture_telemetry ("djl.deploy" )
165
171
def _djl_model_builder_deploy_wrapper (self , * args , ** kwargs ) -> Type [PredictorBase ]:
166
172
"""Placeholder docstring"""
167
- prepare_for_djl_serving (
168
- model_path = self .model_path ,
169
- model = self .pysdk_model ,
170
- dependencies = self .dependencies ,
171
- overwrite_props_from_file = self .overwrite_props_from_file ,
172
- )
173
-
174
173
timeout = kwargs .get ("model_data_download_timeout" )
175
174
if timeout :
176
175
self .env_vars .update ({"MODEL_LOADING_TIMEOUT" : str (timeout )})
@@ -192,6 +191,34 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
192
191
else :
193
192
raise ValueError ("Mode %s is not supported!" % overwrite_mode )
194
193
194
+ manual_set_props = None
195
+ if self .mode == Mode .SAGEMAKER_ENDPOINT :
196
+ if self .nb_instance_type and "instance_type" not in kwargs :
197
+ kwargs .update ({"instance_type" : self .nb_instance_type })
198
+ elif not self .nb_instance_type and "instance_type" not in kwargs :
199
+ raise ValueError (
200
+ "Instance type must be provided when deploying " "to SageMaker Endpoint mode."
201
+ )
202
+ else :
203
+ try :
204
+ tot_gpus = _get_gpu_info (kwargs .get ("instance_type" ), self .sagemaker_session )
205
+ except Exception : # pylint: disable=W0703
206
+ tot_gpus = _get_gpu_info_fallback (kwargs .get ("instance_type" ))
207
+ default_tensor_parallel_degree = _get_default_tensor_parallel_degree (
208
+ self .hf_model_config , tot_gpus
209
+ )
210
+ manual_set_props = {
211
+ "option.tensor_parallel_degree" : str (default_tensor_parallel_degree ) + "\n "
212
+ }
213
+
214
+ prepare_for_djl_serving (
215
+ model_path = self .model_path ,
216
+ model = self .pysdk_model ,
217
+ dependencies = self .dependencies ,
218
+ overwrite_props_from_file = self .overwrite_props_from_file ,
219
+ manual_set_props = manual_set_props ,
220
+ )
221
+
195
222
serializer = self .schema_builder .input_serializer
196
223
deserializer = self .schema_builder ._output_deserializer
197
224
if self .mode == Mode .LOCAL_CONTAINER :
@@ -237,8 +264,6 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
237
264
238
265
if "endpoint_logging" not in kwargs :
239
266
kwargs ["endpoint_logging" ] = True
240
- if self .nb_instance_type and "instance_type" not in kwargs :
241
- kwargs .update ({"instance_type" : self .nb_instance_type })
242
267
243
268
predictor = self ._original_deploy (* args , ** kwargs )
244
269
@@ -252,6 +277,7 @@ def _build_for_hf_djl(self):
252
277
"""Placeholder docstring"""
253
278
self .overwrite_props_from_file = True
254
279
self .nb_instance_type = _get_nb_instance ()
280
+
255
281
_create_dir_structure (self .model_path )
256
282
self .engine , self .hf_model_config = _auto_detect_engine (
257
283
self .model , self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" )
0 commit comments