Skip to content

Commit 765e748

Browse files
committed
use jumpstart deployment config image as default optimization image
1 parent a58654e commit 765e748

File tree

4 files changed

+138
-5
lines changed

4 files changed

+138
-5
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

+106-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import re
1818
from abc import ABC, abstractmethod
1919
from datetime import datetime, timedelta
20-
from typing import Type, Any, List, Dict, Optional
20+
from typing import Type, Any, List, Dict, Optional, Tuple
2121
import logging
2222

2323
from botocore.exceptions import ClientError
@@ -829,7 +829,13 @@ def _optimize_for_jumpstart(
829829
self.pysdk_model._enable_network_isolation = False
830830

831831
if quantization_config or sharding_config or is_compilation:
832-
return create_optimization_job_args
832+
# only apply default image for vLLM usecases.
833+
# vLLM does not support compilation for now so skip on compilation
834+
return (
835+
create_optimization_job_args
836+
if is_compilation
837+
else self._set_optimization_image_default(create_optimization_job_args)
838+
)
833839
return None
834840

835841
def _is_gated_model(self, model=None) -> bool:
@@ -986,3 +992,101 @@ def _get_neuron_model_env_vars(
986992
)
987993
return job_model.env
988994
return None
995+
996+
def _set_optimization_image_default(
997+
self, create_optimization_job_args: Dict[str, Any]
998+
) -> Dict[str, Any]:
999+
"""Defaults the optimization image to the JumpStart deployment config default
1000+
1001+
Args:
1002+
create_optimization_job_args (Dict[str, Any]): create optimization job request
1003+
1004+
Returns:
1005+
Dict[str, Any]: create optimization job request with image uri default
1006+
"""
1007+
default_image = self._get_default_vllm_image(self.pysdk_model.init_kwargs["image_uri"])
1008+
1009+
# find the latest vLLM image version
1010+
for optimization_config in create_optimization_job_args.get("OptimizationConfigs"):
1011+
if optimization_config.get("ModelQuantizationConfig"):
1012+
model_quantization_config = optimization_config.get("ModelQuantizationConfig")
1013+
provided_image = model_quantization_config.get("Image")
1014+
if provided_image and self._compare_lmi_versions(default_image, provided_image):
1015+
default_image = provided_image
1016+
elif optimization_config.get("ModelShardingConfig"):
1017+
model_sharding_config = optimization_config.get("ModelShardingConfig")
1018+
provided_image = model_sharding_config.get("Image")
1019+
if provided_image and self._compare_lmi_versions(default_image, provided_image):
1020+
default_image = provided_image
1021+
1022+
# default to latest vLLM version
1023+
for optimization_config in create_optimization_job_args.get("OptimizationConfigs"):
1024+
if optimization_config.get("ModelQuantizationConfig"):
1025+
optimization_config.get("ModelQuantizationConfig")["Image"] = default_image
1026+
elif optimization_config.get("ModelShardingConfig"):
1027+
optimization_config.get("ModelShardingConfig")["Image"] = default_image
1028+
1029+
logger.info(f"Defaulting to {default_image} image for optimization")
1030+
1031+
return create_optimization_job_args
1032+
1033+
def _get_default_vllm_image(self, image: str) -> bool:
1034+
"""Ensures the minimum working image version for vLLM enabled optimization techniques
1035+
1036+
Args:
1037+
image (str): JumpStart provided default image
1038+
1039+
Returns:
1040+
str: minimum working image version
1041+
"""
1042+
dlc_name, _ = image.split(":")
1043+
major_version_number, _, _ = self._parse_lmi_version(image)
1044+
1045+
if int(major_version_number) < 13:
1046+
minimum_version_default = f"{dlc_name}:0.31.0-lmi13.0.0-cu124"
1047+
return minimum_version_default
1048+
return image
1049+
1050+
def _compare_lmi_versions(self, version: str, version_to_compare: str) -> bool:
1051+
"""LMI version comparator
1052+
1053+
Args:
1054+
version (str): current version
1055+
version_to_compare (str): version to compare to
1056+
1057+
Returns:
1058+
bool: if version_to_compare larger or equal to version
1059+
"""
1060+
parse_lmi_version = self._parse_lmi_version(version)
1061+
parse_lmi_version_to_compare = self._parse_lmi_version(version_to_compare)
1062+
1063+
# Check major version
1064+
if parse_lmi_version_to_compare[0] > parse_lmi_version[0]:
1065+
return True
1066+
# Check minor version
1067+
if parse_lmi_version_to_compare[0] == parse_lmi_version[0]:
1068+
if parse_lmi_version_to_compare[1] > parse_lmi_version[1]:
1069+
return True
1070+
if parse_lmi_version_to_compare[1] == parse_lmi_version[1]:
1071+
# Check patch version
1072+
if parse_lmi_version_to_compare[2] >= parse_lmi_version[2]:
1073+
return True
1074+
return False
1075+
return False
1076+
return False
1077+
1078+
def _parse_lmi_version(self, image: str) -> Tuple[int, int, int]:
1079+
"""Parse out LMI version
1080+
1081+
Args:
1082+
image (str): image to parse version out of
1083+
1084+
Returns:
1085+
Tuple[int, int, it]: LMI version split into major, minor, patch
1086+
"""
1087+
dlc_name, dlc_tag = image.split(":")
1088+
_, lmi_version, _ = dlc_tag.split("-")
1089+
major_version, minor_version, patch_version = lmi_version.split(".")
1090+
major_version_number = major_version[3:]
1091+
1092+
return (int(major_version_number), int(minor_version), int(patch_version))

tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py

+18
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
3232
iam_client = sagemaker_session.boto_session.client("iam")
3333
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]
3434

