Skip to content

Commit c64e307

Browse files
committed
feature: Partition support for DJLModel using SM Training job
1 parent 46215ae commit c64e307

File tree

2 files changed

+283
-23
lines changed

2 files changed

+283
-23
lines changed

src/sagemaker/djl_inference/model.py

+209-23
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from json import JSONDecodeError
2121
from urllib.error import HTTPError, URLError
2222
from enum import Enum
23-
from typing import Optional, Union, Dict, Any
23+
from typing import Optional, Union, Dict, Any, List
2424

2525
import sagemaker
2626
from sagemaker import s3, Predictor, image_uris, fw_utils
@@ -31,6 +31,8 @@
3131
from sagemaker.session import Session
3232
from sagemaker.utils import _tmpdir, _create_or_update_code_dir
3333
from sagemaker.workflow.entities import PipelineVariable
34+
from sagemaker.estimator import Estimator
35+
from sagemaker.s3 import S3Uploader
3436

3537
logger = logging.getLogger("sagemaker")
3638

@@ -180,6 +182,47 @@ def _get_model_config_properties_from_hf(model_id: str):
180182
return model_config
181183

182184

185+
def _create_estimator(instance_type: str,
186+
s3_output_uri: str,
187+
image_uri: str,
188+
role: str,
189+
sagemaker_session: Optional[Session],
190+
volume_size: int = 30,
191+
vpc_config: Optional[Dict[str, List[str, ]]] = None,
192+
volume_kms_key=None,
193+
output_kms_key=None,
194+
use_spot_instances: bool = False,
195+
max_wait: int = None,
196+
enable_network_isolation: bool = False,
197+
):
198+
"""Placeholder docstring"""
199+
200+
subnets = None
201+
if vpc_config:
202+
subnets = vpc_config.get("Subnets")
203+
204+
security_group_ids = None
205+
if security_group_ids:
206+
security_group_ids = vpc_config.get("SecurityGroupIds")
207+
208+
return Estimator(
209+
image_uri=image_uri,
210+
role=role,
211+
instance_count=1,
212+
instance_type=instance_type,
213+
volume_size=volume_size,
214+
volume_kms_key=volume_kms_key,
215+
output_path=s3_output_uri,
216+
output_kms_key=output_kms_key,
217+
sagemaker_session=sagemaker_session,
218+
subnets=subnets,
219+
security_group_ids=security_group_ids,
220+
use_spot_instances=use_spot_instances,
221+
max_wait=max_wait,
222+
enable_network_isolation=enable_network_isolation,
223+
)
224+
225+
183226
class DJLModel(FrameworkModel):
184227
"""A DJL SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
185228

@@ -381,6 +424,91 @@ def right_size(self, checkpoint_data_type: str):
381424
"DJLModels do not currently support Inference Recommendation Jobs"
382425
)
383426

427+
def partition(
428+
self,
429+
instance_type: str,
430+
s3_output_uri: str = None,
431+
job_name: Optional[str] = None,
432+
volume_kms_key: Optional[str] = None,
433+
output_kms_key: Optional[str] = None,
434+
use_spot_instances: bool = False,
435+
max_wait: int = None,
436+
enable_network_isolation: bool = False
437+
):
438+
"""Partitions the model using SageMaker Training Job.
439+
This is a synchronous API call.
440+
441+
Args:
442+
instance_type (str): The EC2 instance type to partition this Model.
443+
For example, 'ml.p4d.24xlarge'.
444+
s3_output_uri (str): S3 location for saving the training result (model
445+
artifacts and output files). If not specified, results are
446+
stored to a default bucket. If the bucket with the specific name
447+
does not exist, it will be created.
448+
job_name (str): Training job name. If not specified, a unique training job
449+
name will be created.
450+
volume_kms_key (str): Optional. KMS key ID for encrypting EBS
451+
volume attached to the training instance (default: None).
452+
output_kms_key (str): Optional. KMS key ID for encrypting the
453+
training output (default: None).
454+
use_spot_instances (bool): Specifies whether to use SageMaker
455+
Managed Spot instances for training. If enabled then the
456+
``max_wait`` arg should also be set.
457+
458+
More information:
459+
https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
460+
(default: ``False``).
461+
max_wait (int): Timeout in seconds waiting for spot training
462+
job (default: None). After this amount of time Amazon
463+
SageMaker will stop waiting for managed spot training job to
464+
complete (default: None).
465+
enable_network_isolation (bool): Specifies whether container will
466+
run in network isolation mode (default: ``False``). Network
467+
isolation mode restricts the container access to outside networks
468+
(such as the Internet). The container does not make any inbound or
469+
outbound network calls. Also known as Internet-free mode.
470+
Returns:
471+
None
472+
"""
473+
474+
deploy_key_prefix = fw_utils.model_code_key_prefix(
475+
self.key_prefix, self.name, self.image_uri
476+
)
477+
if s3_output_uri is None:
478+
bucket = self.bucket or self.sagemaker_session.default_bucket()
479+
s3_output_uri = f"s3://{bucket}/{deploy_key_prefix}"
480+
else:
481+
s3_output_uri = f"{s3_output_uri}/{deploy_key_prefix}"
482+
483+
self.save_mp_checkpoint_path = f"{s3_output_uri}/aot"
484+
485+
container_def = self._upload_model_to_s3(upload_as_tar=False)
486+
estimator = _create_estimator(instance_type=instance_type,
487+
s3_output_uri=s3_output_uri,
488+
image_uri=self.image_uri,
489+
role=self.role,
490+
sagemaker_session=self.sagemaker_session,
491+
vpc_config=self.vpc_config,
492+
volume_kms_key=volume_kms_key,
493+
output_kms_key=output_kms_key,
494+
use_spot_instances=use_spot_instances,
495+
max_wait=max_wait,
496+
enable_network_isolation=enable_network_isolation
497+
)
498+
499+
# creates a training job to do partitions
500+
estimator.fit(
501+
inputs=container_def["ModelDataUrl"],
502+
wait=True,
503+
logs="All",
504+
job_name=job_name,
505+
experiment_config=None,
506+
)
507+
508+
self.model_id = self.save_mp_checkpoint_path
509+
# reset save_mp_checkpoint_path since partition is completed.
510+
self.save_mp_checkpoint_path = None
511+
384512
def deploy(
385513
self,
386514
instance_type,
@@ -477,18 +605,10 @@ def deploy(
477605
container_startup_health_check_timeout=container_startup_health_check_timeout,
478606
)
479607

480-
def prepare_container_def(
481-
self,
482-
instance_type=None,
483-
accelerator_type=None,
484-
serverless_inference_config=None,
485-
): # pylint: disable=unused-argument
486-
"""A container definition with framework configuration set in model environment variables.
487-
488-
Returns:
489-
dict[str, str]: A container definition object usable with the
490-
CreateModel API.
491-
"""
608+
def _upload_model_to_s3(self,
609+
upload_as_tar: bool = True
610+
):
611+
"""Placeholder docstring"""
492612

493613
if not self.image_uri:
494614
region_name = self.sagemaker_session.boto_session.region_name
@@ -528,19 +648,41 @@ def prepare_container_def(
528648
self.key_prefix, self.name, self.image_uri
529649
)
530650
bucket = self.bucket or self.sagemaker_session.default_bucket()
531-
uploaded_code = fw_utils.tar_and_upload_dir(
532-
self.sagemaker_session.boto_session,
533-
bucket,
534-
deploy_key_prefix,
535-
self.entry_point,
536-
directory=tmp_code_dir,
537-
dependencies=self.dependencies,
538-
kms_key=self.model_kms_key,
539-
)
651+
if upload_as_tar:
652+
uploaded_code = fw_utils.tar_and_upload_dir(
653+
self.sagemaker_session.boto_session,
654+
bucket,
655+
deploy_key_prefix,
656+
self.entry_point,
657+
directory=tmp_code_dir,
658+
dependencies=self.dependencies,
659+
kms_key=self.model_kms_key,
660+
)
661+
model_data_url = uploaded_code.s3_prefix
662+
else:
663+
model_data_url = S3Uploader.upload(tmp_code_dir,
664+
"s3://%s/%s" % (bucket, key),
665+
self.model_kms_key,
666+
self.sagemaker_session)
540667
return sagemaker.container_def(
541-
self.image_uri, model_data_url=uploaded_code.s3_prefix, env=environment
668+
self.image_uri, model_data_url=model_data_url, env=environment
542669
)
543670

671+
def prepare_container_def(
672+
self,
673+
instance_type=None,
674+
accelerator_type=None,
675+
serverless_inference_config=None,
676+
): # pylint: disable=unused-argument
677+
"""A container definition with framework configuration set in model environment variables.
678+
679+
Returns:
680+
dict[str, str]: A container definition object usable with the
681+
CreateModel API.
682+
"""
683+
684+
return self._upload_model_to_s3(upload_as_tar=True)
685+
544686
def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]:
545687
"""Generates the DJL Serving configuration to use for the model.
546688
@@ -699,6 +841,8 @@ def __init__(
699841
self.enable_cuda_graph = enable_cuda_graph
700842
self.triangular_masking = triangular_masking
701843
self.return_tuple = return_tuple
844+
self.save_mp_checkpoint_path = None
845+
self.checkpoint = None
702846

703847
def generate_serving_properties(self, serving_properties=None) -> Dict[str, Any]:
704848
"""Generates the DJL Serving configuration to use for the model.
@@ -733,9 +877,35 @@ def generate_serving_properties(self, serving_properties=None) -> Dict[str, Any]
733877
serving_properties["option.triangular_masking"] = self.triangular_masking
734878
if self.return_tuple:
735879
serving_properties["option.return_tuple"] = self.return_tuple
880+
if self.save_mp_checkpoint_path:
881+
serving_properties["option.save_mp_checkpoint_path"] = self.save_mp_checkpoint_path
882+
if self.checkpoint:
883+
serving_properties["option.checkpoint"] = self.checkpoint
736884

737885
return serving_properties
738886

887+
def partition(
888+
self,
889+
instance_type: str,
890+
s3_output_uri: str = None,
891+
job_name: Optional[str] = None,
892+
volume_kms_key: Optional[str] = None,
893+
output_kms_key: Optional[str] = None,
894+
use_spot_instances: bool = False,
895+
max_wait: int = None,
896+
enable_network_isolation: bool = False
897+
):
898+
super(DeepSpeedModel, self).partition(instance_type,
899+
s3_output_uri,
900+
job_name,
901+
volume_kms_key=volume_kms_key,
902+
output_kms_key=output_kms_key,
903+
use_spot_instances=use_spot_instances,
904+
max_wait=max_wait,
905+
enable_network_isolation=enable_network_isolation)
906+
907+
self.checkpoint = "ds_inference_config.json"
908+
739909

740910
class HuggingFaceAccelerateModel(DJLModel):
741911
"""A DJL Hugging Face SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
@@ -846,3 +1016,19 @@ def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]
8461016
serving_properties["option.dtype"] = "auto"
8471017
serving_properties.pop("option.load_in_8bit", None)
8481018
return serving_properties
1019+
1020+
def partition(
1021+
self,
1022+
instance_type: str,
1023+
s3_output_uri: str = None,
1024+
job_name: Optional[str] = None,
1025+
volume_kms_key: Optional[str] = None,
1026+
output_kms_key: Optional[str] = None,
1027+
use_spot_instances: bool = False,
1028+
max_wait: int = None,
1029+
enable_network_isolation: bool = False
1030+
):
1031+
raise NotImplementedError(
1032+
"HuggingFace engine does not currently support tensor parallelism. "
1033+
"Hence ahead of partitioning cannot be done"
1034+
)

