Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1ad75c9

Browse files
beniericpintaoz-aws
authored andcommittedDec 4, 2024
Simplify Config Class Names and DistributedRunner structures (#1573)
1 parent ce55d45 commit 1ad75c9

File tree

22 files changed

+693
-413
lines changed

22 files changed

+693
-413
lines changed
 

‎.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ env/
3333
*.html
3434
**/_repack_script_launcher.sh
3535
src/sagemaker/modules/train/container_drivers/sm_train.sh
36-
src/sagemaker/modules/train/container_drivers/sourcecodeconfig.json
37-
src/sagemaker/modules/train/container_drivers/distribution.json
36+
src/sagemaker/modules/train/container_drivers/sourcecode.json
37+
src/sagemaker/modules/train/container_drivers/distributed_runner.json
3838
tests/data/**/_repack_model.py
3939
tests/data/experiment/sagemaker-dev-1.0.tar.gz
4040
src/sagemaker/serve/tmp_workspace

‎src/sagemaker/modules/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,7 @@
1616
from sagemaker_core.main.utils import logger as sagemaker_core_logger
1717

1818
logger = sagemaker_core_logger
19+
20+
from sagemaker.modules.train.model_trainer import ( # noqa: F401 E402 # pylint: disable=C0413
21+
ModelTrainer,
22+
)

‎src/sagemaker/modules/configs.py

Lines changed: 15 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from __future__ import absolute_import
2323

24-
from typing import Optional, Union, Dict, Any, List
24+
from typing import Optional, Union
2525
from pydantic import BaseModel, model_validator
2626

2727
import sagemaker_core.shapes as shapes
@@ -54,15 +54,10 @@
5454
CheckpointConfig,
5555
)
5656

57-
from sagemaker.modules import logger
5857
from sagemaker.modules.utils import convert_unassigned_to_none
5958

6059
__all__ = [
61-
"SourceCodeConfig",
62-
"TorchDistributionConfig",
63-
"MPIDistributionConfig",
64-
"SMDistributedSettings",
65-
"DistributionConfig",
60+
"SourceCode",
6661
"StoppingCondition",
6762
"RetryStrategy",
6863
"OutputDataConfig",
@@ -87,107 +82,16 @@
8782
"InstanceGroup",
8883
"TensorBoardOutputConfig",
8984
"CheckpointConfig",
90-
"ComputeConfig",
91-
"NetworkingConfig",
85+
"Compute",
86+
"Networking",
9287
"InputData",
9388
]
9489

9590

96-
class SMDistributedSettings(BaseModel):
97-
"""SMDistributedSettings.
91+
class SourceCode(BaseModel):
92+
"""SourceCode.
9893
99-
The SMDistributedSettings is used to configure distributed training when
100-
using the smdistributed library.
101-
102-
Attributes:
103-
enable_dataparallel (Optional[bool]):
104-
Whether to enable data parallelism.
105-
enable_modelparallel (Optional[bool]):
106-
Whether to enable model parallelism.
107-
modelparallel_parameters (Optional[Dict[str, Any]]):
108-
The parameters for model parallelism.
109-
"""
110-
111-
enable_dataparallel: Optional[bool] = False
112-
enable_modelparallel: Optional[bool] = False
113-
modelparallel_parameters: Optional[Dict[str, Any]] = None
114-
115-
116-
class DistributionConfig(BaseModel):
117-
"""Base class for distribution configurations."""
118-
119-
_distribution_type: str
120-
121-
122-
class TorchDistributionConfig(DistributionConfig):
123-
"""TorchDistributionConfig.
124-
125-
The TorchDistributionConfig uses `torchrun` or `torch.distributed.launch` in the backend to
126-
launch distributed training.
127-
128-
SMDistributed Library Information:
129-
- `TorchDistributionConfig` can be used for SMModelParallel V2.
130-
- For SMDataParallel or SMModelParallel V1, it is recommended to use the
131-
`MPIDistributionConfig.`
132-
133-
134-
Attributes:
135-
smdistributed_settings (Optional[SMDistributedSettings]):
136-
The settings for smdistributed library.
137-
process_count_per_node (int):
138-
The number of processes to run on each node in the training job.
139-
Will default to the number of CPUs or GPUs available in the container.
140-
"""
141-
142-
_distribution_type: str = "torch_distributed"
143-
144-
smdistributed_settings: Optional[SMDistributedSettings] = None
145-
process_count_per_node: Optional[int] = None
146-
147-
@model_validator(mode="after")
148-
def _validate_model(cls, model): # pylint: disable=E0213
149-
"""Validate the model."""
150-
if (
151-
getattr(model, "smddistributed_settings", None)
152-
and model.smddistributed_settings.enable_dataparallel
153-
):
154-
logger.warning(
155-
"For smdistributed data parallelism, it is recommended to use "
156-
+ "MPIDistributionConfig."
157-
)
158-
return model
159-
160-
161-
class MPIDistributionConfig(DistributionConfig):
162-
"""MPIDistributionConfig.
163-
164-
The MPIDistributionConfig uses `mpirun` in the backend to launch distributed training.
165-
166-
SMDistributed Library Information:
167-
- `MPIDistributionConfig` can be used for SMDataParallel and SMModelParallel V1.
168-
- For SMModelParallel V2, it is recommended to use the `TorchDistributionConfig`.
169-
170-
Attributes:
171-
smdistributed_settings (Optional[SMDistributedSettings]):
172-
The settings for smdistributed library.
173-
process_count_per_node (int):
174-
The number of processes to run on each node in the training job.
175-
Will default to the number of CPUs or GPUs available in the container.
176-
mpi_additional_options (Optional[str]):
177-
The custom MPI options to use for the training job.
178-
"""
179-
180-
_distribution_type: str = "mpi"
181-
182-
smdistributed_settings: Optional[SMDistributedSettings] = None
183-
process_count_per_node: Optional[int] = None
184-
mpi_additional_options: Optional[List[str]] = None
185-
186-
187-
class SourceCodeConfig(BaseModel):
188-
"""SourceCodeConfig.
189-
190-
This config allows the user to specify the source code location, dependencies,
94+
The SourceCode class allows the user to specify the source code location, dependencies,
19195
entry script, or commands to be executed in the training job container.
19296
19397
Attributes:
@@ -210,10 +114,10 @@ class SourceCodeConfig(BaseModel):
210114
command: Optional[str] = None
211115

212116

213-
class ComputeConfig(shapes.ResourceConfig):
214-
"""ComputeConfig.
117+
class Compute(shapes.ResourceConfig):
118+
"""Compute.
215119
216-
The ComputeConfig is a subclass of `sagemaker_core.shapes.ResourceConfig`
120+
The Compute class is a subclass of `sagemaker_core.shapes.ResourceConfig`
217121
and allows the user to specify the compute resources for the training job.
218122
219123
Attributes:
@@ -245,7 +149,7 @@ class ComputeConfig(shapes.ResourceConfig):
245149
enable_managed_spot_training: Optional[bool] = None
246150

247151
@model_validator(mode="after")
248-
def _model_validator(self) -> "ComputeConfig":
152+
def _model_validator(self) -> "Compute":
249153
"""Convert Unassigned values to None."""
250154
return convert_unassigned_to_none(self)
251155

@@ -259,10 +163,10 @@ def _to_resource_config(self) -> shapes.ResourceConfig:
259163
return shapes.ResourceConfig(**filtered_dict)
260164

261165

