Skip to content

Commit 6a4fd6a

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
authored
fix: Support parameterized source code input for TrainingStep (#3202)
Co-authored-by: Dewen Qi <[email protected]>
1 parent 920d660 commit 6a4fd6a

File tree

22 files changed

+605
-162
lines changed

22 files changed

+605
-162
lines changed

src/sagemaker/amazon/hyperparameter.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import json
1717

18+
from sagemaker.workflow import is_pipeline_variable
19+
1820

1921
class Hyperparameter(object):
2022
"""An algorithm hyperparameter with optional validation.
@@ -98,8 +100,14 @@ def serialize_all(obj):
98100
"""
99101
if "_hyperparameters" not in dir(obj):
100102
return {}
101-
return {
102-
k: json.dumps(v) if isinstance(v, list) else str(v)
103-
for k, v in obj._hyperparameters.items()
104-
if v is not None
105-
}
103+
hps = {}
104+
for k, v in obj._hyperparameters.items():
105+
if v is not None:
106+
if isinstance(v, list):
107+
v = json.dumps(v)
108+
elif is_pipeline_variable(v):
109+
v = v.to_string()
110+
else:
111+
v = str(v)
112+
hps[k] = v
113+
return hps

src/sagemaker/chainer/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17+
from typing import Union, Optional
1718

1819
from sagemaker.estimator import Framework, EstimatorBase
1920
from sagemaker.fw_utils import (
@@ -25,6 +26,7 @@
2526
from sagemaker.chainer import defaults
2627
from sagemaker.chainer.model import ChainerModel
2728
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
29+
from sagemaker.workflow.entities import PipelineVariable
2830

2931
logger = logging.getLogger("sagemaker")
3032

@@ -42,12 +44,12 @@ class Chainer(Framework):
4244

4345
def __init__(
4446
self,
45-
entry_point,
47+
entry_point: Union[str, PipelineVariable],
4648
use_mpi=None,
4749
num_processes=None,
4850
process_slots_per_host=None,
4951
additional_mpi_options=None,
50-
source_dir=None,
52+
source_dir: Optional[Union[str, PipelineVariable]] = None,
5153
hyperparameters=None,
5254
framework_version=None,
5355
py_version=None,

src/sagemaker/estimator.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
get_mp_parameters,
4848
tar_and_upload_dir,
4949
validate_source_dir,
50+
validate_source_code_input_against_pipeline_variables,
5051
)
5152
from sagemaker.inputs import TrainingInput, FileSystemInput
5253
from sagemaker.job import _Job
@@ -140,12 +141,12 @@ def __init__(
140141
disable_profiler: bool = False,
141142
environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
142143
max_retry_attempts: Optional[Union[int, PipelineVariable]] = None,
143-
source_dir: Optional[str] = None,
144+
source_dir: Optional[Union[str, PipelineVariable]] = None,
144145
git_config: Optional[Dict[str, str]] = None,
145146
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
146147
container_log_level: Union[int, PipelineVariable] = logging.INFO,
147148
code_location: Optional[str] = None,
148-
entry_point: Optional[str] = None,
149+
entry_point: Optional[Union[str, PipelineVariable]] = None,
149150
dependencies: Optional[List[Union[str]]] = None,
150151
instance_groups: Optional[Dict[str, Union[str, int]]] = None,
151152
**kwargs,
@@ -461,6 +462,13 @@ def __init__(
461462
"train_volume_kms_key", "volume_kms_key", volume_kms_key, kwargs
462463
)
463464

465+
validate_source_code_input_against_pipeline_variables(
466+
entry_point=entry_point,
467+
source_dir=source_dir,
468+
git_config=git_config,
469+
enable_network_isolation=enable_network_isolation,
470+
)
471+
464472
self.role = role
465473
self.instance_count = instance_count
466474
self.instance_type = instance_type
@@ -663,7 +671,11 @@ def _prepare_for_training(self, job_name=None):
663671
# validate source dir will raise a ValueError if there is something wrong with
664672
# the source directory. We are intentionally not handling it because this is a
665673
# critical error.
666-
if self.source_dir and not self.source_dir.lower().startswith("s3://"):
674+
if (
675+
self.source_dir
676+
and not is_pipeline_variable(self.source_dir)
677+
and not self.source_dir.lower().startswith("s3://")
678+
):
667679
validate_source_dir(self.entry_point, self.source_dir)
668680

669681
# if we are in local mode with local_code=True. We want the container to just
@@ -2151,11 +2163,11 @@ def __init__(
21512163
disable_profiler: bool = False,
21522164
environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
21532165
max_retry_attempts: Optional[Union[int, PipelineVariable]] = None,
2154-
source_dir: Optional[str] = None,
2166+
source_dir: Optional[Union[str, PipelineVariable]] = None,
21552167
git_config: Optional[Dict[str, str]] = None,
21562168
container_log_level: Union[int, PipelineVariable] = logging.INFO,
21572169
code_location: Optional[str] = None,
2158-
entry_point: Optional[str] = None,
2170+
entry_point: Optional[Union[str, PipelineVariable]] = None,
21592171
dependencies: Optional[List[str]] = None,
21602172
instance_groups: Optional[Dict[str, Union[str, int]]] = None,
21612173
**kwargs,
@@ -2603,8 +2615,8 @@ class Framework(EstimatorBase):
26032615

26042616
def __init__(
26052617
self,
2606-
entry_point: str,
2607-
source_dir: Optional[str] = None,
2618+
entry_point: Union[str, PipelineVariable],
2619+
source_dir: Optional[Union[str, PipelineVariable]] = None,
26082620
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
26092621
container_log_level: Union[int, PipelineVariable] = logging.INFO,
26102622
code_location: Optional[str] = None,
@@ -2783,7 +2795,14 @@ def __init__(
27832795
"""
27842796
super(Framework, self).__init__(enable_network_isolation=enable_network_isolation, **kwargs)
27852797
image_uri = renamed_kwargs("image_name", "image_uri", image_uri, kwargs)
2786-
if entry_point.startswith("s3://"):
2798+
2799+
validate_source_code_input_against_pipeline_variables(
2800+
entry_point=entry_point,
2801+
source_dir=source_dir,
2802+
git_config=git_config,
2803+
enable_network_isolation=enable_network_isolation,
2804+
)
2805+
if not is_pipeline_variable(entry_point) and entry_point.startswith("s3://"):
27872806
raise ValueError(
27882807
"Invalid entry point script: {}. Must be a path to a local file.".format(
27892808
entry_point

src/sagemaker/fw_utils.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
import shutil
2121
import tempfile
2222
from collections import namedtuple
23-
from typing import Optional
23+
from typing import Optional, Union, Dict
2424

2525
import sagemaker.image_uris
2626
from sagemaker.session_settings import SessionSettings
2727
import sagemaker.utils
2828
from sagemaker.workflow import is_pipeline_variable
2929

3030
from sagemaker.deprecations import renamed_warning, renamed_kwargs
31+
from sagemaker.workflow.entities import PipelineVariable
3132

3233
logger = logging.getLogger(__name__)
3334

@@ -124,6 +125,58 @@ def validate_source_dir(script, directory):
124125
return True
125126

126127

128+
def validate_source_code_input_against_pipeline_variables(
129+
entry_point: Optional[Union[str, PipelineVariable]] = None,
130+
source_dir: Optional[Union[str, PipelineVariable]] = None,
131+
git_config: Optional[Dict[str, str]] = None,
132+
enable_network_isolation: Union[bool, PipelineVariable] = False,
133+
):
134+
"""Validate source code input against pipeline variables
135+
136+
Args:
137+
entry_point (str, PipelineVariable): The path to the local Python source file that
138+
should be executed as the entry point to training (default: None).
139+
source_dir (str, PipelineVariable): The Path to a directory with any other
140+
training source code dependencies aside from the entry point file (default: None).
141+
git_config (Dict[str, str]): Git configurations used for cloning files (default: None).
142+
enable_network_isolation (bool, PipelineVariable): Specifies whether container will run
143+
in network isolation mode (default: False).
144+
"""
145+
if is_pipeline_variable(enable_network_isolation) or enable_network_isolation is True:
146+
if is_pipeline_variable(entry_point) or is_pipeline_variable(source_dir):
147+
raise TypeError(
148+
"entry_point, source_dir should not be pipeline variables "
149+
"when enable_network_isolation is a pipeline variable or it is set to True."
150+
)
151+
if git_config:
152+
if is_pipeline_variable(entry_point) or is_pipeline_variable(source_dir):
153+
raise TypeError(
154+
"entry_point, source_dir should not be pipeline variables when git_config is given."
155+
)
156+
if is_pipeline_variable(entry_point):
157+
if not source_dir:
158+
raise TypeError(
159+
"The entry_point should not be a pipeline variable when source_dir is missing."
160+
)
161+
if not is_pipeline_variable(source_dir) and not source_dir.lower().startswith("s3://"):
162+
raise TypeError(
163+
"The entry_point should not be a pipeline variable when source_dir is a local path."
164+
)
165+
logger.warning(
166+
"The entry_point is a pipeline variable: %s. During pipeline execution, "
167+
"the interpreted value of entry_point has to be a local path in the container "
168+
"pointing to a Python source file which is located at the root of source_dir.",
169+
type(entry_point),
170+
)
171+
if is_pipeline_variable(source_dir):
172+
logger.warning(
173+
"The source_dir is a pipeline variable: %s. During pipeline execution, "
174+
"the interpreted value of source_dir has to be an S3 URI and "
175+
"must point to a tar.gz file",
176+
type(source_dir),
177+
)
178+
179+
127180
def get_mp_parameters(distribution):
128181
"""Get the model parallelism parameters provided by the user.
129182
@@ -265,7 +318,7 @@ def tar_and_upload_dir(
265318
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and
266319
script name.
267320
"""
268-
if directory and directory.lower().startswith("s3://"):
321+
if directory and (is_pipeline_variable(directory) or directory.lower().startswith("s3://")):
269322
return UploadedCode(s3_prefix=directory, script_name=script)
270323

271324
script_name = script if directory else os.path.basename(script)

src/sagemaker/huggingface/estimator.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717
import re
18+
from typing import Optional, Union, Dict
1819

1920
from sagemaker.deprecations import renamed_kwargs
2021
from sagemaker.estimator import Framework, EstimatorBase
@@ -27,6 +28,7 @@
2728
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2829

2930
from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig
31+
from sagemaker.workflow.entities import PipelineVariable
3032

3133
logger = logging.getLogger("sagemaker")
3234

@@ -38,16 +40,16 @@ class HuggingFace(Framework):
3840

3941
def __init__(
4042
self,
41-
py_version,
42-
entry_point,
43-
transformers_version=None,
44-
tensorflow_version=None,
45-
pytorch_version=None,
46-
source_dir=None,
47-
hyperparameters=None,
48-
image_uri=None,
49-
distribution=None,
50-
compiler_config=None,
43+
py_version: str,
44+
entry_point: Union[str, PipelineVariable],
45+
transformers_version: Optional[str] = None,
46+
tensorflow_version: Optional[str] = None,
47+
pytorch_version: Optional[str] = None,
48+
source_dir: Optional[Union[str, PipelineVariable]] = None,
49+
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
50+
image_uri: Optional[Union[str, PipelineVariable]] = None,
51+
distribution: Optional[Dict] = None,
52+
compiler_config: Optional[TrainingCompilerConfig] = None,
5153
**kwargs,
5254
):
5355
"""This estimator runs a Hugging Face training script in a SageMaker training environment.

src/sagemaker/job.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from sagemaker.inputs import FileSystemInput, TrainingInput
2020
from sagemaker.local import file_input
21+
from sagemaker.workflow import is_pipeline_variable
2122

2223

2324
class _Job(object):
@@ -168,14 +169,14 @@ def _format_string_uri_input(
168169
target_attribute_name=None,
169170
):
170171
"""Placeholder docstring"""
172+
s3_input_result = TrainingInput(
173+
uri_input,
174+
content_type=content_type,
175+
input_mode=input_mode,
176+
compression=compression,
177+
target_attribute_name=target_attribute_name,
178+
)
171179
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"):
172-
s3_input_result = TrainingInput(
173-
uri_input,
174-
content_type=content_type,
175-
input_mode=input_mode,
176-
compression=compression,
177-
target_attribute_name=target_attribute_name,
178-
)
179180
return s3_input_result
180181
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("file://"):
181182
return file_input(uri_input)
@@ -185,16 +186,11 @@ def _format_string_uri_input(
185186
'"file://"'.format(uri_input)
186187
)
187188
if isinstance(uri_input, str):
188-
s3_input_result = TrainingInput(
189-
uri_input,
190-
content_type=content_type,
191-
input_mode=input_mode,
192-
compression=compression,
193-
target_attribute_name=target_attribute_name,
194-
)
195189
return s3_input_result
196190
if isinstance(uri_input, (TrainingInput, file_input, FileSystemInput)):
197191
return uri_input
192+
if is_pipeline_variable(uri_input):
193+
return s3_input_result
198194

199195
raise ValueError(
200196
"Cannot format input {}. Expecting one of str, TrainingInput, file_input or "

src/sagemaker/jumpstart/utils.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
JumpStartModelSpecs,
3131
JumpStartVersionedModelId,
3232
)
33-
33+
from sagemaker.workflow import is_pipeline_variable
3434

3535
LOGGER = logging.getLogger(__name__)
3636

@@ -271,26 +271,41 @@ def add_jumpstart_tags(
271271
training_script_uri (Optional[str]): S3 URI for training script tarball.
272272
(Default: None).
273273
"""
274-
274+
warn_msg = (
275+
"The URI (%s) is a pipeline variable which is only interpreted at execution time. "
276+
"As a result, the JumpStart resources will not be tagged."
277+
)
275278
if inference_model_uri:
276-
tags = add_single_jumpstart_tag(
277-
inference_model_uri, enums.JumpStartTag.INFERENCE_MODEL_URI, tags
278-
)
279+
if is_pipeline_variable(inference_model_uri):
280+
logging.warning(warn_msg, "inference_model_uri")
281+
else:
282+
tags = add_single_jumpstart_tag(
283+
inference_model_uri, enums.JumpStartTag.INFERENCE_MODEL_URI, tags
284+
)
279285

280286
if inference_script_uri:
281-
tags = add_single_jumpstart_tag(
282-
inference_script_uri, enums.JumpStartTag.INFERENCE_SCRIPT_URI, tags
283-
)
287+
if is_pipeline_variable(inference_script_uri):
288+
logging.warning(warn_msg, "inference_script_uri")
289+
else:
290+
tags = add_single_jumpstart_tag(
291+
inference_script_uri, enums.JumpStartTag.INFERENCE_SCRIPT_URI, tags
292+
)
284293

285294
if training_model_uri:
286-
tags = add_single_jumpstart_tag(
287-
training_model_uri, enums.JumpStartTag.TRAINING_MODEL_URI, tags
288-
)
295+
if is_pipeline_variable(training_model_uri):
296+
logging.warning(warn_msg, "training_model_uri")
297+
else:
298+
tags = add_single_jumpstart_tag(
299+
training_model_uri, enums.JumpStartTag.TRAINING_MODEL_URI, tags
300+
)
289301

290302
if training_script_uri:
291-
tags = add_single_jumpstart_tag(
292-
training_script_uri, enums.JumpStartTag.TRAINING_SCRIPT_URI, tags
293-
)
303+
if is_pipeline_variable(training_script_uri):
304+
logging.warning(warn_msg, "training_script_uri")
305+
else:
306+
tags = add_single_jumpstart_tag(
307+
training_script_uri, enums.JumpStartTag.TRAINING_SCRIPT_URI, tags
308+
)
294309

295310
return tags
296311

0 commit comments

Comments
 (0)