Skip to content

Commit 7a97cc0

Browse files
Captainiabenieric
authored andcommitted
chore: require config name and instance type in set_deployment_config (aws#4625)
* require config_name and instance_type in set config * docstring * add supported instance types check * add more tests * format * fix tests
1 parent ef6a6e1 commit 7a97cc0

File tree

3 files changed

+72
-31
lines changed

3 files changed

+72
-31
lines changed

src/sagemaker/jumpstart/factory/model.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,11 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
544544

545545

546546
def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
547-
"""Sets default config name to the kwargs. Returns full kwargs."""
547+
"""Sets default config name to the kwargs. Returns full kwargs.
548+
549+
Raises:
550+
ValueError: If the instance_type is not supported with the current config.
551+
"""
548552

549553
specs = verify_model_region_and_return_specs(
550554
model_id=kwargs.model_id,
@@ -565,6 +569,22 @@ def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod
565569
kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name
566570
)
567571

572+
if not kwargs.config_name:
573+
return kwargs
574+
575+
if kwargs.config_name not in set(specs.inference_configs.configs.keys()):
576+
raise ValueError(
577+
f"Config {kwargs.config_name} is not supported for model {kwargs.model_id}."
578+
)
579+
580+
resolved_config = specs.inference_configs.configs[kwargs.config_name].resolved_config
581+
supported_instance_types = resolved_config.get("supported_inference_instance_types", [])
582+
if kwargs.instance_type not in supported_instance_types:
583+
raise ValueError(
584+
f"Instance type {kwargs.instance_type} "
585+
f"is not supported for config {kwargs.config_name}."
586+
)
587+
568588
return kwargs
569589

570590

src/sagemaker/jumpstart/model.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -429,16 +429,22 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload:
429429
sagemaker_session=self.sagemaker_session,
430430
)
431431

432-
def set_deployment_config(self, config_name: Optional[str]) -> None:
432+
def set_deployment_config(self, config_name: str, instance_type: str) -> None:
433433
"""Sets the deployment config to apply to the model.
434434
435435
Args:
436-
config_name (Optional[str]):
437-
The name of the deployment config. Set to None to unset
438-
any existing config that is applied to the model.
436+
config_name (str):
437+
The name of the deployment config to apply to the model.
438+
Call list_deployment_configs to see the list of config names.
439+
instance_type (str):
440+
The instance_type that the model will use after setting
441+
the config.
439442
"""
440443
self.__init__(
441-
model_id=self.model_id, model_version=self.model_version, config_name=config_name
444+
model_id=self.model_id,
445+
model_version=self.model_version,
446+
instance_type=instance_type,
447+
config_name=config_name,
442448
)
443449

444450
@property

tests/unit/sagemaker/jumpstart/model/test_model.py

