Skip to content

Commit bc917e8

Browse files
committed
feat: Support deploying uncompressed ML model from S3 to SageMaker Hosting endpoints
1 parent 1d886c4 commit bc917e8

File tree

9 files changed

+370
-23
lines changed

9 files changed

+370
-23
lines changed

src/sagemaker/estimator.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1737,13 +1737,35 @@ def register(
17371737

17381738
@property
17391739
def model_data(self):
1740-
"""str: The model location in S3. Only set if Estimator has been ``fit()``."""
1740+
"""Str or dict: The model location in S3. Only set if Estimator has been ``fit()``."""
17411741
if self.latest_training_job is not None and not isinstance(
17421742
self.sagemaker_session, PipelineSession
17431743
):
1744-
model_uri = self.sagemaker_session.sagemaker_client.describe_training_job(
1744+
job_details = self.sagemaker_session.sagemaker_client.describe_training_job(
17451745
TrainingJobName=self.latest_training_job.name
1746-
)["ModelArtifacts"]["S3ModelArtifacts"]
1746+
)
1747+
model_uri = job_details["ModelArtifacts"]["S3ModelArtifacts"]
1748+
compression_type = job_details.get("OutputDataConfig", {}).get(
1749+
"CompressionType", "GZIP"
1750+
)
1751+
if compression_type == "GZIP":
1752+
return model_uri
1753+
# fail fast if we don't recognize training output compression type
1754+
if compression_type not in {"GZIP", "NONE"}:
1755+
raise ValueError(
1756+
f'Unrecognized training job output data compression type "{compression_type}"'
1757+
)
1758+
# model data is in uncompressed form NOTE SageMaker Hosting mandates presence of
1759+
# trailing forward slash in S3 model data URI, so append one if necessary.
1760+
if not model_uri.endswith("/"):
1761+
model_uri += "/"
1762+
return {
1763+
"S3DataSource": {
1764+
"S3Uri": model_uri,
1765+
"S3DataType": "S3Prefix",
1766+
"CompressionType": "None",
1767+
}
1768+
}
17471769
else:
17481770
logger.warning(
17491771
"No finished training job found associated with this estimator. Please make sure "
@@ -1752,7 +1774,7 @@ def model_data(self):
17521774
model_uri = os.path.join(
17531775
self.output_path, self._current_job_name, "output", "model.tar.gz"
17541776
)
1755-
return model_uri
1777+
return model_uri
17561778

17571779
@abstractmethod
17581780
def create_model(self, **kwargs):

src/sagemaker/model.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class Model(ModelBase, InferenceRecommenderMixin):
131131
def __init__(
132132
self,
133133
image_uri: Union[str, PipelineVariable],
134-
model_data: Optional[Union[str, PipelineVariable]] = None,
134+
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
135135
role: Optional[str] = None,
136136
predictor_cls: Optional[callable] = None,
137137
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
@@ -152,8 +152,8 @@ def __init__(
152152
153153
Args:
154154
image_uri (str or PipelineVariable): A Docker image URI.
155-
model_data (str or PipelineVariable): The S3 location of a SageMaker
156-
model data ``.tar.gz`` file (default: None).
155+
model_data (str or PipelineVariable or dict): Location
156+
of SageMaker model data (default: None).
157157
role (str): An AWS IAM role (either name or full ARN). The Amazon
158158
SageMaker training jobs and APIs that create Amazon SageMaker
159159
endpoints use this role to access training data and model
@@ -455,6 +455,11 @@ def register(
455455
"""
456456
if self.model_data is None:
457457
raise ValueError("SageMaker Model Package cannot be created without model data.")
458+
if isinstance(self.model_data, dict):
459+
raise ValueError(
460+
"SageMaker Model Package currently cannot be created with ModelDataSource."
461+
)
462+
458463
if image_uri is not None:
459464
self.image_uri = image_uri
460465

@@ -600,6 +605,7 @@ def prepare_container_def(
600605
)
601606
self._upload_code(deploy_key_prefix, repack=is_repack)
602607
deploy_env.update(self._script_mode_env_vars())
608+
603609
return sagemaker.container_def(
604610
self.image_uri,
605611
self.repacked_model_data or self.model_data,
@@ -639,6 +645,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
639645
)
640646

641647
if repack and self.model_data is not None and self.entry_point is not None:
648+
if isinstance(self.model_data, dict):
649+
logging.warning("ModelDataSource currently doesn't support model repacking")
650+
return
642651
if is_pipeline_variable(self.model_data):
643652
# model is not yet there, defer repacking to later during pipeline execution
644653
if not isinstance(self.sagemaker_session, PipelineSession):
@@ -765,10 +774,16 @@ def _create_sagemaker_model(
765774
# _base_name, model_name are not needed under PipelineSession.
766775
# the model_data may be Pipeline variable
767776
# which may break the _base_name generation
777+
model_uri = None
778+
if isinstance(self.model_data, (str, PipelineVariable)):
779+
model_uri = self.model_data
780+
elif isinstance(self.model_data, dict):
781+
model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None)
782+
768783
self._ensure_base_name_if_needed(
769784
image_uri=container_def["Image"],
770785
script_uri=self.source_dir,
771-
model_uri=self.model_data,
786+
model_uri=model_uri,
772787
)
773788
self._set_model_name_if_needed()
774789

@@ -1110,6 +1125,8 @@ def compile(
11101125
raise ValueError("You must provide a compilation job name")
11111126
if self.model_data is None:
11121127
raise ValueError("You must provide an S3 path to the compressed model artifacts.")
1128+
if isinstance(self.model_data, dict):
1129+
raise ValueError("Compiling model data from ModelDataSource is currently not supported")
11131130

11141131
framework_version = framework_version or self._get_framework_version()
11151132

@@ -1301,7 +1318,7 @@ def deploy(
13011318

13021319
tags = add_jumpstart_tags(
13031320
tags=tags,
1304-
inference_model_uri=self.model_data,
1321+
inference_model_uri=self.model_data if isinstance(self.model_data, str) else None,
13051322
inference_script_uri=self.source_dir,
13061323
)
13071324

@@ -1545,7 +1562,7 @@ class FrameworkModel(Model):
15451562

15461563
def __init__(
15471564
self,
1548-
model_data: Union[str, PipelineVariable],
1565+
model_data: Union[str, PipelineVariable, dict],
15491566
image_uri: Union[str, PipelineVariable],
15501567
role: Optional[str] = None,
15511568
entry_point: Optional[str] = None,
@@ -1563,8 +1580,8 @@ def __init__(
15631580
"""Initialize a ``FrameworkModel``.
15641581
15651582
Args:
1566-
model_data (str or PipelineVariable): The S3 location of a SageMaker
1567-
model data ``.tar.gz`` file.
1583+
model_data (str or PipelineVariable or dict): The S3 location of
1584+
SageMaker model data.
15681585
image_uri (str or PipelineVariable): A Docker image URI.
15691586
role (str): An IAM role name or ARN for SageMaker to access AWS
15701587
resources on your behalf.
@@ -1758,6 +1775,11 @@ def __init__(
17581775
``model_data`` is not required.
17591776
**kwargs: Additional kwargs passed to the Model constructor.
17601777
"""
1778+
if isinstance(model_data, dict):
1779+
raise ValueError(
1780+
"Creating ModelPackage with ModelDataSource is currently not supported"
1781+
)
1782+
17611783
super(ModelPackage, self).__init__(
17621784
role=role, model_data=model_data, image_uri=None, **kwargs
17631785
)

src/sagemaker/session.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3584,7 +3584,7 @@ def create_model_from_job(
35843584
)
35853585
primary_container = container_def(
35863586
image_uri or training_job["AlgorithmSpecification"]["TrainingImage"],
3587-
model_data_url=model_data_url or training_job["ModelArtifacts"]["S3ModelArtifacts"],
3587+
model_data_url=model_data_url or self._gen_s3_model_data_source(training_job),
35883588
env=env,
35893589
)
35903590
vpc_config = _vpc_config_from_training_job(training_job, vpc_config_override)
@@ -4486,14 +4486,14 @@ def endpoint_from_job(
44864486
str: Name of the ``Endpoint`` that is created.
44874487
"""
44884488
job_desc = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
4489-
output_url = job_desc["ModelArtifacts"]["S3ModelArtifacts"]
4489+
model_s3_location = self._gen_s3_model_data_source(job_desc)
44904490
image_uri = image_uri or job_desc["AlgorithmSpecification"]["TrainingImage"]
44914491
role = role or job_desc["RoleArn"]
44924492
name = name or job_name
44934493
vpc_config_override = _vpc_config_from_training_job(job_desc, vpc_config_override)
44944494

44954495
return self.endpoint_from_model_data(
4496-
model_s3_location=output_url,
4496+
model_s3_location=model_s3_location,
44974497
image_uri=image_uri,
44984498
initial_instance_count=initial_instance_count,
44994499
instance_type=instance_type,
@@ -4506,6 +4506,40 @@ def endpoint_from_job(
45064506
data_capture_config=data_capture_config,
45074507
)
45084508

4509+
def _gen_s3_model_data_source(self, training_job_spec):
4510+
"""Generates ``ModelDataSource`` value from given DescribeTrainingJob API response.
4511+
4512+
Args:
4513+
training_job_spec (dict): SageMaker DescribeTrainingJob API response.
4514+
4515+
Returns:
4516+
dict: A ``ModelDataSource`` value.
4517+
"""
4518+
model_data_s3_uri = training_job_spec["ModelArtifacts"]["S3ModelArtifacts"]
4519+
compression_type = training_job_spec.get("OutputDataConfig", {}).get(
4520+
"CompressionType", "GZIP"
4521+
)
4522+
# See https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_OutputDataConfig.html
4523+
# and https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3ModelDataSource.html
4524+
if compression_type in {"NONE", "GZIP"}:
4525+
model_compression_type = compression_type.title()
4526+
else:
4527+
raise ValueError(
4528+
f'Unrecognized training job output data compression type "{compression_type}"'
4529+
)
4530+
s3_model_data_type = "S3Object" if model_compression_type == "Gzip" else "S3Prefix"
4531+
# if model data is in S3Prefix type and has no trailing forward slash in its URI,
4532+
# append one so that it meets SageMaker Hosting's mandate for deploying uncompressed model.
4533+
if s3_model_data_type == "S3Prefix" and not model_data_s3_uri.endswith("/"):
4534+
model_data_s3_uri += "/"
4535+
return {
4536+
"S3DataSource": {
4537+
"S3Uri": model_data_s3_uri,
4538+
"S3DataType": s3_model_data_type,
4539+
"CompressionType": model_compression_type,
4540+
}
4541+
}
4542+
45094543
def endpoint_from_model_data(
45104544
self,
45114545
model_s3_location,
@@ -4524,7 +4558,8 @@ def endpoint_from_model_data(
45244558
"""Create and deploy to an ``Endpoint`` using existing model data stored in S3.
45254559
45264560
Args:
4527-
model_s3_location (str): S3 URI of the model artifacts to use for the endpoint.
4561+
model_s3_location (str or dict): S3 location of the model artifacts
4562+
to use for the endpoint.
45284563
image_uri (str): The Docker image URI which defines the runtime code to be
45294564
used as the entry point for accepting prediction requests.
45304565
initial_instance_count (int): Minimum number of EC2 instances to launch. The actual
@@ -5925,8 +5960,10 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None,
59255960
59265961
Args:
59275962
image_uri (str): Docker image URI to run for this container.
5928-
model_data_url (str): S3 URI of data required by this container,
5929-
e.g. SageMaker training job model artifacts (default: None).
5963+
model_data_url (str or dict[str, Any]): S3 location of model data required by this
5964+
container, e.g. SageMaker training job model artifacts. It can either be a string
5965+
representing S3 URI of model data, or a dictionary representing a
5966+
``ModelDataSource`` object. (default: None).
59305967
env (dict[str, str]): Environment variables to set inside the container (default: None).
59315968
container_mode (str): The model container mode. Valid modes:
59325969
* MultiModel: Indicates that model container can support hosting multiple models
@@ -5943,8 +5980,12 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None,
59435980
if env is None:
59445981
env = {}
59455982
c_def = {"Image": image_uri, "Environment": env}
5946-
if model_data_url:
5983+
5984+
if isinstance(model_data_url, dict):
5985+
c_def["ModelDataSource"] = model_data_url
5986+
elif model_data_url:
59475987
c_def["ModelDataUrl"] = model_data_url
5988+
59485989
if container_mode:
59495990
c_def["Mode"] = container_mode
59505991
if image_config:

tests/unit/sagemaker/model/test_model.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,39 @@ def test_register_calls_model_package_args(get_model_package_args, sagemaker_ses
810810
get_model_package_args"""
811811

812812

813+
def test_register_calls_model_data_source_not_supported(sagemaker_session):
814+
source_dir = "s3://blah/blah/blah"
815+
t = Model(
816+
entry_point=ENTRY_POINT_INFERENCE,
817+
role=ROLE,
818+
sagemaker_session=sagemaker_session,
819+
source_dir=source_dir,
820+
image_uri=IMAGE_URI,
821+
model_data={
822+
"S3DataSource": {
823+
"S3Uri": "s3://bucket/model/prefix/",
824+
"S3DataType": "S3Prefix",
825+
"CompressionType": "None",
826+
}
827+
},
828+
)
829+
830+
with pytest.raises(
831+
ValueError,
832+
match="SageMaker Model Package currently cannot be created with ModelDataSource.",
833+
):
834+
t.register(
835+
SUPPORTED_CONTENT_TYPES,
836+
SUPPORTED_RESPONSE_MIME_TYPES,
837+
SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES,
838+
SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES,
839+
marketplace_cert=True,
840+
description=MODEL_DESCRIPTION,
841+
model_package_name=MODEL_NAME,
842+
validation_specification=VALIDATION_SPECIFICATION,
843+
)
844+
845+
813846
@patch("sagemaker.utils.repack_model")
814847
def test_model_local_download_dir(repack_model, sagemaker_session):
815848

tests/unit/sagemaker/model/test_model_package.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,24 @@ def test_create_sagemaker_model_include_tags(sagemaker_session):
209209
)
210210

211211

212+
def test_model_package_model_data_source_not_supported(sagemaker_session):
213+
with pytest.raises(
214+
ValueError, match="Creating ModelPackage with ModelDataSource is currently not supported"
215+
):
216+
ModelPackage(
217+
role="role",
218+
model_package_arn="my-model-package",
219+
model_data={
220+
"S3DataSource": {
221+
"S3Uri": "s3://bucket/model/prefix/",
222+
"S3DataType": "S3Prefix",
223+
"CompressionType": "None",
224+
}
225+
},
226+
sagemaker_session=sagemaker_session,
227+
)
228+
229+
212230
@patch("sagemaker.utils.name_from_base")
213231
def test_create_sagemaker_model_generates_model_name(name_from_base, sagemaker_session):
214232
model_package_name = "my-model-package"

tests/unit/sagemaker/model/test_neo.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,31 @@ def test_compile_validates_model_data():
244244
assert "You must provide an S3 path to the compressed model artifacts." in str(e)
245245

246246

247+
def test_compile_validates_model_data_source():
248+
model_data_src = {
249+
"S3DataSource": {
250+
"S3Uri": "s3://bucket/model/prefix",
251+
"S3DataType": "S3Prefix",
252+
"CompressionType": "None",
253+
}
254+
}
255+
model = Model(MODEL_IMAGE, model_data=model_data_src)
256+
257+
with pytest.raises(
258+
ValueError, match="Compiling model data from ModelDataSource is currently not supported"
259+
) as e:
260+
model.compile(
261+
target_instance_family="ml_c4",
262+
input_shape={"data": [1, 3, 1024, 1024]},
263+
output_path="s3://output",
264+
role="role",
265+
framework="tensorflow",
266+
job_name="compile-model",
267+
)
268+
269+
assert "Compiling model data from ModelDataSource is currently not supported" in str(e)
270+
271+
247272
def test_deploy_honors_provided_model_name(sagemaker_session):
248273
model = _create_model(sagemaker_session)
249274
model._is_compiled_model = True

tests/unit/test_endpoint_from_job.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@
2323
ACCELERATOR_TYPE = "ml.eia.medium"
2424
IMAGE = "myimage"
2525
S3_MODEL_ARTIFACTS = "s3://mybucket/mymodel"
26+
S3_MODEL_SRC_COMPRESSED = {
27+
"S3DataSource": {
28+
"S3Uri": S3_MODEL_ARTIFACTS,
29+
"S3DataType": "S3Object",
30+
"CompressionType": "Gzip",
31+
}
32+
}
2633
TRAIN_ROLE = "mytrainrole"
2734
VPC_CONFIG = {"Subnets": ["subnet-foo"], "SecurityGroupIds": ["sg-foo"]}
2835
TRAINING_JOB_RESPONSE = {
@@ -68,7 +75,7 @@ def test_all_defaults_no_existing_entities(sagemaker_session):
6875

6976
expected_args = original_args.copy()
7077
expected_args.pop("job_name")
71-
expected_args["model_s3_location"] = S3_MODEL_ARTIFACTS
78+
expected_args["model_s3_location"] = S3_MODEL_SRC_COMPRESSED
7279
expected_args["image_uri"] = IMAGE
7380
expected_args["role"] = TRAIN_ROLE
7481
expected_args["name"] = JOB_NAME
@@ -100,7 +107,7 @@ def test_no_defaults_no_existing_entities(sagemaker_session):
100107

101108
expected_args = original_args.copy()
102109
expected_args.pop("job_name")
103-
expected_args["model_s3_location"] = S3_MODEL_ARTIFACTS
110+
expected_args["model_s3_location"] = S3_MODEL_SRC_COMPRESSED
104111
expected_args["model_vpc_config"] = expected_args.pop("vpc_config_override")
105112
expected_args["data_capture_config"] = None
106113
sagemaker_session.endpoint_from_model_data.assert_called_once_with(**expected_args)

0 commit comments

Comments
 (0)