14
14
from __future__ import absolute_import
15
15
16
16
import logging
17
+ import os
18
+ import shutil
19
+ import tempfile
17
20
from typing import Union , Optional , Dict
21
+ from urllib .request import urlretrieve
18
22
23
+ from omegaconf import OmegaConf
19
24
from packaging .version import Version
20
25
21
26
from sagemaker .estimator import Framework , EstimatorBase
27
32
validate_distribution ,
28
33
profiler_config_deprecation_warning ,
29
34
)
35
+ from sagemaker .git_utils import _run_clone_command
30
36
from sagemaker .pytorch import defaults
31
37
from sagemaker .pytorch .model import PyTorchModel
32
38
from sagemaker .pytorch .training_compiler .config import TrainingCompilerConfig
@@ -44,16 +50,22 @@ class PyTorch(Framework):
44
50
LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
45
51
INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"
46
52
53
+ # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
54
+ # to retrieve the image uri below before GA.
55
+ SM_ADAPTER_REPO = "[email protected] :aws/private-sagemaker-training-adapter-for-nemo-staging.git"
56
+ SM_LAUNCHER_REPO = "[email protected] :aws/private-sagemaker-training-launcher-staging.git"
57
+
47
58
def __init__ (
48
59
self ,
49
- entry_point : Union [str , PipelineVariable ],
60
+ entry_point : Optional [ Union [str , PipelineVariable ]] = None ,
50
61
framework_version : Optional [str ] = None ,
51
62
py_version : Optional [str ] = None ,
52
63
source_dir : Optional [Union [str , PipelineVariable ]] = None ,
53
64
hyperparameters : Optional [Dict [str , Union [str , PipelineVariable ]]] = None ,
54
65
image_uri : Optional [Union [str , PipelineVariable ]] = None ,
55
66
distribution : Optional [Dict ] = None ,
56
67
compiler_config : Optional [TrainingCompilerConfig ] = None ,
68
+ training_recipe : Optional [str ] = None ,
57
69
** kwargs ,
58
70
):
59
71
"""This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.
@@ -246,6 +258,10 @@ def __init__(
246
258
compiler_config (:class:`~sagemaker.pytorch.TrainingCompilerConfig`):
247
259
Configures SageMaker Training Compiler to accelerate training.
248
260
261
+ training_recipe (str): Training recipe to use. This is a local file path,
262
+ a url to fetch, or a recipe provided by Saagemaker
263
+ training.
264
+
249
265
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
250
266
constructor.
251
267
@@ -255,6 +271,26 @@ def __init__(
255
271
:class:`~sagemaker.estimator.Framework` and
256
272
:class:`~sagemaker.estimator.EstimatorBase`.
257
273
"""
274
+ if training_recipe is not None :
275
+ if entry_point is not None :
276
+ logger .warning ("Argument entry_point will be ignored with training_recipe." )
277
+ if source_dir is not None :
278
+ logger .warning ("Argument source_dir will be ignored with training_recipe." )
279
+ if hyperparameters is not None :
280
+ logger .warning ("Argument hyperparameters will be ignored with training recipe." )
281
+ if distribution is not None :
282
+ logger .warning ("Argument distribution will be ignored with training_recipe." )
283
+ args = self ._setup_for_training_recipe (training_recipe , kwargs )
284
+ entry_point = args ["entry_point" ]
285
+ source_dir = args ["source_dir" ]
286
+ hyperparameters = args ["hyperparameters" ]
287
+ if image_uri is None :
288
+ image_uri = args ["image_uri" ]
289
+ distribution = args ["distribution" ]
290
+ elif entry_point is None :
291
+ raise ValueError (
292
+ "Argument entry_point must be set when training_recipe is not provided"
293
+ )
258
294
validate_version_or_image_args (framework_version , py_version , image_uri )
259
295
if py_version == "py2" :
260
296
logger .warning (
@@ -480,3 +516,128 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
480
516
)
481
517
482
518
return init_params
519
+
520
+ @classmethod
521
+ def _setup_for_training_recipe (cls , training_recipe , kwargs ):
522
+ """Performs training recipe specific setup and returns recipe specific args.
523
+
524
+ Updates kwargs and returns a dictionary of args to use for estimator
525
+ initialization and setup when using a training recipe. Updates the paths in
526
+ the recipe for Sagemaker Jobs environment.
527
+
528
+ Args:
529
+ training_recipe (str): A recipe which is a local file path, a url or a
530
+ sagemaker training recipe.
531
+ kwargs (dict): Dictionary of args used for estimator initializaiton.
532
+ Returns:
533
+ dict containing arg values for estimator initialization and setup.
534
+
535
+ """
536
+ cls .recipe_train_dir = tempfile .TemporaryDirectory (prefix = "training_" )
537
+ cls .recipe_launcher_dir = tempfile .TemporaryDirectory (prefix = "launcher_" )
538
+
539
+ adapter_repo = os .environ .get ("training_adapter_git" , None ) or cls .SM_ADAPTER_REPO
540
+ _run_clone_command (adapter_repo , cls .recipe_train_dir .name )
541
+ source_dir = os .path .join (cls .recipe_train_dir .name , "scripts" )
542
+
543
+ model_type_to_script = {"llama_v3" : "llama_pretrain.py" }
544
+
545
+ args = {"source_dir" : source_dir }
546
+ local_recipe_path = os .path .join (source_dir , "recipe.yaml" )
547
+ if training_recipe .endswith (".yaml" ):
548
+ if os .path .isfile (training_recipe ):
549
+ shutil .copy (training_recipe , local_recipe_path )
550
+ else :
551
+ try :
552
+ urlretrieve (training_recipe , local_recipe_path )
553
+ except Exception as e :
554
+ raise ValueError (
555
+ f"Could not fetch the provided recipe { training_recipe } : exception { str (e )} "
556
+ )
557
+ else :
558
+ launcher_repo = os .environ .get ("training_launcher_git" , None ) or cls .SM_LAUNCHER_REPO
559
+ _run_clone_command (launcher_repo , cls .recipe_launcher_dir .name )
560
+ recipe = os .path .join (
561
+ cls .recipe_launcher_dir .name ,
562
+ "examples" ,
563
+ "recipes" ,
564
+ "training" ,
565
+ training_recipe + ".yaml" ,
566
+ )
567
+ if os .path .isfile (recipe ):
568
+ shutil .copy (recipe , local_recipe_path )
569
+ else :
570
+ raise ValueError (f"Recipe { training_recipe } not found." )
571
+
572
+ recipe = OmegaConf .load (local_recipe_path )
573
+
574
+ if "model" not in recipe :
575
+ raise ValueError ("Supplied recipe does not contain required field model." )
576
+ if "model_type" not in recipe ["model" ]:
577
+ raise ValueError ("Supplied recipe does not contain required field model_type." )
578
+ model_type = recipe ["model" ]["model_type" ]
579
+ if model_type not in model_type_to_script :
580
+ raise ValueError (f"Model type { model_type } not supported" )
581
+ args ["model_type" ] = model_type
582
+ args ["entry_point" ] = model_type_to_script [model_type ]
583
+ args ["hyperparameters" ] = {"config-path" : "." , "config-name" : "recipe.yaml" }
584
+
585
+ if "trainer" not in recipe :
586
+ raise ValueError ("Supplied recipe does not contain required field trainer." )
587
+ if "instance_count" in kwargs and "num_nodes" in recipe ["trainer" ]:
588
+ logger .warning (
589
+ "Using instance_count argument to estimator to set number "
590
+ " of nodes. Ignoring trainer -> num_nodes in recipe."
591
+ )
592
+ if "instance_count" not in kwargs :
593
+ if "num_nodes" not in recipe ["trainer" ]:
594
+ raise ValueError (
595
+ "Must set either instance_count argument for estimator or"
596
+ "set trainer -> num_nodes in recipe."
597
+ )
598
+ kwargs ["instance_count" ] = recipe ["trainer" ]["num_nodes" ]
599
+
600
+ if "accelerator" not in recipe ["trainer" ]:
601
+ raise ValueError (
602
+ "Supplied recipe does not contain required field trainer -> accelerator."
603
+ )
604
+ accelerator = recipe ["trainer" ]["accelerator" ]
605
+ if accelerator == "gpu" :
606
+ # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
607
+ # to retrieve the image uri below before we go GA.
608
+ args ["image_uri" ] = (
609
+ "855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:adaptor_sept9_v1"
610
+ )
611
+ smp_options = {
612
+ "enabled" : True ,
613
+ "parameters" : {
614
+ "placement_strategy" : "cluster" ,
615
+ },
616
+ }
617
+ args ["distribution" ] = {
618
+ "smdistributed" : {"modelparallel" : smp_options },
619
+ "torch_distributed" : {"enabled" : True },
620
+ }
621
+ else :
622
+ raise ValueError (f"Accelerator type { accelerator } not yet supported." )
623
+
624
+ try :
625
+ recipe ["run" ]["results_dir" ] = "/opt/ml/model/"
626
+ recipe ["exp_manager" ]["exp_dir" ] = "/opt/ml/model/"
627
+ recipe ["exp_manager" ]["explicit_log_dir" ] = "/opt/ml/output/tensorboard"
628
+ recipe ["exp_manager" ]["checkpoint_dir" ] = "/opt/ml/checkpoints"
629
+ recipe ["model" ]["data" ]["train_dir" ] = ["/opt/ml/input/data/train" ]
630
+ recipe ["model" ]["data" ]["val_dir" ] = ["/opt/ml/input/data/val" ]
631
+ except KeyError as e :
632
+ raise RuntimeError (
633
+ f"Error when trying to update recipe for sagemaker jobs with key { str (e )} ."
634
+ )
635
+
636
+ if "container" in recipe and not recipe ["container" ]:
637
+ logger .warning (
638
+ "Ignoring container from training_recipe. Use image_uri arg for estimator."
639
+ )
640
+
641
+ OmegaConf .save (config = recipe , f = local_recipe_path )
642
+
643
+ return args
0 commit comments