+40-25
Original file line numberDiff line numberDiff line change
@@ -1614,7 +1614,25 @@ def test_model_set_deployment_config(
16141614
mock_get_model_specs.reset_mock()
16151615
mock_model_deploy.reset_mock()
16161616
mock_get_model_specs.side_effect = get_prototype_spec_with_configs
1617-
model.set_deployment_config("neuron-inference")
1617+
model.set_deployment_config("neuron-inference", "ml.inf2.2xlarge")
1618+
1619+
assert model.config_name == "neuron-inference"
1620+
1621+
model.deploy()
1622+
1623+
mock_model_deploy.assert_called_once_with(
1624+
initial_instance_count=1,
1625+
instance_type="ml.inf2.2xlarge",
1626+
tags=[
1627+
{"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
1628+
{"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
1629+
{"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"},
1630+
],
1631+
wait=True,
1632+
endpoint_logging=False,
1633+
)
1634+
mock_model_deploy.reset_mock()
1635+
model.set_deployment_config("neuron-inference", "ml.inf2.xlarge")
16181636

16191637
assert model.config_name == "neuron-inference"
16201638

@@ -1640,15 +1658,15 @@ def test_model_set_deployment_config(
16401658
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
16411659
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
16421660
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
1643-
def test_model_unset_deployment_config(
1661+
def test_model_set_deployment_config_incompatible_instance_type_or_name(
16441662
self,
16451663
mock_model_deploy: mock.Mock,
16461664
mock_get_model_specs: mock.Mock,
16471665
mock_session: mock.Mock,
16481666
mock_get_manifest: mock.Mock,
16491667
mock_get_jumpstart_configs: mock.Mock,
16501668
):
1651-
mock_get_model_specs.side_effect = get_prototype_spec_with_configs
1669+
mock_get_model_specs.side_effect = get_prototype_model_spec
16521670
mock_get_manifest.side_effect = (
16531671
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
16541672
)
@@ -1658,40 +1676,40 @@ def test_model_unset_deployment_config(
16581676

16591677
mock_session.return_value = sagemaker_session
16601678

1661-
model = JumpStartModel(model_id=model_id, config_name="neuron-inference")
1679+
model = JumpStartModel(model_id=model_id)
16621680

1663-
assert model.config_name == "neuron-inference"
1681+
assert model.config_name is None
16641682

16651683
model.deploy()
16661684

16671685
mock_model_deploy.assert_called_once_with(
16681686
initial_instance_count=1,
1669-
instance_type="ml.inf2.xlarge",
1687+
instance_type="ml.p2.xlarge",
16701688
tags=[
16711689
{"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
16721690
{"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
1673-
{"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"},
16741691
],
16751692
wait=True,
16761693
endpoint_logging=False,
16771694
)
16781695

16791696
mock_get_model_specs.reset_mock()
16801697
mock_model_deploy.reset_mock()
1681-
mock_get_model_specs.side_effect = get_prototype_model_spec
1682-
model.set_deployment_config(None)
1683-
1684-
model.deploy()
1698+
mock_get_model_specs.side_effect = get_prototype_spec_with_configs
1699+
with pytest.raises(ValueError) as error:
1700+
model.set_deployment_config("neuron-inference", "ml.inf2.32xlarge")
1701+
assert (
1702+
"Instance type ml.inf2.32xlarge is not supported for config neuron-inference."
1703+
in str(error)
1704+
)
16851705

1686-
mock_model_deploy.assert_called_once_with(
1687-
initial_instance_count=1,
1688-
instance_type="ml.p2.xlarge",
1689-
tags=[
1690-
{"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
1691-
{"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
1692-
],
1693-
wait=True,
1694-
endpoint_logging=False,
1706+
with pytest.raises(ValueError) as error:
1707+
model.set_deployment_config("neuron-inference-unknown-name", "ml.inf2.32xlarge")
1708+
assert (
1709+
"Cannot find Jumpstart config name neuron-inference-unknown-name. "
1710+
"List of config names that is supported by the model: "
1711+
"['neuron-inference', 'neuron-inference-budget', 'gpu-inference-budget', 'gpu-inference']"
1712+
in str(error)
16951713
)
16961714

16971715
@mock.patch("sagemaker.jumpstart.model.get_init_kwargs")
@@ -1813,6 +1831,7 @@ def test_model_retrieve_deployment_config(
18131831

18141832
expected = get_base_deployment_configs()[0]
18151833
config_name = expected.get("DeploymentConfigName")
1834+
instance_type = expected.get("InstanceType")
18161835
mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(
18171836
model_id, config_name
18181837
)
@@ -1821,17 +1840,13 @@ def test_model_retrieve_deployment_config(
18211840

18221841
model = JumpStartModel(model_id=model_id)
18231842

1824-
model.set_deployment_config(config_name)
1843+
model.set_deployment_config(config_name, instance_type)
18251844

18261845
self.assertEqual(model.deployment_config, expected)
18271846

18281847
mock_get_init_kwargs.reset_mock()
18291848
mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)
18301849

1831-
# Unset
1832-
model.set_deployment_config(None)
1833-
self.assertIsNone(model.deployment_config)
1834-
18351850
@mock.patch("sagemaker.jumpstart.model.get_init_kwargs")
18361851
@mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
18371852
@mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour")

0 commit comments

Comments
 (0)