diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py index ff38bcbde8..549645cbe2 100644 --- a/src/sagemaker/modules/train/sm_recipes/utils.py +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -125,6 +125,27 @@ def _register_custom_resolvers(): OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers)) +def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): + """Get the model base name and script for the training recipe.""" + + model_type_to_script = { + "llama_v3": ("llama", "llama_pretrain.py"), + "mistral": ("mistral", "mistral_pretrain.py"), + "mixtral": ("mixtral", "mixtral_pretrain.py"), + "deepseek": ("deepseek", "deepseek_pretrain.py"), + } + + for key in model_type_to_script: + if model_type.startswith(key): + model_type = key + break + + if model_type not in model_type_to_script: + raise ValueError(f"Model type {model_type} not supported") + + return model_type_to_script[model_type][0], model_type_to_script[model_type][1] + + def _configure_gpu_args( training_recipes_cfg: Dict[str, Any], region_name: str, @@ -140,24 +161,16 @@ def _configure_gpu_args( ) _run_clone_command_silent(adapter_repo, recipe_train_dir.name) - model_type_to_entry = { - "llama_v3": ("llama", "llama_pretrain.py"), - "mistral": ("mistral", "mistral_pretrain.py"), - "mixtral": ("mixtral", "mixtral_pretrain.py"), - } - if "model" not in recipe: raise ValueError("Supplied recipe does not contain required field model.") if "model_type" not in recipe["model"]: raise ValueError("Supplied recipe does not contain required field model_type.") model_type = recipe["model"]["model_type"] - if model_type not in model_type_to_entry: - raise ValueError(f"Model type {model_type} not supported") - source_code.source_dir = os.path.join( - recipe_train_dir.name, "examples", model_type_to_entry[model_type][0] - ) - source_code.entry_script = model_type_to_entry[model_type][1] + model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type) + + source_code.source_dir = os.path.join(recipe_train_dir.name, "examples", model_base_name) + source_code.entry_script = script gpu_image_cfg = training_recipes_cfg.get("gpu_image") if isinstance(gpu_image_cfg, str): diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 46c57581d1..8f300d09fd 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -95,6 +95,7 @@ def _get_training_recipe_gpu_script(code_dir, recipe, source_dir): "llama_v3": ("llama", "llama_pretrain.py"), "mistral": ("mistral", "mistral_pretrain.py"), "mixtral": ("mixtral", "mixtral_pretrain.py"), + "deepseek": ("deepseek", "deepseek_pretrain.py"), } if "model" not in recipe: @@ -102,6 +103,12 @@ def _get_training_recipe_gpu_script(code_dir, recipe, source_dir): if "model_type" not in recipe["model"]: raise ValueError("Supplied recipe does not contain required field model_type.") model_type = recipe["model"]["model_type"] + + for key in model_type_to_script: + if model_type.startswith(key): + model_type = key + break + if model_type not in model_type_to_script: raise ValueError(f"Model type {model_type} not supported") diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py index 66eafab4f0..f5f7ceb083 100644 --- a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py +++ b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py @@ -26,6 +26,7 @@ _load_recipes_cfg, _configure_gpu_args, _configure_trainium_args, + _get_trainining_recipe_gpu_model_name_and_script, ) from sagemaker.modules.utils import _run_clone_command_silent from sagemaker.modules.configs import Compute @@ -178,3 +179,37 @@ def test_get_args_from_recipe_compute( assert mock_gpu_args.call_count == 0 assert mock_trainium_args.call_count == 0 assert args is None + + @pytest.mark.parametrize( + "test_case", + [ + { + "model_type": "llama_v3", + "script": "llama_pretrain.py", + "model_base_name": "llama_v3", + }, + { + "model_type": "mistral", + "script": "mistral_pretrain.py", + "model_base_name": "mistral", + }, + { + "model_type": "deepseek_llamav3", + "script": "deepseek_pretrain.py", + "model_base_name": "deepseek", + }, + { + "model_type": "deepseek_qwenv2", + "script": "deepseek_pretrain.py", + "model_base_name": "deepseek", + }, + ], + ) + def test_get_trainining_recipe_gpu_model_name_and_script(test_case): + model_type = test_case["model_type"] + script = test_case["script"] + model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script( + model_type, script + ) + assert model_base_name == test_case["model_base_name"] + assert script == test_case["script"] diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 6076d44e90..34d3c6784b 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -23,7 +23,10 @@ from sagemaker import image_uris from sagemaker.pytorch import defaults from sagemaker.pytorch import PyTorch, PyTorchPredictor, PyTorchModel -from sagemaker.pytorch.estimator import _get_training_recipe_image_uri +from sagemaker.pytorch.estimator import ( + _get_training_recipe_image_uri, + _get_training_recipe_gpu_script, +) from sagemaker.instance_group import InstanceGroup from sagemaker.session_settings import SessionSettings @@ -1049,6 +1052,52 @@ def test_training_recipe_for_trainium(sagemaker_session): assert pytorch.distribution == expected_distribution +@pytest.mark.parametrize( + "test_case", + [ + { + "script": "llama_pretrain.py", + "recipe": { + "model": { + "model_type": "llama_v3", + }, + }, + }, + { + "script": "mistral_pretrain.py", + "recipe": { + "model": { + "model_type": "mistral", + }, + }, + }, + { + "script": "deepseek_pretrain.py", + "recipe": { + "model": { + "model_type": "deepseek_llamav3", + }, + }, + }, + { + "script": "deepseek_pretrain.py", + "recipe": { + "model": { + "model_type": "deepseek_qwenv2", + }, + }, + }, + ], +) +@patch("shutil.copyfile") +def test_get_training_recipe_gpu_script(mock_copyfile, test_case): + script = test_case["script"] + recipe = test_case["recipe"] + mock_copyfile.return_value = None + + assert _get_training_recipe_gpu_script("code_dir", recipe, "source_dir") == script + + def test_training_recipe_for_trainium_custom_source_dir(sagemaker_session): container_log_level = '"logging.INFO"'