262-
class NetworkingConfig(shapes.VpcConfig):
263-
"""NetworkingConfig.
166+
class Networking(shapes.VpcConfig):
167+
"""Networking.
264168
265-
The NetworkingConifg is a subclass of `sagemaker_core.shapes.VpcConfig ` and
169+
The Networking class is a subclass of `sagemaker_core.shapes.VpcConfig ` and
266170
allows the user to specify the networking configuration for the training job.
267171
268172
Attributes:
@@ -290,7 +194,7 @@ class NetworkingConfig(shapes.VpcConfig):
290194
enable_inter_container_traffic_encryption: Optional[bool] = None
291195

292196
@model_validator(mode="after")
293-
def _model_validator(self) -> "NetworkingConfig":
197+
def _model_validator(self) -> "Networking":
294198
"""Convert Unassigned values to None."""
295199
return convert_unassigned_to_none(self)
296200

‎src/sagemaker/modules/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
os.path.dirname(os.path.abspath(__file__)), "train/container_drivers"
2626
)
2727

28-
SOURCE_CODE_CONFIG_JSON = "sourcecodeconfig.json"
29-
DISTRIBUTION_JSON = "distribution.json"
28+
SOURCE_CODE_JSON = "sourcecode.json"
29+
DISTRIBUTED_RUNNER_JSON = "distributed_runner.json"
3030
TRAIN_SCRIPT = "sm_train.sh"
3131

3232
DEFAULT_CONTAINER_ENTRYPOINT = ["/bin/bash"]

‎src/sagemaker/modules/distributed.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Distributed module."""
14+
from __future__ import absolute_import
15+
16+
from typing import Optional, Dict, Any, List
17+
from pydantic import BaseModel, PrivateAttr
18+
19+
20+
class DistributedRunner(BaseModel):
21+
"""Base class for DistributedRunner Class"""
22+
23+
_type: str = PrivateAttr()
24+
25+
def model_dump(self, *args, **kwargs):
26+
"""Dump the model to a dictionary."""
27+
result = super().model_dump(*args, **kwargs)
28+
result["_type"] = self._type
29+
return result
30+
31+
32+
class Torchrun(DistributedRunner):
33+
"""TorchDistribution.
34+
35+
The TorchDistribution runner uses `torchrun` or `torch.distributed.launch` in the backend to
36+
launch distributed training.
37+
38+
Attributes:
39+
process_count_per_node (int):
40+
The number of processes to run on each node in the training job.
41+
Will default to the number of GPUs available in the container.
42+
"""
43+
44+
_type: str = PrivateAttr(default="torchrun")
45+
46+
process_count_per_node: Optional[int] = None
47+
48+
49+
class TorchrunSMP(DistributedRunner):
50+
"""TorchrunSMP.
51+
52+
The TorchrunSMP runner uses `torchrun` or `torch.distributed.launch` in the backend
53+
to launch distributed training. This strategy is used for a PyTorch job using the SageMaker
54+
Model Parallelism library v2. For more information on the model parallelism parameters, see:
55+
https://docs.aws.amazon.com/sagemaker/latest/dg/distributed-model-parallel-v2-reference.html#distributed-model-parallel-v2-reference-init-config
56+
57+
Attributes:
58+
process_count_per_node (int):
59+
The number of processes to run on each node in the training job.
60+
Will default to the number of GPUs available in the container.
61+
hybrid_shard_degree (Optional[int]):
62+
Specifies a sharded parallelism degree for the model.
63+
sm_activation_offloading (Optional[bool]):
64+
Specifies whether to enable the SMP activation offloading implementation.
65+
activation_loading_horizon (Optional[int]):
66+
An integer specifying the activation offloading horizon type for FSDP. This is the
67+
maximum number of checkpointed or offloaded layers whose inputs can be in the GPU
68+
memory simultaneously.
69+
fsdp_cache_flush_warnings (Optional[bool]):
70+
Detects and warns if cache flushes happen in the PyTorch memory manager, because they
71+
can degrade computational performance.
72+
allow_empty_shards (Optional[bool]):
73+
Whether to allow empty shards when sharding tensors if tensor is not divisible. This is
74+
an experimental fix for crash during checkpointing in certain scenarios. Disabling this
75+
falls back to the original PyTorch behavior.
76+
tensor_parallel_degree (Optional[int]):
77+
Specifies a tensor parallelism degree. The value must be between 1 and world_size.
78+
context_parallel_degree (Optional[int]):
79+
Specifies the context parallelism degree. The value must be between 1 and world_size ,
80+
and must be <= hybrid_shard_degree.
81+
expert_parallel_degree (Optional[int]):
82+
Specifies a expert parallelism degree. The value must be between 1 and world_size.
83+
random_seed (Optional[int]):
84+
A seed number for the random operations in distributed modules by SMP tensor
85+
parallelism or expert parallelism.
86+
"""
87+
88+
_type: str = PrivateAttr(default="torchrun")
89+
90+
process_count_per_node: Optional[int] = None
91+
hybrid_shard_degree: Optional[int] = None
92+
sm_activation_offloading: Optional[bool] = None
93+
activation_loading_horizon: Optional[int] = None
94+
fsdp_cache_flush_warnings: Optional[bool] = None
95+
allow_empty_shards: Optional[bool] = None
96+
tensor_parallel_degree: Optional[int] = None
97+
context_parallel_degree: Optional[int] = None
98+
expert_parallel_degree: Optional[int] = None
99+
random_seed: Optional[int] = None
100+
101+
def _to_mp_parameters_dict(self) -> Dict[str, Any]:
102+
"""Convert to a dictionary of MP parameters."""
103+
mp_parameters = self.model_dump(exclude_none=True)
104+
mp_parameters.pop("_type")
105+
if mp_parameters.get("process_count_per_node") is not None:
106+
mp_parameters.pop("process_count_per_node")
107+
return mp_parameters
108+
109+
110+
class MPI(DistributedRunner):
111+
"""MPI.
112+
113+
The MPI runner uses `mpirun` in the backend to launch distributed training.
114+
115+
Attributes:
116+
process_count_per_node (int):
117+
The number of processes to run on each node in the training job.
118+
Will default to the number of GPUs available in the container.
119+
mpi_additional_options (Optional[str]):
120+
The custom MPI options to use for the training job.
121+
"""
122+
123+
_type: str = PrivateAttr(default="mpi")
124+
125+
process_count_per_node: Optional[int] = None
126+
mpi_additional_options: Optional[List[str]] = None

‎src/sagemaker/modules/templates.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
eval $CMD
2020
"""
2121

