Skip to content

Commit e5b38aa

Browse files
committed
fix: Model server override logic
1 parent 382fde1 commit e5b38aa

File tree

2 files changed

+175
-32
lines changed

2 files changed

+175
-32
lines changed

src/sagemaker/serve/builder/model_builder.py

+54-28
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,15 @@
8888

8989
logger = logging.getLogger(__name__)
9090

91-
supported_model_server = {
91+
# Any new server type should be added here
92+
supported_model_servers = {
9293
ModelServer.TORCHSERVE,
9394
ModelServer.TRITON,
9495
ModelServer.DJL_SERVING,
9596
ModelServer.TENSORFLOW_SERVING,
97+
ModelServer.MMS,
98+
ModelServer.TGI,
99+
ModelServer.TEI,
96100
}
97101

98102

@@ -281,31 +285,6 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
281285
},
282286
)
283287

284-
def _build_validations(self):
285-
"""Placeholder docstring"""
286-
# TODO: Beta validations - remove after the launch
287-
if self.mode == Mode.IN_PROCESS:
288-
raise ValueError("IN_PROCESS mode is not supported yet!")
289-
290-
if self.inference_spec and self.model:
291-
raise ValueError("Cannot have both the Model and Inference spec in the builder")
292-
293-
if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None:
294-
raise ValueError(
295-
"Model_server must be set when non-first-party image_uri is set. "
296-
+ "Supported model servers: %s" % supported_model_server
297-
)
298-
299-
# Set TorchServe as default model server
300-
if not self.model_server:
301-
self.model_server = ModelServer.TORCHSERVE
302-
303-
if self.model_server not in supported_model_server:
304-
raise ValueError(
305-
"%s is not supported yet! Supported model servers: %s"
306-
% (self.model_server, supported_model_server)
307-
)
308-
309288
def _save_model_inference_spec(self):
310289
"""Placeholder docstring"""
311290
# check if path exists and create if not
@@ -748,6 +727,11 @@ def build( # pylint: disable=R0911
748727
self._initialize_for_mlflow()
749728
_validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR"))
750729

730+
self._build_validations()
731+
732+
if self.model_server:
733+
return self._build_for_model_server()
734+
751735
if isinstance(self.model, str):
752736
model_task = None
753737
if self.model_metadata:
@@ -779,7 +763,39 @@ def build( # pylint: disable=R0911
779763
else:
780764
return self._build_for_transformers()
781765

782-
self._build_validations()
766+
# Set TorchServe as default model server
767+
if not self.model_server:
768+
self.model_server = ModelServer.TORCHSERVE
769+
return self._build_for_torchserve()
770+
771+
raise ValueError("%s model server is not supported" % self.model_server)
772+
773+
def _build_validations(self):
774+
"""Validations needed for model server overrides, or auto-detection or fallback"""
775+
if self.mode == Mode.IN_PROCESS:
776+
raise ValueError("IN_PROCESS mode is not supported yet!")
777+
778+
if self.inference_spec and self.model:
779+
raise ValueError("Can only set one of the following: model, inference_spec.")
780+
781+
if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None:
782+
raise ValueError(
783+
"Model_server must be set when non-first-party image_uri is set. "
784+
+ "Supported model servers: %s" % supported_model_servers
785+
)
786+
787+
def _build_for_model_server(self): # pylint: disable=R0911, R1710
788+
"""Model server overrides"""
789+
if self.model_server not in supported_model_servers:
790+
raise ValueError(
791+
"%s is not supported yet! Supported model servers: %s"
792+
% (self.model_server, supported_model_servers)
793+
)
794+
795+
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
796+
797+
if not self.model and not mlflow_path:
798+
raise ValueError("Missing required parameter `model` or 'ml_flow' path")
783799

784800
if self.model_server == ModelServer.TORCHSERVE:
785801
return self._build_for_torchserve()
@@ -790,7 +806,17 @@ def build( # pylint: disable=R0911
790806
if self.model_server == ModelServer.TENSORFLOW_SERVING:
791807
return self._build_for_tensorflow_serving()
792808

793-
raise ValueError("%s model server is not supported" % self.model_server)
809+
if self.model_server == ModelServer.DJL_SERVING:
810+
return self._build_for_djl()
811+
812+
if self.model_server == ModelServer.TEI:
813+
return self._build_for_tei()
814+
815+
if self.model_server == ModelServer.TGI:
816+
return self._build_for_tgi()
817+
818+
if self.model_server == ModelServer.MMS:
819+
return self._build_for_transformers()
794820

795821
def save(
796822
self,

tests/unit/sagemaker/serve/builder/test_model_builder.py

+121-4
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,14 @@
4949
mock_secret_key = "mock_secret_key"
5050
mock_instance_type = "mock instance type"
5151

52-
supported_model_server = {
52+
supported_model_servers = {
5353
ModelServer.TORCHSERVE,
5454
ModelServer.TRITON,
5555
ModelServer.DJL_SERVING,
5656
ModelServer.TENSORFLOW_SERVING,
57+
ModelServer.MMS,
58+
ModelServer.TGI,
59+
ModelServer.TEI,
5760
}
5861

5962
mock_session = MagicMock()
@@ -77,7 +80,7 @@ def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSet
7780
builder = ModelBuilder(inference_spec="some value", model=Mock(spec=object))
7881
self.assertRaisesRegex(
7982
Exception,
80-
"Cannot have both the Model and Inference spec in the builder",
83+
"Can only set one of the following: model, inference_spec.",
8184
builder.build,
8285
Mode.SAGEMAKER_ENDPOINT,
8386
mock_role_arn,
@@ -90,7 +93,7 @@ def test_validation_unsupported_model_server_type(self, mock_serveSettings):
9093
self.assertRaisesRegex(
9194
Exception,
9295
"%s is not supported yet! Supported model servers: %s"
93-
% (builder.model_server, supported_model_server),
96+
% (builder.model_server, supported_model_servers),
9497
builder.build,
9598
Mode.SAGEMAKER_ENDPOINT,
9699
mock_role_arn,
@@ -103,7 +106,7 @@ def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings
103106
self.assertRaisesRegex(
104107
Exception,
105108
"Model_server must be set when non-first-party image_uri is set. "
106-
+ "Supported model servers: %s" % supported_model_server,
109+
+ "Supported model servers: %s" % supported_model_servers,
107110
builder.build,
108111
Mode.SAGEMAKER_ENDPOINT,
109112
mock_role_arn,
@@ -124,6 +127,120 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set
124127
mock_session,
125128
)
126129

130+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
131+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl")
132+
def test_model_server_override_djl_with_model(self, mock_build_for_djl, mock_serve_settings):
133+
mock_setting_object = mock_serve_settings.return_value
134+
mock_setting_object.role_arn = mock_role_arn
135+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
136+
137+
builder = ModelBuilder(model_server=ModelServer.DJL_SERVING, model="gpt_llm_burt")
138+
builder.build(sagemaker_session=mock_session)
139+
140+
mock_build_for_djl.assert_called_once()
141+
142+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
143+
def test_model_server_override_djl_without_model_or_mlflow(self, mock_serve_settings):
144+
builder = ModelBuilder(
145+
model_server=ModelServer.DJL_SERVING, model=None, inference_spec=None
146+
)
147+
self.assertRaisesRegex(
148+
Exception,
149+
"Missing required parameter `model` or 'ml_flow' path",
150+
builder.build,
151+
Mode.SAGEMAKER_ENDPOINT,
152+
mock_role_arn,
153+
mock_session,
154+
)
155+
156+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
157+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_torchserve")
158+
def test_model_server_override_torchserve_with_model(
159+
self, mock_build_for_ts, mock_serve_settings
160+
):
161+
mock_setting_object = mock_serve_settings.return_value
162+
mock_setting_object.role_arn = mock_role_arn
163+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
164+
165+
builder = ModelBuilder(model_server=ModelServer.TORCHSERVE, model="gpt_llm_burt")
166+
builder.build(sagemaker_session=mock_session)
167+
168+
mock_build_for_ts.assert_called_once()
169+
170+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
171+
def test_model_server_override_torchserve_without_model_or_mlflow(self, mock_serve_settings):
172+
builder = ModelBuilder(model_server=ModelServer.TORCHSERVE)
173+
self.assertRaisesRegex(
174+
Exception,
175+
"Missing required parameter `model` or 'ml_flow' path",
176+
builder.build,
177+
Mode.SAGEMAKER_ENDPOINT,
178+
mock_role_arn,
179+
mock_session,
180+
)
181+
182+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
183+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_triton")
184+
def test_model_server_override_triton_with_model(self, mock_build_for_ts, mock_serve_settings):
185+
mock_setting_object = mock_serve_settings.return_value
186+
mock_setting_object.role_arn = mock_role_arn
187+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
188+
189+
builder = ModelBuilder(model_server=ModelServer.TRITON, model="gpt_llm_burt")
190+
builder.build(sagemaker_session=mock_session)
191+
192+
mock_build_for_ts.assert_called_once()
193+
194+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
195+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tensorflow_serving")
196+
def test_model_server_override_tensor_with_model(self, mock_build_for_ts, mock_serve_settings):
197+
mock_setting_object = mock_serve_settings.return_value
198+
mock_setting_object.role_arn = mock_role_arn
199+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
200+
201+
builder = ModelBuilder(model_server=ModelServer.TENSORFLOW_SERVING, model="gpt_llm_burt")
202+
builder.build(sagemaker_session=mock_session)
203+
204+
mock_build_for_ts.assert_called_once()
205+
206+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
207+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei")
208+
def test_model_server_override_tei_with_model(self, mock_build_for_ts, mock_serve_settings):
209+
mock_setting_object = mock_serve_settings.return_value
210+
mock_setting_object.role_arn = mock_role_arn
211+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
212+
213+
builder = ModelBuilder(model_server=ModelServer.TEI, model="gpt_llm_burt")
214+
builder.build(sagemaker_session=mock_session)
215+
216+
mock_build_for_ts.assert_called_once()
217+
218+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
219+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi")
220+
def test_model_server_override_tgi_with_model(self, mock_build_for_ts, mock_serve_settings):
221+
mock_setting_object = mock_serve_settings.return_value
222+
mock_setting_object.role_arn = mock_role_arn
223+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
224+
225+
builder = ModelBuilder(model_server=ModelServer.TGI, model="gpt_llm_burt")
226+
builder.build(sagemaker_session=mock_session)
227+
228+
mock_build_for_ts.assert_called_once()
229+
230+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
231+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers")
232+
def test_model_server_override_transformers_with_model(
233+
self, mock_build_for_ts, mock_serve_settings
234+
):
235+
mock_setting_object = mock_serve_settings.return_value
236+
mock_setting_object.role_arn = mock_role_arn
237+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
238+
239+
builder = ModelBuilder(model_server=ModelServer.MMS, model="gpt_llm_burt")
240+
builder.build(sagemaker_session=mock_session)
241+
242+
mock_build_for_ts.assert_called_once()
243+
127244
@patch("os.makedirs", Mock())
128245
@patch("sagemaker.serve.builder.model_builder._detect_framework_and_version")
129246
@patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve")

0 commit comments

Comments
 (0)