Skip to content

SageMaker @remote function: Added multi-node functionality #4984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions src/sagemaker/remote_function/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def remote(
use_spot_instances=False,
max_wait_time_in_seconds=None,
use_torchrun=False,
nproc_per_node=1,
nproc_per_node: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does nproc stand for ? Can we use the unabbreviated string for the parameter ?

):
"""Decorator for running the annotated function as a SageMaker training job.

Expand Down Expand Up @@ -284,8 +284,9 @@ def remote(
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.

nproc_per_node (int): Specifies the number of processes per node for distributed training.
Defaults to ``1``.
nproc_per_node (Optional int): Specifies the number of processes per node for
distributed training. Defaults to ``None``.
This is defined automatically configured on the instance type.
"""

def _remote(func):
Expand Down Expand Up @@ -325,9 +326,13 @@ def _remote(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):

if instance_count > 1 and not spark_config:
if instance_count > 1 and not (
(spark_config is not None and not use_torchrun)
or (spark_config is None and use_torchrun)
):
raise ValueError(
"Remote function do not support training on multi instances. "
"Remote function do not support training on multi instances "
+ "without spark_config or use_torchrun. "
+ "Please provide instance_count = 1"
)

Expand Down Expand Up @@ -532,7 +537,7 @@ def __init__(
use_spot_instances=False,
max_wait_time_in_seconds=None,
use_torchrun=False,
nproc_per_node=1,
nproc_per_node: Optional[int] = None,
):
"""Constructor for RemoteExecutor

Expand Down Expand Up @@ -725,17 +730,22 @@ def __init__(
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.

nproc_per_node (int): Specifies the number of processes per node.
Defaults to ``1``.
nproc_per_node (Optional int): Specifies the number of processes per node for
distributed training. Defaults to ``None``.
This is defined automatically configured on the instance type.
"""
self.max_parallel_jobs = max_parallel_jobs

if self.max_parallel_jobs <= 0:
raise ValueError("max_parallel_jobs must be greater than 0.")

if instance_count > 1 and not spark_config:
if instance_count > 1 and not (
(spark_config is not None and not use_torchrun)
or (spark_config is None and use_torchrun)
):
raise ValueError(
"Remote function do not support training on multi instances. "
"Remote function do not support training on multi instances "
+ "without spark_config or use_torchrun. "
+ "Please provide instance_count = 1"
)

Expand Down
6 changes: 0 additions & 6 deletions src/sagemaker/remote_function/core/stored_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def __init__(
hmac_key: str,
s3_kms_key: str = None,
context: Context = Context(),
use_torchrun: bool = False,
nproc_per_node: int = 1,
):
"""Construct a StoredFunction object.

Expand All @@ -67,16 +65,12 @@ def __init__(
s3_kms_key: KMS key used to encrypt artifacts uploaded to S3.
hmac_key: Key used to encrypt serialized and deserialized function and arguments.
context: Build or run context of a pipeline step.
use_torchrun: Whether to use torchrun for distributed training.
nproc_per_node: Number of processes per node for distributed training.
"""
self.sagemaker_session = sagemaker_session
self.s3_base_uri = s3_base_uri
self.s3_kms_key = s3_kms_key
self.hmac_key = hmac_key
self.context = context
self.use_torchrun = use_torchrun
self.nproc_per_node = nproc_per_node

self.func_upload_path = s3_path_join(
s3_base_uri, context.step_name, context.func_step_s3_dir
Expand Down
75 changes: 64 additions & 11 deletions src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,12 @@
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n"

printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n"
cat /opt/ml/input/config/resourceconfig.json

printf "INFO: Bootstraping runtime environment.\\n"
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@"
source /opt/ml/input/sm_training.env

if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ]
then
Expand All @@ -155,9 +158,11 @@
fi

printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n"
printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function \\n"
$conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function "$@"
else
printf "INFO: No conda env provided. Invoking remote function\\n"
printf "INFO: python -m sagemaker.remote_function.invoke_function \\n"
python -m sagemaker.remote_function.invoke_function "$@"
fi
"""
Expand All @@ -175,9 +180,12 @@
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n"

printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n"
cat /opt/ml/input/config/resourceconfig.json

printf "INFO: Bootstraping runtime environment.\\n"
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@"
source /opt/ml/input/sm_training.env

if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ]
then
Expand All @@ -200,11 +208,18 @@
fi

printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n"
$conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE \
printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
-m sagemaker.remote_function.invoke_function \\n"
$conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
-m sagemaker.remote_function.invoke_function "$@"
else
printf "INFO: No conda env provided. Invoking remote function with torchrun\\n"
torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@"
printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\n"
torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@"
fi
"""

