Skip to content

Commit 1c45f14

Browse files
committed
Detect hyperparameters from contents rather than file extension
1 parent e4474fa commit 1c45f14

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

src/sagemaker/modules/train/model_trainer.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -475,12 +475,22 @@ def model_post_init(self, __context: Any):
475475
if not os.path.exists(self.hyperparameters):
476476
raise ValueError(f"Hyperparameters file not found: {self.hyperparameters}")
477477
logger.info(f"Loading hyperparameters from file: {self.hyperparameters}")
478-
if self.hyperparameters.endswith(".json"):
479-
with open(self.hyperparameters, "r") as f:
480-
self.hyperparameters = json.load(f)
481-
elif self.hyperparameters.endswith(".yaml"):
482-
with open(self.hyperparameters, "r") as f:
483-
self.hyperparameters = yaml.safe_load(f)
478+
with open(self.hyperparameters, "r") as f:
479+
contents = f.read()
480+
try:
481+
self.hyperparameters = json.loads(contents)
482+
logger.debug("Hyperparameters loaded as JSON")
483+
except json.JSONDecodeError:
484+
try:
485+
self.hyperparameters = yaml.safe_load(contents)
486+
if not isinstance(self.hyperparameters, dict):
487+
raise ValueError("YAML content is not a valid mapping.")
488+
logger.debug("Hyperparameters loaded as YAML")
489+
except (yaml.YAMLError, ValueError) as e:
490+
raise ValueError(
491+
f"Invalid hyperparameters file: {self.hyperparameters}. "
492+
"Must be a valid JSON or YAML file."
493+
)
484494

485495
if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None:
486496
session = self.sagemaker_session

0 commit comments

Comments
 (0)