|
36 | 36 | TrainingJob - {os.environ['TRAINING_JOB_NAME']}
|
37 | 37 | """
|
38 | 38 |
|
39 |
| -USER_CODE_PATH = "/opt/ml/input/data/code" |
40 |
| -SOURCE_CODE_CONFIG_JSON = "/opt/ml/input/data/sm_code/sourcecodeconfig.json" |
| 39 | +USER_CODE_PATH = "/opt/ml/input/data/sm_code" |
| 40 | +SOURCE_CODE_CONFIG_JSON = "/opt/ml/input/data/sm_drivers/sourcecodeconfig.json" |
| 41 | +DISTRIBUTION_JSON = "/opt/ml/input/data/sm_drivers/distribution.json" |
41 | 42 |
|
42 | 43 | SM_EFA_NCCL_INSTANCES = [
|
43 | 44 | "ml.g4dn.8xlarge",
|
@@ -67,19 +68,25 @@ def write_failure_file(message: str = DEFAULT_FAILURE_MESSAGE):
|
67 | 68 | def read_source_code_config_json(source_code_config_file: Dict[str, Any] = SOURCE_CODE_CONFIG_JSON):
|
68 | 69 | """Read the source code config json file."""
|
69 | 70 | with open(source_code_config_file, "r") as f:
|
70 |
| - distribution_config = json.load(f) |
71 |
| - return distribution_config |
| 71 | + source_code_config_json = json.load(f) |
| 72 | + return source_code_config_json |
72 | 73 |
|
73 | 74 |
|
74 |
| -def get_process_count(source_code_config: Dict[str, Any]) -> int: |
| 75 | +def read_distribution_json(distribution_file: Dict[str, Any] = DISTRIBUTION_JSON): |
| 76 | + """Read the distribution json file.""" |
| 77 | + with open(distribution_file, "r") as f: |
| 78 | + distribution_json = json.load(f) |
| 79 | + return distribution_json |
| 80 | + |
| 81 | + |
| 82 | +def get_process_count(distribution: Dict[str, Any]) -> int: |
75 | 83 | """Get the number of processes to run on each node in the training job."""
|
76 |
| - if source_code_config.get("distribution", {}).get("process_count_per_node") is not None: |
77 |
| - return int(source_code_config["distribution"]["process_count_per_node"]) |
78 |
| - if os.environ.get("SM_NUM_GPUS") is not None: |
79 |
| - return int(os.environ["SM_NUM_GPUS"]) |
80 |
| - if os.environ.get("SM_NUM_NEURONS") is not None: |
81 |
| - return int(os.environ["SM_NUM_NEURONS"]) |
82 |
| - return 1 # Default to 1 process per node |
| 84 | + return ( |
| 85 | + int(distribution.get("process_count_per_node", 0)) |
| 86 | + or int(os.environ.get("SM_NUM_GPUS", 0)) |
| 87 | + or int(os.environ.get("SM_NUM_NEURONS", 0)) |
| 88 | + or 1 |
| 89 | + ) |
83 | 90 |
|
84 | 91 |
|
85 | 92 | def get_python_executable() -> str:
|
|
0 commit comments