Expand Down Expand Up @@ -262,8 +277,8 @@ def __init__(
spark_config: SparkConfig = None,
use_spot_instances=False,
max_wait_time_in_seconds=None,
use_torchrun=False,
nproc_per_node=1,
use_torchrun: bool = False,
nproc_per_node: Optional[int] = None,
):
"""Initialize a _JobSettings instance which configures the remote job.

Expand Down Expand Up @@ -445,6 +460,13 @@ def __init__(
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
After this amount of time Amazon SageMaker will stop waiting for managed spot
training job to complete. Defaults to ``None``.

use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.

nproc_per_node (Optional int): Specifies the number of processes per node for
distributed training. Defaults to ``None``.
This is defined automatically configured on the instance type.
"""
self.sagemaker_session = sagemaker_session or Session()
self.environment_variables = resolve_value_from_config(
Expand Down Expand Up @@ -732,6 +754,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
)

logger.info("Creating job: %s", job_name)

job_settings.sagemaker_session.sagemaker_client.create_training_job(**training_job_request)

return _Job(
Expand Down Expand Up @@ -776,8 +799,6 @@ def compile(
s3_base_uri=s3_base_uri,
hmac_key=hmac_key,
s3_kms_key=job_settings.s3_kms_key,
use_torchrun=job_settings.use_torchrun,
nproc_per_node=job_settings.nproc_per_node,
)
stored_function.save(func, *func_args, **func_kwargs)
else:
Expand All @@ -790,8 +811,6 @@ def compile(
step_name=step_compilation_context.step_name,
func_step_s3_dir=step_compilation_context.pipeline_build_time,
),
use_torchrun=job_settings.use_torchrun,
nproc_per_node=job_settings.nproc_per_node,
)

stored_function.save_pipeline_step_function(serialized_data)
Expand Down Expand Up @@ -931,6 +950,7 @@ def compile(
request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key})

extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri)
extended_request = _extend_torchrun_to_request(extended_request, job_settings)

return extended_request

Expand Down Expand Up @@ -1011,7 +1031,7 @@ def _prepare_and_upload_runtime_scripts(
s3_kms_key: str,
sagemaker_session: Session,
use_torchrun: bool = False,
nproc_per_node: int = 1,
nproc_per_node: Optional[int] = None,
):
"""Copy runtime scripts to a folder and upload to S3.

Expand All @@ -1030,7 +1050,7 @@ def _prepare_and_upload_runtime_scripts(

use_torchrun (bool): Whether to use torchrun or not.

nproc_per_node (int): Number of processes per node.
nproc_per_node (Optional[int]): Number of processes per node
"""

from sagemaker.workflow.utilities import load_step_compilation_context
Expand All @@ -1054,7 +1074,11 @@ def _prepare_and_upload_runtime_scripts(

if use_torchrun:
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
entry_point_script = entry_point_script.replace("$NPROC_PER_NODE", str(nproc_per_node))

if nproc_per_node is not None and nproc_per_node > 0:
entry_point_script = entry_point_script.replace(
"$SM_NPROC_PER_NODE", str(nproc_per_node)
)

with open(entrypoint_script_path, "w", newline="\n") as file:
file.writelines(entry_point_script)
Expand Down Expand Up @@ -1435,6 +1459,35 @@ def _upload_serialized_spark_configuration(
return config_file_s3_uri


def _extend_torchrun_to_request(
request_dict: Dict,
job_settings: _JobSettings,
) -> Dict:
"""Extend the create training job request with torchrun configuration.

Args:
request_dict (Dict): create training job request dict.
job_settings (_JobSettings): the job settings.
"""
use_torchrun = job_settings.use_torchrun
instance_count = job_settings.instance_count

if not use_torchrun:
return request_dict

if instance_count == 1:
return request_dict

extended_request = request_dict.copy()

for input_channel in extended_request["InputDataConfig"]:
s3_data_source = input_channel["DataSource"].get("S3DataSource", None)
if s3_data_source:
s3_data_source["S3DataDistributionType"] = "FullyReplicated"

return extended_request


def _extend_spark_config_to_request(
request_dict: Dict,
job_settings: _JobSettings,
Expand Down
Loading
Loading