tests/unit/test_djl_inference.py

+74
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616

1717
import json
18+
import time
1819
from json import JSONDecodeError
1920

2021
import pytest
@@ -490,6 +491,7 @@ def test_deploy_model_no_local_code(
490491
mock_path_exists,
491492
mock_mkdir,
492493
mock_tar_upload,
494+
mock_upload,
493495
mock_create_code_dir,
494496
mock_tmpdir,
495497
mock_container_def,
@@ -534,3 +536,75 @@ def test_deploy_model_no_local_code(
534536
mock_container_def.assert_called_once_with(
535537
IMAGE_URI, model_data_url="s3prefix", env=expected_env
536538
)
539+
540+
541+
@patch("sagemaker.image_uris.retrieve", return_value=IMAGE_URI)
542+
@patch("shutil.rmtree")
543+
@patch("sagemaker.utils.base_name_from_image")
544+
@patch("tempfile.mkdtemp")
545+
@patch("sagemaker.container_def")
546+
@patch("sagemaker.utils._tmpdir")
547+
@patch("sagemaker.utils._create_or_update_code_dir")
548+
@patch("os.mkdir")
549+
@patch("os.path.exists")
550+
@patch("sagemaker.s3.S3Downloader.read_file")
551+
@patch("sagemaker.s3.S3Downloader.list")
552+
@patch("sagemaker.s3.S3Uploader.upload")
553+
@patch("sagemaker.estimator.Estimator.fit")
554+
@patch("sagemaker.fw_utils.model_code_key_prefix")
555+
def test_partition(
556+
mock_model_key_prefix,
557+
mock_estimator_fit,
558+
mock_upload,
559+
mock_s3_list,
560+
mock_read_file,
561+
mock_path_exists,
562+
mock_mkdir,
563+
mock_create_code_dir,
564+
mock_tmpdir,
565+
mock_container_def,
566+
mock_mktmp,
567+
mock_name_from_base,
568+
mock_shutil_rmtree,
569+
mock_imguri_retrieve,
570+
sagemaker_session,
571+
):
572+
mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"]
573+
model_config = {
574+
"model_type": "bloom",
575+
"n_heads": 120,
576+
}
577+
mock_read_file.return_value = json.dumps(model_config)
578+
model = DJLModel(
579+
VALID_UNCOMPRESSED_MODEL_DATA,
580+
ROLE,
581+
sagemaker_session=sagemaker_session,
582+
number_of_partitions=4,
583+
data_type="fp16",
584+
container_log_level=logging.DEBUG,
585+
env=ENV,
586+
)
587+
588+
589+
assert model.image_uri is None
590+
591+
mock_path_exists.side_effect = [True, False, True]
592+
mock_mktmp.return_value = "/tmp/dir"
593+
expected_env = {"ENV_VAR": "env_value", "SERVING_OPTS": '"-Dai.djl.logging.level=debug"'}
594+
mock_upload.return_value = "s3prefix"
595+
596+
s3_output_uri = f's3://{BUCKET}/partitions/'
597+
mock_model_key_prefix.return_value = "s3prefix"
598+
with patch("builtins.open", mock_open()) as fake_serving_properties:
599+
model.partition(GPU_INSTANCE, s3_output_uri)
600+
601+
mock_mktmp.assert_called_once_with(prefix="tmp", suffix="", dir=None)
602+
mock_mkdir.assert_called()
603+
assert fake_serving_properties.call_count == 2
604+
fake_serving_properties.assert_any_call("/tmp/dir/code/serving.properties", "w+")
605+
fake_serving_properties.assert_any_call("/tmp/dir/code/serving.properties", "r")
606+
mock_container_def.assert_called_once_with(
607+
IMAGE_URI, model_data_url="s3prefix", env=expected_env
608+
)
609+
610+
assert model.model_id == f'{s3_output_uri}/s3prefix/aot'

0 commit comments

Comments
 (0)