Skip to content

fix: HuggingFaceProcessor parameterized instance_type when image_uri is absent #4072

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def _format_tag(tag_prefix, processor, py_version, container_version, inference_
return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x)


@override_pipeline_parameter_var
def get_training_image_uri(
region,
framework,
Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/workflow/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,10 @@ def override_pipeline_parameter_var(func):
We should remove this decorator after the grace period.
"""
warning_msg_template = (
"The input argument %s of function (%s) is a pipeline variable (%s), which is not allowed. "
"The default_value of this Parameter object will be used to override it. "
"The input argument %s of function (%s) is a pipeline variable (%s), "
"which is interpreted in pipeline execution time only. "
"As the function needs to evaluate the argument value in SDK compile time, "
"the default_value of this Parameter object will be used to override it. "
"Please make sure the default_value is valid."
)

Expand Down
17 changes: 16 additions & 1 deletion tests/unit/sagemaker/workflow/test_processing_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from tests.unit import DATA_DIR
from tests.unit.sagemaker.workflow.conftest import ROLE, BUCKET, IMAGE_URI, INSTANCE_TYPE

HF_INSTANCE_TYPE = "ml.p3.2xlarge"
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
LOCAL_SCRIPT_PATH = os.path.join(DATA_DIR, "workflow/abalone/preprocessing.py")
SPARK_APP_JAR_PATH = os.path.join(
Expand Down Expand Up @@ -142,7 +143,7 @@
pytorch_version="1.7",
role=ROLE,
instance_count=1,
instance_type="ml.p3.2xlarge",
instance_type=HF_INSTANCE_TYPE,
),
{"code": DUMMY_S3_SCRIPT_PATH},
),
Expand Down Expand Up @@ -446,8 +447,15 @@ def test_processing_step_with_framework_processor(
):

processor, run_inputs = framework_processor
default_instance_type = (
HF_INSTANCE_TYPE if type(processor) is HuggingFaceProcessor else INSTANCE_TYPE
)
instance_type_param = ParameterString(
name="ProcessingInstanceType", default_value=default_instance_type
)
processor.sagemaker_session = pipeline_session
processor.role = ROLE
processor.instance_type = instance_type_param

processor.volume_kms_key = "volume-kms-key"
processor.network_config = network_config
Expand All @@ -465,6 +473,7 @@ def test_processing_step_with_framework_processor(
name="MyPipeline",
steps=[step],
sagemaker_session=pipeline_session,
parameters=[instance_type_param],
)

step_args = get_step_args_helper(step_args, "Processing")
Expand All @@ -475,6 +484,12 @@ def test_processing_step_with_framework_processor(
step_args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
== processing_output.destination
)
assert (
type(step_args["ProcessingResources"]["ClusterConfig"]["InstanceType"]) is ParameterString
)
step_args["ProcessingResources"]["ClusterConfig"]["InstanceType"] = step_args[
"ProcessingResources"
]["ClusterConfig"]["InstanceType"].expr

del step_args["ProcessingInputs"][0]["S3Input"]["S3Uri"]
del step_def["Arguments"]["ProcessingInputs"][0]["S3Input"]["S3Uri"]
Expand Down