Skip to content

Commit e2846fb

Browse files
sreedesBasil BeiroutiPayton Staubahsan-z-khanmufaddal-rohawala
authored andcommitted
fix: Model Registration with BYO scripts (aws#2797)
Co-authored-by: Basil Beirouti <[email protected]> Co-authored-by: Payton Staub <[email protected]> Co-authored-by: Ahsan Khan <[email protected]> Co-authored-by: Mufaddal Rohawala <[email protected]> Co-authored-by: Basil Beirouti <[email protected]> Co-authored-by: Payton Staub <[email protected]> Co-authored-by: Shreya Pandit <[email protected]>
1 parent ee707cb commit e2846fb

File tree

2 files changed

+62
-9
lines changed

2 files changed

+62
-9
lines changed

src/sagemaker/model.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -178,21 +178,26 @@ def register(
178178
"""
179179
if self.model_data is None:
180180
raise ValueError("SageMaker Model Package cannot be created without model data.")
181+
if image_uri is not None:
182+
self.image_uri = image_uri
183+
if model_package_group_name is not None:
184+
container_def = self.prepare_container_def()
185+
else:
186+
container_def = {"Image": self.image_uri, "ModelDataUrl": self.model_data}
181187

182188
model_pkg_args = sagemaker.get_model_package_args(
183189
content_types,
184190
response_types,
185191
inference_instances,
186192
transform_instances,
187-
model_package_name,
188-
model_package_group_name,
189-
self.model_data,
190-
image_uri or self.image_uri,
191-
model_metrics,
192-
metadata_properties,
193-
marketplace_cert,
194-
approval_status,
195-
description,
193+
model_package_name=model_package_name,
194+
model_package_group_name=model_package_group_name,
195+
model_metrics=model_metrics,
196+
metadata_properties=metadata_properties,
197+
marketplace_cert=marketplace_cert,
198+
approval_status=approval_status,
199+
description=description,
200+
container_def_list=[container_def],
196201
drift_check_baselines=drift_check_baselines,
197202
)
198203
model_package = self.sagemaker_session.create_model_package_from_containers(

tests/integ/test_mxnet.py

+48
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,54 @@ def test_register_model_package(
231231
sagemaker_session.sagemaker_client.delete_model_package(ModelPackageName=model_package_name)
232232

233233

234+
def test_register_model_package_versioned(
235+
mxnet_training_job,
236+
sagemaker_session,
237+
mxnet_inference_latest_version,
238+
mxnet_inference_latest_py_version,
239+
cpu_instance_type,
240+
):
241+
endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp())
242+
243+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
244+
desc = sagemaker_session.sagemaker_client.describe_training_job(
245+
TrainingJobName=mxnet_training_job
246+
)
247+
model_package_group_name = "register-model-package-{}".format(sagemaker_timestamp())
248+
sagemaker_session.sagemaker_client.create_model_package_group(
249+
ModelPackageGroupName=model_package_group_name
250+
)
251+
model_data = desc["ModelArtifacts"]["S3ModelArtifacts"]
252+
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
253+
model = MXNetModel(
254+
model_data,
255+
"SageMakerRole",
256+
entry_point=script_path,
257+
py_version=mxnet_inference_latest_py_version,
258+
sagemaker_session=sagemaker_session,
259+
framework_version=mxnet_inference_latest_version,
260+
)
261+
model_pkg = model.register(
262+
content_types=["application/json"],
263+
response_types=["application/json"],
264+
inference_instances=["ml.m5.large"],
265+
transform_instances=["ml.m5.large"],
266+
model_package_group_name=model_package_group_name,
267+
approval_status="Approved",
268+
)
269+
assert isinstance(model_pkg, ModelPackage)
270+
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
271+
data = numpy.zeros(shape=(1, 1, 28, 28))
272+
result = predictor.predict(data)
273+
assert result is not None
274+
sagemaker_session.sagemaker_client.delete_model_package(
275+
ModelPackageName=model_pkg.model_package_arn
276+
)
277+
sagemaker_session.sagemaker_client.delete_model_package_group(
278+
ModelPackageGroupName=model_package_group_name
279+
)
280+
281+
234282
def test_deploy_model_with_tags_and_kms(
235283
mxnet_training_job,
236284
sagemaker_session,

0 commit comments

Comments
 (0)