49
49
mock_secret_key = "mock_secret_key"
50
50
mock_instance_type = "mock instance type"
51
51
52
- supported_model_server = {
52
+ supported_model_servers = {
53
53
ModelServer .TORCHSERVE ,
54
54
ModelServer .TRITON ,
55
55
ModelServer .DJL_SERVING ,
56
56
ModelServer .TENSORFLOW_SERVING ,
57
+ ModelServer .MMS ,
58
+ ModelServer .TGI ,
59
+ ModelServer .TEI ,
57
60
}
58
61
59
62
mock_session = MagicMock ()
@@ -77,7 +80,7 @@ def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSet
77
80
builder = ModelBuilder (inference_spec = "some value" , model = Mock (spec = object ))
78
81
self .assertRaisesRegex (
79
82
Exception ,
80
- "Cannot have both the Model and Inference spec in the builder " ,
83
+ "Can only set one of the following: model, inference_spec. " ,
81
84
builder .build ,
82
85
Mode .SAGEMAKER_ENDPOINT ,
83
86
mock_role_arn ,
@@ -90,7 +93,7 @@ def test_validation_unsupported_model_server_type(self, mock_serveSettings):
90
93
self .assertRaisesRegex (
91
94
Exception ,
92
95
"%s is not supported yet! Supported model servers: %s"
93
- % (builder .model_server , supported_model_server ),
96
+ % (builder .model_server , supported_model_servers ),
94
97
builder .build ,
95
98
Mode .SAGEMAKER_ENDPOINT ,
96
99
mock_role_arn ,
@@ -103,7 +106,7 @@ def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings
103
106
self .assertRaisesRegex (
104
107
Exception ,
105
108
"Model_server must be set when non-first-party image_uri is set. "
106
- + "Supported model servers: %s" % supported_model_server ,
109
+ + "Supported model servers: %s" % supported_model_servers ,
107
110
builder .build ,
108
111
Mode .SAGEMAKER_ENDPOINT ,
109
112
mock_role_arn ,
@@ -124,6 +127,120 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set
124
127
mock_session ,
125
128
)
126
129
130
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
131
+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl" )
132
+ def test_model_server_override_djl_with_model (self , mock_build_for_djl , mock_serve_settings ):
133
+ mock_setting_object = mock_serve_settings .return_value
134
+ mock_setting_object .role_arn = mock_role_arn
135
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
136
+
137
+ builder = ModelBuilder (model_server = ModelServer .DJL_SERVING , model = "gpt_llm_burt" )
138
+ builder .build (sagemaker_session = mock_session )
139
+
140
+ mock_build_for_djl .assert_called_once ()
141
+
142
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
143
+ def test_model_server_override_djl_without_model_or_mlflow (self , mock_serve_settings ):
144
+ builder = ModelBuilder (
145
+ model_server = ModelServer .DJL_SERVING , model = None , inference_spec = None
146
+ )
147
+ self .assertRaisesRegex (
148
+ Exception ,
149
+ "Missing required parameter `model` or 'ml_flow' path" ,
150
+ builder .build ,
151
+ Mode .SAGEMAKER_ENDPOINT ,
152
+ mock_role_arn ,
153
+ mock_session ,
154
+ )
155
+
156
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
157
+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_torchserve" )
158
+ def test_model_server_override_torchserve_with_model (
159
+ self , mock_build_for_ts , mock_serve_settings
160
+ ):
161
+ mock_setting_object = mock_serve_settings .return_value
162
+ mock_setting_object .role_arn = mock_role_arn
163
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
164
+
165
+ builder = ModelBuilder (model_server = ModelServer .TORCHSERVE , model = "gpt_llm_burt" )
166
+ builder .build (sagemaker_session = mock_session )
167
+
168
+ mock_build_for_ts .assert_called_once ()
169
+
170
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
171
+ def test_model_server_override_torchserve_without_model_or_mlflow (self , mock_serve_settings ):
172
+ builder = ModelBuilder (model_server = ModelServer .TORCHSERVE )
173
+ self .assertRaisesRegex (
174
+ Exception ,
175
+ "Missing required parameter `model` or 'ml_flow' path" ,
176
+ builder .build ,
177
+ Mode .SAGEMAKER_ENDPOINT ,
178
+ mock_role_arn ,
179
+ mock_session ,
180
+ )
181
+
182
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
183
+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_triton" )
184
+ def test_model_server_override_triton_with_model (self , mock_build_for_ts , mock_serve_settings ):
185
+ mock_setting_object = mock_serve_settings .return_value
186
+ mock_setting_object .role_arn = mock_role_arn
187
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
188
+
189
+ builder = ModelBuilder (model_server = ModelServer .TRITON , model = "gpt_llm_burt" )
190
+ builder .build (sagemaker_session = mock_session )
191
+
192
+ mock_build_for_ts .assert_called_once ()
193
+
194
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
195
+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tensorflow_serving" )
196
+ def test_model_server_override_tensor_with_model (self , mock_build_for_ts , mock_serve_settings ):
197
+ mock_setting_object = mock_serve_settings .return_value
198
+ mock_setting_object .role_arn = mock_role_arn
199
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
200
+
201
+ builder = ModelBuilder (model_server = ModelServer .TENSORFLOW_SERVING , model = "gpt_llm_burt" )
202
+ builder .build (sagemaker_session = mock_session )
203
+
204
+ mock_build_for_ts .assert_called_once ()
205
+
206
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
207
+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei" )
208
+ def test_model_server_override_tei_with_model (self , mock_build_for_ts , mock_serve_settings ):
209
+ mock_setting_object = mock_serve_settings .return_value
210
+ mock_setting_object .role_arn = mock_role_arn
211
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
212
+
213
+ builder = ModelBuilder (model_server = ModelServer .TEI , model = "gpt_llm_burt" )
214
+ builder .build (sagemaker_session = mock_session )
215
+
216
+ mock_build_for_ts .assert_called_once ()
217
+
218
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
219
+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi" )
220
+ def test_model_server_override_tgi_with_model (self , mock_build_for_ts , mock_serve_settings ):
221
+ mock_setting_object = mock_serve_settings .return_value
222
+ mock_setting_object .role_arn = mock_role_arn
223
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
224
+
225
+ builder = ModelBuilder (model_server = ModelServer .TGI , model = "gpt_llm_burt" )
226
+ builder .build (sagemaker_session = mock_session )
227
+
228
+ mock_build_for_ts .assert_called_once ()
229
+
230
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
231
+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers" )
232
+ def test_model_server_override_transformers_with_model (
233
+ self , mock_build_for_ts , mock_serve_settings
234
+ ):
235
+ mock_setting_object = mock_serve_settings .return_value
236
+ mock_setting_object .role_arn = mock_role_arn
237
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
238
+
239
+ builder = ModelBuilder (model_server = ModelServer .MMS , model = "gpt_llm_burt" )
240
+ builder .build (sagemaker_session = mock_session )
241
+
242
+ mock_build_for_ts .assert_called_once ()
243
+
127
244
@patch ("os.makedirs" , Mock ())
128
245
@patch ("sagemaker.serve.builder.model_builder._detect_framework_and_version" )
129
246
@patch ("sagemaker.serve.builder.model_builder.prepare_for_torchserve" )
0 commit comments