Skip to content

feat: Allow ModelTrainer to accept hyperparameters file #5059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 5, 2025
113 changes: 113 additions & 0 deletions src/sagemaker/modules/hyperparameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Hyperparameters class module."""
from __future__ import absolute_import

import os
import json
import dataclasses
from typing import Any, Type, TypeVar

from sagemaker.modules import logger

T = TypeVar("T")


class DictConfig:
"""Class that supports both dict and dot notation access"""

def __init__(self, **kwargs):
# Store the original dict
self._data = kwargs

# Set all items as attributes for dot notation
for key, value in kwargs.items():
# Recursively convert nested dicts to DictConfig
if isinstance(value, dict):
value = DictConfig(**value)
setattr(self, key, value)

def __getitem__(self, key: str) -> Any:
"""Enable dictionary-style access: config['key']"""
return self._data[key]

def __setitem__(self, key: str, value: Any):
"""Enable dictionary-style assignment: config['key'] = value"""
self._data[key] = value
setattr(self, key, value)

def __str__(self) -> str:
"""String representation"""
return str(self._data)

def __repr__(self) -> str:
"""Detailed string representation"""
return f"DictConfig({self._data})"


class Hyperparameters:
"""Class to load hyperparameters in training container."""

@staticmethod
def load() -> DictConfig:
"""Loads hyperparameters in training container

Example:

.. code:: python
from sagemaker.modules.hyperparameters import Hyperparameters

hps = Hyperparameters.load()
print(hps.batch_size)

Returns:
DictConfig: hyperparameters as a DictConfig object
"""
hps = json.loads(os.environ.get("SM_HPS", "{}"))
if not hps:
logger.warning("No hyperparameters found in SM_HPS environment variable.")
return DictConfig(**hps)

@staticmethod
def load_structured(dataclass_type: Type[T]) -> T:
"""Loads hyperparameters as a structured dataclass

Example:

.. code:: python
from sagemaker.modules.hyperparameters import Hyperparameters

@dataclass
class TrainingConfig:
batch_size: int
learning_rate: float

config = Hyperparameters.load_structured(TrainingConfig)
print(config.batch_size) # typed int

Args:
dataclass_type: Dataclass type to structure the config

Returns:
dataclass_type: Instance of provided dataclass type
"""

if not dataclasses.is_dataclass(dataclass_type):
raise ValueError(f"{dataclass_type} is not a dataclass type.")

hps = json.loads(os.environ.get("SM_HPS", "{}"))
if not hps:
logger.warning("No hyperparameters found in SM_HPS environment variable.")

# Convert hyperparameters to dataclass
return dataclass_type(**hps)
30 changes: 26 additions & 4 deletions src/sagemaker/modules/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import json
import shutil
from tempfile import TemporaryDirectory

from typing import Optional, List, Union, Dict, Any, ClassVar
import yaml

from graphene.utils.str_converters import to_camel_case, to_snake_case

Expand Down Expand Up @@ -195,8 +195,9 @@ class ModelTrainer(BaseModel):
Defaults to "File".
environment (Optional[Dict[str, str]]):
The environment variables for the training job.
hyperparameters (Optional[Dict[str, Any]]):
The hyperparameters for the training job.
hyperparameters (Optional[Union[Dict[str, Any], str]):
The hyperparameters for the training job. Can be a dictionary of hyperparameters
or a path to hyperparameters json/yaml file.
tags (Optional[List[Tag]]):
An array of key-value pairs. You can use tags to categorize your AWS resources
in different ways, for example, by purpose, owner, or environment.
Expand Down Expand Up @@ -226,7 +227,7 @@ class ModelTrainer(BaseModel):
checkpoint_config: Optional[CheckpointConfig] = None
training_input_mode: Optional[str] = "File"
environment: Optional[Dict[str, str]] = {}
hyperparameters: Optional[Dict[str, Any]] = {}
hyperparameters: Optional[Union[Dict[str, Any], str]] = {}
tags: Optional[List[Tag]] = None
local_container_root: Optional[str] = os.getcwd()

Expand Down Expand Up @@ -470,6 +471,27 @@ def model_post_init(self, __context: Any):
f"StoppingCondition not provided. Using default:\n{self.stopping_condition}"
)

if self.hyperparameters and isinstance(self.hyperparameters, str):
if not os.path.exists(self.hyperparameters):
raise ValueError(f"Hyperparameters file not found: {self.hyperparameters}")
logger.info(f"Loading hyperparameters from file: {self.hyperparameters}")
with open(self.hyperparameters, "r") as f:
contents = f.read()
try:
self.hyperparameters = json.loads(contents)
logger.debug("Hyperparameters loaded as JSON")
except json.JSONDecodeError:
try:
self.hyperparameters = yaml.safe_load(contents)
if not isinstance(self.hyperparameters, dict):
raise ValueError("YAML content is not a valid mapping.")
logger.debug("Hyperparameters loaded as YAML")
except (yaml.YAMLError, ValueError):
raise ValueError(
f"Invalid hyperparameters file: {self.hyperparameters}. "
"Must be a valid JSON or YAML file."
)

if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None:
session = self.sagemaker_session
base_job_name = self.base_job_name
Expand Down