Skip to content

fix: Model server override logic #4733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 56 additions & 28 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,15 @@

logger = logging.getLogger(__name__)

supported_model_server = {
# Any new server type should be added here
supported_model_servers = {
ModelServer.TORCHSERVE,
ModelServer.TRITON,
ModelServer.DJL_SERVING,
ModelServer.TENSORFLOW_SERVING,
ModelServer.MMS,
ModelServer.TGI,
ModelServer.TEI,
}


Expand Down Expand Up @@ -288,31 +292,6 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
},
)

def _build_validations(self):
"""Placeholder docstring"""
# TODO: Beta validations - remove after the launch
if self.mode == Mode.IN_PROCESS:
raise ValueError("IN_PROCESS mode is not supported yet!")

if self.inference_spec and self.model:
raise ValueError("Cannot have both the Model and Inference spec in the builder")

if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None:
raise ValueError(
"Model_server must be set when non-first-party image_uri is set. "
+ "Supported model servers: %s" % supported_model_server
)

# Set TorchServe as default model server
if not self.model_server:
self.model_server = ModelServer.TORCHSERVE

if self.model_server not in supported_model_server:
raise ValueError(
"%s is not supported yet! Supported model servers: %s"
% (self.model_server, supported_model_server)
)

def _save_model_inference_spec(self):
"""Placeholder docstring"""
# check if path exists and create if not
Expand Down Expand Up @@ -839,6 +818,11 @@ def build( # pylint: disable=R0911

self._handle_mlflow_input()

self._build_validations()

if self.model_server:
return self._build_for_model_server()

if isinstance(self.model, str):
model_task = None
if self.model_metadata:
Expand Down Expand Up @@ -870,7 +854,41 @@ def build( # pylint: disable=R0911
else:
return self._build_for_transformers()

self._build_validations()
# Set TorchServe as default model server
if not self.model_server:
self.model_server = ModelServer.TORCHSERVE
return self._build_for_torchserve()

raise ValueError("%s model server is not supported" % self.model_server)

def _build_validations(self):
"""Validations needed for model server overrides, or auto-detection or fallback"""
if self.mode == Mode.IN_PROCESS:
raise ValueError("IN_PROCESS mode is not supported yet!")

if self.inference_spec and self.model:
raise ValueError("Can only set one of the following: model, inference_spec.")

if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None:
raise ValueError(
"Model_server must be set when non-first-party image_uri is set. "
+ "Supported model servers: %s" % supported_model_servers
)

def _build_for_model_server(self): # pylint: disable=R0911, R1710
"""Model server overrides"""
if self.model_server not in supported_model_servers:
raise ValueError(
"%s is not supported yet! Supported model servers: %s"
% (self.model_server, supported_model_servers)
)

mlflow_path = None
if self.model_metadata:
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)

if not self.model and not mlflow_path:
raise ValueError("Missing required parameter `model` or 'ml_flow' path")

if self.model_server == ModelServer.TORCHSERVE:
return self._build_for_torchserve()
Expand All @@ -881,7 +899,17 @@ def build( # pylint: disable=R0911
if self.model_server == ModelServer.TENSORFLOW_SERVING:
return self._build_for_tensorflow_serving()

raise ValueError("%s model server is not supported" % self.model_server)
if self.model_server == ModelServer.DJL_SERVING:
return self._build_for_djl()

if self.model_server == ModelServer.TEI:
return self._build_for_tei()

if self.model_server == ModelServer.TGI:
return self._build_for_tgi()

if self.model_server == ModelServer.MMS:
return self._build_for_transformers()

def save(
self,
Expand Down
125 changes: 121 additions & 4 deletions tests/unit/sagemaker/serve/builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@
mock_secret_key = "mock_secret_key"
mock_instance_type = "mock instance type"

supported_model_server = {
supported_model_servers = {
ModelServer.TORCHSERVE,
ModelServer.TRITON,
ModelServer.DJL_SERVING,
ModelServer.TENSORFLOW_SERVING,
ModelServer.MMS,
ModelServer.TGI,
ModelServer.TEI,
}

mock_session = MagicMock()
Expand All @@ -78,7 +81,7 @@ def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSet
builder = ModelBuilder(inference_spec="some value", model=Mock(spec=object))
self.assertRaisesRegex(
Exception,
"Cannot have both the Model and Inference spec in the builder",
"Can only set one of the following: model, inference_spec.",
builder.build,
Mode.SAGEMAKER_ENDPOINT,
mock_role_arn,
Expand All @@ -91,7 +94,7 @@ def test_validation_unsupported_model_server_type(self, mock_serveSettings):
self.assertRaisesRegex(
Exception,
"%s is not supported yet! Supported model servers: %s"
% (builder.model_server, supported_model_server),
% (builder.model_server, supported_model_servers),
builder.build,
Mode.SAGEMAKER_ENDPOINT,
mock_role_arn,
Expand All @@ -104,7 +107,7 @@ def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings
self.assertRaisesRegex(
Exception,
"Model_server must be set when non-first-party image_uri is set. "
+ "Supported model servers: %s" % supported_model_server,
+ "Supported model servers: %s" % supported_model_servers,
builder.build,
Mode.SAGEMAKER_ENDPOINT,
mock_role_arn,
Expand All @@ -125,6 +128,120 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set
mock_session,
)

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl")
def test_model_server_override_djl_with_model(self, mock_build_for_djl, mock_serve_settings):
mock_setting_object = mock_serve_settings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

builder = ModelBuilder(model_server=ModelServer.DJL_SERVING, model="gpt_llm_burt")
builder.build(sagemaker_session=mock_session)

mock_build_for_djl.assert_called_once()

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
def test_model_server_override_djl_without_model_or_mlflow(self, mock_serve_settings):
builder = ModelBuilder(
model_server=ModelServer.DJL_SERVING, model=None, inference_spec=None
)
self.assertRaisesRegex(
Exception,
"Missing required parameter `model` or 'ml_flow' path",
builder.build,
Mode.SAGEMAKER_ENDPOINT,
mock_role_arn,
mock_session,
)

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_torchserve")
def test_model_server_override_torchserve_with_model(
self, mock_build_for_ts, mock_serve_settings
):
mock_setting_object = mock_serve_settings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

builder = ModelBuilder(model_server=ModelServer.TORCHSERVE, model="gpt_llm_burt")
builder.build(sagemaker_session=mock_session)

mock_build_for_ts.assert_called_once()

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
def test_model_server_override_torchserve_without_model_or_mlflow(self, mock_serve_settings):
builder = ModelBuilder(model_server=ModelServer.TORCHSERVE)
self.assertRaisesRegex(
Exception,
"Missing required parameter `model` or 'ml_flow' path",
builder.build,
Mode.SAGEMAKER_ENDPOINT,
mock_role_arn,
mock_session,
)

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_triton")
def test_model_server_override_triton_with_model(self, mock_build_for_ts, mock_serve_settings):
mock_setting_object = mock_serve_settings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

builder = ModelBuilder(model_server=ModelServer.TRITON, model="gpt_llm_burt")
builder.build(sagemaker_session=mock_session)

mock_build_for_ts.assert_called_once()

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tensorflow_serving")
def test_model_server_override_tensor_with_model(self, mock_build_for_ts, mock_serve_settings):
mock_setting_object = mock_serve_settings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

builder = ModelBuilder(model_server=ModelServer.TENSORFLOW_SERVING, model="gpt_llm_burt")
builder.build(sagemaker_session=mock_session)

mock_build_for_ts.assert_called_once()

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei")
def test_model_server_override_tei_with_model(self, mock_build_for_ts, mock_serve_settings):
mock_setting_object = mock_serve_settings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

builder = ModelBuilder(model_server=ModelServer.TEI, model="gpt_llm_burt")
builder.build(sagemaker_session=mock_session)

mock_build_for_ts.assert_called_once()

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi")
def test_model_server_override_tgi_with_model(self, mock_build_for_ts, mock_serve_settings):
mock_setting_object = mock_serve_settings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

builder = ModelBuilder(model_server=ModelServer.TGI, model="gpt_llm_burt")
builder.build(sagemaker_session=mock_session)

mock_build_for_ts.assert_called_once()

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers")
def test_model_server_override_transformers_with_model(
self, mock_build_for_ts, mock_serve_settings
):
mock_setting_object = mock_serve_settings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

builder = ModelBuilder(model_server=ModelServer.MMS, model="gpt_llm_burt")
builder.build(sagemaker_session=mock_session)

mock_build_for_ts.assert_called_once()

@patch("os.makedirs", Mock())
@patch("sagemaker.serve.builder.model_builder._detect_framework_and_version")
@patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve")
Expand Down