Skip to content

ModelBuilder: Add functionalities to get and set deployment config. #4614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 25, 2024
32 changes: 29 additions & 3 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import absolute_import

from functools import lru_cache
from typing import Dict, List, Optional, Union, Any
from typing import Dict, List, Optional, Any, Union
import pandas as pd
from botocore.exceptions import ClientError

Expand Down Expand Up @@ -441,14 +441,23 @@ def set_deployment_config(self, config_name: Optional[str]) -> None:
model_id=self.model_id, model_version=self.model_version, config_name=config_name
)

@property
def deployment_config(self) -> Optional[Dict[str, Any]]:
"""The deployment config to apply to the model.

Returns:
Union[Dict[str, Any], None]: Deployment config to apply to this model.
"""
return self._retrieve_selected_deployment_config(self.config_name)

@property
def benchmark_metrics(self) -> pd.DataFrame:
"""Benchmark Metrics for deployment configs

Returns:
Metrics: Pandas DataFrame object.
"""
return pd.DataFrame(self._get_benchmark_data(self.config_name))
return pd.DataFrame(self._get_benchmarks_data(self.config_name))

def display_benchmark_metrics(self) -> None:
"""Display Benchmark Metrics for deployment configs."""
Expand Down Expand Up @@ -851,7 +860,7 @@ def register_deploy_wrapper(*args, **kwargs):
return model_package

@lru_cache
def _get_benchmark_data(self, config_name: str) -> Dict[str, List[str]]:
def _get_benchmarks_data(self, config_name: str) -> Dict[str, List[str]]:
"""Constructs deployment configs benchmark data.

Args:
Expand All @@ -864,6 +873,23 @@ def _get_benchmark_data(self, config_name: str) -> Dict[str, List[str]]:
config_name,
)

@lru_cache
def _retrieve_selected_deployment_config(self, config_name: str) -> Optional[Dict[str, Any]]:
"""Retrieve the deployment config to apply to the model.

Args:
config_name (str): The name of the selected deployment config.
Returns:
Union[Dict[str, Any], None]: The deployment config to apply to the model.
"""
if config_name is None:
return None

for deployment_config in self._deployment_configs:
if deployment_config.get("ConfigName") == config_name:
return deployment_config
return None

def _convert_to_deployment_config_metadata(
self, config_name: str, metadata_config: JumpStartMetadataConfig
) -> Dict[str, Any]:
Expand Down
38 changes: 36 additions & 2 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import copy
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import Type, Any, List, Dict
from typing import Type, Any, List, Dict, Optional
import logging

from sagemaker.model import Model
Expand Down Expand Up @@ -431,8 +431,35 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration
)

def set_deployment_config(self, config_name: Optional[str]) -> None:
"""Sets the deployment config to apply to the model.

Args:
config_name (Optional[str]):
The name of the deployment config. Set to None to unset
any existing config that is applied to the model.
"""
if self.pysdk_model is None:
self.pysdk_model = self._create_pre_trained_js_model()

self.pysdk_model.set_deployment_config(config_name)

def get_deployment_config(self) -> Optional[Dict[str, Any]]:
"""Gets the deployment config to apply to the model.

Returns:
Union[Dict[str, Any], None]: Deployment config to apply to this model.
"""
if self.pysdk_model is None:
self.pysdk_model = self._create_pre_trained_js_model()

return self.pysdk_model.deployment_config

def display_benchmark_metrics(self):
"""Display Markdown Benchmark Metrics for deployment configs."""
if self.pysdk_model is None:
self.pysdk_model = self._create_pre_trained_js_model()

self.pysdk_model.display_benchmark_metrics()

def list_deployment_configs(self) -> List[Dict[str, Any]]:
Expand All @@ -441,6 +468,9 @@ def list_deployment_configs(self) -> List[Dict[str, Any]]:
Returns:
List[Dict[str, Any]]: A list of deployment configs.
"""
if self.pysdk_model is None:
self.pysdk_model = self._create_pre_trained_js_model()

return self.pysdk_model.list_deployment_configs()

def _build_for_jumpstart(self):
Expand All @@ -449,7 +479,11 @@ def _build_for_jumpstart(self):
self.secret_key = None
self.jumpstart = True

