Skip to content

Commit 529752c

Browse files
staubhpPayton Staubahsan-z-khanicywang86rui
authored
[fix] Check py_version existence in RegisterModel step (#2320)
* Check py_version existence in RegisterModel step * black-format Co-authored-by: Payton Staub <[email protected]> Co-authored-by: Ahsan Khan <[email protected]> Co-authored-by: icywang86rui <[email protected]>
1 parent f399fb7 commit 529752c

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

src/sagemaker/workflow/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def arguments(self) -> RequestType:
307307
model._framework_name,
308308
region_name,
309309
version=model.framework_version,
310-
py_version=model.py_version,
310+
py_version=model.py_version if hasattr(model, "py_version") else None,
311311
instance_type=self.kwargs.get("instance_type", self.estimator.instance_type),
312312
accelerator_type=self.kwargs.get("accelerator_type"),
313313
image_scope="inference",

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929

3030
from sagemaker.estimator import Estimator
31+
from sagemaker.tensorflow import TensorFlow
3132
from sagemaker.inputs import CreateModelInput, TransformInput
3233
from sagemaker.model_metrics import (
3334
MetricsSource,
@@ -120,6 +121,19 @@ def estimator(sagemaker_session):
120121
)
121122

122123

124+
@pytest.fixture
125+
def estimator_tf(sagemaker_session):
126+
return TensorFlow(
127+
entry_point="/some/script.py",
128+
framework_version="1.15.2",
129+
py_version="py3",
130+
role=ROLE,
131+
instance_type="ml.c4.2xlarge",
132+
instance_count=1,
133+
sagemaker_session=sagemaker_session,
134+
)
135+
136+
123137
@pytest.fixture
124138
def model_metrics():
125139
return ModelMetrics(
@@ -202,6 +216,56 @@ def test_register_model(estimator, model_metrics):
202216
)
203217

204218

219+
def test_register_model_tf(estimator_tf, model_metrics):
220+
model_data = f"s3://{BUCKET}/model.tar.gz"
221+
register_model = RegisterModel(
222+
name="RegisterModelStep",
223+
estimator=estimator_tf,
224+
model_data=model_data,
225+
content_types=["content_type"],
226+
response_types=["response_type"],
227+
inference_instances=["inference_instance"],
228+
transform_instances=["transform_instance"],
229+
model_package_group_name="mpg",
230+
model_metrics=model_metrics,
231+
approval_status="Approved",
232+
description="description",
233+
)
234+
assert ordered(register_model.request_dicts()) == ordered(
235+
[
236+
{
237+
"Name": "RegisterModelStep",
238+
"Type": "RegisterModel",
239+
"Arguments": {
240+
"InferenceSpecification": {
241+
"Containers": [
242+
{
243+
"Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:1.15.2-cpu",
244+
"ModelDataUrl": f"s3://{BUCKET}/model.tar.gz",
245+
}
246+
],
247+
"SupportedContentTypes": ["content_type"],
248+
"SupportedRealtimeInferenceInstanceTypes": ["inference_instance"],
249+
"SupportedResponseMIMETypes": ["response_type"],
250+
"SupportedTransformInstanceTypes": ["transform_instance"],
251+
},
252+
"ModelApprovalStatus": "Approved",
253+
"ModelMetrics": {
254+
"ModelQuality": {
255+
"Statistics": {
256+
"ContentType": "text/csv",
257+
"S3Uri": f"s3://{BUCKET}/metrics.csv",
258+
},
259+
},
260+
},
261+
"ModelPackageDescription": "description",
262+
"ModelPackageGroupName": "mpg",
263+
},
264+
},
265+
]
266+
)
267+
268+
205269
def test_register_model_with_model_repack(estimator, model_metrics):
206270
model_data = f"s3://{BUCKET}/model.tar.gz"
207271
register_model = RegisterModel(

0 commit comments

Comments
 (0)