Skip to content

Commit ae1146c

Browse files
authored
feat: Add support for deepseek recipes (#5011)
* feat: Add support for deeepseek recipes * pylint * add unit test
1 parent 6d2dfa0 commit ae1146c

File tree

4 files changed

+117
-13
lines changed

4 files changed

+117
-13
lines changed

src/sagemaker/modules/train/sm_recipes/utils.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,27 @@ def _register_custom_resolvers():
125125
OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers))
126126

127127

128+
def _get_trainining_recipe_gpu_model_name_and_script(model_type: str):
129+
"""Get the model base name and script for the training recipe."""
130+
131+
model_type_to_script = {
132+
"llama_v3": ("llama", "llama_pretrain.py"),
133+
"mistral": ("mistral", "mistral_pretrain.py"),
134+
"mixtral": ("mixtral", "mixtral_pretrain.py"),
135+
"deepseek": ("deepseek", "deepseek_pretrain.py"),
136+
}
137+
138+
for key in model_type_to_script:
139+
if model_type.startswith(key):
140+
model_type = key
141+
break
142+
143+
if model_type not in model_type_to_script:
144+
raise ValueError(f"Model type {model_type} not supported")
145+
146+
return model_type_to_script[model_type][0], model_type_to_script[model_type][1]
147+
148+
128149
def _configure_gpu_args(
129150
training_recipes_cfg: Dict[str, Any],
130151
region_name: str,
@@ -140,24 +161,16 @@ def _configure_gpu_args(
140161
)
141162
_run_clone_command_silent(adapter_repo, recipe_train_dir.name)
142163

143-
model_type_to_entry = {
144-
"llama_v3": ("llama", "llama_pretrain.py"),
145-
"mistral": ("mistral", "mistral_pretrain.py"),
146-
"mixtral": ("mixtral", "mixtral_pretrain.py"),
147-
}
148-
149164
if "model" not in recipe:
150165
raise ValueError("Supplied recipe does not contain required field model.")
151166
if "model_type" not in recipe["model"]:
152167
raise ValueError("Supplied recipe does not contain required field model_type.")
153168
model_type = recipe["model"]["model_type"]
154-
if model_type not in model_type_to_entry:
155-
raise ValueError(f"Model type {model_type} not supported")
156169

157-
source_code.source_dir = os.path.join(
158-
recipe_train_dir.name, "examples", model_type_to_entry[model_type][0]
159-
)
160-
source_code.entry_script = model_type_to_entry[model_type][1]
170+
model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type)
171+
172+
source_code.source_dir = os.path.join(recipe_train_dir.name, "examples", model_base_name)
173+
source_code.entry_script = script
161174

162175
gpu_image_cfg = training_recipes_cfg.get("gpu_image")
163176
if isinstance(gpu_image_cfg, str):

src/sagemaker/pytorch/estimator.py

+7
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,20 @@ def _get_training_recipe_gpu_script(code_dir, recipe, source_dir):
9595
"llama_v3": ("llama", "llama_pretrain.py"),
9696
"mistral": ("mistral", "mistral_pretrain.py"),
9797
"mixtral": ("mixtral", "mixtral_pretrain.py"),
98+
"deepseek": ("deepseek", "deepseek_pretrain.py"),
9899
}
99100

100101
if "model" not in recipe:
101102
raise ValueError("Supplied recipe does not contain required field model.")
102103
if "model_type" not in recipe["model"]:
103104
raise ValueError("Supplied recipe does not contain required field model_type.")
104105
model_type = recipe["model"]["model_type"]
106+
107+
for key in model_type_to_script:
108+
if model_type.startswith(key):
109+
model_type = key
110+
break
111+
105112
if model_type not in model_type_to_script:
106113
raise ValueError(f"Model type {model_type} not supported")
107114

tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py

+35
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
_load_recipes_cfg,
2727
_configure_gpu_args,
2828
_configure_trainium_args,
29+
_get_trainining_recipe_gpu_model_name_and_script,
2930
)
3031
from sagemaker.modules.utils import _run_clone_command_silent
3132
from sagemaker.modules.configs import Compute
@@ -178,3 +179,37 @@ def test_get_args_from_recipe_compute(
178179
assert mock_gpu_args.call_count == 0
179180
assert mock_trainium_args.call_count == 0
180181
assert args is None
182+
183+
@pytest.mark.parametrize(
184+
"test_case",
185+
[
186+
{
187+
"model_type": "llama_v3",
188+
"script": "llama_pretrain.py",
189+
"model_base_name": "llama_v3",
190+
},
191+
{
192+
"model_type": "mistral",
193+
"script": "mistral_pretrain.py",
194+
"model_base_name": "mistral",
195+
},
196+
{
197+
"model_type": "deepseek_llamav3",
198+
"script": "deepseek_pretrain.py",
199+
"model_base_name": "deepseek",
200+
},
201+
{
202+
"model_type": "deepseek_qwenv2",
203+
"script": "deepseek_pretrain.py",
204+
"model_base_name": "deepseek",
205+
},
206+
],
207+
)
208+
def test_get_trainining_recipe_gpu_model_name_and_script(test_case):
209+
model_type = test_case["model_type"]
210+
script = test_case["script"]
211+
model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(
212+
model_type, script
213+
)
214+
assert model_base_name == test_case["model_base_name"]
215+
assert script == test_case["script"]

tests/unit/test_pytorch.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from sagemaker import image_uris
2424
from sagemaker.pytorch import defaults
2525
from sagemaker.pytorch import PyTorch, PyTorchPredictor, PyTorchModel
26-
from sagemaker.pytorch.estimator import _get_training_recipe_image_uri
26+
from sagemaker.pytorch.estimator import (
27+
_get_training_recipe_image_uri,
28+
_get_training_recipe_gpu_script,
29+
)
2730
from sagemaker.instance_group import InstanceGroup
2831
from sagemaker.session_settings import SessionSettings
2932

@@ -1049,6 +1052,52 @@ def test_training_recipe_for_trainium(sagemaker_session):
10491052
assert pytorch.distribution == expected_distribution
10501053

10511054

1055+
@pytest.mark.parametrize(
1056+
"test_case",
1057+
[
1058+
{
1059+
"script": "llama_pretrain.py",
1060+
"recipe": {
1061+
"model": {
1062+
"model_type": "llama_v3",
1063+
},
1064+
},
1065+
},
1066+
{
1067+
"script": "mistral_pretrain.py",
1068+
"recipe": {
1069+
"model": {
1070+
"model_type": "mistral",
1071+
},
1072+
},
1073+
},
1074+
{
1075+
"script": "deepseek_pretrain.py",
1076+
"recipe": {
1077+
"model": {
1078+
"model_type": "deepseek_llamav3",
1079+
},
1080+
},
1081+
},
1082+
{
1083+
"script": "deepseek_pretrain.py",
1084+
"recipe": {
1085+
"model": {
1086+
"model_type": "deepseek_qwenv2",
1087+
},
1088+
},
1089+
},
1090+
],
1091+
)
1092+
@patch("shutil.copyfile")
1093+
def test_get_training_recipe_gpu_script(mock_copyfile, test_case):
1094+
script = test_case["script"]
1095+
recipe = test_case["recipe"]
1096+
mock_copyfile.return_value = None
1097+
1098+
assert _get_training_recipe_gpu_script("code_dir", recipe, "source_dir") == script
1099+
1100+
10521101
def test_training_recipe_for_trainium_custom_source_dir(sagemaker_session):
10531102
container_log_level = '"logging.INFO"'
10541103

0 commit comments

Comments
 (0)