Skip to content

Commit 775a627

Browse files
beniericpintaoz-aws
authored andcommitted
Update ModelTrainer Interface Parameters (#1617)
1 parent c015e3f commit 775a627

File tree

16 files changed

+430
-634
lines changed

16 files changed

+430
-634
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ env/
3434
**/_repack_script_launcher.sh
3535
src/sagemaker/modules/train/container_drivers/sm_train.sh
3636
src/sagemaker/modules/train/container_drivers/sourcecode.json
37-
src/sagemaker/modules/train/container_drivers/distributed_runner.json
37+
src/sagemaker/modules/train/container_drivers/distributed.json
3838
tests/data/**/_repack_model.py
3939
tests/data/experiment/sagemaker-dev-1.0.tar.gz
4040
src/sagemaker/serve/tmp_workspace

src/sagemaker/config/config_schema.py

+14-56
Original file line numberDiff line numberDiff line change
@@ -664,58 +664,19 @@ def _simple_path(*args: str):
664664
"minLength": 20,
665665
"maxLength": 2048,
666666
},
667-
"baseJobName": {
668-
TYPE: OBJECT,
669-
ADDITIONAL_PROPERTIES: True
670-
},
671-
"sourceCode": {
672-
TYPE: OBJECT,
673-
ADDITIONAL_PROPERTIES: True
674-
},
675-
"distributed_runner": {
676-
TYPE: OBJECT,
677-
ADDITIONAL_PROPERTIES: True
678-
},
679-
"compute": {
680-
TYPE: OBJECT,
681-
ADDITIONAL_PROPERTIES: True
682-
},
683-
"networking": {
684-
TYPE: OBJECT,
685-
ADDITIONAL_PROPERTIES: True
686-
},
687-
"stoppingCondition": {
688-
TYPE: OBJECT,
689-
ADDITIONAL_PROPERTIES: True
690-
},
691-
"trainingImage": {
692-
TYPE: OBJECT,
693-
ADDITIONAL_PROPERTIES: True
694-
},
695-
"trainingImageConfig": {
696-
TYPE: OBJECT,
697-
ADDITIONAL_PROPERTIES: True
698-
},
699-
"algorithmName": {
700-
TYPE: OBJECT,
701-
ADDITIONAL_PROPERTIES: True
702-
},
703-
"outputDataConfig": {
704-
TYPE: OBJECT,
705-
ADDITIONAL_PROPERTIES: True
706-
},
707-
"trainingInputMode": {
708-
TYPE: OBJECT,
709-
ADDITIONAL_PROPERTIES: True
710-
},
711-
"environment": {
712-
TYPE: OBJECT,
713-
ADDITIONAL_PROPERTIES: True
714-
},
715-
"hyperparameters": {
716-
TYPE: OBJECT,
717-
ADDITIONAL_PROPERTIES: True
718-
},
667+
"baseJobName": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
668+
"sourceCode": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
669+
"distributed": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
670+
"compute": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
671+
"networking": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
672+
"stoppingCondition": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
673+
"trainingImage": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
674+
"trainingImageConfig": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
675+
"algorithmName": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
676+
"outputDataConfig": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
677+
"trainingInputMode": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
678+
"environment": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
679+
"hyperparameters": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
719680
},
720681
PROPERTIES: {
721682
SCHEMA_VERSION: {
@@ -769,10 +730,7 @@ def _simple_path(*args: str):
769730
},
770731
},
771732
},
772-
MODEL_TRAINER: {
773-
TYPE: OBJECT,
774-
ADDITIONAL_PROPERTIES: True
775-
},
733+
MODEL_TRAINER: {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
776734
ESTIMATOR: {
777735
TYPE: OBJECT,
778736
ADDITIONAL_PROPERTIES: False,

src/sagemaker/modules/configs.py

-14
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,7 @@
3939
TrainingImageConfig,
4040
TrainingRepositoryAuthConfig,
4141
Tag,
42-
MetricDefinition,
43-
DebugHookConfig,
44-
CollectionConfiguration,
45-
DebugRuleConfiguration,
46-
ExperimentConfig,
4742
InfraCheckConfig,
48-
ProfilerConfig,
49-
ProfilerRuleConfiguration,
5043
RemoteDebugConfig,
5144
SessionChainingConfig,
5245
InstanceGroup,
@@ -69,14 +62,7 @@
6962
"TrainingImageConfig",
7063
"TrainingRepositoryAuthConfig",
7164
"Tag",
72-
"MetricDefinition",
73-
"DebugHookConfig",
74-
"CollectionConfiguration",
75-
"DebugRuleConfiguration",
76-
"ExperimentConfig",
7765
"InfraCheckConfig",
78-
"ProfilerConfig",
79-
"ProfilerRuleConfiguration",
8066
"RemoteDebugConfig",
8167
"SessionChainingConfig",
8268
"InstanceGroup",

src/sagemaker/modules/constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727

2828
SOURCE_CODE_JSON = "sourcecode.json"
29-
DISTRIBUTED_RUNNER_JSON = "distributed_runner.json"
29+
DISTRIBUTED_JSON = "distributed.json"
3030
TRAIN_SCRIPT = "sm_train.sh"
3131

3232
DEFAULT_CONTAINER_ENTRYPOINT = ["/bin/bash"]

src/sagemaker/modules/distributed.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]:
7272
return hyperparameters
7373

7474

75-
class DistributedRunner(BaseModel):
76-
"""Base class for DistributedRunner Class"""
75+
class DistributedConfig(BaseModel):
76+
"""Base class for distributed training configurations."""
7777

7878
_type: str = PrivateAttr()
7979

@@ -84,11 +84,11 @@ def model_dump(self, *args, **kwargs):
8484
return result
8585

8686

87-
class Torchrun(DistributedRunner):
88-
"""TorchDistributed.
87+
class Torchrun(DistributedConfig):
88+
"""Torchrun.
8989
90-
The Torchrun runner uses `torchrun` or `torch.distributed.launch` in the backend to
91-
launch distributed training.
90+
The Torchrun class configures a job that uses `torchrun` or
91+
`torch.distributed.launch` in the backend to launch distributed training.
9292
9393
Attributes:
9494
process_count_per_node (int):
@@ -104,10 +104,11 @@ class Torchrun(DistributedRunner):
104104
smp: Optional["SMP"] = None
105105

106106

107-
class MPI(DistributedRunner):
107+
class MPI(DistributedConfig):
108108
"""MPI.
109109
110-
The MPI runner uses `mpirun` in the backend to launch distributed training.
110+
The MPI class configures a job that uses `mpirun` in the backend to launch
111+
distributed training.
111112
112113
Attributes:
113114
process_count_per_node (int):

src/sagemaker/modules/templates.py

-8
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,6 @@
7676
cat /opt/ml/input/config/inputdataconfig.json
7777
echo
7878
79-
echo "/opt/ml/input/data/sm_drivers/sourcecode.json"
80-
cat /opt/ml/input/data/sm_drivers/sourcecode.json
81-
echo
82-
83-
echo "/opt/ml/input/data/sm_drivers/distributed_runner.json"
84-
cat /opt/ml/input/data/sm_drivers/distributed_runner.json
85-
echo
86-
8779
echo "Setting up environment variables"
8880
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/scripts/environment.py
8981
source /opt/ml/input/sm_training.env

src/sagemaker/modules/testing_notebooks/base_model_trainer.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@
402402
" compute=compute,\n",
403403
" hyperparameters=hyperparameters,\n",
404404
" source_code=source_code,\n",
405-
" distributed_runner=torchrun,\n",
405+
" distributed=torchrun,\n",
406406
" base_job_name=f\"{alias}-distributed-case-2\",\n",
407407
")"
408408
]
@@ -498,7 +498,7 @@
498498
" hyperparameters=hyperparameters,\n",
499499
" environment=env,\n",
500500
" source_code=source_code,\n",
501-
" distributed_runner=mpi,\n",
501+
" distributed=mpi,\n",
502502
" base_job_name=f\"{alias}-distributed-case-3\",\n",
503503
")"
504504
]

src/sagemaker/modules/train/container_drivers/mpi_driver.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from utils import (
2121
logger,
2222
read_source_code_json,
23-
read_distributed_runner_json,
23+
read_distributed_json,
2424
read_hyperparameters_json,
2525
hyperparameters_to_cli_args,
2626
get_process_count,
@@ -59,7 +59,7 @@ def main():
5959
6060
"""
6161
source_code = read_source_code_json()
62-
distribution = read_distributed_runner_json()
62+
distribution = read_distributed_json()
6363
hyperparameters = read_hyperparameters_json()
6464

6565
sm_current_host = os.environ["SM_CURRENT_HOST"]

src/sagemaker/modules/train/container_drivers/torchrun_driver.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from utils import (
2222
logger,
2323
read_source_code_json,
24-
read_distributed_runner_json,
24+
read_distributed_json,
2525
read_hyperparameters_json,
2626
hyperparameters_to_cli_args,
2727
get_process_count,
@@ -66,7 +66,7 @@ def setup_env():
6666
def create_commands():
6767
"""Create the Torch Distributed command to execute"""
6868
source_code = read_source_code_json()
69-
distribution = read_distributed_runner_json()
69+
distribution = read_distributed_json()
7070
hyperparameters = read_hyperparameters_json()
7171

7272
process_count = get_process_count(distribution)

src/sagemaker/modules/train/container_drivers/utils.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
USER_CODE_PATH = "/opt/ml/input/data/sm_code"
4040
SOURCE_CODE_JSON = "/opt/ml/input/data/sm_drivers/sourcecode.json"
41-
DISTRIBUTED_RUNNER_JSON = "/opt/ml/input/data/sm_drivers/distributed_runner.json"
41+
DISTRIBUTED_JSON = "/opt/ml/input/data/sm_drivers/distributed.json"
4242

4343
HYPERPARAMETERS_JSON = "/opt/ml/input/config/hyperparameters.json"
4444

@@ -79,14 +79,14 @@ def read_source_code_json(source_code_json: Dict[str, Any] = SOURCE_CODE_JSON):
7979
return source_code_dict
8080

8181

82-
def read_distributed_runner_json(distributed_json: Dict[str, Any] = DISTRIBUTED_RUNNER_JSON):
82+
def read_distributed_json(distributed_json: Dict[str, Any] = DISTRIBUTED_JSON):
8383
"""Read the distribution config json file."""
8484
try:
8585
with open(distributed_json, "r") as f:
86-
distributed_runner_dict = json.load(f) or {}
86+
distributed_dict = json.load(f) or {}
8787
except FileNotFoundError:
88-
distributed_runner_dict = {}
89-
return distributed_runner_dict
88+
distributed_dict = {}
89+
return distributed_dict
9090

9191

9292
def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAMETERS_JSON):
@@ -99,10 +99,10 @@ def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAME
9999
return hyperparameters_dict
100100

101101

102-
def get_process_count(distributed_runner_dict: Dict[str, Any]) -> int:
102+
def get_process_count(distributed_dict: Dict[str, Any]) -> int:
103103
"""Get the number of processes to run on each node in the training job."""
104104
return (
105-
int(distributed_runner_dict.get("process_count_per_node", 0))
105+
int(distributed_dict.get("process_count_per_node", 0))
106106
or int(os.environ.get("SM_NUM_GPUS", 0))
107107
or int(os.environ.get("SM_NUM_NEURONS", 0))
108108
or 1

0 commit comments

Comments
 (0)