@@ -300,6 +300,11 @@ def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800)
300
300
returns:
301
301
Tuned Model.
302
302
"""
303
+ if self .mode == Mode .SAGEMAKER_ENDPOINT :
304
+ logger .warning (
305
+ "Tuning is only a %s capability. Returning original model." , Mode .LOCAL_CONTAINER
306
+ )
307
+ return self .pysdk_model
303
308
304
309
num_shard_env_var_name = "SM_NUM_GPUS"
305
310
if "OPTION_TENSOR_PARALLEL_DEGREE" in self .pysdk_model .env .keys ():
@@ -468,58 +473,47 @@ def _build_for_jumpstart(self):
468
473
self .secret_key = None
469
474
self .jumpstart = True
470
475
471
- self .pysdk_model = self ._create_pre_trained_js_model ()
472
- self .pysdk_model .tune = lambda * args , ** kwargs : self ._default_tune ()
473
-
474
- logger .info (
475
- "JumpStart ID %s is packaged with Image URI: %s" , self .model , self .pysdk_model .image_uri
476
- )
477
-
478
- if self .mode != Mode .SAGEMAKER_ENDPOINT :
479
- if self ._is_gated_model (self .pysdk_model ):
480
- raise ValueError (
481
- "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
482
- )
483
-
484
- if "djl-inference" in self .pysdk_model .image_uri :
485
- logger .info ("Building for DJL JumpStart Model ID..." )
486
- self .model_server = ModelServer .DJL_SERVING
487
- self .image_uri = self .pysdk_model .image_uri
488
-
489
- self ._build_for_djl_jumpstart ()
490
-
491
- self .pysdk_model .tune = self .tune_for_djl_jumpstart
492
- elif "tgi-inference" in self .pysdk_model .image_uri :
493
- logger .info ("Building for TGI JumpStart Model ID..." )
494
- self .model_server = ModelServer .TGI
495
- self .image_uri = self .pysdk_model .image_uri
496
-
497
- self ._build_for_tgi_jumpstart ()
476
+ pysdk_model = self ._create_pre_trained_js_model ()
477
+ image_uri = pysdk_model .image_uri
498
478
499
- self .pysdk_model .tune = self .tune_for_tgi_jumpstart
500
- elif "huggingface-pytorch-inference:" in self .pysdk_model .image_uri :
501
- logger .info ("Building for MMS JumpStart Model ID..." )
502
- self .model_server = ModelServer .MMS
503
- self .image_uri = self .pysdk_model .image_uri
479
+ logger .info ("JumpStart ID %s is packaged with Image URI: %s" , self .model , image_uri )
504
480
505
- self ._build_for_mms_jumpstart ()
506
- else :
507
- raise ValueError (
508
- "JumpStart Model ID was not packaged "
509
- "with djl-inference, tgi-inference, or mms-inference container."
510
- )
511
-
512
- return self .pysdk_model
481
+ if self ._is_gated_model (pysdk_model ) and self .mode != Mode .SAGEMAKER_ENDPOINT :
482
+ raise ValueError (
483
+ "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
484
+ )
513
485
514
- def _default_tune (self ):
515
- """Logs a warning message if tune is invoked on endpoint mode.
486
+ if "djl-inference" in image_uri :
487
+ logger .info ("Building for DJL JumpStart Model ID..." )
488
+ self .model_server = ModelServer .DJL_SERVING
489
+ self .pysdk_model = pysdk_model
490
+ self .image_uri = self .pysdk_model .image_uri
491
+
492
+ self ._build_for_djl_jumpstart ()
493
+
494
+ self .pysdk_model .tune = self .tune_for_djl_jumpstart
495
+ elif "tgi-inference" in image_uri :
496
+ logger .info ("Building for TGI JumpStart Model ID..." )
497
+ self .model_server = ModelServer .TGI
498
+ self .pysdk_model = pysdk_model
499
+ self .image_uri = self .pysdk_model .image_uri
500
+
501
+ self ._build_for_tgi_jumpstart ()
502
+
503
+ self .pysdk_model .tune = self .tune_for_tgi_jumpstart
504
+ elif "huggingface-pytorch-inference:" in image_uri :
505
+ logger .info ("Building for MMS JumpStart Model ID..." )
506
+ self .model_server = ModelServer .MMS
507
+ self .pysdk_model = pysdk_model
508
+ self .image_uri = self .pysdk_model .image_uri
509
+
510
+ self ._build_for_mms_jumpstart ()
511
+ elif self .mode != Mode .SAGEMAKER_ENDPOINT :
512
+ raise ValueError (
513
+ "JumpStart Model ID was not packaged "
514
+ "with djl-inference, tgi-inference, or mms-inference container."
515
+ )
516
516
517
- Returns:
518
- Jumpstart Model: ``This`` model
519
- """
520
- logger .warning (
521
- "Tuning is only a %s capability. Returning original model." , Mode .LOCAL_CONTAINER
522
- )
523
517
return self .pysdk_model
524
518
525
519
def _is_gated_model (self , model ) -> bool :
0 commit comments