Skip to content

Commit 4496072

Browse files
samrudsbenieric
andauthored
fix: Model server override logic (#4733)
* fix: Model server override logic * Fix formatting --------- Co-authored-by: Erick Benitez-Ramos <[email protected]>
1 parent 0984e8d commit 4496072

File tree

2 files changed

+177
-32
lines changed

2 files changed

+177
-32
lines changed

src/sagemaker/serve/builder/model_builder.py

+56-28
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,15 @@
9494

9595
logger = logging.getLogger(__name__)
9696

97-
supported_model_server = {
97+
# Any new server type should be added here
98+
supported_model_servers = {
9899
ModelServer.TORCHSERVE,
99100
ModelServer.TRITON,
100101
ModelServer.DJL_SERVING,
101102
ModelServer.TENSORFLOW_SERVING,
103+
ModelServer.MMS,
104+
ModelServer.TGI,
105+
ModelServer.TEI,
102106
}
103107

104108

@@ -288,31 +292,6 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
288292
},
289293
)
290294

291-
def _build_validations(self):
292-
"""Placeholder docstring"""
293-
# TODO: Beta validations - remove after the launch
294-
if self.mode == Mode.IN_PROCESS:
295-
raise ValueError("IN_PROCESS mode is not supported yet!")
296-
297-
if self.inference_spec and self.model:
298-
raise ValueError("Cannot have both the Model and Inference spec in the builder")
299-
300-
if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None:
301-
raise ValueError(
302-
"Model_server must be set when non-first-party image_uri is set. "
303-
+ "Supported model servers: %s" % supported_model_server
304-
)
305-
306-
# Set TorchServe as default model server
307-
if not self.model_server:
308-
self.model_server = ModelServer.TORCHSERVE
309-
310-
if self.model_server not in supported_model_server:
311-
raise ValueError(
312-
"%s is not supported yet! Supported model servers: %s"
313-
% (self.model_server, supported_model_server)
314-
)
315-
316295
def _save_model_inference_spec(self):
317296
"""Placeholder docstring"""
318297
# check if path exists and create if not
@@ -839,6 +818,11 @@ def build( # pylint: disable=R0911
839818

840819
self._handle_mlflow_input()
841820

821+
self._build_validations()
822+
823+
if self.model_server:
824+
return self._build_for_model_server()
825+
842826
if isinstance(self.model, str):
843827
model_task = None
844828
if self.model_metadata:
@@ -870,7 +854,41 @@ def build( # pylint: disable=R0911
870854
else:
871855
return self._build_for_transformers()
872856

873-
self._build_validations()
857+
# Set TorchServe as default model server
858+
if not self.model_server:
859+
self.model_server = ModelServer.TORCHSERVE
860+
return self._build_for_torchserve()
861+
862+
raise ValueError("%s model server is not supported" % self.model_server)
863+
864+
def _build_validations(self):
865+
"""Validations needed for model server overrides, or auto-detection or fallback"""
866+
if self.mode == Mode.IN_PROCESS:
867+
raise ValueError("IN_PROCESS mode is not supported yet!")
868+
869+
if self.inference_spec and self.model:
870+
raise ValueError("Can only set one of the following: model, inference_spec.")
871+
872+
if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None:
873+
raise ValueError(
874+
"Model_server must be set when non-first-party image_uri is set. "
875+
+ "Supported model servers: %s" % supported_model_servers
876+
)
877+
878+
def _build_for_model_server(self): # pylint: disable=R0911, R1710
879+
"""Model server overrides"""
880+
if self.model_server not in supported_model_servers:
881+
raise ValueError(
882+
"%s is not supported yet! Supported model servers: %s"
883+
% (self.model_server, supported_model_servers)
884+
)
885+
886+
mlflow_path = None
887+
if self.model_metadata:
888+
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
889+
890+
if not self.model and not mlflow_path:
891+
raise ValueError("Missing required parameter `model` or 'ml_flow' path")
874892

875893
if self.model_server == ModelServer.TORCHSERVE:
876894
return self._build_for_torchserve()
@@ -881,7 +899,17 @@ def build( # pylint: disable=R0911
881899
if self.model_server == ModelServer.TENSORFLOW_SERVING:
882900
return self._build_for_tensorflow_serving()
883901

884-
raise ValueError("%s model server is not supported" % self.model_server)
902+
if self.model_server == ModelServer.DJL_SERVING:
903+
return self._build_for_djl()
904+
905+
if self.model_server == ModelServer.TEI:
906+
return self._build_for_tei()
907+
908+
if self.model_server == ModelServer.TGI:
909+
return self._build_for_tgi()
910+
911+
if self.model_server == ModelServer.MMS:
912+
return self._build_for_transformers()
885913

886914
def save(
887915
self,

tests/unit/sagemaker/serve/builder/test_model_builder.py

+121-4
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,14 @@
5050
mock_secret_key = "mock_secret_key"
5151
mock_instance_type = "mock instance type"
5252

53-
supported_model_server = {
53+
supported_model_servers = {
5454
ModelServer.TORCHSERVE,
5555
ModelServer.TRITON,
5656
ModelServer.DJL_SERVING,
5757
ModelServer.TENSORFLOW_SERVING,
58+
ModelServer.MMS,
59+
ModelServer.TGI,
60+
ModelServer.TEI,
5861
}
5962

6063
mock_session = MagicMock()
@@ -78,7 +81,7 @@ def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSet
7881
builder = ModelBuilder(inference_spec="some value", model=Mock(spec=object))
7982
self.assertRaisesRegex(
8083
Exception,
81-
"Cannot have both the Model and Inference spec in the builder",
84+
"Can only set one of the following: model, inference_spec.",
8285
builder.build,
8386
Mode.SAGEMAKER_ENDPOINT,
8487
mock_role_arn,
@@ -91,7 +94,7 @@ def test_validation_unsupported_model_server_type(self, mock_serveSettings):
9194
self.assertRaisesRegex(
9295
Exception,
9396
"%s is not supported yet! Supported model servers: %s"
94-
% (builder.model_server, supported_model_server),
97+
% (builder.model_server, supported_model_servers),
9598
builder.build,
9699
Mode.SAGEMAKER_ENDPOINT,
97100
mock_role_arn,
@@ -104,7 +107,7 @@ def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings
104107
self.assertRaisesRegex(
105108
Exception,
106109
"Model_server must be set when non-first-party image_uri is set. "
107-
+ "Supported model servers: %s" % supported_model_server,
110+
+ "Supported model servers: %s" % supported_model_servers,
108111
builder.build,
109112
Mode.SAGEMAKER_ENDPOINT,
110113
mock_role_arn,
@@ -125,6 +128,120 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set
125128
mock_session,
126129
)
127130

131+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
132+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl")
133+
def test_model_server_override_djl_with_model(self, mock_build_for_djl, mock_serve_settings):
134+
mock_setting_object = mock_serve_settings.return_value
135+
mock_setting_object.role_arn = mock_role_arn
136+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
137+
138+
builder = ModelBuilder(model_server=ModelServer.DJL_SERVING, model="gpt_llm_burt")
139+
builder.build(sagemaker_session=mock_session)
140+
141+
mock_build_for_djl.assert_called_once()
142+
143+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
144+
def test_model_server_override_djl_without_model_or_mlflow(self, mock_serve_settings):
145+
builder = ModelBuilder(
146+
model_server=ModelServer.DJL_SERVING, model=None, inference_spec=None
147+
)
148+
self.assertRaisesRegex(
149+
Exception,
150+
"Missing required parameter `model` or 'ml_flow' path",
151+
builder.build,
152+
Mode.SAGEMAKER_ENDPOINT,
153+
mock_role_arn,
154+
mock_session,
155+
)
156+
157+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
158+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_torchserve")
159+
def test_model_server_override_torchserve_with_model(
160+
self, mock_build_for_ts, mock_serve_settings
161+
):
162+
mock_setting_object = mock_serve_settings.return_value
163+
mock_setting_object.role_arn = mock_role_arn
164+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
165+
166+
builder = ModelBuilder(model_server=ModelServer.TORCHSERVE, model="gpt_llm_burt")
167+
builder.build(sagemaker_session=mock_session)
168+
169+
mock_build_for_ts.assert_called_once()
170+
171+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
172+
def test_model_server_override_torchserve_without_model_or_mlflow(self, mock_serve_settings):
173+
builder = ModelBuilder(model_server=ModelServer.TORCHSERVE)
174+
self.assertRaisesRegex(
175+
Exception,
176+
"Missing required parameter `model` or 'ml_flow' path",
177+
builder.build,
178+
Mode.SAGEMAKER_ENDPOINT,
179+
mock_role_arn,
180+
mock_session,
181+
)
182+
183+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
184+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_triton")
185+
def test_model_server_override_triton_with_model(self, mock_build_for_ts, mock_serve_settings):
186+
mock_setting_object = mock_serve_settings.return_value
187+
mock_setting_object.role_arn = mock_role_arn
188+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
189+
190+
builder = ModelBuilder(model_server=ModelServer.TRITON, model="gpt_llm_burt")
191+
builder.build(sagemaker_session=mock_session)
192+
193+
mock_build_for_ts.assert_called_once()
194+
195+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
196+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tensorflow_serving")
197+
def test_model_server_override_tensor_with_model(self, mock_build_for_ts, mock_serve_settings):
198+
mock_setting_object = mock_serve_settings.return_value
199+
mock_setting_object.role_arn = mock_role_arn
200+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
201+
202+
builder = ModelBuilder(model_server=ModelServer.TENSORFLOW_SERVING, model="gpt_llm_burt")
203+
builder.build(sagemaker_session=mock_session)
204+
205+
mock_build_for_ts.assert_called_once()
206+
207+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
208+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei")
209+
def test_model_server_override_tei_with_model(self, mock_build_for_ts, mock_serve_settings):
210+
mock_setting_object = mock_serve_settings.return_value
211+
mock_setting_object.role_arn = mock_role_arn
212+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
213+
214+
builder = ModelBuilder(model_server=ModelServer.TEI, model="gpt_llm_burt")
215+
builder.build(sagemaker_session=mock_session)
216+
217+
mock_build_for_ts.assert_called_once()
218+
219+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
220+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi")
221+
def test_model_server_override_tgi_with_model(self, mock_build_for_ts, mock_serve_settings):
222+
mock_setting_object = mock_serve_settings.return_value
223+
mock_setting_object.role_arn = mock_role_arn
224+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
225+
226+
builder = ModelBuilder(model_server=ModelServer.TGI, model="gpt_llm_burt")
227+
builder.build(sagemaker_session=mock_session)
228+
229+
mock_build_for_ts.assert_called_once()
230+
231+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
232+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers")
233+
def test_model_server_override_transformers_with_model(
234+
self, mock_build_for_ts, mock_serve_settings
235+
):
236+
mock_setting_object = mock_serve_settings.return_value
237+
mock_setting_object.role_arn = mock_role_arn
238+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
239+
240+
builder = ModelBuilder(model_server=ModelServer.MMS, model="gpt_llm_burt")
241+
builder.build(sagemaker_session=mock_session)
242+
243+
mock_build_for_ts.assert_called_once()
244+
128245
@patch("os.makedirs", Mock())
129246
@patch("sagemaker.serve.builder.model_builder._detect_framework_and_version")
130247
@patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve")

0 commit comments

Comments
 (0)