50
50
mock_secret_key = "mock_secret_key"
51
51
mock_instance_type = "mock instance type"
52
52
53
- supported_model_server = {
53
+ supported_model_servers = {
54
54
ModelServer .TORCHSERVE ,
55
55
ModelServer .TRITON ,
56
56
ModelServer .DJL_SERVING ,
57
57
ModelServer .TENSORFLOW_SERVING ,
58
+ ModelServer .MMS ,
59
+ ModelServer .TGI ,
60
+ ModelServer .TEI ,
58
61
}
59
62
60
63
mock_session = MagicMock ()
@@ -78,7 +81,7 @@ def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSet
78
81
builder = ModelBuilder (inference_spec = "some value" , model = Mock (spec = object ))
79
82
self .assertRaisesRegex (
80
83
Exception ,
81
- "Cannot have both the Model and Inference spec in the builder " ,
84
+ "Can only set one of the following: model, inference_spec. " ,
82
85
builder .build ,
83
86
Mode .SAGEMAKER_ENDPOINT ,
84
87
mock_role_arn ,
@@ -91,7 +94,7 @@ def test_validation_unsupported_model_server_type(self, mock_serveSettings):
91
94
self .assertRaisesRegex (
92
95
Exception ,
93
96
"%s is not supported yet! Supported model servers: %s"
94
- % (builder .model_server , supported_model_server ),
97
+ % (builder .model_server , supported_model_servers ),
95
98
builder .build ,
96
99
Mode .SAGEMAKER_ENDPOINT ,
97
100
mock_role_arn ,
@@ -104,7 +107,7 @@ def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings
104
107
self .assertRaisesRegex (
105
108
Exception ,
106
109
"Model_server must be set when non-first-party image_uri is set. "
107
- + "Supported model servers: %s" % supported_model_server ,
110
+ + "Supported model servers: %s" % supported_model_servers ,
108
111
builder .build ,
109
112
Mode .SAGEMAKER_ENDPOINT ,
110
113
mock_role_arn ,
@@ -125,6 +128,120 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set
125
128
mock_session ,
126
129
)
127
130
131
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
132
+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl" )
133
+ def test_model_server_override_djl_with_model (self , mock_build_for_djl , mock_serve_settings ):
134
+ mock_setting_object = mock_serve_settings .return_value
135
+ mock_setting_object .role_arn = mock_role_arn
136
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
137
+
138
+ builder = ModelBuilder (model_server = ModelServer .DJL_SERVING , model = "gpt_llm_burt" )
139
+ builder .build (sagemaker_session = mock_session )
140
+
141
+ mock_build_for_djl .assert_called_once ()
142
+
143
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
144
+ def test_model_server_override_djl_without_model_or_mlflow (self , mock_serve_settings ):
145
+ builder = ModelBuilder (
146
+ model_server = ModelServer .DJL_SERVING , model = None , inference_spec = None
147
+ )
148
+ self .assertRaisesRegex (
149
+ Exception ,
150
+ "Missing required parameter `model` or 'ml_flow' path" ,
151
+ builder .build ,
152
+ Mode .SAGEMAKER_ENDPOINT ,
153
+ mock_role_arn ,
154
+ mock_session ,
155
+ )
156
+
157
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
158
+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_torchserve" )
159
+ def test_model_server_override_torchserve_with_model (
160
+ self , mock_build_for_ts , mock_serve_settings
161
+ ):
162
+ mock_setting_object = mock_serve_settings .return_value
163
+ mock_setting_object .role_arn = mock_role_arn
164
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
165
+
166
+ builder = ModelBuilder (model_server = ModelServer .TORCHSERVE , model = "gpt_llm_burt" )
167
+ builder .build (sagemaker_session = mock_session )
168
+
169
+ mock_build_for_ts .assert_called_once ()
170
+
171
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
172
+ def test_model_server_override_torchserve_without_model_or_mlflow (self , mock_serve_settings ):
173
+ builder = ModelBuilder (model_server = ModelServer .TORCHSERVE )
174
+ self .assertRaisesRegex (
175
+ Exception ,
176
+ "Missing required parameter `model` or 'ml_flow' path" ,
177
+ builder .build ,
178
+ Mode .SAGEMAKER_ENDPOINT ,
179
+ mock_role_arn ,
180
+ mock_session ,
181
+ )
182
+
183
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
184
+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_triton" )
185
+ def test_model_server_override_triton_with_model (self , mock_build_for_ts , mock_serve_settings ):
186
+ mock_setting_object = mock_serve_settings .return_value
187
+ mock_setting_object .role_arn = mock_role_arn
188
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
189
+
190
+ builder = ModelBuilder (model_server = ModelServer .TRITON , model = "gpt_llm_burt" )
191
+ builder .build (sagemaker_session = mock_session )
192
+
193
+ mock_build_for_ts .assert_called_once ()
194
+
195
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
196
+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tensorflow_serving" )
197
+ def test_model_server_override_tensor_with_model (self , mock_build_for_ts , mock_serve_settings ):
198
+ mock_setting_object = mock_serve_settings .return_value
199
+ mock_setting_object .role_arn = mock_role_arn
200
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
201
+
202
+ builder = ModelBuilder (model_server = ModelServer .TENSORFLOW_SERVING , model = "gpt_llm_burt" )
203
+ builder .build (sagemaker_session = mock_session )
204
+
205
+ mock_build_for_ts .assert_called_once ()
206
+
207
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
208
+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei" )
209
+ def test_model_server_override_tei_with_model (self , mock_build_for_ts , mock_serve_settings ):
210
+ mock_setting_object = mock_serve_settings .return_value
211
+ mock_setting_object .role_arn = mock_role_arn
212
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
213
+
214
+ builder = ModelBuilder (model_server = ModelServer .TEI , model = "gpt_llm_burt" )
215
+ builder .build (sagemaker_session = mock_session )
216
+
217
+ mock_build_for_ts .assert_called_once ()
218
+
219
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
220
+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi" )
221
+ def test_model_server_override_tgi_with_model (self , mock_build_for_ts , mock_serve_settings ):
222
+ mock_setting_object = mock_serve_settings .return_value
223
+ mock_setting_object .role_arn = mock_role_arn
224
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
225
+
226
+ builder = ModelBuilder (model_server = ModelServer .TGI , model = "gpt_llm_burt" )
227
+ builder .build (sagemaker_session = mock_session )
228
+
229
+ mock_build_for_ts .assert_called_once ()
230
+
231
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
232
+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers" )
233
+ def test_model_server_override_transformers_with_model (
234
+ self , mock_build_for_ts , mock_serve_settings
235
+ ):
236
+ mock_setting_object = mock_serve_settings .return_value
237
+ mock_setting_object .role_arn = mock_role_arn
238
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
239
+
240
+ builder = ModelBuilder (model_server = ModelServer .MMS , model = "gpt_llm_burt" )
241
+ builder .build (sagemaker_session = mock_session )
242
+
243
+ mock_build_for_ts .assert_called_once ()
244
+
128
245
@patch ("os.makedirs" , Mock ())
129
246
@patch ("sagemaker.serve.builder.model_builder._detect_framework_and_version" )
130
247
@patch ("sagemaker.serve.builder.model_builder.prepare_for_torchserve" )
0 commit comments