Skip to content

Commit 2c7f025

Browse files
committed
use jumpstart deployment config image as default optimization image
1 parent a58654e commit 2c7f025

File tree

4 files changed

+55
-2
lines changed

4 files changed

+55
-2
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ def _optimize_for_jumpstart(
829829
self.pysdk_model._enable_network_isolation = False
830830

831831
if quantization_config or sharding_config or is_compilation:
832-
return create_optimization_job_args
832+
return self._set_optimization_image_default(create_optimization_job_args)
833833
return None
834834

835835
def _is_gated_model(self, model=None) -> bool:
@@ -986,3 +986,24 @@ def _get_neuron_model_env_vars(
986986
)
987987
return job_model.env
988988
return None
989+
990+
def _set_optimization_image_default(
991+
self, create_optimization_job_args: Dict[str, Any]
992+
) -> Dict[str, Any]:
993+
"""Defaults the optimization image to the JumpStart deployment config default
994+
995+
Args:
996+
create_optimization_job_args (Dict[str, Any]): create optimization job request
997+
998+
Returns:
999+
Dict[str, Any]: create optimization job request with image uri default
1000+
"""
1001+
for optimization_config in create_optimization_job_args.get("OptimizationConfigs"):
1002+
model_compilation_config = optimization_config.get("ModelCompilationConfig")
1003+
if not model_compilation_config:
1004+
optimization_config["ModelCompilationConfig"] = {
1005+
"Image": self.pysdk_model.init_kwargs["image_uri"]
1006+
}
1007+
elif not model_compilation_config.get("Image"):
1008+
model_compilation_config["Image"] = self.pysdk_model.init_kwargs["image_uri"]
1009+
return create_optimization_job_args

tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py

+18
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
3232
iam_client = sagemaker_session.boto_session.client("iam")
3333
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]
3434

35+
sagemaker_session.sagemaker_client.create_optimization_job = MagicMock()
36+
3537
schema_builder = SchemaBuilder("test", "test")
3638
model_builder = ModelBuilder(
3739
model="meta-textgeneration-llama-3-1-8b-instruct",
@@ -50,6 +52,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
5052
accept_eula=True,
5153
)
5254

55+
assert not sagemaker_session.sagemaker_client.create_optimization_job.called
56+
5357
optimized_model.deploy()
5458

5559
mock_create_model.assert_called_once_with(
@@ -126,6 +130,13 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_
126130
accept_eula=True,
127131
)
128132

133+
assert (
134+
sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][
135+
"OptimizationConfigs"
136+
][0]["ModelCompilationConfig"]["Image"]
137+
is not None
138+
)
139+
129140
optimized_model.deploy(
130141
resources=ResourceRequirements(requests={"memory": 196608, "num_accelerators": 8})
131142
)
@@ -206,6 +217,13 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are
206217
accept_eula=True,
207218
)
208219

220+
assert (
221+
sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][
222+
"OptimizationConfigs"
223+
][0]["ModelCompilationConfig"]["Image"]
224+
is not None
225+
)
226+
209227
optimized_model.deploy()
210228

211229
mock_create_model.assert_called_once_with(

tests/unit/sagemaker/serve/builder/test_js_builder.py

+12
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,7 @@ def test_optimize_quantize_for_jumpstart(
11661166
mock_pysdk_model.image_uri = mock_tgi_image_uri
11671167
mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS
11681168
mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0]
1169+
mock_pysdk_model.init_kwargs = {"image_uri": "mock_js_image"}
11691170

11701171
sample_input = {
11711172
"inputs": "The diamondback terrapin or simply terrapin is a species "
@@ -1201,6 +1202,9 @@ def test_optimize_quantize_for_jumpstart(
12011202
)
12021203

12031204
self.assertIsNotNone(out_put)
1205+
self.assertEqual(
1206+
out_put["OptimizationConfigs"][0]["ModelCompilationConfig"]["Image"], "mock_js_image"
1207+
)
12041208

12051209
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
12061210
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
@@ -1287,6 +1291,7 @@ def test_optimize_quantize_and_compile_for_jumpstart(
12871291
mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0]
12881292
mock_pysdk_model.config_name = "config_name"
12891293
mock_pysdk_model._metadata_configs = {"config_name": mock_metadata_config}
1294+
mock_pysdk_model.init_kwargs = {"image_uri": "mock_js_image"}
12901295

12911296
sample_input = {
12921297
"inputs": "The diamondback terrapin or simply terrapin is a species "
@@ -1319,6 +1324,13 @@ def test_optimize_quantize_and_compile_for_jumpstart(
13191324
)
13201325

13211326
self.assertIsNotNone(out_put)
1327+
self.assertEqual(
1328+
out_put["OptimizationConfigs"][1]["ModelCompilationConfig"],
1329+
{
1330+
"Image": "mock_js_image",
1331+
"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"},
1332+
},
1333+
)
13221334

13231335
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
13241336
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)

tests/unit/sagemaker/serve/builder/test_model_builder.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3733,6 +3733,7 @@ def test_optimize_sharding_with_override_for_js(
37333733
pysdk_model.env = {"key": "val"}
37343734
pysdk_model._enable_network_isolation = True
37353735
pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None
3736+
pysdk_model.init_kwargs = {"image_uri": "mock_js_image"}
37363737

37373738
mock_build_for_jumpstart.side_effect = lambda **kwargs: pysdk_model
37383739
mock_prepare_for_mode.side_effect = lambda *args, **kwargs: (
@@ -3802,9 +3803,10 @@ def test_optimize_sharding_with_override_for_js(
38023803
DeploymentInstanceType="ml.g5.24xlarge",
38033804
OptimizationConfigs=[
38043805
{
3806+
"ModelCompilationConfig": {"Image": "mock_js_image"},
38053807
"ModelShardingConfig": {
38063808
"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}
3807-
}
3809+
},
38083810
}
38093811
],
38103812
OutputConfig={

0 commit comments

Comments
 (0)