@@ -54,6 +54,13 @@ class PyTorch(Framework):
54
54
# to retrieve the image uri below before GA.
55
55
SM_ADAPTER_REPO = "[email protected] :aws/private-sagemaker-training-adapter-for-nemo-staging.git"
56
56
SM_LAUNCHER_REPO = "[email protected] :aws/private-sagemaker-training-launcher-staging.git"
57
+ SM_TRAINING_RECIPE_GPU_IMG = (
58
+ "855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:adaptor_sept9_v1"
59
+ )
60
+ SM_NEURONX_DIST_REPO = "https://github.com/aws-neuron/neuronx-distributed-training.git"
61
+ SM_NEURONX_DIST_IMG = (
62
+ "855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:neuron_sept26_v1"
63
+ )
57
64
58
65
def __init__ (
59
66
self ,
@@ -66,6 +73,7 @@ def __init__(
66
73
distribution : Optional [Dict ] = None ,
67
74
compiler_config : Optional [TrainingCompilerConfig ] = None ,
68
75
training_recipe : Optional [str ] = None ,
76
+ recipe_overrides : Optional [Dict ] = None ,
69
77
** kwargs ,
70
78
):
71
79
"""This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.
@@ -262,6 +270,9 @@ def __init__(
262
270
a url to fetch, or a recipe provided by Saagemaker
263
271
training.
264
272
273
+ recipe_overrides (Dict): Dictionary specifying key values to override in the
274
+ training_recipe.
275
+
265
276
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
266
277
constructor.
267
278
@@ -280,12 +291,12 @@ def __init__(
280
291
logger .warning ("Argument hyperparameters will be ignored with training recipe." )
281
292
if distribution is not None :
282
293
logger .warning ("Argument distribution will be ignored with training_recipe." )
283
- args = self ._setup_for_training_recipe (training_recipe , kwargs )
294
+ args = self ._setup_for_training_recipe (training_recipe , recipe_overrides , kwargs )
284
295
entry_point = args ["entry_point" ]
285
296
source_dir = args ["source_dir" ]
286
297
hyperparameters = args ["hyperparameters" ]
287
298
if image_uri is None :
288
- image_uri = args ["image_uri " ]
299
+ image_uri = args ["default_image_uri " ]
289
300
distribution = args ["distribution" ]
290
301
elif entry_point is None :
291
302
raise ValueError (
@@ -518,7 +529,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
518
529
return init_params
519
530
520
531
@classmethod
521
- def _setup_for_training_recipe (cls , training_recipe , kwargs ):
532
+ def _setup_for_training_recipe (cls , training_recipe , recipe_overrides , kwargs ):
522
533
"""Performs training recipe specific setup and returns recipe specific args.
523
534
524
535
Updates kwargs and returns a dictionary of args to use for estimator
@@ -528,28 +539,25 @@ def _setup_for_training_recipe(cls, training_recipe, kwargs):
528
539
Args:
529
540
training_recipe (str): A recipe which is a local file path, a url or a
530
541
sagemaker training recipe.
542
+ recipe_overrides (Dict): Dictionary specifying key values to override in the
543
+ training_recipe.
531
544
kwargs (dict): Dictionary of args used for estimator initializaiton.
532
545
Returns:
533
546
dict containing arg values for estimator initialization and setup.
534
547
535
548
"""
549
+ if recipe_overrides is None :
550
+ recipe_overrides = dict ()
536
551
cls .recipe_train_dir = tempfile .TemporaryDirectory (prefix = "training_" )
537
552
cls .recipe_launcher_dir = tempfile .TemporaryDirectory (prefix = "launcher_" )
538
553
539
- adapter_repo = os .environ .get ("training_adapter_git" , None ) or cls .SM_ADAPTER_REPO
540
- _run_clone_command (adapter_repo , cls .recipe_train_dir .name )
541
- source_dir = os .path .join (cls .recipe_train_dir .name , "scripts" )
542
-
543
- model_type_to_script = {"llama_v3" : "llama_pretrain.py" }
544
-
545
- args = {"source_dir" : source_dir }
546
- local_recipe_path = os .path .join (source_dir , "recipe.yaml" )
554
+ temp_local_recipe = tempfile .NamedTemporaryFile (prefix = "recipe" ).name
547
555
if training_recipe .endswith (".yaml" ):
548
556
if os .path .isfile (training_recipe ):
549
- shutil .copy (training_recipe , local_recipe_path )
557
+ shutil .copy (training_recipe , temp_local_recipe )
550
558
else :
551
559
try :
552
- urlretrieve (training_recipe , local_recipe_path )
560
+ urlretrieve (training_recipe , temp_local_recipe )
553
561
except Exception as e :
554
562
raise ValueError (
555
563
f"Could not fetch the provided recipe { training_recipe } : exception { str (e )} "
@@ -559,28 +567,27 @@ def _setup_for_training_recipe(cls, training_recipe, kwargs):
559
567
_run_clone_command (launcher_repo , cls .recipe_launcher_dir .name )
560
568
recipe = os .path .join (
561
569
cls .recipe_launcher_dir .name ,
562
- "examples " ,
570
+ "recipes-collection " ,
563
571
"recipes" ,
564
572
"training" ,
565
573
training_recipe + ".yaml" ,
566
574
)
567
575
if os .path .isfile (recipe ):
568
- shutil .copy (recipe , local_recipe_path )
576
+ shutil .copy (recipe , temp_local_recipe )
569
577
else :
570
578
raise ValueError (f"Recipe { training_recipe } not found." )
571
579
572
- recipe = OmegaConf .load (local_recipe_path )
573
-
574
- if "model" not in recipe :
575
- raise ValueError ("Supplied recipe does not contain required field model." )
576
- if "model_type" not in recipe ["model" ]:
577
- raise ValueError ("Supplied recipe does not contain required field model_type." )
578
- model_type = recipe ["model" ]["model_type" ]
579
- if model_type not in model_type_to_script :
580
- raise ValueError (f"Model type { model_type } not supported" )
581
- args ["model_type" ] = model_type
582
- args ["entry_point" ] = model_type_to_script [model_type ]
583
- args ["hyperparameters" ] = {"config-path" : "." , "config-name" : "recipe.yaml" }
580
+ recipe = OmegaConf .load (temp_local_recipe )
581
+
582
+ if "instance_type" not in kwargs :
583
+ raise ValueError ("Must pass instance type to estimator when using training recipes." )
584
+ instance_type = kwargs ["instance_type" ].split ("." )[1 ]
585
+ if instance_type .startswith (("p" , "g" )):
586
+ device_type = "gpu"
587
+ elif instance_type .startswith ("trn" ):
588
+ device_type = "trainium"
589
+ else :
590
+ device_type = "cpu"
584
591
585
592
if "trainer" not in recipe :
586
593
raise ValueError ("Supplied recipe does not contain required field trainer." )
@@ -597,17 +604,32 @@ def _setup_for_training_recipe(cls, training_recipe, kwargs):
597
604
)
598
605
kwargs ["instance_count" ] = recipe ["trainer" ]["num_nodes" ]
599
606
600
- if "accelerator" not in recipe ["trainer" ]:
601
- raise ValueError (
602
- "Supplied recipe does not contain required field trainer -> accelerator."
603
- )
604
- accelerator = recipe ["trainer" ]["accelerator" ]
605
- if accelerator == "gpu" :
606
- # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
607
- # to retrieve the image uri below before we go GA.
608
- args ["image_uri" ] = (
609
- "855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:adaptor_sept9_v1"
607
+ args = dict ()
608
+ # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
609
+ # to retrieve the image uri below before we go GA.
610
+ if device_type == "gpu" :
611
+ adapter_repo = os .environ .get ("training_adapter_git" , None ) or cls .SM_ADAPTER_REPO
612
+ _run_clone_command (adapter_repo , cls .recipe_train_dir .name )
613
+
614
+ model_type_to_entry = {
615
+ "llama_v3" : ("llama" , "llama_pretrain.py" ),
616
+ "mistral" : ("mistral" , "mistral_pretrain.py" ),
617
+ "mixtral" : ("mixtral" , "mixtral_pretrain.py" ),
618
+ }
619
+
620
+ if "model" not in recipe :
621
+ raise ValueError ("Supplied recipe does not contain required field model." )
622
+ if "model_type" not in recipe ["model" ]:
623
+ raise ValueError ("Supplied recipe does not contain required field model_type." )
624
+ model_type = recipe ["model" ]["model_type" ]
625
+ if model_type not in model_type_to_entry :
626
+ raise ValueError (f"Model type { model_type } not supported" )
627
+
628
+ args ["source_dir" ] = os .path .join (
629
+ cls .recipe_train_dir .name , "examples" , model_type_to_entry [model_type ][0 ]
610
630
)
631
+ args ["entry_point" ] = model_type_to_entry [model_type ][1 ]
632
+ args ["default_image_uri" ] = cls .SM_TRAINING_RECIPE_GPU_IMG
611
633
smp_options = {
612
634
"enabled" : True ,
613
635
"parameters" : {
@@ -618,26 +640,29 @@ def _setup_for_training_recipe(cls, training_recipe, kwargs):
618
640
"smdistributed" : {"modelparallel" : smp_options },
619
641
"torch_distributed" : {"enabled" : True },
620
642
}
643
+ elif device_type == "trainium" :
644
+ _run_clone_command (cls .SM_NEURONX_DIST_REPO , cls .recipe_train_dir .name )
645
+ args ["source_dir" ] = os .path .join (cls .recipe_train_dir .name , "examples" )
646
+ args ["entry_point" ] = "training_orchestrator.py"
647
+ args ["default_image_uri" ] = cls .SM_NEURONX_DIST_IMG
648
+ args ["distribution" ] = {
649
+ "torch_distributed" : {"enabled" : True },
650
+ }
621
651
else :
622
- raise ValueError (f"Accelerator type { accelerator } not yet supported." )
623
-
624
- try :
625
- recipe ["run" ]["results_dir" ] = "/opt/ml/model/"
626
- recipe ["exp_manager" ]["exp_dir" ] = "/opt/ml/model/"
627
- recipe ["exp_manager" ]["explicit_log_dir" ] = "/opt/ml/output/tensorboard"
628
- recipe ["exp_manager" ]["checkpoint_dir" ] = "/opt/ml/checkpoints"
629
- recipe ["model" ]["data" ]["train_dir" ] = ["/opt/ml/input/data/train" ]
630
- recipe ["model" ]["data" ]["val_dir" ] = ["/opt/ml/input/data/val" ]
631
- except KeyError as e :
632
- raise RuntimeError (
633
- f"Error when trying to update recipe for sagemaker jobs with key { str (e )} ."
652
+ raise ValueError (
653
+ f"Devices of type { device_type } are not supported with training recipes."
634
654
)
635
655
656
+ recipe_overrides .setdefault ("run" , dict ())["results_dir" ] = "/opt/ml/model"
657
+ recipe_overrides .setdefault ("exp_manager" , dict ())["exp_dir" ] = "/opt/ml/model/"
658
+ recipe = OmegaConf .merge (recipe , recipe_overrides )
659
+
636
660
if "container" in recipe and not recipe ["container" ]:
637
661
logger .warning (
638
662
"Ignoring container from training_recipe. Use image_uri arg for estimator."
639
663
)
640
664
641
- OmegaConf .save (config = recipe , f = local_recipe_path )
665
+ OmegaConf .save (config = recipe , f = os .path .join (args ["source_dir" ], "recipe.yaml" ))
666
+ args ["hyperparameters" ] = {"config-path" : "." , "config-name" : "recipe.yaml" }
642
667
643
668
return args
0 commit comments