Skip to content

Commit 6953bcc

Browse files
committed
fix: HuggingFaceProcessor parameterized instance_type when image_uri is absent
1 parent a3f5874 commit 6953bcc

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

src/sagemaker/image_uris.py

+1
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ def _format_tag(tag_prefix, processor, py_version, container_version, inference_
614614
return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x)
615615

616616

617+
@override_pipeline_parameter_var
617618
def get_training_image_uri(
618619
region,
619620
framework,

src/sagemaker/workflow/utilities.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,10 @@ def override_pipeline_parameter_var(func):
390390
We should remove this decorator after the grace period.
391391
"""
392392
warning_msg_template = (
393-
"The input argument %s of function (%s) is a pipeline variable (%s), which is not allowed. "
394-
"The default_value of this Parameter object will be used to override it. "
393+
"The input argument %s of function (%s) is a pipeline variable (%s), "
394+
"which is interpreted in pipeline execution time only. "
395+
"As the function needs to evaluate the argument value in SDK compile time, "
396+
"the default_value of this Parameter object will be used to override it. "
395397
"Please make sure the default_value is valid."
396398
)
397399

tests/unit/sagemaker/workflow/test_processing_step.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from tests.unit import DATA_DIR
7070
from tests.unit.sagemaker.workflow.conftest import ROLE, BUCKET, IMAGE_URI, INSTANCE_TYPE
7171

72+
HF_INSTANCE_TYPE = "ml.p3.2xlarge"
7273
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
7374
LOCAL_SCRIPT_PATH = os.path.join(DATA_DIR, "workflow/abalone/preprocessing.py")
7475
SPARK_APP_JAR_PATH = os.path.join(
@@ -142,7 +143,7 @@
142143
pytorch_version="1.7",
143144
role=ROLE,
144145
instance_count=1,
145-
instance_type="ml.p3.2xlarge",
146+
instance_type=HF_INSTANCE_TYPE,
146147
),
147148
{"code": DUMMY_S3_SCRIPT_PATH},
148149
),
@@ -446,8 +447,15 @@ def test_processing_step_with_framework_processor(
446447
):
447448

448449
processor, run_inputs = framework_processor
450+
default_instance_type = (
451+
HF_INSTANCE_TYPE if type(processor) is HuggingFaceProcessor else INSTANCE_TYPE
452+
)
453+
instance_type_param = ParameterString(
454+
name="ProcessingInstanceType", default_value=default_instance_type
455+
)
449456
processor.sagemaker_session = pipeline_session
450457
processor.role = ROLE
458+
processor.instance_type = instance_type_param
451459

452460
processor.volume_kms_key = "volume-kms-key"
453461
processor.network_config = network_config
@@ -465,6 +473,7 @@ def test_processing_step_with_framework_processor(
465473
name="MyPipeline",
466474
steps=[step],
467475
sagemaker_session=pipeline_session,
476+
parameters=[instance_type_param],
468477
)
469478

470479
step_args = get_step_args_helper(step_args, "Processing")
@@ -475,6 +484,12 @@ def test_processing_step_with_framework_processor(
475484
step_args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
476485
== processing_output.destination
477486
)
487+
assert (
488+
type(step_args["ProcessingResources"]["ClusterConfig"]["InstanceType"]) is ParameterString
489+
)
490+
step_args["ProcessingResources"]["ClusterConfig"]["InstanceType"] = step_args[
491+
"ProcessingResources"
492+
]["ClusterConfig"]["InstanceType"].expr
478493

479494
del step_args["ProcessingInputs"][0]["S3Input"]["S3Uri"]
480495
del step_def["Arguments"]["ProcessingInputs"][0]["S3Input"]["S3Uri"]

0 commit comments

Comments
 (0)