Skip to content

Commit 111b9cb

Browse files
committed
Partition support for DJLModel using SM Training job
1 parent 46215ae commit 111b9cb

File tree

1 file changed

+208
-22
lines changed

1 file changed

+208
-22
lines changed

src/sagemaker/djl_inference/model.py

+208-22
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sagemaker.session import Session
3232
from sagemaker.utils import _tmpdir, _create_or_update_code_dir
3333
from sagemaker.workflow.entities import PipelineVariable
34+
from sagemaker.estimator import Estimator
3435

3536
logger = logging.getLogger("sagemaker")
3637

@@ -180,6 +181,47 @@ def _get_model_config_properties_from_hf(model_id: str):
180181
return model_config
181182

182183

184+
def _create_estimator(instance_type: str,
185+
s3_output_uri: str,
186+
image_uri: str,
187+
role: str,
188+
sagemaker_session: Optional[Session],
189+
volume_size: int = 30,
190+
vpc_config: Optional[Dict[str, List[str, ]]] = None,
191+
volume_kms_key=None,
192+
output_kms_key=None,
193+
use_spot_instances: bool = False,
194+
max_wait: int = None,
195+
enable_network_isolation: bool = False,
196+
):
197+
"""Placeholder docstring"""
198+
199+
subnets = None
200+
if vpc_config:
201+
subnets = vpc_config.get("Subnets")
202+
203+
security_group_ids = None
204+
if security_group_ids:
205+
security_group_ids = vpc_config.get("SecurityGroupIds")
206+
207+
return Estimator(
208+
image_uri=image_uri,
209+
role=role,
210+
instance_count=1,
211+
instance_type=instance_type,
212+
volume_size=volume_size,
213+
volume_kms_key=volume_kms_key,
214+
output_path=s3_output_uri,
215+
output_kms_key=output_kms_key,
216+
sagemaker_session=sagemaker_session,
217+
subnets=subnets,
218+
security_group_ids=security_group_ids,
219+
use_spot_instances=use_spot_instances,
220+
max_wait=max_wait,
221+
enable_network_isolation=enable_network_isolation,
222+
)
223+
224+
183225
class DJLModel(FrameworkModel):
184226
"""A DJL SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
185227

@@ -381,6 +423,91 @@ def right_size(self, checkpoint_data_type: str):
381423
"DJLModels do not currently support Inference Recommendation Jobs"
382424
)
383425

426+
def partition(
427+
self,
428+
instance_type: str,
429+
s3_output_uri: str,
430+
job_name: Optional[str] = None,
431+
volume_kms_key: Optional[str] = None,
432+
output_kms_key: Optional[str] = None,
433+
use_spot_instances: bool = False,
434+
max_wait: int = None,
435+
enable_network_isolation: bool = False
436+
):
437+
"""Partitions the model using SageMaker Training Job.
438+
This is a synchronous API call.
439+
440+
Args:
441+
instance_type (str): The EC2 instance type to partition this Model.
442+
For example, 'ml.p4d.24xlarge'.
443+
s3_output_uri (str): S3 location for saving the training result (model
444+
artifacts and output files). If not specified, results are
445+
stored to a default bucket. If the bucket with the specific name
446+
does not exist, it will be created.
447+
job_name (str): Training job name. If not specified, a unique training job
448+
name will be created.
449+
volume_kms_key (str): Optional. KMS key ID for encrypting EBS
450+
volume attached to the training instance (default: None).
451+
output_kms_key (str): Optional. KMS key ID for encrypting the
452+
training output (default: None).
453+
use_spot_instances (bool): Specifies whether to use SageMaker
454+
Managed Spot instances for training. If enabled then the
455+
``max_wait`` arg should also be set.
456+
457+
More information:
458+
https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
459+
(default: ``False``).
460+
max_wait (int): Timeout in seconds waiting for spot training
461+
job (default: None). After this amount of time Amazon
462+
SageMaker will stop waiting for managed spot training job to
463+
complete (default: None).
464+
enable_network_isolation (bool): Specifies whether container will
465+
run in network isolation mode (default: ``False``). Network
466+
isolation mode restricts the container access to outside networks
467+
(such as the Internet). The container does not make any inbound or
468+
outbound network calls. Also known as Internet-free mode.
469+
Returns:
470+
None
471+
"""
472+
473+
deploy_key_prefix = fw_utils.model_code_key_prefix(
474+
self.key_prefix, self.name, self.image_uri
475+
)
476+
if s3_output_uri is None:
477+
bucket = self.bucket or self.sagemaker_session.default_bucket()
478+
s3_output_uri = f"{bucket}/{deploy_key_prefix}"
479+
else:
480+
s3_output_uri = f"{s3_output_uri}/{deploy_key_prefix}"
481+
482+
self.save_mp_checkpoint_path = f"{s3_output_uri}/aot"
483+
484+
container_def = self._upload_model_to_s3(upload_as_tar=False)
485+
estimator = _create_estimator(instance_type=instance_type,
486+
s3_output_uri=s3_output_uri,
487+
image_uri=self.image_uri,
488+
role=self.role,
489+
sagemaker_session=self.sagemaker_session,
490+
vpc_config=self.vpc_config,
491+
volume_kms_key=volume_kms_key,
492+
output_kms_key=output_kms_key,
493+
use_spot_instances=use_spot_instances,
494+
max_wait=max_wait,
495+
enable_network_isolation=enable_network_isolation
496+
)
497+
498+
# creates a training job to do partitions
499+
estimator.fit(
500+
inputs=container_def["ModelDataUrl"],
501+
wait=True,
502+
logs="All",
503+
job_name=job_name,
504+
experiment_config=None,
505+
)
506+
507+
self.model_id = self.save_mp_checkpoint_path
508+
# reset save_mp_checkpoint_path since partition is completed.
509+
self.save_mp_checkpoint_path = None
510+
384511
def deploy(
385512
self,
386513
instance_type,
@@ -477,18 +604,10 @@ def deploy(
477604
container_startup_health_check_timeout=container_startup_health_check_timeout,
478605
)
479606

480-
def prepare_container_def(
481-
self,
482-
instance_type=None,
483-
accelerator_type=None,
484-
serverless_inference_config=None,
485-
): # pylint: disable=unused-argument
486-
"""A container definition with framework configuration set in model environment variables.
487-
488-
Returns:
489-
dict[str, str]: A container definition object usable with the
490-
CreateModel API.
491-
"""
607+
def _upload_model_to_s3(self,
608+
upload_as_tar: bool = True
609+
):
610+
"""Placeholder docstring"""
492611

493612
if not self.image_uri:
494613
region_name = self.sagemaker_session.boto_session.region_name
@@ -528,19 +647,42 @@ def prepare_container_def(
528647
self.key_prefix, self.name, self.image_uri
529648
)
530649
bucket = self.bucket or self.sagemaker_session.default_bucket()
531-
uploaded_code = fw_utils.tar_and_upload_dir(
532-
self.sagemaker_session.boto_session,
533-
bucket,
534-
deploy_key_prefix,
535-
self.entry_point,
536-
directory=tmp_code_dir,
537-
dependencies=self.dependencies,
538-
kms_key=self.model_kms_key,
539-
)
650+
if upload_as_tar:
651+
uploaded_code = fw_utils.tar_and_upload_dir(
652+
self.sagemaker_session.boto_session,
653+
bucket,
654+
deploy_key_prefix,
655+
self.entry_point,
656+
directory=tmp_code_dir,
657+
dependencies=self.dependencies,
658+
kms_key=self.model_kms_key,
659+
)
660+
model_data_url = uploaded_code.s3_prefix
661+
else:
662+
from sagemaker.s3 import S3Uploader
663+
model_data_url = S3Uploader.upload(tmp_code_dir,
664+
"s3://%s/%s" % (bucket, key),
665+
self.model_kms_key,
666+
self.sagemaker_session)
540667
return sagemaker.container_def(
541-
self.image_uri, model_data_url=uploaded_code.s3_prefix, env=environment
668+
self.image_uri, model_data_url=model_data_url, env=environment
542669
)
543670

671+
def prepare_container_def(
672+
self,
673+
instance_type=None,
674+
accelerator_type=None,
675+
serverless_inference_config=None,
676+
): # pylint: disable=unused-argument
677+
"""A container definition with framework configuration set in model environment variables.
678+
679+
Returns:
680+
dict[str, str]: A container definition object usable with the
681+
CreateModel API.
682+
"""
683+
684+
return self._upload_model_to_s3(upload_as_tar=True)
685+
544686
def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]:
545687
"""Generates the DJL Serving configuration to use for the model.
546688
@@ -699,6 +841,8 @@ def __init__(
699841
self.enable_cuda_graph = enable_cuda_graph
700842
self.triangular_masking = triangular_masking
701843
self.return_tuple = return_tuple
844+
self.save_mp_checkpoint_path = None
845+
self.checkpoint = None
702846

703847
def generate_serving_properties(self, serving_properties=None) -> Dict[str, Any]:
704848
"""Generates the DJL Serving configuration to use for the model.
@@ -733,9 +877,35 @@ def generate_serving_properties(self, serving_properties=None) -> Dict[str, Any]
733877
serving_properties["option.triangular_masking"] = self.triangular_masking
734878
if self.return_tuple:
735879
serving_properties["option.return_tuple"] = self.return_tuple
880+
if self.save_mp_checkpoint_path:
881+
serving_properties["option.save_mp_checkpoint_path"] = self.save_mp_checkpoint_path
882+
if self.checkpoint:
883+
serving_properties["option.checkpoint"] = self.checkpoint
736884

737885
return serving_properties
738886

887+
def partition(
888+
self,
889+
instance_type: str,
890+
s3_output_uri: str,
891+
job_name: Optional[str] = None,
892+
volume_kms_key: Optional[str] = None,
893+
output_kms_key: Optional[str] = None,
894+
use_spot_instances: bool = False,
895+
max_wait: int = None,
896+
enable_network_isolation: bool = False
897+
):
898+
super(DeepSpeedModel, self).partition(instance_type,
899+
s3_output_uri,
900+
job_name,
901+
volume_kms_key=volume_kms_key,
902+
output_kms_key=output_kms_key,
903+
use_spot_instances=use_spot_instances,
904+
max_wait=max_wait,
905+
enable_network_isolation=enable_network_isolation)
906+
907+
self.checkpoint = "ds_inference_config.json"
908+
739909

740910
class HuggingFaceAccelerateModel(DJLModel):
741911
"""A DJL Hugging Face SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
@@ -846,3 +1016,19 @@ def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]
8461016
serving_properties["option.dtype"] = "auto"
8471017
serving_properties.pop("option.load_in_8bit", None)
8481018
return serving_properties
1019+
1020+
def partition(
1021+
self,
1022+
instance_type: str,
1023+
s3_output_uri: str,
1024+
job_name: Optional[str] = None,
1025+
volume_kms_key: Optional[str] = None,
1026+
output_kms_key: Optional[str] = None,
1027+
use_spot_instances: bool = False,
1028+
max_wait: int = None,
1029+
enable_network_isolation: bool = False
1030+
):
1031+
raise NotImplementedError(
1032+
"HuggingFace engine does not currently support tensor parallelism. "
1033+
"Hence ahead of partitioning cannot be done"
1034+
)

0 commit comments

Comments
 (0)