Skip to content

Commit b4e983a

Browse files
author
Jonathan Makunga
committed
Test backward compatibility
1 parent dc2f5b7 commit b4e983a

File tree

2 files changed

+75
-9
lines changed

2 files changed

+75
-9
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def set_deployment_config(self, config_name: Optional[str]) -> None:
440440
any existing config that is applied to the model.
441441
"""
442442
if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
443-
raise Exception("Cannot set deployment config to an uninitialized model")
443+
raise Exception("Cannot set deployment config to an uninitialized model.")
444444

445445
self.pysdk_model.set_deployment_config(config_name)
446446

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,80 @@ def test_get_deployment_config(
718718

719719
self.assertEqual(builder.get_deployment_config(), expected)
720720

721+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
722+
@patch(
723+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
724+
return_value=True,
725+
)
726+
@patch(
727+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
728+
return_value=MagicMock(),
729+
)
730+
@patch(
731+
"sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
732+
return_value=({"model_type": "t5", "n_head": 71}, True),
733+
)
734+
@patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
735+
@patch(
736+
"sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
737+
)
738+
def test_set_deployment_config(
739+
self,
740+
mock_get_nb_instance,
741+
mock_get_ram_usage_mb,
742+
mock_prepare_for_tgi,
743+
mock_pre_trained_model,
744+
mock_is_jumpstart_model,
745+
mock_telemetry,
746+
):
747+
builder = ModelBuilder(
748+
model="facebook/galactica-mock-model-id",
749+
schema_builder=mock_schema_builder,
750+
)
751+
752+
mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
753+
754+
builder.build()
755+
builder.set_deployment_config("config-1")
756+
757+
mock_pre_trained_model.return_value.set_deployment_config.assert_called_with("config-1")
758+
759+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
760+
@patch(
761+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
762+
return_value=True,
763+
)
764+
@patch(
765+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
766+
return_value=MagicMock(),
767+
)
768+
@patch(
769+
"sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
770+
return_value=({"model_type": "t5", "n_head": 71}, True),
771+
)
772+
@patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
773+
@patch(
774+
"sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
775+
)
776+
def test_set_deployment_config_ex(
777+
self,
778+
mock_get_nb_instance,
779+
mock_get_ram_usage_mb,
780+
mock_prepare_for_tgi,
781+
mock_pre_trained_model,
782+
mock_is_jumpstart_model,
783+
mock_telemetry,
784+
):
785+
mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
786+
787+
self.assertRaisesRegex(
788+
Exception,
789+
"Cannot set deployment config to an uninitialized model.",
790+
lambda: ModelBuilder(
791+
model="facebook/galactica-mock-model-id", schema_builder=mock_schema_builder
792+
).set_deployment_config("config-2"),
793+
)
794+
721795
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
722796
@patch(
723797
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
@@ -759,11 +833,3 @@ def test_display_benchmark_metrics(
759833
builder.display_benchmark_metrics()
760834

761835
mock_pre_trained_model.return_value.display_benchmark_metrics.assert_called_once()
762-
763-
def test_display_benchmark_metrics_ex(self):
764-
self.assertRaises(
765-
Exception,
766-
lambda: ModelBuilder(
767-
model="facebook/galactica-mock-model-id", schema_builder=mock_schema_builder
768-
).display_benchmark_metrics(),
769-
)

0 commit comments

Comments
 (0)