pysdk_model = self._create_pre_trained_js_model()
pysdk_model = (
self.pysdk_model
if self.pysdk_model is not None
else self._create_pre_trained_js_model()
)

image_uri = pysdk_model.image_uri

Expand Down
48 changes: 48 additions & 0 deletions tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,6 +1765,54 @@ def test_model_list_deployment_configs_empty(

self.assertTrue(len(configs) == 0)

@mock.patch("sagemaker.jumpstart.model.get_init_kwargs")
@mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
@mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
def test_model_retrieve_deployment_config(
self,
mock_model_deploy: mock.Mock,
mock_get_model_specs: mock.Mock,
mock_session: mock.Mock,
mock_get_manifest: mock.Mock,
mock_get_instance_rate_per_hour: mock.Mock,
mock_verify_model_region_and_return_specs: mock.Mock,
mock_get_init_kwargs: mock.Mock,
):
model_id, _ = "pytorch-eqa-bert-base-cased", "*"

mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)
mock_verify_model_region_and_return_specs.side_effect = (
lambda *args, **kwargs: get_base_spec_with_prototype_configs()
)
mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: {
"name": "Instance Rate",
"unit": "USD/Hrs",
"value": "0.0083000000",
}
mock_get_model_specs.side_effect = get_prototype_spec_with_configs
mock_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)
mock_model_deploy.return_value = default_predictor

mock_session.return_value = sagemaker_session

model = JumpStartModel(model_id=model_id)

expected = get_base_deployment_configs()[0]
model.set_deployment_config(expected.get("ConfigName"))

self.assertEqual(model.deployment_config, expected)

# Unset
model.set_deployment_config(None)
self.assertIsNone(model.deployment_config)

@mock.patch("sagemaker.jumpstart.model.get_init_kwargs")
@mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
@mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour")
Expand Down
91 changes: 87 additions & 4 deletions tests/unit/sagemaker/serve/builder/test_js_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,13 +676,96 @@ def test_list_deployment_configs(
lambda: DEPLOYMENT_CONFIGS
)

model = builder.build()
builder.build()
builder.serve_settings.telemetry_opt_out = True

configs = model.list_deployment_configs()
configs = builder.list_deployment_configs()

self.assertEqual(configs, DEPLOYMENT_CONFIGS)

@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
return_value=True,
)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
return_value=MagicMock(),
)
@patch(
"sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
return_value=({"model_type": "t5", "n_head": 71}, True),
)
@patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
@patch(
"sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
)
def test_get_deployment_config(
self,
mock_get_nb_instance,
mock_get_ram_usage_mb,
mock_prepare_for_tgi,
mock_pre_trained_model,
mock_is_jumpstart_model,
mock_telemetry,
):
builder = ModelBuilder(
model="facebook/galactica-mock-model-id",
schema_builder=mock_schema_builder,
)

mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri

expected = DEPLOYMENT_CONFIGS[0]
mock_pre_trained_model.return_value.deployment_config = expected

builder.build()
builder.serve_settings.telemetry_opt_out = True

self.assertEqual(builder.get_deployment_config(), expected)

@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
return_value=True,
)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
return_value=MagicMock(),
)
@patch(
"sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
return_value=({"model_type": "t5", "n_head": 71}, True),
)
@patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
@patch(
"sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
)
def test_set_deployment_config(
self,
mock_get_nb_instance,
mock_get_ram_usage_mb,
mock_prepare_for_tgi,
mock_pre_trained_model,
mock_is_jumpstart_model,
mock_telemetry,
):
builder = ModelBuilder(
model="facebook/galactica-mock-model-id",
schema_builder=mock_schema_builder,
)

mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri

builder.build()
builder.serve_settings.telemetry_opt_out = True

builder.set_deployment_config("config_name")

mock_pre_trained_model.return_value.set_deployment_config.assert_called_once_with(
"config_name"
)

@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
Expand Down Expand Up @@ -719,7 +802,7 @@ def test_display_benchmark_metrics(
lambda *args, **kwargs: "metric data"
)

model = builder.build()
builder.build()
builder.serve_settings.telemetry_opt_out = True

model.display_benchmark_metrics()
builder.display_benchmark_metrics()
Loading