Skip to content

Commit 45b89cc

Browse files
schinmayeepintaoz-aws
authored andcommitted
Feature: Support GPU training recipes with Sagemaker Python SDK (#1516)
* v0 estimator for launching kandinksy training * code cleanup * option to over-ride git repos for kandinsky for testing purposes * update dependencies * update comment * formatting fixes * style fixes * code cleanup * Add warning messages for ingored arguments * cleanup, address comments * fix * clone launcher repo only if necessary * add a cleanup method to call after fit * fix docstring * fix warning * cleanup update * fix * code style fix * rename cleanup method for clarity * missed change * move cleanup to when object is destroyed * add unit tests * formatting fix * removing tests which don't work as recipe repos are private * removing tests which don't work as recipe repos are private * resolve comments * resolve comments
1 parent cff8216 commit 45b89cc

File tree

1 file changed

+162
-1
lines changed

1 file changed

+162
-1
lines changed

src/sagemaker/pytorch/estimator.py

+162-1
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17+
import os
18+
import shutil
19+
import tempfile
1720
from typing import Union, Optional, Dict
21+
from urllib.request import urlretrieve
1822

23+
from omegaconf import OmegaConf
1924
from packaging.version import Version
2025

2126
from sagemaker.estimator import Framework, EstimatorBase
@@ -27,6 +32,7 @@
2732
validate_distribution,
2833
profiler_config_deprecation_warning,
2934
)
35+
from sagemaker.git_utils import _run_clone_command
3036
from sagemaker.pytorch import defaults
3137
from sagemaker.pytorch.model import PyTorchModel
3238
from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig
@@ -44,16 +50,22 @@ class PyTorch(Framework):
4450
LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
4551
INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"
4652

53+
# [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
54+
# to retrieve the image uri below before GA.
55+
SM_ADAPTER_REPO = "[email protected]:aws/private-sagemaker-training-adapter-for-nemo-staging.git"
56+
SM_LAUNCHER_REPO = "[email protected]:aws/private-sagemaker-training-launcher-staging.git"
57+
4758
def __init__(
4859
self,
49-
entry_point: Union[str, PipelineVariable],
60+
entry_point: Optional[Union[str, PipelineVariable]] = None,
5061
framework_version: Optional[str] = None,
5162
py_version: Optional[str] = None,
5263
source_dir: Optional[Union[str, PipelineVariable]] = None,
5364
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
5465
image_uri: Optional[Union[str, PipelineVariable]] = None,
5566
distribution: Optional[Dict] = None,
5667
compiler_config: Optional[TrainingCompilerConfig] = None,
68+
training_recipe: Optional[str] = None,
5769
**kwargs,
5870
):
5971
"""This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.
@@ -246,6 +258,10 @@ def __init__(
246258
compiler_config (:class:`~sagemaker.pytorch.TrainingCompilerConfig`):
247259
Configures SageMaker Training Compiler to accelerate training.
248260
261+
training_recipe (str): Training recipe to use. This is a local file path,
262+
a url to fetch, or a recipe provided by Saagemaker
263+
training.
264+
249265
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
250266
constructor.
251267
@@ -255,6 +271,26 @@ def __init__(
255271
:class:`~sagemaker.estimator.Framework` and
256272
:class:`~sagemaker.estimator.EstimatorBase`.
257273
"""
274+
if training_recipe is not None:
275+
if entry_point is not None:
276+
logger.warning("Argument entry_point will be ignored with training_recipe.")
277+
if source_dir is not None:
278+
logger.warning("Argument source_dir will be ignored with training_recipe.")
279+
if hyperparameters is not None:
280+
logger.warning("Argument hyperparameters will be ignored with training recipe.")
281+
if distribution is not None:
282+
logger.warning("Argument distribution will be ignored with training_recipe.")
283+
args = self._setup_for_training_recipe(training_recipe, kwargs)
284+
entry_point = args["entry_point"]
285+
source_dir = args["source_dir"]
286+
hyperparameters = args["hyperparameters"]
287+
if image_uri is None:
288+
image_uri = args["image_uri"]
289+
distribution = args["distribution"]
290+
elif entry_point is None:
291+
raise ValueError(
292+
"Argument entry_point must be set when training_recipe is not provided"
293+
)
258294
validate_version_or_image_args(framework_version, py_version, image_uri)
259295
if py_version == "py2":
260296
logger.warning(
@@ -480,3 +516,128 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
480516
)
481517

