Skip to content

feat: Allow ModelTrainer to accept hyperparameters file #5059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 5, 2025
Merged
32 changes: 28 additions & 4 deletions src/sagemaker/modules/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import json
import shutil
from tempfile import TemporaryDirectory

from typing import Optional, List, Union, Dict, Any, ClassVar
import yaml

from graphene.utils.str_converters import to_camel_case, to_snake_case

Expand Down Expand Up @@ -195,8 +195,9 @@ class ModelTrainer(BaseModel):
Defaults to "File".
environment (Optional[Dict[str, str]]):
The environment variables for the training job.
hyperparameters (Optional[Dict[str, Any]]):
The hyperparameters for the training job.
hyperparameters (Optional[Union[Dict[str, Any], str]):
The hyperparameters for the training job. Can be a dictionary of hyperparameters
or a path to hyperparameters json/yaml file.
tags (Optional[List[Tag]]):
An array of key-value pairs. You can use tags to categorize your AWS resources
in different ways, for example, by purpose, owner, or environment.
Expand Down Expand Up @@ -226,7 +227,7 @@ class ModelTrainer(BaseModel):
checkpoint_config: Optional[CheckpointConfig] = None
training_input_mode: Optional[str] = "File"
environment: Optional[Dict[str, str]] = {}
hyperparameters: Optional[Dict[str, Any]] = {}
hyperparameters: Optional[Union[Dict[str, Any], str]] = {}
tags: Optional[List[Tag]] = None
local_container_root: Optional[str] = os.getcwd()

Expand Down Expand Up @@ -470,6 +471,29 @@ def model_post_init(self, __context: Any):
f"StoppingCondition not provided. Using default:\n{self.stopping_condition}"
)

if self.hyperparameters and isinstance(self.hyperparameters, str):
if not os.path.exists(self.hyperparameters):
raise ValueError(f"Hyperparameters file not found: {self.hyperparameters}")
logger.info(f"Loading hyperparameters from file: {self.hyperparameters}")
with open(self.hyperparameters, "r") as f:
contents = f.read()
try:
self.hyperparameters = json.loads(contents)
logger.debug("Hyperparameters loaded as JSON")
except json.JSONDecodeError:
try:
logger.info(f"contents: {contents}")
self.hyperparameters = yaml.safe_load(contents)
if not isinstance(self.hyperparameters, dict):
raise ValueError("YAML contents must be a valid mapping")
logger.info(f"hyperparameters: {self.hyperparameters}")
logger.debug("Hyperparameters loaded as YAML")
except (yaml.YAMLError, ValueError):
raise ValueError(
f"Invalid hyperparameters file: {self.hyperparameters}. "
"Must be a valid JSON or YAML file."
)

if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None:
session = self.sagemaker_session
base_job_name = self.base_job_name
Expand Down
15 changes: 15 additions & 0 deletions tests/data/modules/params_script/hyperparameters.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"integer": 1,
"boolean": true,
"float": 3.14,
"string": "Hello World",
"list": [1, 2, 3],
"dict": {
"string": "value",
"integer": 3,
"float": 3.14,
"list": [1, 2, 3],
"dict": {"key": "value"},
"boolean": true
}
}
19 changes: 19 additions & 0 deletions tests/data/modules/params_script/hyperparameters.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
integer: 1
boolean: true
float: 3.14
string: "Hello World"
list:
- 1
- 2
- 3
dict:
string: value
integer: 3
float: 3.14
list:
- 1
- 2
- 3
dict:
key: value
boolean: true
1 change: 1 addition & 0 deletions tests/data/modules/params_script/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
omegaconf
97 changes: 94 additions & 3 deletions tests/data/modules/params_script/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import argparse
import json
import os
from typing import List, Dict, Any
from dataclasses import dataclass
from omegaconf import OmegaConf

EXPECTED_HYPERPARAMETERS = {
"integer": 1,
Expand All @@ -26,6 +29,7 @@
"dict": {
"string": "value",
"integer": 3,
"float": 3.14,
"list": [1, 2, 3],
"dict": {"key": "value"},
"boolean": True,
Expand Down Expand Up @@ -117,7 +121,7 @@ def main():
assert isinstance(params["dict"], dict)

params = json.loads(os.environ["SM_TRAINING_ENV"])["hyperparameters"]
print(params)
print(f"SM_TRAINING_ENV -> hyperparameters: {params}")
assert params["string"] == EXPECTED_HYPERPARAMETERS["string"]
assert params["integer"] == EXPECTED_HYPERPARAMETERS["integer"]
assert params["boolean"] == EXPECTED_HYPERPARAMETERS["boolean"]
Expand All @@ -132,9 +136,96 @@ def main():
assert isinstance(params["float"], float)
assert isinstance(params["list"], list)
assert isinstance(params["dict"], dict)
print(f"SM_TRAINING_ENV -> hyperparameters: {params}")

print("Test passed.")
# Local JSON - DictConfig OmegaConf
params = OmegaConf.load("hyperparameters.json")

print(f"Local hyperparameters.json: {params}")
assert params.string == EXPECTED_HYPERPARAMETERS["string"]
assert params.integer == EXPECTED_HYPERPARAMETERS["integer"]
assert params.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
assert params.float == EXPECTED_HYPERPARAMETERS["float"]
assert params.list == EXPECTED_HYPERPARAMETERS["list"]
assert params.dict == EXPECTED_HYPERPARAMETERS["dict"]
assert params.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
assert params.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
assert params.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
assert params.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
assert params.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
assert params.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]

@dataclass
class DictConfig:
string: str
integer: int
boolean: bool
float: float
list: List[int]
dict: Dict[str, Any]

@dataclass
class HPConfig:
string: str
integer: int
boolean: bool
float: float
list: List[int]
dict: DictConfig

# Local JSON - Structured OmegaConf
hp_config: HPConfig = OmegaConf.merge(
OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.json")
)
print(f"Local hyperparameters.json - Structured: {hp_config}")
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]

