Skip to content

Commit 869b75f

Browse files
beniericpintaoz-aws
authored andcommitted
Fix bug in script mode setup ModelTrainer (#1575)
1 parent 18d3cda commit 869b75f

File tree

7 files changed

+43
-15
lines changed

7 files changed

+43
-15
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ env/
3434
**/_repack_script_launcher.sh
3535
src/sagemaker/modules/train/container_drivers/sm_train.sh
3636
src/sagemaker/modules/train/container_drivers/sourcecodeconfig.json
37+
src/sagemaker/modules/train/container_drivers/distribution.json
3738
tests/data/**/_repack_model.py
3839
tests/data/experiment/sagemaker-dev-1.0.tar.gz
3940
src/sagemaker/serve/tmp_workspace

src/sagemaker/modules/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727

2828
SOURCE_CODE_CONFIG_JSON = "sourcecodeconfig.json"
29+
DISTRIBUTION_JSON = "distribution.json"
2930
TRAIN_SCRIPT = "sm_train.sh"
3031

3132
DEFAULT_CONTAINER_ENTRYPOINT = ["/bin/bash"]

src/sagemaker/modules/templates.py

+4
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@
7777
cat /opt/ml/input/data/sm_drivers/sourcecodeconfig.json
7878
echo
7979
80+
echo "/opt/ml/input/data/sm_drivers/distribution.json"
81+
cat /opt/ml/input/data/sm_drivers/distribution.json
82+
echo
83+
8084
echo "Setting up environment variables"
8185
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/scripts/environment.py
8286
source /opt/ml/input/data/sm_drivers/scripts/sm_training.env

src/sagemaker/modules/train/container_drivers/mpi_driver.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from utils import (
2020
logger,
2121
read_source_code_config_json,
22+
read_distribution_json,
2223
get_process_count,
2324
execute_commands,
2425
write_failure_file,
@@ -55,7 +56,7 @@ def main():
5556
5657
"""
5758
source_code_config = read_source_code_config_json()
58-
distribution = source_code_config.get("distribution", {})
59+
distribution = read_distribution_json()
5960
sm_distributed_settings = distribution.get("smdistributed_settings", {})
6061

6162
sm_current_host = os.environ["SM_CURRENT_HOST"]
@@ -73,7 +74,7 @@ def main():
7374

7475
host_list = json.loads(os.environ["SM_HOSTS"])
7576
host_count = int(os.environ["SM_HOST_COUNT"])
76-
process_count = get_process_count(source_code_config)
77+
process_count = get_process_count(distribution)
7778

7879
if process_count > 1:
7980
host_list = ["{}:{}".format(host, process_count) for host in host_list]

src/sagemaker/modules/train/container_drivers/pytorch_driver.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from utils import (
2121
logger,
2222
read_source_code_config_json,
23+
read_distribution_json,
2324
get_process_count,
2425
get_python_executable,
2526
SM_EFA_NCCL_INSTANCES,
@@ -62,8 +63,9 @@ def setup_env():
6263
def create_commands():
6364
"""Create the Torch Distributed command to execute"""
6465
source_code_config = read_source_code_config_json()
66+
distribution = read_distribution_json()
6567

66-
process_count = get_process_count(source_code_config)
68+
process_count = get_process_count(distribution)
6769
host_count = int(os.environ["SM_HOST_COUNT"])
6870

6971
torch_cmd = []

src/sagemaker/modules/train/container_drivers/utils.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@
3636
TrainingJob - {os.environ['TRAINING_JOB_NAME']}
3737
"""
3838

39-
USER_CODE_PATH = "/opt/ml/input/data/code"
40-
SOURCE_CODE_CONFIG_JSON = "/opt/ml/input/data/sm_code/sourcecodeconfig.json"
39+
USER_CODE_PATH = "/opt/ml/input/data/sm_code"
40+
SOURCE_CODE_CONFIG_JSON = "/opt/ml/input/data/sm_drivers/sourcecodeconfig.json"
41+
DISTRIBUTION_JSON = "/opt/ml/input/data/sm_drivers/distribution.json"
4142

4243
SM_EFA_NCCL_INSTANCES = [
4344
"ml.g4dn.8xlarge",
@@ -67,19 +68,25 @@ def write_failure_file(message: str = DEFAULT_FAILURE_MESSAGE):
6768
def read_source_code_config_json(source_code_config_file: Dict[str, Any] = SOURCE_CODE_CONFIG_JSON):
6869
"""Read the source code config json file."""
6970
with open(source_code_config_file, "r") as f:
70-
distribution_config = json.load(f)
71-
return distribution_config
71+
source_code_config_json = json.load(f)
72+
return source_code_config_json
7273

7374

74-
def get_process_count(source_code_config: Dict[str, Any]) -> int:
75+
def read_distribution_json(distribution_file: Dict[str, Any] = DISTRIBUTION_JSON):
76+
"""Read the distribution json file."""
77+
with open(distribution_file, "r") as f:
78+
distribution_json = json.load(f)
79+
return distribution_json
80+
81+
82+
def get_process_count(distribution: Dict[str, Any]) -> int:
7583
"""Get the number of processes to run on each node in the training job."""
76-
if source_code_config.get("distribution", {}).get("process_count_per_node") is not None:
77-
return int(source_code_config["distribution"]["process_count_per_node"])
78-
if os.environ.get("SM_NUM_GPUS") is not None:
79-
return int(os.environ["SM_NUM_GPUS"])
80-
if os.environ.get("SM_NUM_NEURONS") is not None:
81-
return int(os.environ["SM_NUM_NEURONS"])
82-
return 1 # Default to 1 process per node
84+
return (
85+
int(distribution.get("process_count_per_node", 0))
86+
or int(os.environ.get("SM_NUM_GPUS", 0))
87+
or int(os.environ.get("SM_NUM_NEURONS", 0))
88+
or 1
89+
)
8390

8491

8592
def get_python_executable() -> str:

src/sagemaker/modules/train/model_trainer.py

+12
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
DEFAULT_CONTAINER_ENTRYPOINT,
7171
DEFAULT_CONTAINER_ARGUMENTS,
7272
SOURCE_CODE_CONFIG_JSON,
73+
DISTRIBUTION_JSON,
7374
)
7475
from sagemaker.modules.templates import (
7576
TRAIN_SCRIPT_TEMPLATE,
@@ -385,6 +386,7 @@ def train(
385386

386387
self._prepare_train_script(
387388
source_code_config=self.source_code_config,
389+
distribution_config=self.distribution_config,
388390
)
389391
if self.distribution_config:
390392
smd_modelparallel_parameters = getattr(
@@ -397,6 +399,8 @@ def train(
397399
smd_modelparallel_parameters
398400
)
399401
self._write_source_code_config_json(self.source_code_config)
402+
if self.distribution_config:
403+
self._write_distribution_config_json(self.distribution_config)
400404

401405
# Create an input channel for drivers packaged by the sdk
402406
sm_drivers_channel = self.create_input_data_channel(SM_DRIVERS, SM_DRIVERS_LOCAL_PATH)
@@ -555,6 +559,14 @@ def _write_source_code_config_json(self, source_code_config: SourceCodeConfig):
555559
with open(file_path, "w") as f:
556560
f.write(source_code_config.model_dump_json())
557561

562+
def _write_distribution_config_json(
563+
self, distribution: Union[MPIDistributionConfig, TorchDistributionConfig]
564+
):
565+
"""Write the distribution configuration to a JSON file."""
566+
file_path = os.path.join(SM_DRIVERS_LOCAL_PATH, DISTRIBUTION_JSON)
567+
with open(file_path, "w") as f:
568+
f.write(distribution.model_dump_json())
569+
558570
def _prepare_train_script(
559571
self,
560572
source_code_config: SourceCodeConfig,

0 commit comments

Comments
 (0)