Skip to content

Commit abe46cb

Browse files
committed
fix integs
1 parent b139afb commit abe46cb

File tree

1 file changed

+15
-25
lines changed

1 file changed

+15
-25
lines changed

tests/integ/sagemaker/modules/train/test_model_trainer.py

+15-25
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,14 @@
3535
},
3636
}
3737

38-
DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310"
38+
PARAM_SCRIPT_SOURCE_DIR = f"{DATA_DIR}/modules/params_script"
39+
PARAM_SCRIPT_SOURCE_CODE = SourceCode(
40+
source_dir=PARAM_SCRIPT_SOURCE_DIR,
41+
requirements="requirements.txt",
42+
entry_script="train.py",
43+
)
44+
45+
DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py31"
3946

4047

4148
def test_hp_contract_basic_py_script(modules_sagemaker_session):
@@ -59,6 +66,7 @@ def test_hp_contract_basic_py_script(modules_sagemaker_session):
5966
def test_hp_contract_basic_sh_script(modules_sagemaker_session):
6067
source_code = SourceCode(
6168
source_dir=f"{DATA_DIR}/modules/params_script",
69+
requirements="requirements.txt",
6270
entry_script="train.sh",
6371
)
6472
model_trainer = ModelTrainer(
@@ -73,17 +81,13 @@ def test_hp_contract_basic_sh_script(modules_sagemaker_session):
7381

7482

7583
def test_hp_contract_mpi_script(modules_sagemaker_session):
76-
source_code = SourceCode(
77-
source_dir=f"{DATA_DIR}/modules/params_script",
78-
entry_script="train.py",
79-
)
8084
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)
8185
model_trainer = ModelTrainer(
8286
sagemaker_session=modules_sagemaker_session,
8387
training_image=DEFAULT_CPU_IMAGE,
8488
compute=compute,
8589
hyperparameters=EXPECTED_HYPERPARAMETERS,
86-
source_code=source_code,
90+
source_code=PARAM_SCRIPT_SOURCE_CODE,
8791
distributed=MPI(),
8892
base_job_name="hp-contract-mpi-script",
8993
)
@@ -92,17 +96,13 @@ def test_hp_contract_mpi_script(modules_sagemaker_session):
9296

9397

9498
def test_hp_contract_torchrun_script(modules_sagemaker_session):
95-
source_code = SourceCode(
96-
source_dir=f"{DATA_DIR}/modules/params_script",
97-
entry_script="train.py",
98-
)
9999
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)
100100
model_trainer = ModelTrainer(
101101
sagemaker_session=modules_sagemaker_session,
102102
training_image=DEFAULT_CPU_IMAGE,
103103
compute=compute,
104104
hyperparameters=EXPECTED_HYPERPARAMETERS,
105-
source_code=source_code,
105+
source_code=PARAM_SCRIPT_SOURCE_CODE,
106106
distributed=Torchrun(),
107107
base_job_name="hp-contract-torchrun-script",
108108
)
@@ -111,33 +111,23 @@ def test_hp_contract_torchrun_script(modules_sagemaker_session):
111111

112112

113113
def test_hp_contract_hyperparameter_json(modules_sagemaker_session):
114-
source_dir = f"{DATA_DIR}/modules/params_script"
115-
source_code = SourceCode(
116-
source_dir=source_dir,
117-
entry_script="train.py",
118-
)
119114
model_trainer = ModelTrainer(
120115
sagemaker_session=modules_sagemaker_session,
121116
training_image=DEFAULT_CPU_IMAGE,
122-
hyperparameters=f"{source_dir}/hyperparameters.json",
123-
source_code=source_code,
117+
hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.json",
118+
source_code=PARAM_SCRIPT_SOURCE_CODE,
124119
base_job_name="hp-contract-hyperparameter-json",
125120
)
126121
assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS
127122
model_trainer.train()
128123

129124

130125
def test_hp_contract_hyperparameter_yaml(modules_sagemaker_session):
131-
source_dir = f"{DATA_DIR}/modules/params_script"
132-
source_code = SourceCode(
133-
source_dir=source_dir,
134-
entry_script="train.py",
135-
)
136126
model_trainer = ModelTrainer(
137127
sagemaker_session=modules_sagemaker_session,
138128
training_image=DEFAULT_CPU_IMAGE,
139-
hyperparameters=f"{source_dir}/hyperparameters.yaml",
140-
source_code=source_code,
129+
hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.yaml",
130+
source_code=PARAM_SCRIPT_SOURCE_CODE,
141131
base_job_name="hp-contract-hyperparameter-yaml",
142132
)
143133
assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS

0 commit comments

Comments
 (0)