Skip to content

Commit b88be2a

Browse files
beniericsagemaker-bot
authored andcommitted
feat: Allow ModelTrainer to accept hyperparameters file (aws#5059)
* Allow ModelTrainer to accept hyperparameter file and create Hyperparameter class * pylint * Detect hyperparameters from contents rather than file extension * pylint * change: add integs * change: add integs * change: remove custom hyperparameter tooling * Add tests for hp contracts * change: add unit tests and remove unreachable condition * fix integs * doc check fix * fix tests * fix tox.ini * add unit test
1 parent 805cb7a commit b88be2a

File tree

7 files changed

+285
-24
lines changed

7 files changed

+285
-24
lines changed

src/sagemaker/modules/train/model_trainer.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import json
1919
import shutil
2020
from tempfile import TemporaryDirectory
21-
2221
from typing import Optional, List, Union, Dict, Any, ClassVar
22+
import yaml
2323

2424
from graphene.utils.str_converters import to_camel_case, to_snake_case
2525

@@ -195,8 +195,9 @@ class ModelTrainer(BaseModel):
195195
Defaults to "File".
196196
environment (Optional[Dict[str, str]]):
197197
The environment variables for the training job.
198-
hyperparameters (Optional[Dict[str, Any]]):
199-
The hyperparameters for the training job.
198+
hyperparameters (Optional[Union[Dict[str, Any], str]):
199+
The hyperparameters for the training job. Can be a dictionary of hyperparameters
200+
or a path to hyperparameters json/yaml file.
200201
tags (Optional[List[Tag]]):
201202
An array of key-value pairs. You can use tags to categorize your AWS resources
202203
in different ways, for example, by purpose, owner, or environment.
@@ -226,7 +227,7 @@ class ModelTrainer(BaseModel):
226227
checkpoint_config: Optional[CheckpointConfig] = None
227228
training_input_mode: Optional[str] = "File"
228229
environment: Optional[Dict[str, str]] = {}
229-
hyperparameters: Optional[Dict[str, Any]] = {}
230+
hyperparameters: Optional[Union[Dict[str, Any], str]] = {}
230231
tags: Optional[List[Tag]] = None
231232
local_container_root: Optional[str] = os.getcwd()
232233

@@ -470,6 +471,29 @@ def model_post_init(self, __context: Any):
470471
f"StoppingCondition not provided. Using default:\n{self.stopping_condition}"
471472
)
472473

474+
if self.hyperparameters and isinstance(self.hyperparameters, str):
475+
if not os.path.exists(self.hyperparameters):
476+
raise ValueError(f"Hyperparameters file not found: {self.hyperparameters}")
477+
logger.info(f"Loading hyperparameters from file: {self.hyperparameters}")
478+
with open(self.hyperparameters, "r") as f:
479+
contents = f.read()
480+
try:
481+
self.hyperparameters = json.loads(contents)
482+
logger.debug("Hyperparameters loaded as JSON")
483+
except json.JSONDecodeError:
484+
try:
485+
logger.info(f"contents: {contents}")
486+
self.hyperparameters = yaml.safe_load(contents)
487+
if not isinstance(self.hyperparameters, dict):
488+
raise ValueError("YAML contents must be a valid mapping")
489+
logger.info(f"hyperparameters: {self.hyperparameters}")
490+
logger.debug("Hyperparameters loaded as YAML")
491+
except (yaml.YAMLError, ValueError):
492+
raise ValueError(
493+
f"Invalid hyperparameters file: {self.hyperparameters}. "
494+
"Must be a valid JSON or YAML file."
495+
)
496+
473497
if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None:
474498
session = self.sagemaker_session
475499
base_job_name = self.base_job_name
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"integer": 1,
3+
"boolean": true,
4+
"float": 3.14,
5+
"string": "Hello World",
6+
"list": [1, 2, 3],
7+
"dict": {
8+
"string": "value",
9+
"integer": 3,
10+
"float": 3.14,
11+
"list": [1, 2, 3],
12+
"dict": {"key": "value"},
13+
"boolean": true
14+
}
15+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
integer: 1
2+
boolean: true
3+
float: 3.14
4+
string: "Hello World"
5+
list:
6+
- 1
7+
- 2
8+
- 3
9+
dict:
10+
string: value
11+
integer: 3
12+
float: 3.14
13+
list:
14+
- 1
15+
- 2
16+
- 3
17+
dict:
18+
key: value
19+
boolean: true
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
omegaconf

tests/data/modules/params_script/train.py

+94-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
import argparse
1717
import json
1818
import os
19+
from typing import List, Dict, Any
20+
from dataclasses import dataclass
21+
from omegaconf import OmegaConf
1922

2023
EXPECTED_HYPERPARAMETERS = {
2124
"integer": 1,
@@ -26,6 +29,7 @@
2629
"dict": {
2730
"string": "value",
2831
"integer": 3,
32+
"float": 3.14,
2933
"list": [1, 2, 3],
3034
"dict": {"key": "value"},
3135
"boolean": True,
@@ -117,7 +121,7 @@ def main():
117121
assert isinstance(params["dict"], dict)
118122

119123
params = json.loads(os.environ["SM_TRAINING_ENV"])["hyperparameters"]
120-
print(params)
124+
print(f"SM_TRAINING_ENV -> hyperparameters: {params}")
121125
assert params["string"] == EXPECTED_HYPERPARAMETERS["string"]
122126
assert params["integer"] == EXPECTED_HYPERPARAMETERS["integer"]
123127
assert params["boolean"] == EXPECTED_HYPERPARAMETERS["boolean"]
@@ -132,9 +136,96 @@ def main():
132136
assert isinstance(params["float"], float)
133137
assert isinstance(params["list"], list)
134138
assert isinstance(params["dict"], dict)
135-
print(f"SM_TRAINING_ENV -> hyperparameters: {params}")
136139

137-
print("Test passed.")
140+
# Local JSON - DictConfig OmegaConf
141+
params = OmegaConf.load("hyperparameters.json")
142+
143+
print(f"Local hyperparameters.json: {params}")
144+
assert params.string == EXPECTED_HYPERPARAMETERS["string"]
145+
assert params.integer == EXPECTED_HYPERPARAMETERS["integer"]
146+
assert params.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
147+
assert params.float == EXPECTED_HYPERPARAMETERS["float"]
148+
assert params.list == EXPECTED_HYPERPARAMETERS["list"]
149+
assert params.dict == EXPECTED_HYPERPARAMETERS["dict"]
150+
assert params.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
151+
assert params.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
152+
assert params.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
153+
assert params.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
154+
assert params.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
155+
assert params.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
156+
157+
@dataclass
158+
class DictConfig:
159+
string: str
160+
integer: int
161+
boolean: bool
162+
float: float
163+
list: List[int]
164+
dict: Dict[str, Any]
165+
166+
@dataclass
167+
class HPConfig:
168+
string: str
169+
integer: int
170+
boolean: bool
171+
float: float
172+
list: List[int]
173+
dict: DictConfig
174+
175+
# Local JSON - Structured OmegaConf
176+
hp_config: HPConfig = OmegaConf.merge(
177+
OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.json")
178+
)
179+
print(f"Local hyperparameters.json - Structured: {hp_config}")
180+
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
181+
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
182+
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
183+
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
184+
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
185+
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
186+
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
187+
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
188+
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
189+
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
190+
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
191+
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
192+
193+
# Local YAML - Structured OmegaConf
194+
hp_config: HPConfig = OmegaConf.merge(
195+
OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.yaml")
196+
)
197+
print(f"Local hyperparameters.yaml - Structured: {hp_config}")
198+
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
199+
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
200+
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
201+
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
202+
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
203+
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
204+
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
205+
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
206+
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
207+
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
208+
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
209+
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
210+
print(f"hyperparameters.yaml -> hyperparameters: {hp_config}")
211+
212+
# HP Dict - Structured OmegaConf
213+
hp_dict = json.loads(os.environ["SM_HPS"])
214+
hp_config: HPConfig = OmegaConf.merge(OmegaConf.structured(HPConfig), OmegaConf.create(hp_dict))
215+
print(f"SM_HPS - Structured: {hp_config}")
216+
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
217+
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
218+
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
219+
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
220+
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
221+
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
222+
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
223+
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
224+
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
225+
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
226+
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
227+
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
228+
print(f"SM_HPS -> hyperparameters: {hp_config}")
138229

139230

140231
if __name__ == "__main__":

tests/integ/sagemaker/modules/train/test_model_trainer.py

+36-16
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,29 @@
2828
"dict": {
2929
"string": "value",
3030
"integer": 3,
31+
"float": 3.14,
3132
"list": [1, 2, 3],
3233
"dict": {"key": "value"},
3334
"boolean": True,
3435
},
3536
}
3637

38+
PARAM_SCRIPT_SOURCE_DIR = f"{DATA_DIR}/modules/params_script"
39+
PARAM_SCRIPT_SOURCE_CODE = SourceCode(
40+
source_dir=PARAM_SCRIPT_SOURCE_DIR,
41+
requirements="requirements.txt",
42+
entry_script="train.py",
43+
)
44+
3745
DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310"
3846

3947

4048
def test_hp_contract_basic_py_script(modules_sagemaker_session):
41-
source_code = SourceCode(
42-
source_dir=f"{DATA_DIR}/modules/params_script",
43-
entry_script="train.py",
44-
)
45-
4649
model_trainer = ModelTrainer(
4750
sagemaker_session=modules_sagemaker_session,
4851
training_image=DEFAULT_CPU_IMAGE,
4952
hyperparameters=EXPECTED_HYPERPARAMETERS,
50-
source_code=source_code,
53+
source_code=PARAM_SCRIPT_SOURCE_CODE,
5154
base_job_name="hp-contract-basic-py-script",
5255
)
5356

@@ -57,6 +60,7 @@ def test_hp_contract_basic_py_script(modules_sagemaker_session):
5760
def test_hp_contract_basic_sh_script(modules_sagemaker_session):
5861
source_code = SourceCode(
5962
source_dir=f"{DATA_DIR}/modules/params_script",
63+
requirements="requirements.txt",
6064
entry_script="train.sh",
6165
)
6266
model_trainer = ModelTrainer(
@@ -71,17 +75,13 @@ def test_hp_contract_basic_sh_script(modules_sagemaker_session):
7175

7276

7377
def test_hp_contract_mpi_script(modules_sagemaker_session):
74-
source_code = SourceCode(
75-
source_dir=f"{DATA_DIR}/modules/params_script",
76-
entry_script="train.py",
77-
)
7878
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)
7979
model_trainer = ModelTrainer(
8080
sagemaker_session=modules_sagemaker_session,
8181
training_image=DEFAULT_CPU_IMAGE,
8282
compute=compute,
8383
hyperparameters=EXPECTED_HYPERPARAMETERS,
84-
source_code=source_code,
84+
source_code=PARAM_SCRIPT_SOURCE_CODE,
8585
distributed=MPI(),
8686
base_job_name="hp-contract-mpi-script",
8787
)
@@ -90,19 +90,39 @@ def test_hp_contract_mpi_script(modules_sagemaker_session):
9090

9191

9292
def test_hp_contract_torchrun_script(modules_sagemaker_session):
93-
source_code = SourceCode(
94-
source_dir=f"{DATA_DIR}/modules/params_script",
95-
entry_script="train.py",
96-
)
9793
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)
9894
model_trainer = ModelTrainer(
9995
sagemaker_session=modules_sagemaker_session,
10096
training_image=DEFAULT_CPU_IMAGE,
10197
compute=compute,
10298
hyperparameters=EXPECTED_HYPERPARAMETERS,
103-
source_code=source_code,
99+
source_code=PARAM_SCRIPT_SOURCE_CODE,
104100
distributed=Torchrun(),
105101
base_job_name="hp-contract-torchrun-script",
106102
)
107103

108104
model_trainer.train()
105+
106+
107+
def test_hp_contract_hyperparameter_json(modules_sagemaker_session):
108+
model_trainer = ModelTrainer(
109+
sagemaker_session=modules_sagemaker_session,
110+
training_image=DEFAULT_CPU_IMAGE,
111+
hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.json",
112+
source_code=PARAM_SCRIPT_SOURCE_CODE,
113+
base_job_name="hp-contract-hyperparameter-json",
114+
)
115+
assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS
116+
model_trainer.train()
117+
118+
119+
def test_hp_contract_hyperparameter_yaml(modules_sagemaker_session):
120+
model_trainer = ModelTrainer(
121+
sagemaker_session=modules_sagemaker_session,
122+
training_image=DEFAULT_CPU_IMAGE,
123+
hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.yaml",
124+
source_code=PARAM_SCRIPT_SOURCE_CODE,
125+
base_job_name="hp-contract-hyperparameter-yaml",
126+
)
127+
assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS
128+
model_trainer.train()

0 commit comments

Comments
 (0)