22-
EXECUTE_PYTORCH_DRIVER = """
23-
echo "Running PyTorch training driver"
24-
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/pytorch_driver.py
22+
EXEUCTE_TORCHRUN_DRIVER = """
23+
echo "Running Torchrun driver"
24+
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/torchrun_driver.py
2525
"""
2626

2727
EXECUTE_MPI_DRIVER = """
28-
echo "Running MPI training driver"
28+
echo "Running MPI driver"
2929
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/mpi_driver.py
3030
"""
3131

@@ -73,12 +73,12 @@
7373
cat /opt/ml/input/config/inputdataconfig.json
7474
echo
7575
76-
echo "/opt/ml/input/data/sm_drivers/sourcecodeconfig.json"
77-
cat /opt/ml/input/data/sm_drivers/sourcecodeconfig.json
76+
echo "/opt/ml/input/data/sm_drivers/sourcecode.json"
77+
cat /opt/ml/input/data/sm_drivers/sourcecode.json
7878
echo
7979
80-
echo "/opt/ml/input/data/sm_drivers/distribution.json"
81-
cat /opt/ml/input/data/sm_drivers/distribution.json
80+
echo "/opt/ml/input/data/sm_drivers/distributed_runner.json"
81+
cat /opt/ml/input/data/sm_drivers/distributed_runner.json
8282
echo
8383
8484
echo "Setting up environment variables"

‎src/sagemaker/modules/testing_notebooks/base_model_trainer.ipynb

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@
2727
"outputs": [],
2828
"source": [
2929
"from sagemaker.modules.train import ModelTrainer\n",
30-
"from sagemaker.modules.configs import SourceCodeConfig\n",
30+
"from sagemaker.modules.configs import SourceCode\n",
3131
"\n",
3232
"pytorch_image = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310\"\n",
3333
"\n",
34-
"source_code_config = SourceCodeConfig(\n",
34+
"source_code = SourceCode(\n",
3535
" command=\"echo 'Hello World' && env\",\n",
3636
")\n",
3737
"model_trainer = ModelTrainer(\n",
3838
" training_image=pytorch_image,\n",
39-
" source_code_config=source_code_config,\n",
39+
" source_code=source_code,\n",
4040
")"
4141
]
4242
},
@@ -70,11 +70,11 @@
7070
"outputs": [],
7171
"source": [
7272
"from sagemaker.modules.train import ModelTrainer\n",
73-
"from sagemaker.modules.configs import SourceCodeConfig\n",
73+
"from sagemaker.modules.configs import SourceCode\n",
7474
"\n",
7575
"pytorch_image = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310\"\n",
7676
"\n",
77-
"source_code_config = SourceCodeConfig(\n",
77+
"source_code = SourceCode(\n",
7878
" source_dir=\"basic-script-mode\",\n",
7979
" command=\"python custom_script.py\",\n",
8080
")\n",
@@ -89,7 +89,7 @@
8989
"\n",
9090
"model_trainer = ModelTrainer(\n",
9191
" training_image=pytorch_image,\n",
92-
" source_code_config=source_code_config,\n",
92+
" source_code=source_code,\n",
9393
" hyperparameters=hyperparameters,\n",
9494
" environment=env_vars,\n",
9595
")\n",
@@ -117,17 +117,17 @@
117117
"metadata": {},
118118
"outputs": [],
119119
"source": [
120-
"from sagemaker.modules.configs import SourceCodeConfig\n",
120+
"from sagemaker.modules.configs import SourceCode\n",
121121
"\n",
122-
"source_code_config = SourceCodeConfig(\n",
122+
"source_code = SourceCode(\n",
123123
" source_dir=\"basic-script-mode\",\n",
124124
" requirements=\"requirements.txt\",\n",
125125
" entry_script=\"custom_script.py\",\n",
126126
")\n",
127127
"\n",
128128
"model_trainer = ModelTrainer(\n",
129129
" training_image=pytorch_image,\n",
130-
" source_code_config=source_code_config,\n",
130+
" source_code=source_code,\n",
131131
")"
132132
]
133133
},
@@ -296,7 +296,7 @@
296296
"outputs": [],
297297
"source": [
298298
"from sagemaker.modules.train import ModelTrainer\n",
299-
"from sagemaker.modules.configs import ComputeConfig, SourceCodeConfig, InputData\n",
299+
"from sagemaker.modules.configs import Compute, SourceCode, InputData\n",
300300
"\n",
301301
"env = {}\n",
302302
"env[\"FI_PROVIDER\"] = \"efa\"\n",
@@ -307,10 +307,11 @@
307307
"env[\"FI_EFA_USE_DEVICE_RDMA\"] = \"1\"\n",
308308
"env[\"RDMAV_FORK_SAFE\"] = \"1\"\n",
309309
"\n",
310-
"compute_config = ComputeConfig(\n",
310+
"compute = Compute(\n",
311311
" instance_count=2,\n",
312312
" instance_type=\"ml.p4d.24xlarge\",\n",
313313
" volume_size_in_gb=96,\n",
314+
" keep_alive_period_in_seconds=3600\n",
314315
")\n",
315316
"\n",
316317
"hugging_face_image = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04\"\n",
@@ -335,7 +336,7 @@
335336
"metadata": {},
336337
"outputs": [],
337338
"source": [
338-
"source_code_config = SourceCodeConfig(\n",
339+
"source_code = SourceCode(\n",
339340
" source_dir=\"distributed-training/scripts\",\n",
340341
" requirements=\"requirements.txt\",\n",
341342
" command=\"torchrun --nnodes 2 \\\n",
@@ -348,10 +349,10 @@
348349
"\n",
349350
"model_trainer = ModelTrainer(\n",
350351
" training_image=hugging_face_image,\n",
351-
" compute_config=compute_config,\n",
352+
" compute=compute,\n",
352353
" environment=env,\n",
353354
" hyperparameters=hyperparameters,\n",
354-
" source_code_config=source_code_config,\n",
355+
" source_code=source_code,\n",
355356
")"
356357
]
357358
},
@@ -365,7 +366,7 @@
365366
" channel_name=\"dataset\",\n",
366367
" data_source=training_input_path,\n",
367368
")\n",
368-
"model_trainer.train(input_data_config=[test_data])"
369+
"model_trainer.train(input_data_config=[test_data], wait=False)"
369370
]
370371
},
371372
{
@@ -383,13 +384,18 @@
383384
"source": [
384385
"from sagemaker.modules.train import ModelTrainer\n",
385386
"from sagemaker.modules.configs import (\n",
386-
" ComputeConfig, SourceCodeConfig, TorchDistributionConfig, InputData\n",
387+
" Compute, SourceCode, InputData\n",
388+
")\n",
389+
"from sagemaker.modules.distributed import (\n",
390+
" Torchrun,\n",
391+
" MPI\n",
387392
")\n",
388393
"\n",
389-
"compute_config = ComputeConfig(\n",
394+
"compute = Compute(\n",
390395
" instance_count=2,\n",
391396
" instance_type=\"ml.p4d.24xlarge\",\n",
392397
" volume_size_in_gb=96,\n",
398+
" keep_alive_period_in_seconds=3600\n",
393399
")\n",
394400
"\n",
395401
"hugging_face_image = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04\"\n",
@@ -414,18 +420,31 @@
414420
"metadata": {},
415421
"outputs": [],
416422
"source": [
417-
"source_code_config = SourceCodeConfig(\n",
423+
"source_code = SourceCode(\n",
418424
" source_dir=\"distributed-training/scripts\",\n",
419425
" requirements=\"requirements.txt\",\n",
420426
" entry_script=\"run_clm_no_trainer.py\",\n",
421427
")\n",
422428
"\n",
429+
"# Run using Torchrun\n",
430+
"torchrun = Torchrun()\n",
431+
"\n",
432+
"# Run using MPI\n",
433+
"mpi = MPI(\n",
434+
" mpi_additional_options=[\n",
435+
" \"-x\",\n",
436+
" \"MASTER_ADDR=algo-1\",\n",
437+
" \"-x\",\n",
438+
" \"MASTER_PORT=7777\",\n",
439+
" ]\n",
440+
")\n",
441+
"\n",
423442
"model_trainer = ModelTrainer(\n",
424443
" training_image=hugging_face_image,\n",
425-
" compute_config=compute_config,\n",
444+
" compute=compute,\n",
426445
" hyperparameters=hyperparameters,\n",
427-
" source_code_config=source_code_config,\n",
428-
" distribution_config=TorchDistributionConfig(),\n",
446+
" source_code=source_code,\n",
447+
" distributed_runner=mpi,\n",
429448
")"
430449
]
431450
},
@@ -439,7 +458,7 @@
439458
" channel_name=\"dataset\",\n",
440459
" data_source=training_input_path,\n",
441460
")\n",
442-
"model_trainer.train(input_data_config=[test_data])"
461+
"model_trainer.train(input_data_config=[test_data], wait=False)"
443462
]
444463
},
445464
{

‎src/sagemaker/modules/train/container_drivers/mpi_driver.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
from __future__ import absolute_import
1515

1616
import os
17+
import sys
1718
import json
1819

1920
from utils import (
2021
logger,
21-
read_source_code_config_json,
22-
read_distribution_json,
22+
read_source_code_json,
23+
read_distributed_runner_json,
2324
get_process_count,
2425
execute_commands,
2526
write_failure_file,
@@ -55,9 +56,8 @@ def main():
5556
5. Exit
5657
5758
"""
58-
source_code_config = read_source_code_config_json()
59-
distribution = read_distribution_json()
60-
sm_distributed_settings = distribution.get("smdistributed_settings", {})
59+
source_code = read_source_code_json()
60+
distribution = read_distributed_runner_json()
6161

