|
23 | 23 | from sagemaker import image_uris
|
24 | 24 | from sagemaker.pytorch import defaults
|
25 | 25 | from sagemaker.pytorch import PyTorch, PyTorchPredictor, PyTorchModel
|
26 |
| -from sagemaker.pytorch.estimator import _get_training_recipe_image_uri |
| 26 | +from sagemaker.pytorch.estimator import ( |
| 27 | + _get_training_recipe_image_uri, |
| 28 | + _get_training_recipe_gpu_script, |
| 29 | +) |
27 | 30 | from sagemaker.instance_group import InstanceGroup
|
28 | 31 | from sagemaker.session_settings import SessionSettings
|
29 | 32 |
|
@@ -1049,6 +1052,52 @@ def test_training_recipe_for_trainium(sagemaker_session):
|
1049 | 1052 | assert pytorch.distribution == expected_distribution
|
1050 | 1053 |
|
1051 | 1054 |
|
| 1055 | +@pytest.mark.parametrize( |
| 1056 | + "test_case", |
| 1057 | + [ |
| 1058 | + { |
| 1059 | + "script": "llama_pretrain.py", |
| 1060 | + "recipe": { |
| 1061 | + "model": { |
| 1062 | + "model_type": "llama_v3", |
| 1063 | + }, |
| 1064 | + }, |
| 1065 | + }, |
| 1066 | + { |
| 1067 | + "script": "mistral_pretrain.py", |
| 1068 | + "recipe": { |
| 1069 | + "model": { |
| 1070 | + "model_type": "mistral", |
| 1071 | + }, |
| 1072 | + }, |
| 1073 | + }, |
| 1074 | + { |
| 1075 | + "script": "deepseek_pretrain.py", |
| 1076 | + "recipe": { |
| 1077 | + "model": { |
| 1078 | + "model_type": "deepseek_llamav3", |
| 1079 | + }, |
| 1080 | + }, |
| 1081 | + }, |
| 1082 | + { |
| 1083 | + "script": "deepseek_pretrain.py", |
| 1084 | + "recipe": { |
| 1085 | + "model": { |
| 1086 | + "model_type": "deepseek_qwenv2", |
| 1087 | + }, |
| 1088 | + }, |
| 1089 | + }, |
| 1090 | + ], |
| 1091 | +) |
| 1092 | +@patch("shutil.copyfile") |
| 1093 | +def test_get_training_recipe_gpu_script(mock_copyfile, test_case): |
| 1094 | + script = test_case["script"] |
| 1095 | + recipe = test_case["recipe"] |
| 1096 | + mock_copyfile.return_value = None |
| 1097 | + |
| 1098 | + assert _get_training_recipe_gpu_script("code_dir", recipe, "source_dir") == script |
| 1099 | + |
| 1100 | + |
1052 | 1101 | def test_training_recipe_for_trainium_custom_source_dir(sagemaker_session):
|
1053 | 1102 | container_log_level = '"logging.INFO"'
|
1054 | 1103 |
|
|
0 commit comments