Skip to content

Commit 60f9634

Browse files
public-git-uiRaymond Lee
public-git-ui
authored and
Raymond Lee
committed
change: allow serving image to be specified when calling MXNet.deploy
1 parent ff06b6c commit 60f9634

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

src/sagemaker/mxnet/estimator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def create_model(
140140
entry_point=None,
141141
source_dir=None,
142142
dependencies=None,
143+
image_name=None,
143144
):
144145
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an
145146
``Endpoint``.
@@ -164,6 +165,12 @@ def create_model(
164165
dependencies (list[str]): A list of paths to directories (absolute or relative) with
165166
any additional libraries that will be exported to the container.
166167
If not specified, the dependencies from training are used.
168+
image_name (str): If specified, the estimator will use this image for hosting, instead
169+
of selecting the appropriate SageMaker official image based on framework_version
170+
and py_version. It can be an ECR url or dockerhub image and tag.
171+
Examples:
172+
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
173+
custom-image:latest.
167174
168175
Returns:
169176
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
@@ -180,7 +187,7 @@ def create_model(
180187
code_location=self.code_location,
181188
py_version=self.py_version,
182189
framework_version=self.framework_version,
183-
image=self.image_name,
190+
image=(image_name or self.image_name),
184191
model_server_workers=model_server_workers,
185192
sagemaker_session=self.sagemaker_session,
186193
vpc_config=self.get_vpc_config(vpc_config_override),

tests/unit/test_mxnet.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,3 +704,26 @@ def test_empty_framework_version(warning, sagemaker_session):
704704

705705
assert mx.framework_version == defaults.MXNET_VERSION
706706
warning.assert_called_with(defaults.MXNET_VERSION, mx.LATEST_VERSION)
707+
708+
709+
def test_create_model_with_custom_hosting_image(sagemaker_session):
710+
container_log_level = '"logging.INFO"'
711+
source_dir = "s3://mybucket/source"
712+
custom_image = "mxnet:2.0"
713+
custom_hosting_image = "mxnet_hosting:2.0"
714+
mx = MXNet(
715+
entry_point=SCRIPT_PATH,
716+
role=ROLE,
717+
sagemaker_session=sagemaker_session,
718+
train_instance_count=INSTANCE_COUNT,
719+
train_instance_type=INSTANCE_TYPE,
720+
image_name=custom_image,
721+
container_log_level=container_log_level,
722+
base_job_name="job",
723+
source_dir=source_dir,
724+
)
725+
726+
mx.fit(inputs="s3://mybucket/train", job_name="new_name")
727+
model = mx.create_model(image_name=custom_hosting_image)
728+
729+
assert model.image == custom_hosting_image

0 commit comments

Comments
 (0)