# Local YAML - Structured OmegaConf
hp_config: HPConfig = OmegaConf.merge(
OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.yaml")
)
print(f"Local hyperparameters.yaml - Structured: {hp_config}")
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
print(f"hyperparameters.yaml -> hyperparameters: {hp_config}")

# HP Dict - Structured OmegaConf
hp_dict = json.loads(os.environ["SM_HPS"])
hp_config: HPConfig = OmegaConf.merge(OmegaConf.structured(HPConfig), OmegaConf.create(hp_dict))
print(f"SM_HPS - Structured: {hp_config}")
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
print(f"SM_HPS -> hyperparameters: {hp_config}")


if __name__ == "__main__":
Expand Down
52 changes: 36 additions & 16 deletions tests/integ/sagemaker/modules/train/test_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,29 @@
"dict": {
"string": "value",
"integer": 3,
"float": 3.14,
"list": [1, 2, 3],
"dict": {"key": "value"},
"boolean": True,
},
}

PARAM_SCRIPT_SOURCE_DIR = f"{DATA_DIR}/modules/params_script"
PARAM_SCRIPT_SOURCE_CODE = SourceCode(
source_dir=PARAM_SCRIPT_SOURCE_DIR,
requirements="requirements.txt",
entry_script="train.py",
)

DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310"


def test_hp_contract_basic_py_script(modules_sagemaker_session):
source_code = SourceCode(
source_dir=f"{DATA_DIR}/modules/params_script",
entry_script="train.py",
)

model_trainer = ModelTrainer(
sagemaker_session=modules_sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
hyperparameters=EXPECTED_HYPERPARAMETERS,
source_code=source_code,
source_code=PARAM_SCRIPT_SOURCE_CODE,
base_job_name="hp-contract-basic-py-script",
)

Expand All @@ -57,6 +60,7 @@ def test_hp_contract_basic_py_script(modules_sagemaker_session):
def test_hp_contract_basic_sh_script(modules_sagemaker_session):
source_code = SourceCode(
source_dir=f"{DATA_DIR}/modules/params_script",
requirements="requirements.txt",
entry_script="train.sh",
)
model_trainer = ModelTrainer(
Expand All @@ -71,17 +75,13 @@ def test_hp_contract_basic_sh_script(modules_sagemaker_session):


def test_hp_contract_mpi_script(modules_sagemaker_session):
source_code = SourceCode(
source_dir=f"{DATA_DIR}/modules/params_script",
entry_script="train.py",
)
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)
model_trainer = ModelTrainer(
sagemaker_session=modules_sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
compute=compute,
hyperparameters=EXPECTED_HYPERPARAMETERS,
source_code=source_code,
source_code=PARAM_SCRIPT_SOURCE_CODE,
distributed=MPI(),
base_job_name="hp-contract-mpi-script",
)
Expand All @@ -90,19 +90,39 @@ def test_hp_contract_mpi_script(modules_sagemaker_session):


def test_hp_contract_torchrun_script(modules_sagemaker_session):
source_code = SourceCode(
source_dir=f"{DATA_DIR}/modules/params_script",
entry_script="train.py",
)
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)
model_trainer = ModelTrainer(
sagemaker_session=modules_sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
compute=compute,
hyperparameters=EXPECTED_HYPERPARAMETERS,
source_code=source_code,
source_code=PARAM_SCRIPT_SOURCE_CODE,
distributed=Torchrun(),
base_job_name="hp-contract-torchrun-script",
)

model_trainer.train()


def test_hp_contract_hyperparameter_json(modules_sagemaker_session):
model_trainer = ModelTrainer(
sagemaker_session=modules_sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.json",
source_code=PARAM_SCRIPT_SOURCE_CODE,
base_job_name="hp-contract-hyperparameter-json",
)
assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS
model_trainer.train()


def test_hp_contract_hyperparameter_yaml(modules_sagemaker_session):
model_trainer = ModelTrainer(
sagemaker_session=modules_sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.yaml",
source_code=PARAM_SCRIPT_SOURCE_CODE,
base_job_name="hp-contract-hyperparameter-yaml",
)
assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS
model_trainer.train()
Loading