Skip to content

Commit 560052f

Browse files
schinmayeepintaoz-aws
authored andcommitted
Feature: Support Neuron training recipes. (#1526)
1 parent 43977de commit 560052f

File tree

1 file changed

+75
-50
lines changed

1 file changed

+75
-50
lines changed

src/sagemaker/pytorch/estimator.py

+75-50
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ class PyTorch(Framework):
5454
# to retrieve the image uri below before GA.
5555
SM_ADAPTER_REPO = "[email protected]:aws/private-sagemaker-training-adapter-for-nemo-staging.git"
5656
SM_LAUNCHER_REPO = "[email protected]:aws/private-sagemaker-training-launcher-staging.git"
57+
SM_TRAINING_RECIPE_GPU_IMG = (
58+
"855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:adaptor_sept9_v1"
59+
)
60+
SM_NEURONX_DIST_REPO = "https://github.com/aws-neuron/neuronx-distributed-training.git"
61+
SM_NEURONX_DIST_IMG = (
62+
"855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:neuron_sept26_v1"
63+
)
5764

5865
def __init__(
5966
self,
@@ -66,6 +73,7 @@ def __init__(
6673
distribution: Optional[Dict] = None,
6774
compiler_config: Optional[TrainingCompilerConfig] = None,
6875
training_recipe: Optional[str] = None,
76+
recipe_overrides: Optional[Dict] = None,
6977
**kwargs,
7078
):
7179
"""This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.
@@ -262,6 +270,9 @@ def __init__(
262270
a url to fetch, or a recipe provided by Saagemaker
263271
training.
264272
273+
recipe_overrides (Dict): Dictionary specifying key values to override in the
274+
training_recipe.
275+
265276
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
266277
constructor.
267278
@@ -280,12 +291,12 @@ def __init__(
280291
logger.warning("Argument hyperparameters will be ignored with training recipe.")
281292
if distribution is not None:
282293
logger.warning("Argument distribution will be ignored with training_recipe.")
283-
args = self._setup_for_training_recipe(training_recipe, kwargs)
294+
args = self._setup_for_training_recipe(training_recipe, recipe_overrides, kwargs)
284295
entry_point = args["entry_point"]
285296
source_dir = args["source_dir"]
286297
hyperparameters = args["hyperparameters"]
287298
if image_uri is None:
288-
image_uri = args["image_uri"]
299+
image_uri = args["default_image_uri"]
289300
distribution = args["distribution"]
290301
elif entry_point is None:
291302
raise ValueError(
@@ -518,7 +529,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
518529
return init_params
519530

520531
@classmethod
521-
def _setup_for_training_recipe(cls, training_recipe, kwargs):
532+
def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
522533
"""Performs training recipe specific setup and returns recipe specific args.
523534
524535
Updates kwargs and returns a dictionary of args to use for estimator
@@ -528,28 +539,25 @@ def _setup_for_training_recipe(cls, training_recipe, kwargs):
528539
Args:
529540
training_recipe (str): A recipe which is a local file path, a url or a
530541
sagemaker training recipe.
542+
recipe_overrides (Dict): Dictionary specifying key values to override in the
543+
training_recipe.
531544
kwargs (dict): Dictionary of args used for estimator initializaiton.
532545
Returns:
533546
dict containing arg values for estimator initialization and setup.
534547
535548
"""
549+
if recipe_overrides is None:
550+
recipe_overrides = dict()
536551
cls.recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_")
537552
cls.recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_")
538553

539-
adapter_repo = os.environ.get("training_adapter_git", None) or cls.SM_ADAPTER_REPO
540-
_run_clone_command(adapter_repo, cls.recipe_train_dir.name)
541-
source_dir = os.path.join(cls.recipe_train_dir.name, "scripts")
542-
543-
model_type_to_script = {"llama_v3": "llama_pretrain.py"}
544-
545-
args = {"source_dir": source_dir}
546-
local_recipe_path = os.path.join(source_dir, "recipe.yaml")
554+
temp_local_recipe = tempfile.NamedTemporaryFile(prefix="recipe").name
547555
if training_recipe.endswith(".yaml"):
548556
if os.path.isfile(training_recipe):
549-
shutil.copy(training_recipe, local_recipe_path)
557+
shutil.copy(training_recipe, temp_local_recipe)
550558
else:
551559
try:
552-
urlretrieve(training_recipe, local_recipe_path)
560+
urlretrieve(training_recipe, temp_local_recipe)
553561
except Exception as e:
554562
raise ValueError(
555563
f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}"
@@ -559,28 +567,27 @@ def _setup_for_training_recipe(cls, training_recipe, kwargs):
559567
_run_clone_command(launcher_repo, cls.recipe_launcher_dir.name)
560568
recipe = os.path.join(
561569
cls.recipe_launcher_dir.name,
562-
"examples",
570+
"recipes-collection",
563571
"recipes",
564572
"training",
565573
training_recipe + ".yaml",
566574
)
567575
if os.path.isfile(recipe):
568-
shutil.copy(recipe, local_recipe_path)
576+
shutil.copy(recipe, temp_local_recipe)
569577
else:
570578
raise ValueError(f"Recipe {training_recipe} not found.")
571579

572-
recipe = OmegaConf.load(local_recipe_path)
573-
574-
if "model" not in recipe:
575-
raise ValueError("Supplied recipe does not contain required field model.")
576-
if "model_type" not in recipe["model"]:
577-
raise ValueError("Supplied recipe does not contain required field model_type.")
578-
model_type = recipe["model"]["model_type"]
579-
if model_type not in model_type_to_script:
580-
raise ValueError(f"Model type {model_type} not supported")
581-
args["model_type"] = model_type
582-
args["entry_point"] = model_type_to_script[model_type]
583-
args["hyperparameters"] = {"config-path": ".", "config-name": "recipe.yaml"}
580+
recipe = OmegaConf.load(temp_local_recipe)
581+
582+
if "instance_type" not in kwargs:
583+
raise ValueError("Must pass instance type to estimator when using training recipes.")
584+
instance_type = kwargs["instance_type"].split(".")[1]
585+
if instance_type.startswith(("p", "g")):
586+
device_type = "gpu"
587+
elif instance_type.startswith("trn"):
588+
device_type = "trainium"
589+
else:
590+
device_type = "cpu"
584591

585592
if "trainer" not in recipe:
586593
raise ValueError("Supplied recipe does not contain required field trainer.")
@@ -597,17 +604,32 @@ def _setup_for_training_recipe(cls, training_recipe, kwargs):
597604
)
598605
kwargs["instance_count"] = recipe["trainer"]["num_nodes"]
599606

600-
if "accelerator" not in recipe["trainer"]:
601-
raise ValueError(
602-
"Supplied recipe does not contain required field trainer -> accelerator."
603-
)
604-
accelerator = recipe["trainer"]["accelerator"]
605-
if accelerator == "gpu":
606-
# [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
607-
# to retrieve the image uri below before we go GA.
608-
args["image_uri"] = (
609-
"855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:adaptor_sept9_v1"
607+
args = dict()
608+
# [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
609+
# to retrieve the image uri below before we go GA.
610+
if device_type == "gpu":
611+
adapter_repo = os.environ.get("training_adapter_git", None) or cls.SM_ADAPTER_REPO
612+
_run_clone_command(adapter_repo, cls.recipe_train_dir.name)
613+
614+
model_type_to_entry = {
615+
"llama_v3": ("llama", "llama_pretrain.py"),
616+
"mistral": ("mistral", "mistral_pretrain.py"),
617+
"mixtral": ("mixtral", "mixtral_pretrain.py"),
618+
}
619+
620+
if "model" not in recipe:
621+
raise ValueError("Supplied recipe does not contain required field model.")
622+
if "model_type" not in recipe["model"]:
623+
raise ValueError("Supplied recipe does not contain required field model_type.")
624+
model_type = recipe["model"]["model_type"]
625+
if model_type not in model_type_to_entry:
626+
raise ValueError(f"Model type {model_type} not supported")
627+
628+
args["source_dir"] = os.path.join(
629+
cls.recipe_train_dir.name, "examples", model_type_to_entry[model_type][0]
610630
)
631+
args["entry_point"] = model_type_to_entry[model_type][1]
632+
args["default_image_uri"] = cls.SM_TRAINING_RECIPE_GPU_IMG
611633
smp_options = {
612634
"enabled": True,
613635
"parameters": {
@@ -618,26 +640,29 @@ def _setup_for_training_recipe(cls, training_recipe, kwargs):
618640
"smdistributed": {"modelparallel": smp_options},
619641
"torch_distributed": {"enabled": True},
620642
}
643+
elif device_type == "trainium":
644+
_run_clone_command(cls.SM_NEURONX_DIST_REPO, cls.recipe_train_dir.name)
645+
args["source_dir"] = os.path.join(cls.recipe_train_dir.name, "examples")
646+
args["entry_point"] = "training_orchestrator.py"
647+
args["default_image_uri"] = cls.SM_NEURONX_DIST_IMG
648+
args["distribution"] = {
649+
"torch_distributed": {"enabled": True},
650+
}
621651
else:
622-
raise ValueError(f"Accelerator type {accelerator} not yet supported.")
623-
624-
try:
625-
recipe["run"]["results_dir"] = "/opt/ml/model/"
626-
recipe["exp_manager"]["exp_dir"] = "/opt/ml/model/"
627-
recipe["exp_manager"]["explicit_log_dir"] = "/opt/ml/output/tensorboard"
628-
recipe["exp_manager"]["checkpoint_dir"] = "/opt/ml/checkpoints"
629-
recipe["model"]["data"]["train_dir"] = ["/opt/ml/input/data/train"]
630-
recipe["model"]["data"]["val_dir"] = ["/opt/ml/input/data/val"]
631-
except KeyError as e:
632-
raise RuntimeError(
633-
f"Error when trying to update recipe for sagemaker jobs with key {str(e)}."
652+
raise ValueError(
653+
f"Devices of type {device_type} are not supported with training recipes."
634654
)
635655

656+
recipe_overrides.setdefault("run", dict())["results_dir"] = "/opt/ml/model"
657+
recipe_overrides.setdefault("exp_manager", dict())["exp_dir"] = "/opt/ml/model/"
658+
recipe = OmegaConf.merge(recipe, recipe_overrides)
659+
636660
if "container" in recipe and not recipe["container"]:
637661
logger.warning(
638662
"Ignoring container from training_recipe. Use image_uri arg for estimator."
639663
)
640664

641-
OmegaConf.save(config=recipe, f=local_recipe_path)
665+
OmegaConf.save(config=recipe, f=os.path.join(args["source_dir"], "recipe.yaml"))
666+
args["hyperparameters"] = {"config-path": ".", "config-name": "recipe.yaml"}
642667

643668
return args

0 commit comments

Comments
 (0)