diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py index 15051dc04a..76a8443fba 100644 --- a/src/sagemaker/remote_function/client.py +++ b/src/sagemaker/remote_function/client.py @@ -90,7 +90,8 @@ def remote( spark_config: SparkConfig = None, use_spot_instances=False, max_wait_time_in_seconds=None, - use_torchrun=False, + use_torchrun: bool = False, + use_mpirun: bool = False, nproc_per_node: Optional[int] = None, ): """Decorator for running the annotated function as a SageMaker training job. @@ -207,7 +208,8 @@ def remote( files are accepted and uploaded to S3. instance_count (int): The number of instances to use. Defaults to 1. - NOTE: Remote function does not support instance_count > 1 for non Spark jobs. + NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and + mpirun utilities instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown. @@ -284,6 +286,9 @@ def remote( use_torchrun (bool): Specifies whether to use torchrun for distributed training. Defaults to ``False``. + use_mpirun (bool): Specifies whether to use mpirun for distributed training. + Defaults to ``False``. + nproc_per_node (Optional int): Specifies the number of processes per node for distributed training. Defaults to ``None``. This is defined automatically configured on the instance type. @@ -320,6 +325,7 @@ def _remote(func): use_spot_instances=use_spot_instances, max_wait_time_in_seconds=max_wait_time_in_seconds, use_torchrun=use_torchrun, + use_mpirun=use_mpirun, nproc_per_node=nproc_per_node, ) @@ -327,12 +333,13 @@ def _remote(func): def wrapper(*args, **kwargs): if instance_count > 1 and not ( - (spark_config is not None and not use_torchrun) - or (spark_config is None and use_torchrun) + (spark_config is not None and not use_torchrun and not use_mpirun) + or (spark_config is None and use_torchrun and not use_mpirun) + or (spark_config is None and not use_torchrun and use_mpirun) ): raise ValueError( "Remote function do not support training on multi instances " - + "without spark_config or use_torchrun. " + + "without spark_config or use_torchrun or use_mpirun. " + "Please provide instance_count = 1" ) @@ -536,7 +543,8 @@ def __init__( spark_config: SparkConfig = None, use_spot_instances=False, max_wait_time_in_seconds=None, - use_torchrun=False, + use_torchrun: bool = False, + use_mpirun: bool = False, nproc_per_node: Optional[int] = None, ): """Constructor for RemoteExecutor @@ -650,7 +658,8 @@ def __init__( files are accepted and uploaded to S3. instance_count (int): The number of instances to use. Defaults to 1. - NOTE: Remote function does not support instance_count > 1 for non Spark jobs. + NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and + mpirun utilities instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown. @@ -730,6 +739,9 @@ def __init__( use_torchrun (bool): Specifies whether to use torchrun for distributed training. Defaults to ``False``. + use_mpirun (bool): Specifies whether to use mpirun for distributed training. + Defaults to ``False``. + nproc_per_node (Optional int): Specifies the number of processes per node for distributed training. Defaults to ``None``. This is defined automatically configured on the instance type. @@ -740,12 +752,13 @@ def __init__( raise ValueError("max_parallel_jobs must be greater than 0.") if instance_count > 1 and not ( - (spark_config is not None and not use_torchrun) - or (spark_config is None and use_torchrun) + (spark_config is not None and not use_torchrun and not use_mpirun) + or (spark_config is None and use_torchrun and not use_mpirun) + or (spark_config is None and not use_torchrun and use_mpirun) ): raise ValueError( "Remote function do not support training on multi instances " - + "without spark_config or use_torchrun. " + + "without spark_config or use_torchrun or use_mpirun. " + "Please provide instance_count = 1" ) @@ -778,6 +791,7 @@ def __init__( use_spot_instances=use_spot_instances, max_wait_time_in_seconds=max_wait_time_in_seconds, use_torchrun=use_torchrun, + use_mpirun=use_mpirun, nproc_per_node=nproc_per_node, ) diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 4e2e749bcb..f6c3a58ad6 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -81,6 +81,7 @@ # runtime script names BOOTSTRAP_SCRIPT_NAME = "bootstrap_runtime_environment.py" +MPI_UTILS_SCRIPT_NAME = "mpi_utils_remote.py" ENTRYPOINT_SCRIPT_NAME = "job_driver.sh" PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh" RUNTIME_MANAGER_SCRIPT_NAME = "runtime_environment_manager.py" @@ -167,6 +168,99 @@ fi """ +ENTRYPOINT_MPIRUN_SCRIPT = f""" +#!/bin/bash + +# Entry point for bootstrapping runtime environment and invoking remote function with mpirun + +set -eu + +PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}} +export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs +printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n" +export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip +printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" + +printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" +cat /opt/ml/input/config/resourceconfig.json + +printf "INFO: Bootstraping runtime environment.\\n" +python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" +source /opt/ml/input/sm_training.env + +if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] +then + if [ -f "remote_function_conda_env.txt" ] + then + cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt + fi + printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n" + cd {JOB_REMOTE_FUNCTION_WORKSPACE} +fi + +if [ -f "remote_function_conda_env.txt" ] +then + conda_env=$(cat remote_function_conda_env.txt) + + if which mamba >/dev/null; then + conda_exe="mamba" + else + conda_exe="conda" + fi + + if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + + printf "INFO: Invoking remote function with mpirun inside conda environment: $conda_env.\\n" + printf "INFO: $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + + python -m mpi4py -m sagemaker.remote_function.invoke_function \\n" + $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ + python -m mpi4py -m sagemaker.remote_function.invoke_function "$@" + + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 + else + printf "INFO: This is the instance $SM_CURRENT_HOST. mpirun command terminated\\n" + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + fi +else + if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + + printf "INFO: No conda env provided. Invoking remote function with mpirun\\n" + printf "INFO: mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ + python -m mpi4py -m sagemaker.remote_function.invoke_function \\n" + + mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ + python -m mpi4py -m sagemaker.remote_function.invoke_function "$@" + + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 + else + printf "INFO: This is the instance $SM_CURRENT_HOST.\\n" + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + fi +fi +""" + ENTRYPOINT_TORCHRUN_SCRIPT = f""" #!/bin/bash @@ -211,6 +305,7 @@ printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ -m sagemaker.remote_function.invoke_function \\n" + $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ -m sagemaker.remote_function.invoke_function "$@" @@ -218,6 +313,7 @@ printf "INFO: No conda env provided. Invoking remote function with torchrun\\n" printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\n" + torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@" fi @@ -278,6 +374,7 @@ def __init__( use_spot_instances=False, max_wait_time_in_seconds=None, use_torchrun: bool = False, + use_mpirun: bool = False, nproc_per_node: Optional[int] = None, ): """Initialize a _JobSettings instance which configures the remote job. @@ -464,6 +561,9 @@ def __init__( use_torchrun (bool): Specifies whether to use torchrun for distributed training. Defaults to ``False``. + use_mpirun (bool): Specifies whether to use mpirun for distributed training. + Defaults to ``False``. + nproc_per_node (Optional int): Specifies the number of processes per node for distributed training. Defaults to ``None``. This is defined automatically configured on the instance type. @@ -626,6 +726,7 @@ def __init__( self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS) self.use_torchrun = use_torchrun + self.use_mpirun = use_mpirun self.nproc_per_node = nproc_per_node @staticmethod @@ -874,6 +975,12 @@ def compile( ).to_string(), ] ) + if job_settings.use_torchrun: + container_args.extend(["--distribution", "torchrun"]) + elif job_settings.use_mpirun: + container_args.extend(["--distribution", "mpirun"]) + if job_settings.nproc_per_node is not None and int(job_settings.nproc_per_node) > 0: + container_args.extend(["--user_nproc_per_node", str(job_settings.nproc_per_node)]) if job_settings.s3_kms_key: container_args.extend(["--s3_kms_key", job_settings.s3_kms_key]) @@ -950,6 +1057,7 @@ def compile( request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key}) extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri) + extended_request = _extend_mpirun_to_request(extended_request, job_settings) extended_request = _extend_torchrun_to_request(extended_request, job_settings) return extended_request @@ -1031,7 +1139,7 @@ def _prepare_and_upload_runtime_scripts( s3_kms_key: str, sagemaker_session: Session, use_torchrun: bool = False, - nproc_per_node: Optional[int] = None, + use_mpirun: bool = False, ): """Copy runtime scripts to a folder and upload to S3. @@ -1050,6 +1158,8 @@ def _prepare_and_upload_runtime_scripts( use_torchrun (bool): Whether to use torchrun or not. + use_mpirun (bool): Whether to use mpirun or not. + nproc_per_node (Optional[int]): Number of processes per node """ @@ -1075,10 +1185,8 @@ def _prepare_and_upload_runtime_scripts( if use_torchrun: entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT - if nproc_per_node is not None and nproc_per_node > 0: - entry_point_script = entry_point_script.replace( - "$SM_NPROC_PER_NODE", str(nproc_per_node) - ) + if use_mpirun: + entry_point_script = ENTRYPOINT_MPIRUN_SCRIPT with open(entrypoint_script_path, "w", newline="\n") as file: file.writelines(entry_point_script) @@ -1086,12 +1194,16 @@ def _prepare_and_upload_runtime_scripts( bootstrap_script_path = os.path.join( os.path.dirname(__file__), "runtime_environment", BOOTSTRAP_SCRIPT_NAME ) + mpi_utils_path = os.path.join( + os.path.dirname(__file__), "runtime_environment", MPI_UTILS_SCRIPT_NAME + ) runtime_manager_script_path = os.path.join( os.path.dirname(__file__), "runtime_environment", RUNTIME_MANAGER_SCRIPT_NAME ) # copy runtime scripts to tmpdir shutil.copy2(bootstrap_script_path, bootstrap_scripts) + shutil.copy2(mpi_utils_path, bootstrap_scripts) shutil.copy2(runtime_manager_script_path, bootstrap_scripts) upload_path = S3Uploader.upload( @@ -1118,7 +1230,7 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): s3_kms_key=job_settings.s3_kms_key, sagemaker_session=job_settings.sagemaker_session, use_torchrun=job_settings.use_torchrun, - nproc_per_node=job_settings.nproc_per_node, + use_mpirun=job_settings.use_mpirun, ) input_data_config = [ @@ -1459,6 +1571,35 @@ def _upload_serialized_spark_configuration( return config_file_s3_uri +def _extend_mpirun_to_request( + request_dict: Dict, + job_settings: _JobSettings, +) -> Dict: + """Extend the create training job request with mpirun configuration. + + Args: + request_dict (Dict): create training job request dict. + job_settings (_JobSettings): the job settings. + """ + use_mpirun = job_settings.use_mpirun + instance_count = job_settings.instance_count + + if not use_mpirun: + return request_dict + + if instance_count == 1: + return request_dict + + extended_request = request_dict.copy() + + for input_channel in extended_request["InputDataConfig"]: + s3_data_source = input_channel["DataSource"].get("S3DataSource", None) + if s3_data_source: + s3_data_source["S3DataDistributionType"] = "FullyReplicated" + + return extended_request + + def _extend_torchrun_to_request( request_dict: Dict, job_settings: _JobSettings, diff --git a/src/sagemaker/remote_function/runtime_environment/__init__.py b/src/sagemaker/remote_function/runtime_environment/__init__.py index e69de29bb2..18557a2eb5 100644 --- a/src/sagemaker/remote_function/runtime_environment/__init__.py +++ b/src/sagemaker/remote_function/runtime_environment/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container_drivers directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index 0b0823da77..da7c493ae5 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -22,7 +22,7 @@ import shutil import subprocess import sys -from typing import Dict, Any +from typing import Any, Dict if __package__ is None or __package__ == "": from runtime_environment_manager import ( @@ -271,6 +271,8 @@ def _parse_args(sys_args): parser.add_argument("--pipeline_execution_id", type=str) parser.add_argument("--dependency_settings", type=str) parser.add_argument("--func_step_s3_dir", type=str) + parser.add_argument("--distribution", type=str, default=None) + parser.add_argument("--user_nproc_per_node", type=str, default=None) args, _ = parser.parse_known_args(sys_args) return args @@ -401,6 +403,8 @@ def safe_serialize(data): def set_env( resource_config: Dict[str, Any], + distribution: str = None, + user_nproc_per_node: bool = None, output_file: str = ENV_OUTPUT_FILE, ): """Set environment variables for the training job container. @@ -442,12 +446,15 @@ def set_env( # Misc. env_vars["SM_RESOURCE_CONFIG"] = resource_config - if int(env_vars["SM_NUM_GPUS"]) > 0: - env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_GPUS"]) - elif int(env_vars["SM_NUM_NEURONS"]) > 0: - env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_NEURONS"]) + if user_nproc_per_node is not None and int(user_nproc_per_node) > 0: + env_vars["SM_NPROC_PER_NODE"] = int(user_nproc_per_node) else: - env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_CPUS"]) + if int(env_vars["SM_NUM_GPUS"]) > 0: + env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_GPUS"]) + elif int(env_vars["SM_NUM_NEURONS"]) > 0: + env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_NEURONS"]) + else: + env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_CPUS"]) # All Training Environment Variables env_vars["SM_TRAINING_ENV"] = { @@ -471,18 +478,45 @@ def set_env( "resource_config": env_vars["SM_RESOURCE_CONFIG"], } - instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] - network_interface_name = env_vars.get("SM_NETWORK_INTERFACE_NAME", "eth0") + if distribution and distribution == "torchrun": + logger.info("Distribution: torchrun") + + instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] + network_interface_name = env_vars.get("SM_NETWORK_INTERFACE_NAME", "eth0") + + if instance_type in SM_EFA_NCCL_INSTANCES: + # Enable EFA use + env_vars["FI_PROVIDER"] = "efa" + if instance_type in SM_EFA_RDMA_INSTANCES: + # Use EFA's RDMA functionality for one-sided and two-sided transfer + env_vars["FI_EFA_USE_DEVICE_RDMA"] = "1" + env_vars["RDMAV_FORK_SAFE"] = "1" + env_vars["NCCL_SOCKET_IFNAME"] = str(network_interface_name) + env_vars["NCCL_PROTO"] = "simple" + elif distribution and distribution == "mpirun": + logger.info("Distribution: mpirun") + + env_vars["MASTER_ADDR"] = env_vars["SM_MASTER_ADDR"] + env_vars["MASTER_PORT"] = str(env_vars["SM_MASTER_PORT"]) + + host_list = [ + "{}:{}".format(host, int(env_vars["SM_NPROC_PER_NODE"])) for host in sorted_hosts + ] + env_vars["SM_HOSTS_LIST"] = ",".join(host_list) + + instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] + + if instance_type in SM_EFA_NCCL_INSTANCES: + env_vars["SM_FI_PROVIDER"] = "-x FI_PROVIDER=efa" + env_vars["SM_NCCL_PROTO"] = "-x NCCL_PROTO=simple" + else: + env_vars["SM_FI_PROVIDER"] = "" + env_vars["SM_NCCL_PROTO"] = "" - if instance_type in SM_EFA_NCCL_INSTANCES: - # Enable EFA use - env_vars["FI_PROVIDER"] = "efa" - if instance_type in SM_EFA_RDMA_INSTANCES: - # Use EFA's RDMA functionality for one-sided and two-sided transfer - env_vars["FI_EFA_USE_DEVICE_RDMA"] = "1" - env_vars["RDMAV_FORK_SAFE"] = "1" - env_vars["NCCL_SOCKET_IFNAME"] = str(network_interface_name) - env_vars["NCCL_PROTO"] = "simple" + if instance_type in SM_EFA_RDMA_INSTANCES: + env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "-x FI_EFA_USE_DEVICE_RDMA=1" + else: + env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "" with open(output_file, "w") as f: for key, value in env_vars.items(): @@ -499,12 +533,19 @@ def main(sys_args=None): try: args = _parse_args(sys_args) + + logger.info("Arguments:") + for arg in vars(args): + logger.info("%s=%s", arg, getattr(args, arg)) + client_python_version = args.client_python_version client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version job_conda_env = args.job_conda_env pipeline_execution_id = args.pipeline_execution_id dependency_settings = _DependencySettings.from_string(args.dependency_settings) func_step_workspace = args.func_step_s3_dir + distribution = args.distribution + user_nproc_per_node = args.user_nproc_per_node conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") @@ -539,7 +580,11 @@ def main(sys_args=None): logger.info("Found %s", RESOURCE_CONFIG) with open(RESOURCE_CONFIG, "r") as f: resource_config = json.load(f) - set_env(resource_config=resource_config) + set_env( + resource_config=resource_config, + distribution=distribution, + user_nproc_per_node=user_nproc_per_node, + ) except (json.JSONDecodeError, FileNotFoundError) as e: # Optionally, you might want to log this error logger.info("ERROR: Error processing %s: %s", RESOURCE_CONFIG, str(e)) diff --git a/src/sagemaker/remote_function/runtime_environment/mpi_utils_remote.py b/src/sagemaker/remote_function/runtime_environment/mpi_utils_remote.py new file mode 100644 index 0000000000..6f3897fb0b --- /dev/null +++ b/src/sagemaker/remote_function/runtime_environment/mpi_utils_remote.py @@ -0,0 +1,252 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""An utils function for runtime environment. This must be kept independent of SageMaker PySDK""" +from __future__ import absolute_import + +import argparse +import json +import os +import subprocess +import sys +import time +from typing import List + +import paramiko + +if __package__ is None or __package__ == "": + from runtime_environment_manager import ( + get_logger, + ) +else: + from sagemaker.remote_function.runtime_environment.runtime_environment_manager import ( + get_logger, + ) + +SUCCESS_EXIT_CODE = 0 +DEFAULT_FAILURE_CODE = 1 + +FINISHED_STATUS_FILE = "/tmp/done.algo-1" +READY_FILE = "/tmp/ready.%s" +DEFAULT_SSH_PORT = 22 + +FAILURE_REASON_PATH = "/opt/ml/output/failure" +FINISHED_STATUS_FILE = "/tmp/done.algo-1" + +logger = get_logger() + + +class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy): + """Class to handle host key policy for SageMaker distributed training SSH connections. + + Example: + >>> client = paramiko.SSHClient() + >>> client.set_missing_host_key_policy(CustomHostKeyPolicy()) + >>> # Will succeed for SageMaker algorithm containers + >>> client.connect('algo-1234.internal') + >>> # Will raise SSHException for other unknown hosts + >>> client.connect('unknown-host') # raises SSHException + """ + + def missing_host_key(self, client, hostname, key): + """Accept host keys for algo-* hostnames, reject others. + + Args: + client: The SSHClient instance + hostname: The hostname attempting to connect + key: The host key + Raises: + paramiko.SSHException: If hostname doesn't match algo-* pattern + """ + if hostname.startswith("algo-"): + client.get_host_keys().add(hostname, key.get_name(), key) + return + raise paramiko.SSHException(f"Unknown host key for {hostname}") + + +def _parse_args(sys_args): + """Parses CLI arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--job_ended", type=str, default="0") + args, _ = parser.parse_known_args(sys_args) + return args + + +def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool: + """Check if the connection to the provided host and port is possible.""" + try: + with paramiko.SSHClient() as client: + client.load_system_host_keys() + client.set_missing_host_key_policy(CustomHostKeyPolicy()) + client.connect(host, port=port) + logger.info("Can connect to host %s", host) + return True + except Exception as e: # pylint: disable=W0703 + logger.info("Cannot connect to host %s", host) + logger.debug("Connection failed with exception: %s", e) + return False + + +def _write_file_to_host(host: str, status_file: str) -> bool: + """Write the a file to the provided host.""" + try: + logger.info("Writing %s to %s", status_file, host) + subprocess.run( + ["ssh", host, "touch", f"{status_file}"], + capture_output=True, + text=True, + check=True, + ) + logger.info("Finished writing status file") + return True + except subprocess.CalledProcessError: + logger.info("Cannot connect to %s", host) + return False + + +def _write_failure_reason_file(failure_msg): + """Create a file 'failure' with failure reason written if bootstrap runtime env failed. + + See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html + Args: + failure_msg: The content of file to be written. + """ + if not os.path.exists(FAILURE_REASON_PATH): + with open(FAILURE_REASON_PATH, "w") as f: + f.write("RuntimeEnvironmentError: " + failure_msg) + + +def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300): + """Worker nodes wait until they can connect to the master node.""" + start_time = time.time() + while True: + logger.info("Worker is attempting to connect to the master node %s...", master_host) + if _can_connect(master_host, port): + logger.info("Worker can connect to master node %s.", master_host) + break + if time.time() - start_time > timeout: + raise TimeoutError("Timed out waiting for master %s to be reachable." % master_host) + + time.sleep(5) # Wait for 5 seconds before trying again + + +def _wait_for_status_file(status_file: str): + """Wait for the status file to be created.""" + logger.info("Waiting for status file %s", status_file) + while not os.path.exists(status_file): + time.sleep(30) + logger.info("Found status file %s", status_file) + + +def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300): + """Master node waits until it can connect to all worker nodes.""" + start_time = time.time() + if not worker_hosts: + logger.info("No worker nodes to connect to.") + return + + while True: + logger.info("Master is attempting to connect to all workers...") + all_workers_connected = all( + _can_connect(worker, port) and os.path.exists(READY_FILE % worker) + for worker in worker_hosts + ) + + if all_workers_connected: + logger.info("Master can connect to all worker nodes.") + break + if time.time() - start_time > timeout: + raise TimeoutError("Timed out waiting for workers to be reachable.") + + time.sleep(5) # Wait for 5 seconds before trying again + + +def bootstrap_master_node(worker_hosts: List[str]): + """Bootstrap the master node.""" + logger.info("Bootstrapping master node...") + _wait_for_workers(worker_hosts) + + +def bootstrap_worker_node( + master_host: str, current_host: str, status_file: str = FINISHED_STATUS_FILE +): + """Bootstrap the worker nodes.""" + logger.info("Bootstrapping worker node...") + _wait_for_master(master_host) + _write_file_to_host(master_host, READY_FILE % current_host) + _wait_for_status_file(status_file) + + +def start_sshd_daemon(): + """Start the SSH daemon on the current node.""" + sshd_executable = "/usr/sbin/sshd" + + if not os.path.exists(sshd_executable): + raise RuntimeError("SSH daemon not found.") + + # Start the sshd in daemon mode (-D) + subprocess.Popen([sshd_executable, "-D"]) + logger.info("Started SSH daemon.") + + +def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE): + """Write the status file to all worker nodes.""" + for worker in worker_hosts: + retry = 0 + while not _write_file_to_host(worker, status_file): + time.sleep(5) + retry += 1 + if retry > 5: + raise TimeoutError("Timed out waiting for %s to be reachable." % worker) + logger.info("Retrying to write status file to %s", worker) + + +def main(sys_args=None): + """Entry point for bootstrap script""" + try: + args = _parse_args(sys_args) + + job_ended = args.job_ended + + main_host = os.environ["SM_MASTER_ADDR"] + current_host = os.environ["SM_CURRENT_HOST"] + + if job_ended == "0": + logger.info("Job is running, bootstrapping nodes") + + start_sshd_daemon() + + if current_host != main_host: + bootstrap_worker_node(main_host, current_host) + else: + sorted_hosts = json.loads(os.environ["SM_HOSTS"]) + worker_hosts = [host for host in sorted_hosts if host != main_host] + + bootstrap_master_node(worker_hosts) + else: + logger.info("Job ended, writing status file to workers") + + if current_host == main_host: + sorted_hosts = json.loads(os.environ["SM_HOSTS"]) + worker_hosts = [host for host in sorted_hosts if host != main_host] + + write_status_file_to_workers(worker_hosts) + except Exception as e: # pylint: disable=broad-except + logger.exception("Error encountered while bootstrapping runtime environment: %s", e) + + _write_failure_reason_file(str(e)) + + sys.exit(DEFAULT_FAILURE_CODE) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/tests/integ/sagemaker/remote_function/test_decorator.py b/tests/integ/sagemaker/remote_function/test_decorator.py index 680bfc01df..fa55d7dfa7 100644 --- a/tests/integ/sagemaker/remote_function/test_decorator.py +++ b/tests/integ/sagemaker/remote_function/test_decorator.py @@ -825,6 +825,7 @@ def test_decorator_torchrun( dummy_container_without_error, gpu_instance_type, use_torchrun=False, + use_mpirun=False, ): @remote( role=ROLE, @@ -833,6 +834,7 @@ def test_decorator_torchrun( sagemaker_session=sagemaker_session, keep_alive_period_in_seconds=60, use_torchrun=use_torchrun, + use_mpirun=use_mpirun, ) def divide(x, y): return x / y diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py index 57f4a54f78..00bd3ca090 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py @@ -908,6 +908,7 @@ def test_remote_decorator_fields_consistency(get_execution_role, session): "max_wait_time_in_seconds", "custom_file_filter", "use_torchrun", + "use_mpirun", "nproc_per_node", } diff --git a/tests/unit/sagemaker/remote_function/runtime_environment/test_mpi_utils.py b/tests/unit/sagemaker/remote_function/runtime_environment/test_mpi_utils.py new file mode 100644 index 0000000000..aa983141ae --- /dev/null +++ b/tests/unit/sagemaker/remote_function/runtime_environment/test_mpi_utils.py @@ -0,0 +1,125 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""MPI Utils Unit Tests.""" +from __future__ import absolute_import + +import os +from mock import patch + +import sagemaker.remote_function.runtime_environment.mpi_utils_remote as mpi_utils_remote # noqa: E402 + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-1", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +def test_mpi_utils_main_job_start( + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main() + + mock_start_sshd_daemon.assert_called_once() + mock_bootstrap_worker_node.assert_not_called() + mock_bootstrap_master_node.assert_called_once() + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-2", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +def test_mpi_utils_worker_job_start( + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main() + + mock_start_sshd_daemon.assert_called_once() + mock_bootstrap_worker_node.assert_called_once() + mock_bootstrap_master_node.assert_not_called() + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-1", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +@patch( + "sagemaker.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers" +) +def test_mpi_utils_main_job_end( + mock_write_status_file_to_workers, + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main(["--job_ended", "1"]) + + mock_start_sshd_daemon.assert_not_called() + mock_bootstrap_worker_node.assert_not_called() + mock_bootstrap_master_node.assert_not_called() + mock_write_status_file_to_workers.assert_called_once() + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-2", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +@patch( + "sagemaker.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers" +) +def test_mpi_utils_worker_job_end( + mock_write_status_file_to_workers, + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main(["--job_ended", "1"]) + + mock_start_sshd_daemon.assert_not_called() + mock_bootstrap_worker_node.assert_not_called() + mock_bootstrap_master_node.assert_not_called() + mock_write_status_file_to_workers.assert_not_called() diff --git a/tests/unit/sagemaker/remote_function/test_client.py b/tests/unit/sagemaker/remote_function/test_client.py index 536bfdfca7..6c2a373dbc 100644 --- a/tests/unit/sagemaker/remote_function/test_client.py +++ b/tests/unit/sagemaker/remote_function/test_client.py @@ -1505,6 +1505,7 @@ def test_consistency_between_remote_and_step_decorator(): "s3_root_uri", "sagemaker_session", "use_torchrun", + "use_mpirun", "nproc_per_node", ] diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py index c7d35b6481..671f091d02 100644 --- a/tests/unit/sagemaker/remote_function/test_job.py +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -96,8 +96,6 @@ export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.t3.xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.t3.xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' export SM_NPROC_PER_NODE='4' export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.t3.xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 4, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 4, "num_gpus": 0, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.t3.xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.t3.xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' -export NCCL_SOCKET_IFNAME='eth0' -export NCCL_PROTO='simple' """ # flake8: noqa @@ -154,6 +152,99 @@ export NCCL_PROTO='simple' """ +# flake8: noqa +EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.12xlarge' +export SM_HOSTS='["algo-1"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='1' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='48' +export SM_NUM_GPUS='4' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='4' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 4, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 48, "num_gpus": 4, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' +export MASTER_ADDR='algo-1' +export MASTER_PORT='7777' +export SM_HOSTS_LIST='algo-1:4' +export SM_FI_PROVIDER='' +export SM_NCCL_PROTO='' +export SM_FI_EFA_USE_DEVICE_RDMA='' +""" + +# flake8: noqa +EXPECTED_ENV_MULTI_NODE_MULTI_GPUS_MPIRUN = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.2xlarge' +export SM_HOSTS='["algo-1", "algo-2", "algo-3", "algo-4"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='4' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='8' +export SM_NUM_GPUS='1' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='1' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.2xlarge", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "host_count": 4, "nproc_per_node": 1, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 8, "num_gpus": 1, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}}' +export MASTER_ADDR='algo-1' +export MASTER_PORT='7777' +export SM_HOSTS_LIST='algo-1:1,algo-2:1,algo-3:1,algo-4:1' +export SM_FI_PROVIDER='' +export SM_NCCL_PROTO='' +export SM_FI_EFA_USE_DEVICE_RDMA='' +""" + +# flake8: noqa +EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN_WITH_NPROC_PER_NODE = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.12xlarge' +export SM_HOSTS='["algo-1"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='1' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='48' +export SM_NUM_GPUS='4' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='2' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 2, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 48, "num_gpus": 4, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' +export MASTER_ADDR='algo-1' +export MASTER_PORT='7777' +export SM_HOSTS_LIST='algo-1:2' +export SM_FI_PROVIDER='' +export SM_NCCL_PROTO='' +export SM_FI_EFA_USE_DEVICE_RDMA='' +""" + DESCRIBE_TRAINING_JOB_RESPONSE = { "TrainingJobArn": TRAINING_JOB_ARN, "TrainingJobStatus": "{}", @@ -478,7 +569,7 @@ def test_start( s3_kms_key=None, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=None, + use_mpirun=False, ) mock_dependency_upload.assert_called_once_with( @@ -761,7 +852,7 @@ def test_start_with_complete_job_settings( s3_kms_key=job_settings.s3_kms_key, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=None, + use_mpirun=False, ) mock_user_workspace_upload.assert_called_once_with( @@ -933,7 +1024,7 @@ def test_get_train_args_under_pipeline_context( s3_kms_key=job_settings.s3_kms_key, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=None, + use_mpirun=False, ) mock_user_workspace_upload.assert_called_once_with( @@ -1109,7 +1200,7 @@ def test_start_with_spark( s3_kms_key=None, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=None, + use_mpirun=False, ) session().sagemaker_client.create_training_job.assert_called_once_with( @@ -1268,7 +1359,7 @@ def test_prepare_and_upload_runtime_scripts(session, mock_copy, mock_s3_upload): assert s3_path == mock_s3_upload.return_value - assert mock_copy.call_count == 2 + assert mock_copy.call_count == 3 mock_s3_upload.assert_called_once() @@ -1288,7 +1379,7 @@ def test_prepare_and_upload_runtime_scripts_under_pipeline_context( ) # Bootstrap scripts are uploaded on the first call assert s3_path == mock_s3_upload.return_value - assert mock_copy.call_count == 2 + assert mock_copy.call_count == 3 mock_s3_upload.assert_called_once() mock_copy.reset_mock() @@ -1725,7 +1816,7 @@ def test_start_with_torchrun_single_node( instance_type="ml.g5.12xlarge", encrypt_inter_container_traffic=True, use_torchrun=True, - nproc_per_node=None, + use_mpirun=False, ) job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) @@ -1751,7 +1842,7 @@ def test_start_with_torchrun_single_node( s3_kms_key=None, sagemaker_session=session(), use_torchrun=True, - nproc_per_node=None, + use_mpirun=False, ) mock_dependency_upload.assert_called_once_with( @@ -1809,6 +1900,8 @@ def test_start_with_torchrun_single_node( mock_sagemaker_pysdk_version, "--dependency_settings", '{"dependency_file": null}', + "--distribution", + "torchrun", "--run_in_context", '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', ], @@ -1853,7 +1946,7 @@ def test_start_with_torchrun_multi_node( instance_type="ml.g5.2xlarge", encrypt_inter_container_traffic=True, use_torchrun=True, - nproc_per_node=None, + use_mpirun=False, ) job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) @@ -1879,7 +1972,7 @@ def test_start_with_torchrun_multi_node( s3_kms_key=None, sagemaker_session=session(), use_torchrun=True, - nproc_per_node=None, + use_mpirun=False, ) mock_dependency_upload.assert_called_once_with( @@ -1939,6 +2032,8 @@ def test_start_with_torchrun_multi_node( mock_sagemaker_pysdk_version, "--dependency_settings", '{"dependency_file": null}', + "--distribution", + "torchrun", "--run_in_context", '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', ], @@ -1969,7 +2064,7 @@ def test_start_with_torchrun_multi_node( return_value=0, ) @patch( - "sagemaker.modules.train.container_drivers.scripts.environment.safe_serialize", + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", side_effect=safe_serialize, ) def test_set_env_single_node_cpu( @@ -1991,6 +2086,7 @@ def test_set_env_single_node_cpu( ], network_interface_name="eth0", ), + distribution=None, output_file=OUTPUT_FILE, ) @@ -2021,7 +2117,7 @@ def test_set_env_single_node_cpu( return_value=0, ) @patch( - "sagemaker.modules.train.container_drivers.scripts.environment.safe_serialize", + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", side_effect=safe_serialize, ) def test_set_env_single_node_multi_gpu( @@ -2043,6 +2139,7 @@ def test_set_env_single_node_multi_gpu( ], network_interface_name="eth0", ), + distribution="torchrun", output_file=OUTPUT_FILE, ) @@ -2073,7 +2170,7 @@ def test_set_env_single_node_multi_gpu( return_value=0, ) @patch( - "sagemaker.modules.train.container_drivers.scripts.environment.safe_serialize", + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", side_effect=safe_serialize, ) def test_set_env_multi_node_multi_gpu( @@ -2095,6 +2192,7 @@ def test_set_env_multi_node_multi_gpu( ], network_interface_name="eth0", ), + distribution="torchrun", output_file=OUTPUT_FILE, ) @@ -2112,6 +2210,432 @@ def test_set_env_multi_node_multi_gpu( assert not os.path.exists(OUTPUT_FILE) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=48, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=4, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_single_node_multi_gpu_mpirun( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.12xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.12xlarge", + hosts=["algo-1"], + ) + ], + network_interface_name="eth0", + ), + distribution="mpirun", + output_file=OUTPUT_FILE, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(OUTPUT_FILE, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + os.remove(OUTPUT_FILE) + assert not os.path.exists(OUTPUT_FILE) + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=8, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=1, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_multi_node_multi_gpu_mpirun( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1", "algo-2", "algo-3", "algo-4"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.2xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.2xlarge", + hosts=["algo-4", "algo-2", "algo-1", "algo-3"], + ) + ], + network_interface_name="eth0", + ), + distribution="mpirun", + output_file=OUTPUT_FILE, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(OUTPUT_FILE, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_MULTI_NODE_MULTI_GPUS_MPIRUN) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + os.remove(OUTPUT_FILE) + assert not os.path.exists(OUTPUT_FILE) + + +@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start_with_torchrun_single_node_with_nproc_per_node( + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, +): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + include_local_workdir=True, + instance_type="ml.g5.12xlarge", + encrypt_inter_container_traffic=True, + use_torchrun=True, + use_mpirun=False, + nproc_per_node=2, + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + mock_stored_function.assert_called_once_with( + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, + ) + + mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() + + mock_script_upload.assert_called_once_with( + spark_config=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + use_torchrun=True, + use_mpirun=False, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=True, + pre_execution_commands=None, + pre_execution_script_local_path=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + custom_file_filter=None, + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": mock_dependency_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, + "--dependency_settings", + '{"dependency_file": null}', + "--distribution", + "torchrun", + "--user_nproc_per_node", + "2", + "--run_in_context", + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=30, + InstanceCount=1, + InstanceType="ml.g5.12xlarge", + KeepAlivePeriodInSeconds=0, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=True, + EnableManagedSpotTraining=False, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + ) + + +@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start_with_mpirun_single_node_with_nproc_per_node( + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, +): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + include_local_workdir=True, + instance_type="ml.g5.12xlarge", + encrypt_inter_container_traffic=True, + use_torchrun=False, + use_mpirun=True, + nproc_per_node=2, + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + mock_stored_function.assert_called_once_with( + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, + ) + + mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() + + mock_script_upload.assert_called_once_with( + spark_config=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + use_torchrun=False, + use_mpirun=True, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=True, + pre_execution_commands=None, + pre_execution_script_local_path=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + custom_file_filter=None, + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": mock_dependency_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, + "--dependency_settings", + '{"dependency_file": null}', + "--distribution", + "mpirun", + "--user_nproc_per_node", + "2", + "--run_in_context", + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=30, + InstanceCount=1, + InstanceType="ml.g5.12xlarge", + KeepAlivePeriodInSeconds=0, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=True, + EnableManagedSpotTraining=False, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + ) + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=48, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=4, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_single_node_multi_gpu_mpirun_with_nproc_per_node( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.12xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.12xlarge", + hosts=["algo-1"], + ) + ], + network_interface_name="eth0", + ), + distribution="mpirun", + user_nproc_per_node=2, + output_file=OUTPUT_FILE, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(OUTPUT_FILE, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines( + EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN_WITH_NPROC_PER_NODE + ) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + os.remove(OUTPUT_FILE) + assert not os.path.exists(OUTPUT_FILE) + + def _remove_extra_lines(string): """Removes extra blank lines from a string.""" return "\n".join([line for line in string.splitlines() if line.strip()])