35+
sagemaker_session.sagemaker_client.create_optimization_job = MagicMock()
36+
3537
schema_builder = SchemaBuilder("test", "test")
3638
model_builder = ModelBuilder(
3739
model="meta-textgeneration-llama-3-1-8b-instruct",
@@ -50,6 +52,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
5052
accept_eula=True,
5153
)
5254

55+
assert not sagemaker_session.sagemaker_client.create_optimization_job.called
56+
5357
optimized_model.deploy()
5458

5559
mock_create_model.assert_called_once_with(
@@ -126,6 +130,13 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_
126130
accept_eula=True,
127131
)
128132

133+
assert (
134+
sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][
135+
"OptimizationConfigs"
136+
][0]["ModelShardingConfig"]["Image"]
137+
is not None
138+
)
139+
129140
optimized_model.deploy(
130141
resources=ResourceRequirements(requests={"memory": 196608, "num_accelerators": 8})
131142
)
@@ -206,6 +217,13 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are
206217
accept_eula=True,
207218
)
208219

220+
assert (
221+
sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][
222+
"OptimizationConfigs"
223+
][0]["ModelQuantizationConfig"]["Image"]
224+
is not None
225+
)
226+
209227
optimized_model.deploy()
210228

211229
mock_create_model.assert_called_once_with(

tests/unit/sagemaker/serve/builder/test_js_builder.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
"-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04"
7676
)
7777
mock_djl_image_uri = (
78-
"123456789712.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.24.0-neuronx-sdk2.14.1"
78+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124"
7979
)
8080

8181
mock_model_data = {
@@ -1166,6 +1166,7 @@ def test_optimize_quantize_for_jumpstart(
11661166
mock_pysdk_model.image_uri = mock_tgi_image_uri
11671167
mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS
11681168
mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0]
1169+
mock_pysdk_model.init_kwargs = {"image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124"}
11691170

11701171
sample_input = {
11711172
"inputs": "The diamondback terrapin or simply terrapin is a species "
@@ -1201,6 +1202,9 @@ def test_optimize_quantize_for_jumpstart(
12011202
)
12021203

12031204
self.assertIsNotNone(out_put)
1205+
self.assertEqual(
1206+
out_put["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124"
1207+
)
12041208

12051209
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
12061210
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
@@ -1287,6 +1291,7 @@ def test_optimize_quantize_and_compile_for_jumpstart(
12871291
mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0]
12881292
mock_pysdk_model.config_name = "config_name"
12891293
mock_pysdk_model._metadata_configs = {"config_name": mock_metadata_config}
1294+
mock_pysdk_model.init_kwargs = {"image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124"}
12901295

12911296
sample_input = {
12921297
"inputs": "The diamondback terrapin or simply terrapin is a species "
@@ -1319,6 +1324,8 @@ def test_optimize_quantize_and_compile_for_jumpstart(
13191324
)
13201325

13211326
self.assertIsNotNone(out_put)
1327+
self.assertIsNone(out_put["OptimizationConfigs"][1]["ModelCompilationConfig"].get("Image"))
1328+
self.assertIsNone(out_put["OptimizationConfigs"][0]["ModelQuantizationConfig"].get("Image"))
13221329

13231330
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
13241331
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
@@ -1640,6 +1647,7 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations(
16401647

16411648
mock_lmi_js_model = MagicMock()
16421649
mock_lmi_js_model.image_uri = mock_djl_image_uri
1650+
mock_lmi_js_model.init_kwargs = {"image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124"}
16431651
mock_lmi_js_model.env = {
16441652
"SAGEMAKER_PROGRAM": "inference.py",
16451653
"ENDPOINT_SERVER_TIMEOUT": "3600",
@@ -1718,6 +1726,7 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_over
17181726

17191727
mock_lmi_js_model = MagicMock()
17201728
mock_lmi_js_model.image_uri = mock_djl_image_uri
1729+
mock_lmi_js_model.init_kwargs = {"image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124"}
17211730
mock_lmi_js_model.env = {
17221731
"SAGEMAKER_PROGRAM": "inference.py",
17231732
"ENDPOINT_SERVER_TIMEOUT": "3600",

tests/unit/sagemaker/serve/builder/test_model_builder.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -3733,6 +3733,7 @@ def test_optimize_sharding_with_override_for_js(
37333733
pysdk_model.env = {"key": "val"}
37343734
pysdk_model._enable_network_isolation = True
37353735
pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None
3736+
pysdk_model.init_kwargs = {"image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124"}
37363737

37373738
mock_build_for_jumpstart.side_effect = lambda **kwargs: pysdk_model
37383739
mock_prepare_for_mode.side_effect = lambda *args, **kwargs: (
@@ -3803,8 +3804,9 @@ def test_optimize_sharding_with_override_for_js(
38033804
OptimizationConfigs=[
38043805
{
38053806
"ModelShardingConfig": {
3806-
"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}
3807-
}
3807+
"Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124",
3808+
"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"},
3809+
},
38083810
}
38093811
],
38103812
OutputConfig={

0 commit comments

Comments
 (0)