Skip to content

Commit a11e299

Browse files
nmadanNamrata Madan
and
Namrata Madan
authored
fix: rename RegisterModel inner steps to prevent duplicate step names (#3240)
Co-authored-by: Namrata Madan <[email protected]>
1 parent bd8ea40 commit a11e299

File tree

3 files changed

+37
-23
lines changed

3 files changed

+37
-23
lines changed

src/sagemaker/workflow/step_collections.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def properties(self):
5757
class RegisterModel(StepCollection): # pragma: no cover
5858
"""Register Model step collection for workflow."""
5959

60+
_REGISTER_MODEL_NAME_BASE = "RegisterModel"
61+
_REPACK_MODEL_NAME_BASE = "RepackModel"
62+
6063
def __init__(
6164
self,
6265
name: str,
@@ -168,7 +171,7 @@ def __init__(
168171
kwargs = dict(**kwargs, output_kms_key=kwargs.pop("model_kms_key", None))
169172

170173
repack_model_step = _RepackModelStep(
171-
name=f"{name}RepackModel",
174+
name="{}-{}".format(self.name, self._REPACK_MODEL_NAME_BASE),
172175
depends_on=depends_on,
173176
retry_policies=repack_model_step_retry_policies,
174177
sagemaker_session=estimator.sagemaker_session,
@@ -212,7 +215,7 @@ def __init__(
212215
model_name = model_entity.name or model_entity._framework_name
213216

214217
repack_model_step = _RepackModelStep(
215-
name=f"{model_name}RepackModel",
218+
name="{}-{}".format(model_name, self._REPACK_MODEL_NAME_BASE),
216219
depends_on=depends_on,
217220
retry_policies=repack_model_step_retry_policies,
218221
sagemaker_session=sagemaker_session,
@@ -256,7 +259,7 @@ def __init__(
256259
)
257260

258261
register_model_step = _RegisterModelStep(
259-
name=name,
262+
name="{}-{}".format(self.name, self._REGISTER_MODEL_NAME_BASE),
260263
estimator=estimator,
261264
model_data=model_data,
262265
content_types=content_types,

tests/integ/sagemaker/workflow/test_model_create_and_registration.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ def test_conditional_pytorch_training_model_registration(
123123
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
124124
content_types=["*"],
125125
response_types=["*"],
126-
inference_instances=["*"],
127-
transform_instances=["*"],
126+
inference_instances=["ml.m5.xlarge"],
127+
transform_instances=["ml.m5.xlarge"],
128128
description="test-description",
129129
sample_payload_url=sample_payload_url,
130130
task=task,
@@ -234,7 +234,7 @@ def test_mxnet_model_registration(
234234
content_types=["*"],
235235
response_types=["*"],
236236
inference_instances=["ml.m5.xlarge"],
237-
transform_instances=["*"],
237+
transform_instances=["ml.m5.xlarge"],
238238
description="test-description",
239239
sample_payload_url=sample_payload_url,
240240
task=task,
@@ -670,7 +670,7 @@ def test_model_registration_with_drift_check_baselines(
670670
)
671671
continue
672672
assert execution_steps[0]["StepStatus"] == "Succeeded"
673-
assert execution_steps[0]["StepName"] == "MyRegisterModelStep"
673+
assert execution_steps[0]["StepName"] == "MyRegisterModelStep-RegisterModel"
674674

675675
response = sagemaker_session.sagemaker_client.describe_model_package(
676676
ModelPackageName=execution_steps[0]["Metadata"]["RegisterModel"]["Arn"]

tests/unit/sagemaker/workflow/test_step_collections.py

+27-16
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
_REPACK_MODEL_NAME_BASE,
2727
)
2828
from sagemaker.workflow.parameters import ParameterString
29-
from sagemaker.workflow.pipeline import Pipeline
29+
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
3030
from sagemaker.workflow.pipeline_context import PipelineSession
3131
from sagemaker.workflow.utilities import list_to_request
3232
from tests.unit import DATA_DIR
@@ -268,7 +268,7 @@ def test_step_collection_properties(pipeline_session, sagemaker_session):
268268
steps = register_model.steps
269269
assert len(steps) == 1
270270
assert register_model.properties.ModelPackageName.expr == {
271-
"Get": f"Steps.{register_model_step_name}.ModelPackageName"
271+
"Get": f"Steps.{register_model_step_name}-RegisterModel.ModelPackageName"
272272
}
273273

274274
# Custom StepCollection
@@ -330,10 +330,9 @@ def test_step_collection_is_depended_on(pipeline_session, sagemaker_session):
330330
step_list = json.loads(pipeline.definition())["Steps"]
331331
assert len(step_list) == 7
332332
for step in step_list:
333-
if step["Name"] not in ["MyStep2", "MyStep3", f"{model_name}RepackModel"]:
333+
if step["Name"] not in ["MyStep2", "MyStep3", f"{model_name}-RepackModel"]:
334334
assert "DependsOn" not in step
335-
continue
336-
if step["Name"] == f"{model_name}RepackModel":
335+
elif step["Name"] == f"{model_name}-RepackModel":
337336
assert set(step["DependsOn"]) == {
338337
"MyStep1",
339338
f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}",
@@ -344,9 +343,21 @@ def test_step_collection_is_depended_on(pipeline_session, sagemaker_session):
344343
"MyStep1",
345344
f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}",
346345
f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}",
347-
f"{model_name}RepackModel",
348-
register_model_name,
346+
f"{model_name}-RepackModel",
347+
f"{register_model_name}-RegisterModel",
349348
}
349+
adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list
350+
assert ordered(adjacency_list) == ordered(
351+
{
352+
"MyStep1": ["MyStep2", "MyStep3", "MyModel-RepackModel"],
353+
"MyStep2": [],
354+
"MyStep3": [],
355+
"MyModelStep-RepackModel-MyModel": ["MyModelStep-CreateModel"],
356+
"MyModelStep-CreateModel": ["MyStep2", "MyStep3", "MyModel-RepackModel"],
357+
"MyModel-RepackModel": [],
358+
"RegisterModelStep-RegisterModel": ["MyStep2", "MyStep3"],
359+
}
360+
)
350361

351362

352363
def test_register_model(estimator, model_metrics, drift_check_baselines):
@@ -378,7 +389,7 @@ def test_register_model(estimator, model_metrics, drift_check_baselines):
378389
assert ordered(register_model.request_dicts()) == ordered(
379390
[
380391
{
381-
"Name": "RegisterModelStep",
392+
"Name": "RegisterModelStep-RegisterModel",
382393
"Type": "RegisterModel",
383394
"DependsOn": ["TestStep"],
384395
"DisplayName": "RegisterModelStep",
@@ -450,7 +461,7 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines):
450461
assert ordered(register_model.request_dicts()) == ordered(
451462
[
452463
{
453-
"Name": "RegisterModelStep",
464+
"Name": "RegisterModelStep-RegisterModel",
454465
"Type": "RegisterModel",
455466
"Description": "description",
456467
"Arguments": {
@@ -526,7 +537,7 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines):
526537
assert ordered(register_model.request_dicts()) == ordered(
527538
[
528539
{
529-
"Name": "RegisterModelStep",
540+
"Name": "RegisterModelStep-RegisterModel",
530541
"Type": "RegisterModel",
531542
"Description": "description",
532543
"DependsOn": ["TestStep"],
@@ -618,7 +629,7 @@ def test_register_model_with_model_repack_with_estimator(
618629

619630
for request_dict in request_dicts:
620631
if request_dict["Type"] == "Training":
621-
assert request_dict["Name"] == "RegisterModelStepRepackModel"
632+
assert request_dict["Name"] == "RegisterModelStep-RepackModel"
622633
assert len(request_dict["DependsOn"]) == 1
623634
assert request_dict["DependsOn"][0] == "TestStep"
624635
arguments = request_dict["Arguments"]
@@ -671,7 +682,7 @@ def test_register_model_with_model_repack_with_estimator(
671682
}
672683
)
673684
elif request_dict["Type"] == "RegisterModel":
674-
assert request_dict["Name"] == "RegisterModelStep"
685+
assert request_dict["Name"] == "RegisterModelStep-RegisterModel"
675686
assert "DependsOn" not in request_dict
676687
arguments = request_dict["Arguments"]
677688
assert len(arguments["InferenceSpecification"]["Containers"]) == 1
@@ -745,7 +756,7 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift
745756

746757
for request_dict in request_dicts:
747758
if request_dict["Type"] == "Training":
748-
assert request_dict["Name"] == "modelNameRepackModel"
759+
assert request_dict["Name"] == "modelName-RepackModel"
749760
assert len(request_dict["DependsOn"]) == 1
750761
assert request_dict["DependsOn"][0] == "TestStep"
751762
arguments = request_dict["Arguments"]
@@ -798,7 +809,7 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift
798809
}
799810
)
800811
elif request_dict["Type"] == "RegisterModel":
801-
assert request_dict["Name"] == "RegisterModelStep"
812+
assert request_dict["Name"] == "RegisterModelStep-RegisterModel"
802813
assert "DependsOn" not in request_dict
803814
arguments = request_dict["Arguments"]
804815
assert len(arguments["InferenceSpecification"]["Containers"]) == 1
@@ -874,7 +885,7 @@ def test_register_model_with_model_repack_with_pipeline_model(
874885

875886
for request_dict in request_dicts:
876887
if request_dict["Type"] == "Training":
877-
assert request_dict["Name"] == "modelNameRepackModel"
888+
assert request_dict["Name"] == "modelName-RepackModel"
878889
assert len(request_dict["DependsOn"]) == 1
879890
assert request_dict["DependsOn"][0] == "TestStep"
880891
arguments = request_dict["Arguments"]
@@ -927,7 +938,7 @@ def test_register_model_with_model_repack_with_pipeline_model(
927938
}
928939
)
929940
elif request_dict["Type"] == "RegisterModel":
930-
assert request_dict["Name"] == "RegisterModelStep"
941+
assert request_dict["Name"] == "RegisterModelStep-RegisterModel"
931942
assert "DependsOn" not in request_dict
932943
arguments = request_dict["Arguments"]
933944
assert len(arguments["InferenceSpecification"]["Containers"]) == 1

0 commit comments

Comments
 (0)