Skip to content

Commit 0eb8359

Browse files
committed
add unit test
1 parent 0bd5104 commit 0eb8359

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

tests/unit/test_pytorch.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from sagemaker import image_uris
2424
from sagemaker.pytorch import defaults
2525
from sagemaker.pytorch import PyTorch, PyTorchPredictor, PyTorchModel
26-
from sagemaker.pytorch.estimator import _get_training_recipe_image_uri
26+
from sagemaker.pytorch.estimator import (
27+
_get_training_recipe_image_uri,
28+
_get_training_recipe_gpu_script,
29+
)
2730
from sagemaker.instance_group import InstanceGroup
2831
from sagemaker.session_settings import SessionSettings
2932

@@ -1049,6 +1052,52 @@ def test_training_recipe_for_trainium(sagemaker_session):
10491052
assert pytorch.distribution == expected_distribution
10501053

10511054

1055+
@pytest.mark.parametrize(
1056+
"test_case",
1057+
[
1058+
{
1059+
"script": "llama_pretrain.py",
1060+
"recipe": {
1061+
"model": {
1062+
"model_type": "llama_v3",
1063+
},
1064+
},
1065+
},
1066+
{
1067+
"script": "mistral_pretrain.py",
1068+
"recipe": {
1069+
"model": {
1070+
"model_type": "mistral",
1071+
},
1072+
},
1073+
},
1074+
{
1075+
"script": "deepseek_pretrain.py",
1076+
"recipe": {
1077+
"model": {
1078+
"model_type": "deepseek_llamav3",
1079+
},
1080+
},
1081+
},
1082+
{
1083+
"script": "deepseek_pretrain.py",
1084+
"recipe": {
1085+
"model": {
1086+
"model_type": "deepseek_qwenv2",
1087+
},
1088+
},
1089+
},
1090+
],
1091+
)
1092+
@patch("shutil.copyfile")
1093+
def test_get_training_recipe_gpu_script(mock_copyfile, test_case):
1094+
script = test_case["script"]
1095+
recipe = test_case["recipe"]
1096+
mock_copyfile.return_value = None
1097+
1098+
assert _get_training_recipe_gpu_script("code_dir", recipe, "source_dir") == script
1099+
1100+
10521101
def test_training_recipe_for_trainium_custom_source_dir(sagemaker_session):
10531102
container_log_level = '"logging.INFO"'
10541103

0 commit comments

Comments
 (0)