6262
sm_current_host = os.environ["SM_CURRENT_HOST"]
6363
sm_hosts = json.loads(os.environ["SM_HOSTS"])
@@ -83,18 +83,17 @@ def main():
8383
host_count=host_count,
8484
host_list=host_list,
8585
num_processes=process_count,
86-
smdataparallel_enabled=sm_distributed_settings.get("enable_dataparallel", False),
87-
smmodelparallel_enabled=sm_distributed_settings.get("enable_modelparallel", False),
8886
additional_options=distribution.get("mpi_additional_options", []),
89-
entry_script_path=os.path.join(USER_CODE_PATH, source_code_config["entry_script"]),
87+
entry_script_path=os.path.join(USER_CODE_PATH, source_code["entry_script"]),
9088
)
9189

9290
logger.info(f"Executing command: {mpi_command}")
9391
exit_code, error_traceback = execute_commands(mpi_command)
92+
write_status_file_to_workers(worker_hosts)
93+
9494
if exit_code != 0:
9595
write_failure_file(error_traceback)
96-
97-
write_status_file_to_workers(worker_hosts)
96+
sys.exit(exit_code)
9897

9998

10099
if __name__ == "__main__":

‎src/sagemaker/modules/train/container_drivers/mpi_utils.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import os
1717
import time
1818
import subprocess
19-
import json
2019

2120
from typing import List
2221

