diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index e4c035dff7..4d0f06fbfb 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -619,6 +619,7 @@ def _format_tag(tag_prefix, processor, py_version, container_version, inference_ return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x) +@override_pipeline_parameter_var def get_training_image_uri( region, framework, diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py index dd27918107..4682572b96 100644 --- a/src/sagemaker/workflow/utilities.py +++ b/src/sagemaker/workflow/utilities.py @@ -390,8 +390,10 @@ def override_pipeline_parameter_var(func): We should remove this decorator after the grace period. """ warning_msg_template = ( - "The input argument %s of function (%s) is a pipeline variable (%s), which is not allowed. " - "The default_value of this Parameter object will be used to override it. " + "The input argument %s of function (%s) is a pipeline variable (%s), " + "which is interpreted in pipeline execution time only. " + "As the function needs to evaluate the argument value in SDK compile time, " + "the default_value of this Parameter object will be used to override it. " "Please make sure the default_value is valid." ) diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py index 2ecc4e7ec7..4491c1f6ee 100644 --- a/tests/unit/sagemaker/workflow/test_processing_step.py +++ b/tests/unit/sagemaker/workflow/test_processing_step.py @@ -69,6 +69,7 @@ from tests.unit import DATA_DIR from tests.unit.sagemaker.workflow.conftest import ROLE, BUCKET, IMAGE_URI, INSTANCE_TYPE +HF_INSTANCE_TYPE = "ml.p3.2xlarge" DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py" LOCAL_SCRIPT_PATH = os.path.join(DATA_DIR, "workflow/abalone/preprocessing.py") SPARK_APP_JAR_PATH = os.path.join( @@ -142,7 +143,7 @@ pytorch_version="1.7", role=ROLE, instance_count=1, - instance_type="ml.p3.2xlarge", + instance_type=HF_INSTANCE_TYPE, ), {"code": DUMMY_S3_SCRIPT_PATH}, ), @@ -446,8 +447,15 @@ def test_processing_step_with_framework_processor( ): processor, run_inputs = framework_processor + default_instance_type = ( + HF_INSTANCE_TYPE if type(processor) is HuggingFaceProcessor else INSTANCE_TYPE + ) + instance_type_param = ParameterString( + name="ProcessingInstanceType", default_value=default_instance_type + ) processor.sagemaker_session = pipeline_session processor.role = ROLE + processor.instance_type = instance_type_param processor.volume_kms_key = "volume-kms-key" processor.network_config = network_config @@ -465,6 +473,7 @@ def test_processing_step_with_framework_processor( name="MyPipeline", steps=[step], sagemaker_session=pipeline_session, + parameters=[instance_type_param], ) step_args = get_step_args_helper(step_args, "Processing") @@ -475,6 +484,12 @@ def test_processing_step_with_framework_processor( step_args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"] == processing_output.destination ) + assert ( + type(step_args["ProcessingResources"]["ClusterConfig"]["InstanceType"]) is ParameterString + ) + step_args["ProcessingResources"]["ClusterConfig"]["InstanceType"] = step_args[ + "ProcessingResources" + ]["ClusterConfig"]["InstanceType"].expr del step_args["ProcessingInputs"][0]["S3Input"]["S3Uri"] del step_def["Arguments"]["ProcessingInputs"][0]["S3Input"]["S3Uri"]