Skip to content

change: allow hosting image to be specified in MXNet.create_model #959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
image_name=None,
):
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an
``Endpoint``.
Expand All @@ -164,6 +165,12 @@ def create_model(
dependencies (list[str]): A list of paths to directories (absolute or relative) with
any additional libraries that will be exported to the container.
If not specified, the dependencies from training are used.
image_name (str): If specified, the estimator will use this image for hosting, instead
of selecting the appropriate SageMaker official image based on framework_version
and py_version. It can be an ECR url or dockerhub image and tag.
Examples:
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
custom-image:latest.

Returns:
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
Expand All @@ -180,7 +187,7 @@ def create_model(
code_location=self.code_location,
py_version=self.py_version,
framework_version=self.framework_version,
image=self.image_name,
image=(image_name or self.image_name),
model_server_workers=model_server_workers,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,26 @@ def test_empty_framework_version(warning, sagemaker_session):

assert mx.framework_version == defaults.MXNET_VERSION
warning.assert_called_with(defaults.MXNET_VERSION, mx.LATEST_VERSION)


def test_create_model_with_custom_hosting_image(sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = "s3://mybucket/source"
custom_image = "mxnet:2.0"
custom_hosting_image = "mxnet_hosting:2.0"
mx = MXNet(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
image_name=custom_image,
container_log_level=container_log_level,
base_job_name="job",
source_dir=source_dir,
)

mx.fit(inputs="s3://mybucket/train", job_name="new_name")
model = mx.create_model(image_name=custom_hosting_image)

assert model.image == custom_hosting_image