482518
return init_params
519+
520+
@classmethod
521+
def _setup_for_training_recipe(cls, training_recipe, kwargs):
522+
"""Performs training recipe specific setup and returns recipe specific args.
523+
524+
Updates kwargs and returns a dictionary of args to use for estimator
525+
initialization and setup when using a training recipe. Updates the paths in
526+
the recipe for Sagemaker Jobs environment.
527+
528+
Args:
529+
training_recipe (str): A recipe which is a local file path, a url or a
530+
sagemaker training recipe.
531+
kwargs (dict): Dictionary of args used for estimator initializaiton.
532+
Returns:
533+
dict containing arg values for estimator initialization and setup.
534+
535+
"""
536+
cls.recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_")
537+
cls.recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_")
538+
539+
adapter_repo = os.environ.get("training_adapter_git", None) or cls.SM_ADAPTER_REPO
540+
_run_clone_command(adapter_repo, cls.recipe_train_dir.name)
541+
source_dir = os.path.join(cls.recipe_train_dir.name, "scripts")
542+
543+
model_type_to_script = {"llama_v3": "llama_pretrain.py"}
544+
545+
args = {"source_dir": source_dir}
546+
local_recipe_path = os.path.join(source_dir, "recipe.yaml")
547+
if training_recipe.endswith(".yaml"):
548+
if os.path.isfile(training_recipe):
549+
shutil.copy(training_recipe, local_recipe_path)
550+
else:
551+
try:
552+
urlretrieve(training_recipe, local_recipe_path)
553+
except Exception as e:
554+
raise ValueError(
555+
f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}"
556+
)
557+
else:
558+
launcher_repo = os.environ.get("training_launcher_git", None) or cls.SM_LAUNCHER_REPO
559+
_run_clone_command(launcher_repo, cls.recipe_launcher_dir.name)
560+
recipe = os.path.join(
561+
cls.recipe_launcher_dir.name,
562+
"examples",
563+
"recipes",
564+
"training",
565+
training_recipe + ".yaml",
566+
)
567+
if os.path.isfile(recipe):
568+
shutil.copy(recipe, local_recipe_path)
569+
else:
570+
raise ValueError(f"Recipe {training_recipe} not found.")
571+
572+
recipe = OmegaConf.load(local_recipe_path)
573+
574+
if "model" not in recipe:
575+
raise ValueError("Supplied recipe does not contain required field model.")
576+
if "model_type" not in recipe["model"]:
577+
raise ValueError("Supplied recipe does not contain required field model_type.")
578+
model_type = recipe["model"]["model_type"]
579+
if model_type not in model_type_to_script:
580+
raise ValueError(f"Model type {model_type} not supported")
581+
args["model_type"] = model_type
582+
args["entry_point"] = model_type_to_script[model_type]
583+
args["hyperparameters"] = {"config-path": ".", "config-name": "recipe.yaml"}
584+
585+
if "trainer" not in recipe:
586+
raise ValueError("Supplied recipe does not contain required field trainer.")
587+
if "instance_count" in kwargs and "num_nodes" in recipe["trainer"]:
588+
logger.warning(
589+
"Using instance_count argument to estimator to set number "
590+
" of nodes. Ignoring trainer -> num_nodes in recipe."
591+
)
592+
if "instance_count" not in kwargs:
593+
if "num_nodes" not in recipe["trainer"]:
594+
raise ValueError(
595+
"Must set either instance_count argument for estimator or"
596+
"set trainer -> num_nodes in recipe."
597+
)
598+
kwargs["instance_count"] = recipe["trainer"]["num_nodes"]
599+
600+
if "accelerator" not in recipe["trainer"]:
601+
raise ValueError(
602+
"Supplied recipe does not contain required field trainer -> accelerator."
603+
)
604+
accelerator = recipe["trainer"]["accelerator"]
605+
if accelerator == "gpu":
606+
# [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
607+
# to retrieve the image uri below before we go GA.
608+
args["image_uri"] = (
609+
"855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:adaptor_sept9_v1"
610+
)
611+
smp_options = {
612+
"enabled": True,
613+
"parameters": {
614+
"placement_strategy": "cluster",
615+
},
616+
}
617+
args["distribution"] = {
618+
"smdistributed": {"modelparallel": smp_options},
619+
"torch_distributed": {"enabled": True},
620+
}
621+
else:
622+
raise ValueError(f"Accelerator type {accelerator} not yet supported.")
623+
624+
try:
625+
recipe["run"]["results_dir"] = "/opt/ml/model/"
626+
recipe["exp_manager"]["exp_dir"] = "/opt/ml/model/"
627+
recipe["exp_manager"]["explicit_log_dir"] = "/opt/ml/output/tensorboard"
628+
recipe["exp_manager"]["checkpoint_dir"] = "/opt/ml/checkpoints"
629+
recipe["model"]["data"]["train_dir"] = ["/opt/ml/input/data/train"]
630+
recipe["model"]["data"]["val_dir"] = ["/opt/ml/input/data/val"]
631+
except KeyError as e:
632+
raise RuntimeError(
633+
f"Error when trying to update recipe for sagemaker jobs with key {str(e)}."
634+
)
635+
636+
if "container" in recipe and not recipe["container"]:
637+
logger.warning(
638+
"Ignoring container from training_recipe. Use image_uri arg for estimator."
639+
)
640+
641+
OmegaConf.save(config=recipe, f=local_recipe_path)
642+
643+
return args

0 commit comments

Comments
 (0)