Skip to content

Commit d0e03a6

Browse files
aoguo64Ao Guo
authored and
Namrata Madan
committed
pre-execution command support (aws#905)
Co-authored-by: Ao Guo <[email protected]>
1 parent a065b4d commit d0e03a6

File tree

15 files changed

+408
-48
lines changed

15 files changed

+408
-48
lines changed

src/sagemaker/config/config_schema.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
MODULES = "Modules"
4747
REMOTE_FUNCTION = "RemoteFunction"
4848
DEPENDENCIES = "Dependencies"
49+
PRE_EXECUTION_SCRIPT = "PreExecutionScript"
50+
PRE_EXECUTION_COMMANDS = "PreExecutionCommands"
4951
ENVIRONMENT_VARIABLES = "EnvironmentVariables"
5052
IMAGE_URI = "ImageUri"
5153
INCLUDE_LOCAL_WORKDIR = "IncludeLocalWorkDir"
@@ -233,6 +235,12 @@ def _simple_path(*args: str):
233235
REMOTE_FUNCTION_DEPENDENCIES = _simple_path(
234236
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, DEPENDENCIES
235237
)
238+
REMOTE_FUNCTION_PRE_EXECUTION_COMMANDS = _simple_path(
239+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, PRE_EXECUTION_COMMANDS
240+
)
241+
REMOTE_FUNCTION_PRE_EXECUTION_SCRIPT = _simple_path(
242+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, PRE_EXECUTION_SCRIPT
243+
)
236244
REMOTE_FUNCTION_ENVIRONMENT_VARIABLES = _simple_path(
237245
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENVIRONMENT_VARIABLES
238246
)
@@ -266,9 +274,6 @@ def _simple_path(*args: str):
266274
REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = _simple_path(
267275
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
268276
)
269-
REMOTE_FUNCTION_ENABLE_NETWORK_ISOLATION = _simple_path(
270-
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENABLE_NETWORK_ISOLATION
271-
)
272277

273278
# Paths for reference elsewhere in the SDK.
274279
# Names include the schema version since the paths could change with other schema versions
@@ -440,6 +445,8 @@ def _simple_path(*args: str):
440445
},
441446
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3DataSource.html#sagemaker-Type-S3DataSource-S3Uri
442447
"s3Uri": {TYPE: "string", "pattern": "^(https|s3)://([^/]+)/?(.*)$", "maxLength": 1024},
448+
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint
449+
"preExecutionCommand": {TYPE: "string", "pattern": r".*"},
443450
},
444451
PROPERTIES: {
445452
SCHEMA_VERSION: {
@@ -475,10 +482,14 @@ def _simple_path(*args: str):
475482
ADDITIONAL_PROPERTIES: False,
476483
PROPERTIES: {
477484
DEPENDENCIES: {TYPE: "string"},
485+
PRE_EXECUTION_COMMANDS: {
486+
TYPE: "array",
487+
"items": {"$ref": "#/definitions/preExecutionCommand"},
488+
},
489+
PRE_EXECUTION_SCRIPT: {TYPE: "string"},
478490
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: {
479491
TYPE: "boolean"
480492
},
481-
ENABLE_NETWORK_ISOLATION: {TYPE: "boolean"},
482493
ENVIRONMENT_VARIABLES: {
483494
"$ref": "#/definitions/environmentVariables"
484495
},

src/sagemaker/remote_function/client.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def remote(
5959
_func=None,
6060
*,
6161
dependencies: str = None,
62+
pre_execution_commands: List[str] = None,
63+
pre_execution_script: str = None,
6264
environment_variables: Dict[str, str] = None,
6365
image_uri: str = None,
6466
include_local_workdir: bool = False,
@@ -79,14 +81,19 @@ def remote(
7981
volume_kms_key: str = None,
8082
volume_size: int = 30,
8183
encrypt_inter_container_traffic: bool = None,
82-
enable_network_isolation: bool = None,
8384
):
8485
"""Function that starts a new SageMaker job synchronously with overridden runtime settings.
8586
8687
Args:
8788
_func (Optional): Python function to be executed on the SageMaker job runtime environment.
8889
dependencies (str): Path to dependencies file or a reserved keyword
8990
``auto_capture``. Defaults to None.
91+
pre_execution_commands (List[str]): List of commands to be executed prior to executing
92+
remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script``
93+
can be specified at the same time. Defaults to None.
94+
pre_execution_script (str): Path to script file to be executed prior to executing
95+
remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script``
96+
can be specified at the same time. Defaults to None.
9097
environment_variables (Dict): environment variables
9198
image_uri (str): Docker image URI on ECR.
9299
include_local_workdir (bool): Set to ``True`` if the remote function code imports local
@@ -118,11 +125,6 @@ def remote(
118125
data. Default is 30.
119126
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
120127
containers is encrypted for the training job. (default: ``False``).
121-
enable_network_isolation (bool): Specifies whether container will
122-
run in network isolation mode (default: ``False``). Network
123-
isolation mode restricts the container access to outside networks
124-
(such as the Internet). The container does not make any inbound or
125-
outbound network calls. Also known as Internet-free mode.
126128
"""
127129

128130
def _remote(func):
@@ -133,6 +135,8 @@ def wrapper(*args, **kwargs):
133135

134136
job_settings = _JobSettings(
135137
dependencies=dependencies,
138+
pre_execution_commands=pre_execution_commands,
139+
pre_execution_script=pre_execution_script,
136140
environment_variables=environment_variables,
137141
image_uri=image_uri,
138142
include_local_workdir=include_local_workdir,
@@ -153,7 +157,6 @@ def wrapper(*args, **kwargs):
153157
volume_kms_key=volume_kms_key,
154158
volume_size=volume_size,
155159
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
156-
enable_network_isolation=enable_network_isolation,
157160
)
158161
job = _Job.start(job_settings, func, args, kwargs)
159162

@@ -322,6 +325,8 @@ def __init__(
322325
self,
323326
*,
324327
dependencies: str = None,
328+
pre_execution_commands: List[str] = None,
329+
pre_execution_script: str = None,
325330
environment_variables: Dict[str, str] = None,
326331
image_uri: str = None,
327332
include_local_workdir: bool = False,
@@ -343,13 +348,18 @@ def __init__(
343348
volume_kms_key: str = None,
344349
volume_size: int = 30,
345350
encrypt_inter_container_traffic: bool = None,
346-
enable_network_isolation: bool = None,
347351
):
348352
"""Initiates a ``RemoteExecutor`` instance.
349353
350354
Args:
351355
dependencies (str): Path to dependencies file or a reserved keyword
352356
``auto_capture``. Defaults to None.
357+
pre_execution_commands (List[str]): List of commands to be executed prior to executing
358+
remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script``
359+
can be specified at the same time. Defaults to None.
360+
pre_execution_script (str): Path to script file to be executed prior to executing
361+
remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script``
362+
can be specified at the same time. Defaults to None.
353363
environment_variables (Dict): Environment variables passed to the underlying sagemaker
354364
job. Defaults to None
355365
image_uri (str): Docker image URI on ECR. Defaults to base Python image.
@@ -388,11 +398,6 @@ def __init__(
388398
data. Defaults to 30.
389399
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
390400
containers is encrypted for the training job. (default: ``False``).
391-
enable_network_isolation (bool): Specifies whether container will
392-
run in network isolation mode (default: ``False``). Network
393-
isolation mode restricts the container access to outside networks
394-
(such as the Internet). The container does not make any inbound or
395-
outbound network calls. Also known as Internet-free mode.
396401
"""
397402
self.max_parallel_jobs = max_parallel_jobs
398403

@@ -401,6 +406,8 @@ def __init__(
401406

402407
self.job_settings = _JobSettings(
403408
dependencies=dependencies,
409+
pre_execution_commands=pre_execution_commands,
410+
pre_execution_script=pre_execution_script,
404411
environment_variables=environment_variables,
405412
image_uri=image_uri,
406413
include_local_workdir=include_local_workdir,
@@ -421,7 +428,6 @@ def __init__(
421428
volume_kms_key=volume_kms_key,
422429
volume_size=volume_size,
423430
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
424-
enable_network_isolation=enable_network_isolation,
425431
)
426432

427433
self._state_condition = threading.Condition()

src/sagemaker/remote_function/core/serialization.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,10 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
106106
e, NotImplementedError
107107
) and "Instance of Run type is not allowed to be pickled." in str(e):
108108
raise SerializationError(
109-
"""You are trying to reference to a sagemaker.experiments.run.Run instance from within the function
110-
or passing it as a function argument.
111-
Instantiate a Run in the function or use load_run instead."""
109+
"""You are trying to pass a sagemaker.experiments.run.Run object to a remote function
110+
or are trying to access a global sagemaker.experiments.run.Run object from within the function.
111+
This is not supported. You must use `load_run` to load an existing Run in the remote function
112+
or instantiate a new Run in the function."""
112113
) from e
113114

114115
raise SerializationError(

src/sagemaker/remote_function/job.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
REMOTE_FUNCTION_ENVIRONMENT_VARIABLES,
2626
REMOTE_FUNCTION_IMAGE_URI,
2727
REMOTE_FUNCTION_DEPENDENCIES,
28+
REMOTE_FUNCTION_PRE_EXECUTION_COMMANDS,
29+
REMOTE_FUNCTION_PRE_EXECUTION_SCRIPT,
2830
REMOTE_FUNCTION_INCLUDE_LOCAL_WORKDIR,
2931
REMOTE_FUNCTION_INSTANCE_TYPE,
3032
REMOTE_FUNCTION_JOB_CONDA_ENV,
@@ -36,7 +38,6 @@
3638
REMOTE_FUNCTION_VPC_CONFIG_SUBNETS,
3739
REMOTE_FUNCTION_VPC_CONFIG_SECURITY_GROUP_IDS,
3840
REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION,
39-
REMOTE_FUNCTION_ENABLE_NETWORK_ISOLATION,
4041
)
4142
from sagemaker.experiments._run_context import _RunContext
4243
from sagemaker.experiments.run import Run
@@ -53,6 +54,7 @@
5354
# runtime script names
5455
BOOTSTRAP_SCRIPT_NAME = "bootstrap_runtime_environment.py"
5556
ENTRYPOINT_SCRIPT_NAME = "job_driver.sh"
57+
PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh"
5658
RUNTIME_MANAGER_SCRIPT_NAME = "runtime_environment_manager.py"
5759

5860
# training channel names
@@ -121,6 +123,8 @@ def __init__(
121123
self,
122124
*,
123125
dependencies: str = None,
126+
pre_execution_commands: List[str] = None,
127+
pre_execution_script: str = None,
124128
environment_variables: Dict[str, str] = None,
125129
image_uri: str = None,
126130
include_local_workdir: bool = None,
@@ -141,7 +145,6 @@ def __init__(
141145
volume_kms_key: str = None,
142146
volume_size: int = 30,
143147
encrypt_inter_container_traffic: bool = None,
144-
enable_network_isolation: bool = None,
145148
):
146149

147150
self.sagemaker_session = sagemaker_session or Session()
@@ -171,6 +174,24 @@ def __init__(
171174
config_path=REMOTE_FUNCTION_DEPENDENCIES,
172175
sagemaker_session=self.sagemaker_session,
173176
)
177+
178+
self.pre_execution_commands = resolve_value_from_config(
179+
direct_input=pre_execution_commands,
180+
config_path=REMOTE_FUNCTION_PRE_EXECUTION_COMMANDS,
181+
sagemaker_session=self.sagemaker_session,
182+
)
183+
184+
self.pre_execution_script = resolve_value_from_config(
185+
direct_input=pre_execution_script,
186+
config_path=REMOTE_FUNCTION_PRE_EXECUTION_SCRIPT,
187+
sagemaker_session=self.sagemaker_session,
188+
)
189+
190+
if self.pre_execution_commands is not None and self.pre_execution_script is not None:
191+
raise ValueError(
192+
"Only one of pre_execution_commands or pre_execution_script can be specified!"
193+
)
194+
174195
self.include_local_workdir = resolve_value_from_config(
175196
direct_input=include_local_workdir,
176197
config_path=REMOTE_FUNCTION_INCLUDE_LOCAL_WORKDIR,
@@ -202,12 +223,7 @@ def __init__(
202223
default_value=False,
203224
sagemaker_session=self.sagemaker_session,
204225
)
205-
self.enable_network_isolation = resolve_value_from_config(
206-
direct_input=enable_network_isolation,
207-
config_path=REMOTE_FUNCTION_ENABLE_NETWORK_ISOLATION,
208-
default_value=False,
209-
sagemaker_session=self.sagemaker_session,
210-
)
226+
self.enable_network_isolation = False
211227

212228
_role = resolve_value_from_config(
213229
direct_input=role,
@@ -326,6 +342,8 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
326342
user_dependencies_s3uri = _prepare_and_upload_dependencies(
327343
local_dependencies_path=dependencies_list_path,
328344
include_local_workdir=job_settings.include_local_workdir,
345+
pre_execution_commands=job_settings.pre_execution_commands,
346+
pre_execution_script_local_path=job_settings.pre_execution_script,
329347
s3_base_uri=s3_base_uri,
330348
s3_kms_key=job_settings.s3_kms_key,
331349
sagemaker_session=job_settings.sagemaker_session,
@@ -540,13 +558,20 @@ def _prepare_and_upload_runtime_scripts(
540558
def _prepare_and_upload_dependencies(
541559
local_dependencies_path: str,
542560
include_local_workdir: bool,
561+
pre_execution_commands: List[str],
562+
pre_execution_script_local_path: str,
543563
s3_base_uri: str,
544564
s3_kms_key: str,
545565
sagemaker_session: Session,
546566
) -> str:
547567
"""Upload the job dependencies to S3 if present"""
548568

549-
if not local_dependencies_path and not include_local_workdir:
569+
if not (
570+
local_dependencies_path
571+
or include_local_workdir
572+
or pre_execution_commands
573+
or pre_execution_script_local_path
574+
):
550575
return None
551576

552577
with _tmpdir() as tmp_workspace:
@@ -570,6 +595,25 @@ def _prepare_and_upload_dependencies(
570595
"Copied dependencies file at '%s' to '%s'", local_dependencies_path, dst_path
571596
)
572597

598+
if pre_execution_commands or pre_execution_script_local_path:
599+
if not os.path.isdir(tmp_workspace):
600+
os.mkdir(tmp_workspace)
601+
pre_execution_script = os.path.join(tmp_workspace, PRE_EXECUTION_SCRIPT_NAME)
602+
if pre_execution_commands:
603+
with open(pre_execution_script, "w") as target_script:
604+
commands = [cmd + "\n" for cmd in pre_execution_commands]
605+
target_script.writelines(commands)
606+
logger.info(
607+
"Generated pre-execution script from commands to '%s'", pre_execution_script
608+
)
609+
else:
610+
shutil.copy(pre_execution_script_local_path, pre_execution_script)
611+
logger.info(
612+
"Copied pre-execution commands from script at '%s' to '%s'",
613+
pre_execution_script_local_path,
614+
pre_execution_script,
615+
)
616+
573617
workspace_archive_path = os.path.join(tmp_workspace, "workspace")
574618
workspace_archive_path = shutil.make_archive(workspace_archive_path, "zip", tmp_workspace)
575619
logger.info("Successfully created workdir archive at '%s'", workspace_archive_path)

src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
FAILURE_REASON_PATH = "/opt/ml/output/failure"
3636
SAGEMAKER_WHL_FILE_NAME = "sagemaker-2.132.1.dev0-py2.py3-none-any.whl"
3737
SAGEMAKER_WHL_CHANNEL = "sagemaker_whl_file"
38+
PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh"
3839

3940

4041
logger = get_logger()
@@ -54,8 +55,6 @@ def main():
5455

5556
RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env)
5657

57-
_execute_pre_exec_cmds()
58-
5958
_bootstrap_runtime_environment(client_python_version, conda_env)
6059

6160
exit_code = SUCCESS_EXIT_CODE
@@ -67,12 +66,6 @@ def main():
6766
sys.exit(exit_code)
6867

6968

70-
def _execute_pre_exec_cmds():
71-
"""Execute pre-flight commands before invkoing remote function"""
72-
# TODO: complete me
73-
pass # pylint: disable=W0107
74-
75-
7669
def _bootstrap_runtime_environment(
7770
client_python_version: str,
7871
conda_env: str = None,
@@ -104,6 +97,10 @@ def _bootstrap_runtime_environment(
10497
shutil.unpack_archive(filename=workspace_archive_path, extract_dir=workspace_unpack_dir)
10598
logger.info("Successfully unpacked workspace archive at '%s'.", workspace_unpack_dir)
10699

100+
# Handle pre-execution commands
101+
path_to_pre_exec_script = os.path.join(workspace_unpack_dir, PRE_EXECUTION_SCRIPT_NAME)
102+
RuntimeEnvironmentManager().run_pre_exec_script(pre_exec_script_path=path_to_pre_exec_script)
103+
107104
# Handle dependencies file.
108105
dependencies_file = None
109106
for file in os.listdir(workspace_unpack_dir):

0 commit comments

Comments
 (0)