Skip to content

feat: Deploy uncompressed ML model from S3 to SageMaker Hosting endpoints #4005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1737,21 +1737,41 @@ def register(

@property
def model_data(self):
"""str: The model location in S3. Only set if Estimator has been ``fit()``."""
"""Str or dict: The model location in S3. Only set if Estimator has been ``fit()``."""
if self.latest_training_job is not None and not isinstance(
self.sagemaker_session, PipelineSession
):
model_uri = self.sagemaker_session.sagemaker_client.describe_training_job(
job_details = self.sagemaker_session.sagemaker_client.describe_training_job(
TrainingJobName=self.latest_training_job.name
)["ModelArtifacts"]["S3ModelArtifacts"]
else:
logger.warning(
"No finished training job found associated with this estimator. Please make sure "
"this estimator is only used for building workflow config"
)
model_uri = os.path.join(
self.output_path, self._current_job_name, "output", "model.tar.gz"
model_uri = job_details["ModelArtifacts"]["S3ModelArtifacts"]
compression_type = job_details.get("OutputDataConfig", {}).get(
"CompressionType", "GZIP"
)
if compression_type == "GZIP":
return model_uri
# fail fast if we don't recognize training output compression type
if compression_type not in {"GZIP", "NONE"}:
raise ValueError(
f'Unrecognized training job output data compression type "{compression_type}"'
)
# model data is in uncompressed form NOTE SageMaker Hosting mandates presence of
# trailing forward slash in S3 model data URI, so append one if necessary.
if not model_uri.endswith("/"):
model_uri += "/"
return {
"S3DataSource": {
"S3Uri": model_uri,
"S3DataType": "S3Prefix",
"CompressionType": "None",
}
}

logger.warning(
"No finished training job found associated with this estimator. Please make sure "
"this estimator is only used for building workflow config"
)
model_uri = os.path.join(self.output_path, self._current_job_name, "output", "model.tar.gz")
return model_uri

@abstractmethod
Expand Down
38 changes: 30 additions & 8 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class Model(ModelBase, InferenceRecommenderMixin):
def __init__(
self,
image_uri: Union[str, PipelineVariable],
model_data: Optional[Union[str, PipelineVariable]] = None,
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
role: Optional[str] = None,
predictor_cls: Optional[callable] = None,
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
Expand All @@ -152,8 +152,8 @@ def __init__(
Args:
image_uri (str or PipelineVariable): A Docker image URI.
model_data (str or PipelineVariable): The S3 location of a SageMaker
model data ``.tar.gz`` file (default: None).
model_data (str or PipelineVariable or dict): Location
of SageMaker model data (default: None).
role (str): An AWS IAM role (either name or full ARN). The Amazon
SageMaker training jobs and APIs that create Amazon SageMaker
endpoints use this role to access training data and model
Expand Down Expand Up @@ -455,6 +455,11 @@ def register(
"""
if self.model_data is None:
raise ValueError("SageMaker Model Package cannot be created without model data.")
if isinstance(self.model_data, dict):
raise ValueError(
"SageMaker Model Package currently cannot be created with ModelDataSource."
)

if image_uri is not None:
self.image_uri = image_uri

Expand Down Expand Up @@ -600,6 +605,7 @@ def prepare_container_def(
)
self._upload_code(deploy_key_prefix, repack=is_repack)
deploy_env.update(self._script_mode_env_vars())

return sagemaker.container_def(
self.image_uri,
self.repacked_model_data or self.model_data,
Expand Down Expand Up @@ -639,6 +645,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
)

if repack and self.model_data is not None and self.entry_point is not None:
if isinstance(self.model_data, dict):
logging.warning("ModelDataSource currently doesn't support model repacking")
return
if is_pipeline_variable(self.model_data):
# model is not yet there, defer repacking to later during pipeline execution
if not isinstance(self.sagemaker_session, PipelineSession):
Expand Down Expand Up @@ -765,10 +774,16 @@ def _create_sagemaker_model(
# _base_name, model_name are not needed under PipelineSession.
# the model_data may be Pipeline variable
# which may break the _base_name generation
model_uri = None
if isinstance(self.model_data, (str, PipelineVariable)):
model_uri = self.model_data
elif isinstance(self.model_data, dict):
model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None)

self._ensure_base_name_if_needed(
image_uri=container_def["Image"],
script_uri=self.source_dir,
model_uri=self.model_data,
model_uri=model_uri,
)
self._set_model_name_if_needed()

Expand Down Expand Up @@ -1110,6 +1125,8 @@ def compile(
raise ValueError("You must provide a compilation job name")
if self.model_data is None:
raise ValueError("You must provide an S3 path to the compressed model artifacts.")
if isinstance(self.model_data, dict):
raise ValueError("Compiling model data from ModelDataSource is currently not supported")

framework_version = framework_version or self._get_framework_version()

Expand Down Expand Up @@ -1301,7 +1318,7 @@ def deploy(

tags = add_jumpstart_tags(
tags=tags,
inference_model_uri=self.model_data,
inference_model_uri=self.model_data if isinstance(self.model_data, str) else None,
inference_script_uri=self.source_dir,
)

Expand Down Expand Up @@ -1545,7 +1562,7 @@ class FrameworkModel(Model):

def __init__(
self,
model_data: Union[str, PipelineVariable],
model_data: Union[str, PipelineVariable, dict],
image_uri: Union[str, PipelineVariable],
role: Optional[str] = None,
entry_point: Optional[str] = None,
Expand All @@ -1563,8 +1580,8 @@ def __init__(
"""Initialize a ``FrameworkModel``.
Args:
model_data (str or PipelineVariable): The S3 location of a SageMaker
model data ``.tar.gz`` file.
model_data (str or PipelineVariable or dict): The S3 location of
SageMaker model data.
image_uri (str or PipelineVariable): A Docker image URI.
role (str): An IAM role name or ARN for SageMaker to access AWS
resources on your behalf.
Expand Down Expand Up @@ -1758,6 +1775,11 @@ def __init__(
``model_data`` is not required.
**kwargs: Additional kwargs passed to the Model constructor.
"""
if isinstance(model_data, dict):
raise ValueError(
"Creating ModelPackage with ModelDataSource is currently not supported"
)

super(ModelPackage, self).__init__(
role=role, model_data=model_data, image_uri=None, **kwargs
)
Expand Down
55 changes: 48 additions & 7 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3584,7 +3584,7 @@ def create_model_from_job(
)
primary_container = container_def(
image_uri or training_job["AlgorithmSpecification"]["TrainingImage"],
model_data_url=model_data_url or training_job["ModelArtifacts"]["S3ModelArtifacts"],
model_data_url=model_data_url or self._gen_s3_model_data_source(training_job),
env=env,
)
vpc_config = _vpc_config_from_training_job(training_job, vpc_config_override)
Expand Down Expand Up @@ -4486,14 +4486,14 @@ def endpoint_from_job(
str: Name of the ``Endpoint`` that is created.
"""
job_desc = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
output_url = job_desc["ModelArtifacts"]["S3ModelArtifacts"]
model_s3_location = self._gen_s3_model_data_source(job_desc)
image_uri = image_uri or job_desc["AlgorithmSpecification"]["TrainingImage"]
role = role or job_desc["RoleArn"]
name = name or job_name
vpc_config_override = _vpc_config_from_training_job(job_desc, vpc_config_override)

return self.endpoint_from_model_data(
model_s3_location=output_url,
model_s3_location=model_s3_location,
image_uri=image_uri,
initial_instance_count=initial_instance_count,
instance_type=instance_type,
Expand All @@ -4506,6 +4506,40 @@ def endpoint_from_job(
data_capture_config=data_capture_config,
)

def _gen_s3_model_data_source(self, training_job_spec):
"""Generates ``ModelDataSource`` value from given DescribeTrainingJob API response.
Args:
training_job_spec (dict): SageMaker DescribeTrainingJob API response.
Returns:
dict: A ``ModelDataSource`` value.
"""
model_data_s3_uri = training_job_spec["ModelArtifacts"]["S3ModelArtifacts"]
compression_type = training_job_spec.get("OutputDataConfig", {}).get(
"CompressionType", "GZIP"
)
# See https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_OutputDataConfig.html
# and https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3ModelDataSource.html
if compression_type in {"NONE", "GZIP"}:
model_compression_type = compression_type.title()
else:
raise ValueError(
f'Unrecognized training job output data compression type "{compression_type}"'
)
s3_model_data_type = "S3Object" if model_compression_type == "Gzip" else "S3Prefix"
# if model data is in S3Prefix type and has no trailing forward slash in its URI,
# append one so that it meets SageMaker Hosting's mandate for deploying uncompressed model.
if s3_model_data_type == "S3Prefix" and not model_data_s3_uri.endswith("/"):
model_data_s3_uri += "/"
return {
"S3DataSource": {
"S3Uri": model_data_s3_uri,
"S3DataType": s3_model_data_type,
"CompressionType": model_compression_type,
}
}

def endpoint_from_model_data(
self,
model_s3_location,
Expand All @@ -4524,7 +4558,8 @@ def endpoint_from_model_data(
"""Create and deploy to an ``Endpoint`` using existing model data stored in S3.
Args:
model_s3_location (str): S3 URI of the model artifacts to use for the endpoint.
model_s3_location (str or dict): S3 location of the model artifacts
to use for the endpoint.
image_uri (str): The Docker image URI which defines the runtime code to be
used as the entry point for accepting prediction requests.
initial_instance_count (int): Minimum number of EC2 instances to launch. The actual
Expand Down Expand Up @@ -5925,8 +5960,10 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None,
Args:
image_uri (str): Docker image URI to run for this container.
model_data_url (str): S3 URI of data required by this container,
e.g. SageMaker training job model artifacts (default: None).
model_data_url (str or dict[str, Any]): S3 location of model data required by this
container, e.g. SageMaker training job model artifacts. It can either be a string
representing S3 URI of model data, or a dictionary representing a
``ModelDataSource`` object. (default: None).
env (dict[str, str]): Environment variables to set inside the container (default: None).
container_mode (str): The model container mode. Valid modes:
* MultiModel: Indicates that model container can support hosting multiple models
Expand All @@ -5943,8 +5980,12 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None,
if env is None:
env = {}
c_def = {"Image": image_uri, "Environment": env}
if model_data_url:

if isinstance(model_data_url, dict):
c_def["ModelDataSource"] = model_data_url
elif model_data_url:
c_def["ModelDataUrl"] = model_data_url

if container_mode:
c_def["Mode"] = container_mode
if image_config:
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/sagemaker/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,39 @@ def test_register_calls_model_package_args(get_model_package_args, sagemaker_ses
get_model_package_args"""


def test_register_calls_model_data_source_not_supported(sagemaker_session):
source_dir = "s3://blah/blah/blah"
t = Model(
entry_point=ENTRY_POINT_INFERENCE,
role=ROLE,
sagemaker_session=sagemaker_session,
source_dir=source_dir,
image_uri=IMAGE_URI,
model_data={
"S3DataSource": {
"S3Uri": "s3://bucket/model/prefix/",
"S3DataType": "S3Prefix",
"CompressionType": "None",
}
},
)

with pytest.raises(
ValueError,
match="SageMaker Model Package currently cannot be created with ModelDataSource.",
):
t.register(
SUPPORTED_CONTENT_TYPES,
SUPPORTED_RESPONSE_MIME_TYPES,
SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES,
SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES,
marketplace_cert=True,
description=MODEL_DESCRIPTION,
model_package_name=MODEL_NAME,
validation_specification=VALIDATION_SPECIFICATION,
)


@patch("sagemaker.utils.repack_model")
def test_model_local_download_dir(repack_model, sagemaker_session):

Expand Down
18 changes: 18 additions & 0 deletions tests/unit/sagemaker/model/test_model_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,24 @@ def test_create_sagemaker_model_include_tags(sagemaker_session):
)


def test_model_package_model_data_source_not_supported(sagemaker_session):
with pytest.raises(
ValueError, match="Creating ModelPackage with ModelDataSource is currently not supported"
):
ModelPackage(
role="role",
model_package_arn="my-model-package",
model_data={
"S3DataSource": {
"S3Uri": "s3://bucket/model/prefix/",
"S3DataType": "S3Prefix",
"CompressionType": "None",
}
},
sagemaker_session=sagemaker_session,
)


@patch("sagemaker.utils.name_from_base")
def test_create_sagemaker_model_generates_model_name(name_from_base, sagemaker_session):
model_package_name = "my-model-package"
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/sagemaker/model/test_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,31 @@ def test_compile_validates_model_data():
assert "You must provide an S3 path to the compressed model artifacts." in str(e)


def test_compile_validates_model_data_source():
model_data_src = {
"S3DataSource": {
"S3Uri": "s3://bucket/model/prefix",
"S3DataType": "S3Prefix",
"CompressionType": "None",
}
}
model = Model(MODEL_IMAGE, model_data=model_data_src)

with pytest.raises(
ValueError, match="Compiling model data from ModelDataSource is currently not supported"
) as e:
model.compile(
target_instance_family="ml_c4",
input_shape={"data": [1, 3, 1024, 1024]},
output_path="s3://output",
role="role",
framework="tensorflow",
job_name="compile-model",
)

assert "Compiling model data from ModelDataSource is currently not supported" in str(e)


def test_deploy_honors_provided_model_name(sagemaker_session):
model = _create_model(sagemaker_session)
model._is_compiled_model = True
Expand Down
Loading