35
35
},
36
36
}
37
37
38
- DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310"
38
+ PARAM_SCRIPT_SOURCE_DIR = f"{ DATA_DIR } /modules/params_script"
39
+ PARAM_SCRIPT_SOURCE_CODE = SourceCode (
40
+ source_dir = PARAM_SCRIPT_SOURCE_DIR ,
41
+ requirements = "requirements.txt" ,
42
+ entry_script = "train.py" ,
43
+ )
44
+
45
+ DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py31"
39
46
40
47
41
48
def test_hp_contract_basic_py_script (modules_sagemaker_session ):
@@ -59,6 +66,7 @@ def test_hp_contract_basic_py_script(modules_sagemaker_session):
59
66
def test_hp_contract_basic_sh_script (modules_sagemaker_session ):
60
67
source_code = SourceCode (
61
68
source_dir = f"{ DATA_DIR } /modules/params_script" ,
69
+ requirements = "requirements.txt" ,
62
70
entry_script = "train.sh" ,
63
71
)
64
72
model_trainer = ModelTrainer (
@@ -73,17 +81,13 @@ def test_hp_contract_basic_sh_script(modules_sagemaker_session):
73
81
74
82
75
83
def test_hp_contract_mpi_script (modules_sagemaker_session ):
76
- source_code = SourceCode (
77
- source_dir = f"{ DATA_DIR } /modules/params_script" ,
78
- entry_script = "train.py" ,
79
- )
80
84
compute = Compute (instance_type = "ml.m5.xlarge" , instance_count = 2 )
81
85
model_trainer = ModelTrainer (
82
86
sagemaker_session = modules_sagemaker_session ,
83
87
training_image = DEFAULT_CPU_IMAGE ,
84
88
compute = compute ,
85
89
hyperparameters = EXPECTED_HYPERPARAMETERS ,
86
- source_code = source_code ,
90
+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
87
91
distributed = MPI (),
88
92
base_job_name = "hp-contract-mpi-script" ,
89
93
)
@@ -92,17 +96,13 @@ def test_hp_contract_mpi_script(modules_sagemaker_session):
92
96
93
97
94
98
def test_hp_contract_torchrun_script (modules_sagemaker_session ):
95
- source_code = SourceCode (
96
- source_dir = f"{ DATA_DIR } /modules/params_script" ,
97
- entry_script = "train.py" ,
98
- )
99
99
compute = Compute (instance_type = "ml.m5.xlarge" , instance_count = 2 )
100
100
model_trainer = ModelTrainer (
101
101
sagemaker_session = modules_sagemaker_session ,
102
102
training_image = DEFAULT_CPU_IMAGE ,
103
103
compute = compute ,
104
104
hyperparameters = EXPECTED_HYPERPARAMETERS ,
105
- source_code = source_code ,
105
+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
106
106
distributed = Torchrun (),
107
107
base_job_name = "hp-contract-torchrun-script" ,
108
108
)
@@ -111,33 +111,23 @@ def test_hp_contract_torchrun_script(modules_sagemaker_session):
111
111
112
112
113
113
def test_hp_contract_hyperparameter_json (modules_sagemaker_session ):
114
- source_dir = f"{ DATA_DIR } /modules/params_script"
115
- source_code = SourceCode (
116
- source_dir = source_dir ,
117
- entry_script = "train.py" ,
118
- )
119
114
model_trainer = ModelTrainer (
120
115
sagemaker_session = modules_sagemaker_session ,
121
116
training_image = DEFAULT_CPU_IMAGE ,
122
- hyperparameters = f"{ source_dir } /hyperparameters.json" ,
123
- source_code = source_code ,
117
+ hyperparameters = f"{ PARAM_SCRIPT_SOURCE_DIR } /hyperparameters.json" ,
118
+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
124
119
base_job_name = "hp-contract-hyperparameter-json" ,
125
120
)
126
121
assert model_trainer .hyperparameters == EXPECTED_HYPERPARAMETERS
127
122
model_trainer .train ()
128
123
129
124
130
125
def test_hp_contract_hyperparameter_yaml (modules_sagemaker_session ):
131
- source_dir = f"{ DATA_DIR } /modules/params_script"
132
- source_code = SourceCode (
133
- source_dir = source_dir ,
134
- entry_script = "train.py" ,
135
- )
136
126
model_trainer = ModelTrainer (
137
127
sagemaker_session = modules_sagemaker_session ,
138
128
training_image = DEFAULT_CPU_IMAGE ,
139
- hyperparameters = f"{ source_dir } /hyperparameters.yaml" ,
140
- source_code = source_code ,
129
+ hyperparameters = f"{ PARAM_SCRIPT_SOURCE_DIR } /hyperparameters.yaml" ,
130
+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
141
131
base_job_name = "hp-contract-hyperparameter-yaml" ,
142
132
)
143
133
assert model_trainer .hyperparameters == EXPECTED_HYPERPARAMETERS
0 commit comments