|
13 | 13 | """Utilities to support workflow."""
|
14 | 14 | from __future__ import absolute_import
|
15 | 15 |
|
| 16 | +import inspect |
| 17 | +import logging |
| 18 | +from functools import wraps |
16 | 19 | from pathlib import Path
|
17 |
| -from typing import List, Sequence, Union, Set |
| 20 | +from typing import List, Sequence, Union, Set, TYPE_CHECKING |
18 | 21 | import hashlib
|
19 | 22 | from urllib.parse import unquote, urlparse
|
20 | 23 | from _hashlib import HASH as Hash
|
21 | 24 |
|
| 25 | +from sagemaker.workflow.parameters import Parameter |
22 | 26 | from sagemaker.workflow.pipeline_context import _StepArguments
|
23 |
| -from sagemaker.workflow.step_collections import StepCollection |
24 | 27 | from sagemaker.workflow.entities import (
|
25 | 28 | Entity,
|
26 | 29 | RequestType,
|
27 | 30 | )
|
28 | 31 |
|
| 32 | +if TYPE_CHECKING: |
| 33 | + from sagemaker.workflow.step_collections import StepCollection |
| 34 | + |
29 | 35 | BUF_SIZE = 65536 # 64KiB
|
30 | 36 |
|
31 | 37 |
|
32 |
| -def list_to_request(entities: Sequence[Union[Entity, StepCollection]]) -> List[RequestType]: |
| 38 | +def list_to_request(entities: Sequence[Union[Entity, "StepCollection"]]) -> List[RequestType]: |
33 | 39 | """Get the request structure for list of entities.
|
34 | 40 |
|
35 | 41 | Args:
|
36 | 42 | entities (Sequence[Entity]): A list of entities.
|
37 | 43 | Returns:
|
38 | 44 | list: A request structure for a workflow service call.
|
39 | 45 | """
|
| 46 | + from sagemaker.workflow.step_collections import StepCollection |
| 47 | + |
40 | 48 | request_dicts = []
|
41 | 49 | for entity in entities:
|
42 | 50 | if isinstance(entity, Entity):
|
@@ -151,3 +159,41 @@ def validate_step_args_input(
|
151 | 159 | raise TypeError(error_message)
|
152 | 160 | if step_args.caller_name not in expected_caller:
|
153 | 161 | raise ValueError(error_message)
|
| 162 | + |
| 163 | + |
| 164 | +def override_pipeline_parameter_var(func): |
| 165 | + """A decorator to override pipeline Parameters passed into a function |
| 166 | +
|
| 167 | + This is a temporary decorator to override pipeline Parameter objects with their default value |
| 168 | + and display warning information to instruct users to update their code. |
| 169 | +
|
| 170 | + This decorator can help to give a grace period for users to update their code when |
| 171 | + we make changes to explicitly prevent passing any pipeline variables to a function. |
| 172 | +
|
| 173 | + We should remove this decorator after the grace period. |
| 174 | + """ |
| 175 | + warning_msg_template = ( |
| 176 | + "%s should not be a pipeline variable (%s). " |
| 177 | + "The default_value of this Parameter object will be used to override it. " |
| 178 | + "Please remove this pipeline variable and use python primitives instead." |
| 179 | + ) |
| 180 | + |
| 181 | + @wraps(func) |
| 182 | + def wrapper(*args, **kwargs): |
| 183 | + params = inspect.signature(func).parameters |
| 184 | + args = list(args) |
| 185 | + for i, (arg_name, _) in enumerate(params.items()): |
| 186 | + if i >= len(args): |
| 187 | + break |
| 188 | + if isinstance(args[i], Parameter): |
| 189 | + logging.warning(warning_msg_template, arg_name, type(args[i])) |
| 190 | + args[i] = args[i].default_value |
| 191 | + args = tuple(args) |
| 192 | + |
| 193 | + for arg_name, value in kwargs.items(): |
| 194 | + if isinstance(value, Parameter): |
| 195 | + logging.warning(warning_msg_template, arg_name, type(value)) |
| 196 | + kwargs[arg_name] = value.default_value |
| 197 | + return func(*args, **kwargs) |
| 198 | + |
| 199 | + return wrapper |
0 commit comments