@@ -29,7 +28,7 @@
2928
def _write_status_file(host: str, status_file: str) -> bool:
3029
"""Write the status file to the provided host."""
3130
try:
32-
logger.info(f"Start writing mpirun finished status to {host}")
31+
logger.info("Writing finished status file (%s) to %s", status_file, host)
3332
subprocess.run(
3433
["ssh", host, "touch", f"{status_file}"],
3534
capture_output=True,
@@ -188,8 +187,6 @@ def get_mpirun_command(
188187
host_count: int,
189188
host_list: List[str],
190189
num_processes: int,
191-
smdataparallel_enabled: bool,
192-
smmodelparallel_enabled: bool,
193190
additional_options: List[str],
194191
entry_script_path: str,
195192
):
@@ -258,37 +255,6 @@ def get_mpirun_command(
258255
if credential in os.environ:
259256
mpirun_command.extend(["-x", credential])
260257

261-
if smdataparallel_enabled:
262-
if host_count == 1:
263-
smdataparallel_flag = "SMDATAPARALLEL_USE_HOMOGENEOUS=1"
264-
mpirun_command.extend(["-x", smdataparallel_flag])
265-
else:
266-
smdataparallel_flag = "SMDATAPARALLEL_USE_SINGLENODE=1"
267-
smdataparallel_server_port = 7592
268-
smdataparallel_server_addr = "algo-1"
269-
270-
mpirun_command.extend(["-x", smdataparallel_flag])
271-
mpirun_command.extend(
272-
[
273-
"-x",
274-
f"SMDATAPARALLEL_SERVER_ADDR={smdataparallel_server_addr}",
275-
"-x",
276-
f"SMDATAPARALLEL_SERVER_PORT={smdataparallel_server_port}",
277-
"-x",
278-
f"SAGEMAKER_INSTANCE_TYPE={instance_type}",
279-
]
280-
)
281-
282-
if validate_smddprun():
283-
mpirun_command.extend(["smddprun"])
284-
285-
if smmodelparallel_enabled:
286-
mp_parameters = json.loads(os.environ.get("SM_HP_MP_PARAMETERS", "{}"))
287-
ddp_dist_backend = mp_parameters.get("ddp_dist_backend", "auto")
288-
if ddp_dist_backend == "auto":
289-
if validate_smddpmprun():
290-
mpirun_command.extend(["smddpmprun", "-i", instance_type, "--allow-bypass"])
291-
292258
mpirun_command.extend([get_python_executable()])
293259
mpirun_command.extend(["-m", "mpi4py", entry_script_path])
294260
return mpirun_command

‎src/sagemaker/modules/train/container_drivers/pytorch_driver.py renamed to ‎src/sagemaker/modules/train/container_drivers/torchrun_driver.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,18 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""This module is the entry point for the PyTorch driver script."""
13+
"""This module is the entry point for the Torchrun driver script."""
1414
from __future__ import absolute_import
1515

1616
import os
17+
import sys
1718

1819
from typing import List, Tuple
1920

2021
from utils import (
2122
logger,
22-
read_source_code_config_json,
23-
read_distribution_json,
23+
read_source_code_json,
24+
read_distributed_runner_json,
2425
get_process_count,
2526
get_python_executable,
2627
SM_EFA_NCCL_INSTANCES,
@@ -62,8 +63,8 @@ def setup_env():
6263

6364
def create_commands():
6465
"""Create the Torch Distributed command to execute"""
65-
source_code_config = read_source_code_config_json()
66-
distribution = read_distribution_json()
66+
source_code = read_source_code_json()
67+
distribution = read_distributed_runner_json()
6768

6869
process_count = get_process_count(distribution)
6970
host_count = int(os.environ["SM_HOST_COUNT"])
@@ -90,7 +91,7 @@ def create_commands():
9091
]
9192
)
9293

93-
torch_cmd.extend([os.path.join(USER_CODE_PATH, source_code_config["entry_script"])])
94+
torch_cmd.extend([os.path.join(USER_CODE_PATH, source_code["entry_script"])])
9495
return torch_cmd
9596

9697

@@ -113,6 +114,7 @@ def main():
113114
exit_code, traceback = execute_commands(torch_cmd)
114115
if exit_code != 0:
115116
write_failure_file(traceback)
117+
sys.exit(exit_code)
116118

117119

118120
if __name__ == "__main__":

‎src/sagemaker/modules/train/container_drivers/utils.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@
3737
"""
3838

3939
USER_CODE_PATH = "/opt/ml/input/data/sm_code"
40-
SOURCE_CODE_CONFIG_JSON = "/opt/ml/input/data/sm_drivers/sourcecodeconfig.json"
41-
DISTRIBUTION_JSON = "/opt/ml/input/data/sm_drivers/distribution.json"
40+
SOURCE_CODE_JSON = "/opt/ml/input/data/sm_drivers/sourcecode.json"
41+
DISTRIBUTED_RUNNER_JSON = "/opt/ml/input/data/sm_drivers/distributed_runner.json"
42+
4243

4344
SM_EFA_NCCL_INSTANCES = [
4445
"ml.g4dn.8xlarge",
@@ -65,24 +66,30 @@ def write_failure_file(message: str = DEFAULT_FAILURE_MESSAGE):
6566
f.write(message)
6667

6768

68-
def read_source_code_config_json(source_code_config_file: Dict[str, Any] = SOURCE_CODE_CONFIG_JSON):
69+
def read_source_code_json(source_code_json: Dict[str, Any] = SOURCE_CODE_JSON):
6970
"""Read the source code config json file."""
70-
with open(source_code_config_file, "r") as f:
71-
source_code_config_json = json.load(f)
72-
return source_code_config_json
71+
try:
72+
with open(source_code_json, "r") as f:
73+
source_code_dict = json.load(f) or {}
74+
except FileNotFoundError:
75+
source_code_dict = {}
76+
return source_code_dict
7377

7478

75-
def read_distribution_json(distribution_file: Dict[str, Any] = DISTRIBUTION_JSON):
76-
"""Read the distribution json file."""
77-
with open(distribution_file, "r") as f:
78-
distribution_json = json.load(f)
79-
return distribution_json
79+
def read_distributed_runner_json(distributed_json: Dict[str, Any] = DISTRIBUTED_RUNNER_JSON):
80+
"""Read the distribution config json file."""
81+
try:
82+
with open(distributed_json, "r") as f:
83+
distributed_runner_dict = json.load(f) or {}
84+
except FileNotFoundError:
85+
distributed_runner_dict = {}
86+
return distributed_runner_dict
8087

8188

82-
def get_process_count(distribution: Dict[str, Any]) -> int:
89+
def get_process_count(distributed_runner_dict: Dict[str, Any]) -> int:
8390
"""Get the number of processes to run on each node in the training job."""
8491
return (
85-
int(distribution.get("process_count_per_node", 0))
92+
int(distributed_runner_dict.get("process_count_per_node", 0))
8693
or int(os.environ.get("SM_NUM_GPUS", 0))
8794
or int(os.environ.get("SM_NUM_NEURONS", 0))
8895
or 1

‎src/sagemaker/modules/train/model_trainer.py

Lines changed: 99 additions & 130 deletions
Large diffs are not rendered by default.
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# flake8: noqa
2+
import argparse
3+
import numpy as np
4+
import os
5+
import sys
6+
import logging
7+
import json
8+
import shutil
9+
import torch
10+
import torch.nn as nn
11+
from torch.utils.data import DataLoader, TensorDataset
12+
from pytorch_model_def import get_model
13+
14+
logger = logging.getLogger(__name__)
15+
logger.setLevel(logging.DEBUG)
16+
logger.addHandler(logging.StreamHandler(sys.stdout))
17+
current_dir = os.path.dirname(os.path.abspath(__file__))
18+
19+
20+
def get_train_data(train_dir):
21+
"""
22+
Get the training data and convert to tensors
23+
"""
24+
25+
x_train = np.load(os.path.join(train_dir, "x_train.npy"))
26+
y_train = np.load(os.path.join(train_dir, "y_train.npy"))
27+
logger.info(f"x train: {x_train.shape}, y train: {y_train.shape}")
28+
29+
return torch.from_numpy(x_train), torch.from_numpy(y_train)
30+
31+
32+
def get_test_data(test_dir):
33+
"""
34+
Get the testing data and convert to tensors
35+
"""
36+
37+
x_test = np.load(os.path.join(test_dir, "x_test.npy"))
38+
y_test = np.load(os.path.join(test_dir, "y_test.npy"))
39+
logger.info(f"x test: {x_test.shape}, y test: {y_test.shape}")
40+
41+
return torch.from_numpy(x_test), torch.from_numpy(y_test)
42+
43+
44+
def model_fn(model_dir):
45+
"""
46+
Load the model for inference
47+
"""
48+
49+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50+
model = get_model()
51+
model.load_state_dict(torch.load(model_dir + "/model.pth"))
52+
model.eval()
53+
return model.to(device)
54+
55+
56+
def input_fn(request_body, request_content_type):
57+
"""
58+
Deserialize and prepare the prediction input
59+
"""
60+
61+
if request_content_type == "application/json":
62+
request = json.loads(request_body)
63+
train_inputs = torch.tensor(request)
64+
return train_inputs
65+
66+
67+
def predict_fn(input_data, model):
68+
"""
69+
Apply model to the incoming request
70+
"""
71+
72+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73+
model.to(device)
74+
model.eval()
75+
with torch.no_grad():
76+
return model(input_data.float()).numpy()[0]
77+
78+
79+
def train():
80+
"""
81+
Train the PyTorch model
82+
"""
83+
# Directories: train, test and model
84+
train_dir = os.path.join(current_dir, "data/train")
85+
test_dir = os.path.join(current_dir, "data/test")
86+
model_dir = os.environ.get("SM_MODEL_DIR", os.path.join(current_dir, "data/model"))
87+
88+
# Load the training and testing data
89+
x_train, y_train = get_train_data(train_dir)
90+
x_test, y_test = get_test_data(test_dir)
91+
train_ds = TensorDataset(x_train, y_train)
92+
93+
# Training parameters - used to configure the training loop
94+
batch_size = 64
95+
epochs = 1
96+
learning_rate = 0.1
97+
logger.info(
98+
"batch_size = {}, epochs = {}, learning rate = {}".format(batch_size, epochs, learning_rate)
99+
)
100+
101+
train_dl = DataLoader(train_ds, batch_size, shuffle=True)
102+
103+
# Define the model, loss function and optimizer
104+
model = get_model()
105+
model = model.to(device)
106+
criterion = nn.MSELoss()
107+
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
108+
109+
# Train the model
110+
for epoch in range(epochs):
111+
for x_train_batch, y_train_batch in train_dl:
112+
y = model(x_train_batch.float())
113+
loss = criterion(y.flatten(), y_train_batch.float())
114+
optimizer.zero_grad()
115+
loss.backward()
116+
optimizer.step()
117+
epoch += 1
118+
logger.info(f"epoch: {epoch} -> loss: {loss}")
119+
120+
# Test the model
121+
with torch.no_grad():
122+
y = model(x_test.float()).flatten()
123+
mse = ((y - y_test) ** 2).sum() / y_test.shape[0]
124+
print("\nTest MSE:", mse.numpy())
125+
126+
# Save the model
127+
os.makedirs(model_dir, exist_ok=True)
128+
torch.save(model.state_dict(), model_dir + "/model.pth")
129+
inference_code_path = model_dir + "/code/"
130+
131+
if not os.path.exists(inference_code_path):
132+
os.mkdir(inference_code_path)
133+
logger.info("Created a folder at {}!".format(inference_code_path))
134+
135+
code_dir = os.environ.get("SM_CHANNEL_CODE", current_dir)
136+
shutil.copy(os.path.join(code_dir, "custom_script.py"), inference_code_path)
137+
shutil.copy(os.path.join(code_dir, "pytorch_model_def.py"), inference_code_path)
138+
logger.info("Saving models files to {}".format(inference_code_path))
139+
140+
141+
if __name__ == "__main__":
142+
print("Running the training job ...\n")
143+
144+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
145+
146+
train()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# flake8: noqa
2+
import torch
3+
import torch.nn as nn
4+
5+
6+
class NeuralNet(nn.Module):
7+
def __init__(self):
8+
super().__init__()
9+
self.fc1 = nn.Linear(8, 8)
10+
self.fc2 = nn.Linear(8, 6)
11+
self.fc3 = nn.Linear(6, 1)
12+
13+
def forward(self, x):
14+
x = torch.tanh(self.fc1(x))
15+
x = torch.sigmoid(self.fc2(x))
16+
x = self.fc3(x)
17+
return x
18+
19+
20+
def get_model():
21+
22+
model = NeuralNet()
23+
return model
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
numpy
2+
-f https://download.pytorch.org/whl/torch_stable.html
3+
torch==2.0.1+cpu

‎tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,28 @@
3030
"algo-1,algo-2",
3131
"-np",
3232
"2",
33+
"--verbose",
34+
"-x",
35+
"ENV_VAR1",
3336
"python",
3437
"-m",
3538
"mpi4py",
3639
"-m",
3740
"script.py",
3841
]
3942

40-
DUMMY_SOURCE_CODE_CONFIG = {
43+
DUMMY_SOURCE_CODE = {
44+
"source_code": "source_code",
4145
"entry_script": "script.py",
42-
"distribution": {
43-
"process_count_per_node": 2,
44-
"sm_distributed_settings": {
45-
"enable_dataparallel": True,
46-
},
47-
"mpi_additional_options": [
48-
"-x",
49-
"AWS_REGION",
50-
],
51-
},
46+
}
47+
DUMMY_DISTRIBUTED_RUNNER = {
48+
"_type": "mpi",
49+
"process_count_per_node": 2,
50+
"mpi_additional_options": [
51+
"--verbose",
52+
"-x",
53+
"ENV_VAR1",
54+
],
5255
}
5356

5457

@@ -61,7 +64,8 @@
6164
"SM_HOST_COUNT": "2",
6265
},
6366
)
64-
@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_config_json")
67+
@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_runner_json")
68+
@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_json")
6569
@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_env_vars_to_file")
6670
@patch("sagemaker.modules.train.container_drivers.mpi_driver.start_sshd_daemon")
6771
@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_master_node")
@@ -75,9 +79,11 @@ def test_mpi_driver_worker(
7579
mock_bootstrap_master_node,
7680
mock_start_sshd_daemon,
7781
mock_write_env_vars_to_file,
78-
mock_read_source_code_config_json,
82+
mock_read_source_code_json,
83+
mock_read_distributed_runner_json,
7984
):
80-
mock_read_source_code_config_json.return_value = DUMMY_SOURCE_CODE_CONFIG
85+
mock_read_source_code_json.return_value = DUMMY_SOURCE_CODE
86+
mock_read_distributed_runner_json.return_value = DUMMY_DISTRIBUTED_RUNNER
8187

8288
mpi_driver.main()
8389

@@ -99,7 +105,8 @@ def test_mpi_driver_worker(
99105
"SM_HOST_COUNT": "2",
100106
},
101107
)
102-
@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_config_json")
108+
@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_runner_json")
109+
@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_json")
103110
@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_env_vars_to_file")
104111
@patch("sagemaker.modules.train.container_drivers.mpi_driver.start_sshd_daemon")
105112
@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_master_node")
@@ -118,8 +125,10 @@ def test_mpi_driver_master(
118125
mock_start_sshd_daemon,
119126
mock_write_env_vars_to_file,
120127
mock_read_source_code_config_json,
128+
mock_read_distributed_runner_json,
121129
):
122-
mock_read_source_code_config_json.return_value = DUMMY_SOURCE_CODE_CONFIG
130+
mock_read_source_code_config_json.return_value = DUMMY_SOURCE_CODE
131+
mock_read_distributed_runner_json.return_value = DUMMY_DISTRIBUTED_RUNNER
123132
mock_get_mpirun_command.return_value = DUMMY_MPI_COMMAND
124133
mock_get_process_count.return_value = 2
125134
mock_execute_commands.return_value = (0, "")

‎tests/unit/sagemaker/modules/train/container_drivers/test_pytorch_driver.py renamed to ‎tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Pytorch Driver Unit Tests."""
13+
"""Torchrun Driver Unit Tests."""
1414
from __future__ import absolute_import
1515

1616
import os
@@ -20,45 +20,38 @@
2020

2121
sys.modules["utils"] = MagicMock()
2222

23-
from sagemaker.modules.train.container_drivers import pytorch_driver # noqa: E402
23+
from sagemaker.modules.train.container_drivers import torchrun_driver # noqa: E402
2424

25-
DUMMY_SOURCE_CODE_CONFIG = {
25+
DUMMY_SOURCE_CODE = {
26+
"source_code": "source_code",
2627
"entry_script": "script.py",
27-
"distribution": {
28-
"process_count_per_node": 2,
29-
"sm_distributed_settings": {
30-
"enable_dataparallel": True,
31-
},
32-
"mpi_additional_options": [
33-
"-x",
34-
"AWS_REGION",
35-
],
36-
},
3728
}
3829

30+
DUMMY_DISTRIBUTED_RUNNER = {"_type": "torchrun", "process_count_per_node": 2}
31+
3932

4033
@patch(
41-
"sagemaker.modules.train.container_drivers.pytorch_driver.get_python_executable",
34+
"sagemaker.modules.train.container_drivers.torchrun_driver.get_python_executable",
4235
return_value="python3",
4336
)
4437
@patch(
45-
"sagemaker.modules.train.container_drivers.pytorch_driver.pytorch_version", return_value=(2, 0)
38+
"sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0)
4639
)
4740
def test_get_base_pytorch_command_torchrun(mock_pytorch_version, mock_get_python_executable):
48-
assert pytorch_driver.get_base_pytorch_command() == ["torchrun"]
41+
assert torchrun_driver.get_base_pytorch_command() == ["torchrun"]
4942

5043

5144
@patch(
52-
"sagemaker.modules.train.container_drivers.pytorch_driver.get_python_executable",
45+
"sagemaker.modules.train.container_drivers.torchrun_driver.get_python_executable",
5346
return_value="python3",
5447
)
5548
@patch(
56-
"sagemaker.modules.train.container_drivers.pytorch_driver.pytorch_version", return_value=(1, 8)
49+
"sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(1, 8)
5750
)
5851
def test_get_base_pytorch_command_torch_distributed_launch(
5952
mock_pytorch_version, mock_get_python_executable
6053
):
61-
assert pytorch_driver.get_base_pytorch_command() == (
54+
assert torchrun_driver.get_base_pytorch_command() == (
6255
["python3", "-m", "torch.distributed.launch"]
6356
)
6457

@@ -72,23 +65,30 @@ def test_get_base_pytorch_command_torch_distributed_launch(
7265
},
7366
)
7467
@patch(
75-
"sagemaker.modules.train.container_drivers.pytorch_driver.USER_CODE_PATH",
68+
"sagemaker.modules.train.container_drivers.torchrun_driver.USER_CODE_PATH",
7669
"/opt/ml/input/data/code",
7770
)
78-
@patch("sagemaker.modules.train.container_drivers.pytorch_driver.get_process_count", return_value=2)
7971
@patch(
80-
"sagemaker.modules.train.container_drivers.pytorch_driver.pytorch_version", return_value=(2, 0)
72+
"sagemaker.modules.train.container_drivers.torchrun_driver.get_process_count", return_value=2
8173
)
8274
@patch(
83-
"sagemaker.modules.train.container_drivers.pytorch_driver.get_base_pytorch_command",
75+
"sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0)
76+
)
77+
@patch(
78+
"sagemaker.modules.train.container_drivers.torchrun_driver.get_base_pytorch_command",
8479
return_value=["torchrun"],
8580
)
8681
@patch(
87-
"sagemaker.modules.train.container_drivers.pytorch_driver.read_source_code_config_json",
88-
return_value=DUMMY_SOURCE_CODE_CONFIG,
82+
"sagemaker.modules.train.container_drivers.torchrun_driver.read_source_code_json",
83+
return_value=DUMMY_SOURCE_CODE,
84+
)
85+
@patch(
86+
"sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_runner_json",
87+
return_value=DUMMY_DISTRIBUTED_RUNNER,
8988
)
9089
def test_create_commands_single_node(
91-
mock_read_source_code_config_json,
90+
mock_read_distributed_runner_json,
91+
mock_read_source_code_json,
9292
mock_get_base_pytorch_command,
9393
mock_pytorch_version,
9494
mock_get_process_count,
@@ -100,7 +100,7 @@ def test_create_commands_single_node(
100100
"/opt/ml/input/data/code/script.py",
101101
]
102102

103-
command = pytorch_driver.create_commands()
103+
command = torchrun_driver.create_commands()
104104
assert command == expected_command
105105

106106

@@ -116,23 +116,30 @@ def test_create_commands_single_node(
116116
},
117117
)
118118
@patch(
119-
"sagemaker.modules.train.container_drivers.pytorch_driver.USER_CODE_PATH",
119+
"sagemaker.modules.train.container_drivers.torchrun_driver.USER_CODE_PATH",
120120
"/opt/ml/input/data/code",
121121
)
122-
@patch("sagemaker.modules.train.container_drivers.pytorch_driver.get_process_count", return_value=2)
123122
@patch(
124-
"sagemaker.modules.train.container_drivers.pytorch_driver.pytorch_version", return_value=(2, 0)
123+
"sagemaker.modules.train.container_drivers.torchrun_driver.get_process_count", return_value=2
124+
)
125+
@patch(
126+
"sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0)
125127
)
126128
@patch(
127-
"sagemaker.modules.train.container_drivers.pytorch_driver.get_base_pytorch_command",
129+
"sagemaker.modules.train.container_drivers.torchrun_driver.get_base_pytorch_command",
128130
return_value=["torchrun"],
129131
)
130132
@patch(
131-
"sagemaker.modules.train.container_drivers.pytorch_driver.read_source_code_config_json",
132-
return_value=DUMMY_SOURCE_CODE_CONFIG,
133+
"sagemaker.modules.train.container_drivers.torchrun_driver.read_source_code_json",
134+
return_value=DUMMY_SOURCE_CODE,
135+
)
136+
@patch(
137+
"sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_runner_json",
138+
return_value=DUMMY_DISTRIBUTED_RUNNER,
133139
)
134140
def test_create_commands_multi_node(
135-
mock_read_source_code_config_json,
141+
mock_read_distributed_runner_json,
142+
mock_read_source_code_json,
136143
mock_get_base_pytorch_command,
137144
mock_pytorch_version,
138145
mock_get_process_count,
@@ -147,5 +154,5 @@ def test_create_commands_multi_node(
147154
"/opt/ml/input/data/code/script.py",
148155
]
149156

150-
command = pytorch_driver.create_commands()
157+
command = torchrun_driver.create_commands()
151158
assert command == expected_command

‎tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 119 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,26 @@
1313
"""ModelTrainer Tests."""
1414
from __future__ import absolute_import
1515

16+
import json
17+
import os
1618
import pytest
1719
from unittest.mock import patch, MagicMock
1820

1921
from sagemaker.session import Session
2022
from sagemaker.modules.train.model_trainer import ModelTrainer
21-
from sagemaker.modules.constants import DEFAULT_INSTANCE_TYPE
23+
from sagemaker.modules.constants import (
24+
DEFAULT_INSTANCE_TYPE,
25+
SM_DRIVERS_LOCAL_PATH,
26+
DISTRIBUTED_RUNNER_JSON,
27+
SOURCE_CODE_JSON,
28+
TRAIN_SCRIPT,
29+
)
2230
from sagemaker.modules.configs import (
23-
ComputeConfig,
31+
Compute,
2432
StoppingCondition,
2533
RetryStrategy,
2634
OutputDataConfig,
27-
SourceCodeConfig,
35+
SourceCode,
2836
S3DataSource,
2937
FileSystemDataSource,
3038
MetricDefinition,
@@ -39,13 +47,15 @@
3947
SessionChainingConfig,
4048
InputData,
4149
)
50+
from sagemaker.modules.distributed import Torchrun, TorchrunSMP, MPI
51+
from sagemaker.modules.templates import EXEUCTE_TORCHRUN_DRIVER, EXECUTE_MPI_DRIVER
4252
from tests.unit import DATA_DIR
4353

4454
DEFAULT_BASE_NAME = "dummy-image-job"
4555
DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest"
4656
DEFAULT_BUCKET = "sagemaker-us-west-2-000000000000"
4757
DEFAULT_ROLE = "arn:aws:iam::000000000000:role/test-role"
48-
DEFAULT_COMPUTE_CONFIG = ComputeConfig(instance_type=DEFAULT_INSTANCE_TYPE, instance_count=1)
58+
DEFAULT_COMPUTE_CONFIG = Compute(instance_type=DEFAULT_INSTANCE_TYPE, instance_count=1)
4959
DEFAULT_OUTPUT_DATA_CONFIG = OutputDataConfig(
5060
s3_output_path=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}",
5161
compression_type="GZIP",
@@ -56,11 +66,11 @@
5666
max_pending_time_in_seconds=None,
5767
max_wait_time_in_seconds=None,
5868
)
59-
DEFAULT_SOURCE_CODE_CONFIG = SourceCodeConfig(
60-
source_dir="test-data",
61-
entry_point="train.py",
69+
DEFAULT_SOURCE_CODE = SourceCode(
70+
source_dir=f"{DATA_DIR}/modules/script_mode",
71+
entry_script="custom_script.py",
6272
)
63-
UNSUPPORTED_SOURCE_CODE_CONFIG = SourceCodeConfig(
73+
UNSUPPORTED_SOURCE_CODE = SourceCode(
6474
entry_script="train.py",
6575
)
6676

@@ -80,7 +90,7 @@ def model_trainer():
8090
trainer = ModelTrainer(
8191
training_image=DEFAULT_IMAGE,
8292
role=DEFAULT_ROLE,
83-
compute_config=DEFAULT_COMPUTE_CONFIG,
93+
compute=DEFAULT_COMPUTE_CONFIG,
8494
stopping_condition=DEFAULT_STOPPING_CONDITION,
8595
output_data_config=DEFAULT_OUTPUT_DATA_CONFIG,
8696
)
@@ -110,14 +120,14 @@ def model_trainer():
110120
{
111121
"init_params": {
112122
"training_image": DEFAULT_IMAGE,
113-
"source_code_config": UNSUPPORTED_SOURCE_CODE_CONFIG,
123+
"source_code": UNSUPPORTED_SOURCE_CODE,
114124
},
115125
"should_throw": True,
116126
},
117127
{
118128
"init_params": {
119129
"training_image": DEFAULT_IMAGE,
120-
"source_code_config": DEFAULT_SOURCE_CODE_CONFIG,
130+
"source_code": DEFAULT_SOURCE_CODE,
121131
},
122132
"should_throw": False,
123133
},
@@ -126,8 +136,8 @@ def model_trainer():
126136
"no_params",
127137
"training_image_and_algorithm_name",
128138
"only_training_image",
129-
"unsupported_source_code_config",
130-
"supported_source_code_config",
139+
"unsupported_source_code",
140+
"supported_source_code",
131141
],
132142
)
133143
def test_model_trainer_param_validation(test_case, modules_session):
@@ -138,7 +148,7 @@ def test_model_trainer_param_validation(test_case, modules_session):
138148
trainer = ModelTrainer(**test_case["init_params"], session=modules_session)
139149
assert trainer is not None
140150
assert trainer.training_image == DEFAULT_IMAGE
141-
assert trainer.compute_config == DEFAULT_COMPUTE_CONFIG
151+
assert trainer.compute == DEFAULT_COMPUTE_CONFIG
142152
assert trainer.output_data_config == DEFAULT_OUTPUT_DATA_CONFIG
143153
assert trainer.stopping_condition == DEFAULT_STOPPING_CONDITION
144154
assert trainer.base_job_name == DEFAULT_BASE_NAME
@@ -282,9 +292,6 @@ def test_debugger_settings(mock_training_job, modules_session):
282292
rule_evaluator_image=image_uri,
283293
rule_parameters={"parameter": "value"},
284294
)
285-
remote_debug_config = RemoteDebugConfig(
286-
enable_remote_debug=True,
287-
)
288295
profiler_config = ProfilerConfig(s3_output_path="s3://dummy-bucket/dummy-prefix")
289296
profiler_rule_config = ProfilerRuleConfiguration(
290297
rule_configuration_name="rule-name",
@@ -301,15 +308,13 @@ def test_debugger_settings(mock_training_job, modules_session):
301308
).with_debugger_settings(
302309
debug_hook_config=debug_hook_config,
303310
debug_rule_configurations=debug_rule_config,
304-
remote_debug_config=remote_debug_config,
305311
profiler_config=profiler_config,
306312
profiler_rule_configurations=profiler_rule_config,
307313
tensor_board_output_config=tensor_board_output_config,
308314
)
309315

310316
assert model_trainer._debug_hook_config == debug_hook_config
311317
assert model_trainer._debug_rule_configurations == debug_rule_config
312-
assert model_trainer._remote_debug_config == remote_debug_config
313318
assert model_trainer._profiler_config == profiler_config
314319
assert model_trainer._profiler_rule_configurations == profiler_rule_config
315320
assert model_trainer._tensor_board_output_config == tensor_board_output_config
@@ -324,9 +329,6 @@ def test_debugger_settings(mock_training_job, modules_session):
324329
mock_training_job.create.call_args.kwargs["debug_rule_configurations"]
325330
== debug_rule_config
326331
)
327-
assert (
328-
mock_training_job.create.call_args.kwargs["remote_debug_config"] == remote_debug_config
329-
)
330332
assert mock_training_job.create.call_args.kwargs["profiler_config"] == profiler_config
331333
assert (
332334
mock_training_job.create.call_args.kwargs["profiler_rule_configurations"]
@@ -346,7 +348,9 @@ def test_additional_settings(mock_training_job, modules_session):
346348
retry_strategy = RetryStrategy(
347349
maximum_retry_attempts=3,
348350
)
349-
351+
remote_debug_config = RemoteDebugConfig(
352+
enable_remote_debug=True,
353+
)
350354
experiment_config = ExperimentConfig(
351355
experiment_name="experiment-name",
352356
trial_name="trial-name",
@@ -364,6 +368,7 @@ def test_additional_settings(mock_training_job, modules_session):
364368
).with_additional_settings(
365369
retry_strategy=retry_strategy,
366370
experiment_config=experiment_config,
371+
remote_debug_config=remote_debug_config,
367372
infra_check_config=infra_check_config,
368373
session_chaining_config=session_chaining_config,
369374
)
@@ -372,6 +377,7 @@ def test_additional_settings(mock_training_job, modules_session):
372377
assert model_trainer._experiment_config == experiment_config
373378
assert model_trainer._infra_check_config == infra_check_config
374379
assert model_trainer._session_chaining_config == session_chaining_config
380+
assert model_trainer._remote_debug_config == remote_debug_config
375381

376382
with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data:
377383
mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix"
@@ -386,3 +392,93 @@ def test_additional_settings(mock_training_job, modules_session):
386392
mock_training_job.create.call_args.kwargs["session_chaining_config"]
387393
== session_chaining_config
388394
)
395+
assert (
396+
mock_training_job.create.call_args.kwargs["remote_debug_config"] == remote_debug_config
397+
)
398+
399+
400+
@pytest.mark.parametrize(
401+
"test_case",
402+
[
403+
{
404+
"source_code": DEFAULT_SOURCE_CODE,
405+
"distributed_runner": Torchrun(),
406+
"expected_template": EXEUCTE_TORCHRUN_DRIVER,
407+
"expected_hyperparameters": {},
408+
},
409+
{
410+
"source_code": DEFAULT_SOURCE_CODE,
411+
"distributed_runner": TorchrunSMP(
412+
hybrid_shard_degree=3,
413+
sm_activation_offloading=True,
414+
allow_empty_shards=True,
415+
tensor_parallel_degree=5,
416+
),
417+
"expected_template": EXEUCTE_TORCHRUN_DRIVER,
418+
"expected_hyperparameters": {
419+
"mp_parameters": json.dumps(
420+
{
421+
"hybrid_shard_degree": 3,
422+
"sm_activation_offloading": True,
423+
"allow_empty_shards": True,
424+
"tensor_parallel_degree": 5,
425+
}
426+
),
427+
},
428+
},
429+
{
430+
"source_code": DEFAULT_SOURCE_CODE,
431+
"distributed_runner": MPI(
432+
custom_mpi_options=["-x", "VAR1", "-x", "VAR2"],
433+
),
434+
"expected_template": EXECUTE_MPI_DRIVER,
435+
"expected_hyperparameters": {},
436+
},
437+
],
438+
ids=[
439+
"torchrun",
440+
"torchrun_smp",
441+
"mpi",
442+
],
443+
)
444+
@patch("sagemaker.modules.train.model_trainer.TrainingJob")
445+
def test_train_with_distributed_runner(mock_training_job, test_case, modules_session):
446+
modules_session.upload_data.return_value = (
447+
f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}-job/input/test"
448+
)
449+
450+
expected_train_script_path = f"{SM_DRIVERS_LOCAL_PATH}/{TRAIN_SCRIPT}"
451+
expected_runner_json_path = f"{SM_DRIVERS_LOCAL_PATH}/{DISTRIBUTED_RUNNER_JSON}"
452+
expected_source_code_json_path = f"{SM_DRIVERS_LOCAL_PATH}/{SOURCE_CODE_JSON}"
453+
454+
model_trainer = ModelTrainer(
455+
session=modules_session,
456+
training_image=DEFAULT_IMAGE,
457+
source_code=test_case["source_code"],
458+
distributed_runner=test_case["distributed_runner"],
459+
)
460+
461+
model_trainer.train()
462+
mock_training_job.create.assert_called_once()
463+
assert mock_training_job.create.call_args.kwargs["hyper_parameters"] == (
464+
test_case["expected_hyperparameters"]
465+
)
466+
467+
assert os.path.exists(expected_train_script_path)
468+
with open(expected_train_script_path, "r") as f:
469+
train_script_content = f.read()
470+
assert test_case["expected_template"] in train_script_content
471+
472+
assert os.path.exists(expected_runner_json_path)
473+
with open(expected_runner_json_path, "r") as f:
474+
runner_json_content = f.read()
475+
assert test_case["distributed_runner"].model_dump(exclude_none=True) == (
476+
json.loads(runner_json_content)
477+
)
478+
479+
assert os.path.exists(expected_source_code_json_path)
480+
with open(expected_source_code_json_path, "r") as f:
481+
source_code_json_content = f.read()
482+
assert test_case["source_code"].model_dump(exclude_none=True) == (
483+
json.loads(source_code_json_content)
484+
)

0 commit comments

Comments
 (0)
Please sign in to comment.