Skip to content

Commit 41df39c

Browse files
qidewenwhenDewen Qi
authored andcommitted
change: Add override_pipeline_parameter_var decorator to give grace period to update invalid pipeline var args (aws#3180)
Co-authored-by: Dewen Qi <[email protected]>
1 parent 8807eb7 commit 41df39c

File tree

4 files changed

+101
-15
lines changed

4 files changed

+101
-15
lines changed

src/sagemaker/image_uris.py

+3
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,16 @@
2424
from sagemaker.spark import defaults
2525
from sagemaker.jumpstart import artifacts
2626
from sagemaker.workflow import is_pipeline_variable
27+
from sagemaker.workflow.utilities import override_pipeline_parameter_var
2728

2829
logger = logging.getLogger(__name__)
2930

3031
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
3132
HUGGING_FACE_FRAMEWORK = "huggingface"
3233

3334

35+
# TODO: we should remove this decorator later
36+
@override_pipeline_parameter_var
3437
def retrieve(
3538
framework,
3639
region,

src/sagemaker/processing.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1195,7 +1195,9 @@ def __init__(
11951195
source (str): The source for the output.
11961196
destination (str): The destination of the output. If a destination
11971197
is not provided, one will be generated:
1198-
"s3://<default-bucket-name>/<job-name>/output/<output-name>".
1198+
"s3://<default-bucket-name>/<job-name>/output/<output-name>"
1199+
(Note: this does not apply when used with
1200+
:class:`~sagemaker.workflow.steps.ProcessingStep`).
11991201
output_name (str): The name of the output. If a name
12001202
is not provided, one will be generated (eg. "output-1").
12011203
s3_upload_mode (str): Valid options are "EndOfJob" or "Continuous".

src/sagemaker/workflow/utilities.py

+49-3
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,38 @@
1313
"""Utilities to support workflow."""
1414
from __future__ import absolute_import
1515

16+
import inspect
17+
import logging
18+
from functools import wraps
1619
from pathlib import Path
17-
from typing import List, Sequence, Union, Set
20+
from typing import List, Sequence, Union, Set, TYPE_CHECKING
1821
import hashlib
1922
from urllib.parse import unquote, urlparse
2023
from _hashlib import HASH as Hash
2124

25+
from sagemaker.workflow.parameters import Parameter
2226
from sagemaker.workflow.pipeline_context import _StepArguments
23-
from sagemaker.workflow.step_collections import StepCollection
2427
from sagemaker.workflow.entities import (
2528
Entity,
2629
RequestType,
2730
)
2831

32+
if TYPE_CHECKING:
33+
from sagemaker.workflow.step_collections import StepCollection
34+
2935
BUF_SIZE = 65536 # 64KiB
3036

3137

32-
def list_to_request(entities: Sequence[Union[Entity, StepCollection]]) -> List[RequestType]:
38+
def list_to_request(entities: Sequence[Union[Entity, "StepCollection"]]) -> List[RequestType]:
3339
"""Get the request structure for list of entities.
3440
3541
Args:
3642
entities (Sequence[Entity]): A list of entities.
3743
Returns:
3844
list: A request structure for a workflow service call.
3945
"""
46+
from sagemaker.workflow.step_collections import StepCollection
47+
4048
request_dicts = []
4149
for entity in entities:
4250
if isinstance(entity, Entity):
@@ -151,3 +159,41 @@ def validate_step_args_input(
151159
raise TypeError(error_message)
152160
if step_args.caller_name not in expected_caller:
153161
raise ValueError(error_message)
162+
163+
164+
def override_pipeline_parameter_var(func):
165+
"""A decorator to override pipeline Parameters passed into a function
166+
167+
This is a temporary decorator to override pipeline Parameter objects with their default value
168+
and display warning information to instruct users to update their code.
169+
170+
This decorator can help to give a grace period for users to update their code when
171+
we make changes to explicitly prevent passing any pipeline variables to a function.
172+
173+
We should remove this decorator after the grace period.
174+
"""
175+
warning_msg_template = (
176+
"%s should not be a pipeline variable (%s). "
177+
"The default_value of this Parameter object will be used to override it. "
178+
"Please remove this pipeline variable and use python primitives instead."
179+
)
180+
181+
@wraps(func)
182+
def wrapper(*args, **kwargs):
183+
params = inspect.signature(func).parameters
184+
args = list(args)
185+
for i, (arg_name, _) in enumerate(params.items()):
186+
if i >= len(args):
187+
break
188+
if isinstance(args[i], Parameter):
189+
logging.warning(warning_msg_template, arg_name, type(args[i]))
190+
args[i] = args[i].default_value
191+
args = tuple(args)
192+
193+
for arg_name, value in kwargs.items():
194+
if isinstance(value, Parameter):
195+
logging.warning(warning_msg_template, arg_name, type(value))
196+
kwargs[arg_name] = value.default_value
197+
return func(*args, **kwargs)
198+
199+
return wrapper

tests/unit/sagemaker/image_uris/test_retrieve.py

+46-11
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from mock import patch
2020

2121
from sagemaker import image_uris
22+
from sagemaker.workflow.functions import Join
2223
from sagemaker.workflow.parameters import ParameterString
2324

2425
BASE_CONFIG = {
@@ -721,16 +722,50 @@ def test_retrieve_huggingface(config_for_framework):
721722

722723

723724
def test_retrieve_with_pipeline_variable():
725+
kwargs = dict(
726+
framework="tensorflow",
727+
version="1.15",
728+
py_version="py3",
729+
instance_type="ml.m5.xlarge",
730+
region="us-east-1",
731+
image_scope="training",
732+
)
733+
# instance_type is plain string which should not break anything
734+
image_uris.retrieve(**kwargs)
735+
736+
# instance_type is parameter string with not None default value
737+
# which should not break anything
738+
kwargs["instance_type"] = ParameterString(
739+
name="TrainingInstanceType",
740+
default_value="ml.m5.xlarge",
741+
)
742+
image_uris.retrieve(**kwargs)
743+
744+
# instance_type is parameter string without default value
745+
# (equivalent to pass in None to instance_type field)
746+
# which should fail due to empty instance type check
747+
kwargs["instance_type"] = ParameterString(name="TrainingInstanceType")
724748
with pytest.raises(Exception) as error:
725-
image_uris.retrieve(
726-
framework="tensorflow",
727-
version="1.15",
728-
py_version="py3",
729-
instance_type=ParameterString(
730-
name="TrainingInstanceType",
731-
default_value="ml.m5.xlarge",
732-
),
733-
region="us-east-1",
734-
image_scope="training",
735-
)
749+
image_uris.retrieve(**kwargs)
750+
assert "Empty SageMaker instance type" in str(error.value)
751+
752+
# instance_type is other types of pipeline variable
753+
# which should break loudly
754+
kwargs["instance_type"] = Join(on="", values=["a", "b"])
755+
with pytest.raises(Exception) as error:
756+
image_uris.retrieve(**kwargs)
736757
assert "instance_type should not be a pipeline variable" in str(error.value)
758+
759+
# instance_type (ParameterString) is given as args rather than kwargs
760+
# which should not break anything
761+
image_uris.retrieve(
762+
"tensorflow",
763+
"us-east-1",
764+
"1.15",
765+
"py3",
766+
ParameterString(
767+
name="TrainingInstanceType",
768+
default_value="ml.m5.xlarge",
769+
),
770+
image_scope="training",
771+
)

0 commit comments

Comments
 (0)