diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index dd9529916e..270b838164 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -57,6 +57,9 @@ def properties(self): class RegisterModel(StepCollection): # pragma: no cover """Register Model step collection for workflow.""" + _REGISTER_MODEL_NAME_BASE = "RegisterModel" + _REPACK_MODEL_NAME_BASE = "RepackModel" + def __init__( self, name: str, @@ -168,7 +171,7 @@ def __init__( kwargs = dict(**kwargs, output_kms_key=kwargs.pop("model_kms_key", None)) repack_model_step = _RepackModelStep( - name=f"{name}RepackModel", + name="{}-{}".format(self.name, self._REPACK_MODEL_NAME_BASE), depends_on=depends_on, retry_policies=repack_model_step_retry_policies, sagemaker_session=estimator.sagemaker_session, @@ -212,7 +215,7 @@ def __init__( model_name = model_entity.name or model_entity._framework_name repack_model_step = _RepackModelStep( - name=f"{model_name}RepackModel", + name="{}-{}".format(model_name, self._REPACK_MODEL_NAME_BASE), depends_on=depends_on, retry_policies=repack_model_step_retry_policies, sagemaker_session=sagemaker_session, @@ -256,7 +259,7 @@ def __init__( ) register_model_step = _RegisterModelStep( - name=name, + name="{}-{}".format(self.name, self._REGISTER_MODEL_NAME_BASE), estimator=estimator, model_data=model_data, content_types=content_types, diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index 56611fb696..1045a8ef0c 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -123,8 +123,8 @@ def test_conditional_pytorch_training_model_registration( model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts, content_types=["*"], response_types=["*"], - inference_instances=["*"], - transform_instances=["*"], + inference_instances=["ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], description="test-description", sample_payload_url=sample_payload_url, task=task, @@ -234,7 +234,7 @@ def test_mxnet_model_registration( content_types=["*"], response_types=["*"], inference_instances=["ml.m5.xlarge"], - transform_instances=["*"], + transform_instances=["ml.m5.xlarge"], description="test-description", sample_payload_url=sample_payload_url, task=task, @@ -670,7 +670,7 @@ def test_model_registration_with_drift_check_baselines( ) continue assert execution_steps[0]["StepStatus"] == "Succeeded" - assert execution_steps[0]["StepName"] == "MyRegisterModelStep" + assert execution_steps[0]["StepName"] == "MyRegisterModelStep-RegisterModel" response = sagemaker_session.sagemaker_client.describe_model_package( ModelPackageName=execution_steps[0]["Metadata"]["RegisterModel"]["Arn"] diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index fd84bf4b77..d3b2a19fe3 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -26,7 +26,7 @@ _REPACK_MODEL_NAME_BASE, ) from sagemaker.workflow.parameters import ParameterString -from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.pipeline import Pipeline, PipelineGraph from sagemaker.workflow.pipeline_context import PipelineSession from sagemaker.workflow.utilities import list_to_request from tests.unit import DATA_DIR @@ -268,7 +268,7 @@ def test_step_collection_properties(pipeline_session, sagemaker_session): steps = register_model.steps assert len(steps) == 1 assert register_model.properties.ModelPackageName.expr == { - "Get": f"Steps.{register_model_step_name}.ModelPackageName" + "Get": f"Steps.{register_model_step_name}-RegisterModel.ModelPackageName" } # Custom StepCollection @@ -330,10 +330,9 @@ def test_step_collection_is_depended_on(pipeline_session, sagemaker_session): step_list = json.loads(pipeline.definition())["Steps"] assert len(step_list) == 7 for step in step_list: - if step["Name"] not in ["MyStep2", "MyStep3", f"{model_name}RepackModel"]: + if step["Name"] not in ["MyStep2", "MyStep3", f"{model_name}-RepackModel"]: assert "DependsOn" not in step - continue - if step["Name"] == f"{model_name}RepackModel": + elif step["Name"] == f"{model_name}-RepackModel": assert set(step["DependsOn"]) == { "MyStep1", f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}", @@ -344,9 +343,21 @@ def test_step_collection_is_depended_on(pipeline_session, sagemaker_session): "MyStep1", f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}", f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}", - f"{model_name}RepackModel", - register_model_name, + f"{model_name}-RepackModel", + f"{register_model_name}-RegisterModel", } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "MyStep1": ["MyStep2", "MyStep3", "MyModel-RepackModel"], + "MyStep2": [], + "MyStep3": [], + "MyModelStep-RepackModel-MyModel": ["MyModelStep-CreateModel"], + "MyModelStep-CreateModel": ["MyStep2", "MyStep3", "MyModel-RepackModel"], + "MyModel-RepackModel": [], + "RegisterModelStep-RegisterModel": ["MyStep2", "MyStep3"], + } + ) def test_register_model(estimator, model_metrics, drift_check_baselines): @@ -378,7 +389,7 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): assert ordered(register_model.request_dicts()) == ordered( [ { - "Name": "RegisterModelStep", + "Name": "RegisterModelStep-RegisterModel", "Type": "RegisterModel", "DependsOn": ["TestStep"], "DisplayName": "RegisterModelStep", @@ -450,7 +461,7 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines): assert ordered(register_model.request_dicts()) == ordered( [ { - "Name": "RegisterModelStep", + "Name": "RegisterModelStep-RegisterModel", "Type": "RegisterModel", "Description": "description", "Arguments": { @@ -526,7 +537,7 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): assert ordered(register_model.request_dicts()) == ordered( [ { - "Name": "RegisterModelStep", + "Name": "RegisterModelStep-RegisterModel", "Type": "RegisterModel", "Description": "description", "DependsOn": ["TestStep"], @@ -618,7 +629,7 @@ def test_register_model_with_model_repack_with_estimator( for request_dict in request_dicts: if request_dict["Type"] == "Training": - assert request_dict["Name"] == "RegisterModelStepRepackModel" + assert request_dict["Name"] == "RegisterModelStep-RepackModel" assert len(request_dict["DependsOn"]) == 1 assert request_dict["DependsOn"][0] == "TestStep" arguments = request_dict["Arguments"] @@ -671,7 +682,7 @@ def test_register_model_with_model_repack_with_estimator( } ) elif request_dict["Type"] == "RegisterModel": - assert request_dict["Name"] == "RegisterModelStep" + assert request_dict["Name"] == "RegisterModelStep-RegisterModel" assert "DependsOn" not in request_dict arguments = request_dict["Arguments"] assert len(arguments["InferenceSpecification"]["Containers"]) == 1 @@ -745,7 +756,7 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift for request_dict in request_dicts: if request_dict["Type"] == "Training": - assert request_dict["Name"] == "modelNameRepackModel" + assert request_dict["Name"] == "modelName-RepackModel" assert len(request_dict["DependsOn"]) == 1 assert request_dict["DependsOn"][0] == "TestStep" arguments = request_dict["Arguments"] @@ -798,7 +809,7 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift } ) elif request_dict["Type"] == "RegisterModel": - assert request_dict["Name"] == "RegisterModelStep" + assert request_dict["Name"] == "RegisterModelStep-RegisterModel" assert "DependsOn" not in request_dict arguments = request_dict["Arguments"] assert len(arguments["InferenceSpecification"]["Containers"]) == 1 @@ -874,7 +885,7 @@ def test_register_model_with_model_repack_with_pipeline_model( for request_dict in request_dicts: if request_dict["Type"] == "Training": - assert request_dict["Name"] == "modelNameRepackModel" + assert request_dict["Name"] == "modelName-RepackModel" assert len(request_dict["DependsOn"]) == 1 assert request_dict["DependsOn"][0] == "TestStep" arguments = request_dict["Arguments"] @@ -927,7 +938,7 @@ def test_register_model_with_model_repack_with_pipeline_model( } ) elif request_dict["Type"] == "RegisterModel": - assert request_dict["Name"] == "RegisterModelStep" + assert request_dict["Name"] == "RegisterModelStep-RegisterModel" assert "DependsOn" not in request_dict arguments = request_dict["Arguments"] assert len(arguments["InferenceSpecification"]["Containers"]) == 1