@@ -475,12 +475,22 @@ def model_post_init(self, __context: Any):
475
475
if not os .path .exists (self .hyperparameters ):
476
476
raise ValueError (f"Hyperparameters file not found: { self .hyperparameters } " )
477
477
logger .info (f"Loading hyperparameters from file: { self .hyperparameters } " )
478
- if self .hyperparameters .endswith (".json" ):
479
- with open (self .hyperparameters , "r" ) as f :
480
- self .hyperparameters = json .load (f )
481
- elif self .hyperparameters .endswith (".yaml" ):
482
- with open (self .hyperparameters , "r" ) as f :
483
- self .hyperparameters = yaml .safe_load (f )
478
+ with open (self .hyperparameters , "r" ) as f :
479
+ contents = f .read ()
480
+ try :
481
+ self .hyperparameters = json .loads (contents )
482
+ logger .debug ("Hyperparameters loaded as JSON" )
483
+ except json .JSONDecodeError :
484
+ try :
485
+ self .hyperparameters = yaml .safe_load (contents )
486
+ if not isinstance (self .hyperparameters , dict ):
487
+ raise ValueError ("YAML content is not a valid mapping." )
488
+ logger .debug ("Hyperparameters loaded as YAML" )
489
+ except (yaml .YAMLError , ValueError ) as e :
490
+ raise ValueError (
491
+ f"Invalid hyperparameters file: { self .hyperparameters } . "
492
+ "Must be a valid JSON or YAML file."
493
+ )
484
494
485
495
if self .training_mode == Mode .SAGEMAKER_TRAINING_JOB and self .output_data_config is None :
486
496
session = self .sagemaker_session
0 commit comments