Skip to content

Commit ba850c2

Browse files
committed
Fix test builds
1 parent 617dc82 commit ba850c2

File tree

6 files changed

+133
-13
lines changed

6 files changed

+133
-13
lines changed

src/sagemaker/serve/builder/model_builder.py

-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
ModelServer.TORCHSERVE,
6363
ModelServer.TRITON,
6464
ModelServer.DJL_SERVING,
65-
ModelServer.MMS,
6665
}
6766

6867

tests/integ/sagemaker/serve/test_serve_js_happy.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import pytest
16+
from unittest.mock import patch, Mock
1617
from sagemaker.serve.builder.model_builder import ModelBuilder
1718
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1819
from tests.integ.sagemaker.serve.constants import (
@@ -26,13 +27,13 @@
2627

2728
logger = logging.getLogger(__name__)
2829

29-
3030
SAMPLE_PROMPT = {"inputs": "Hello, I'm a language model,", "parameters": {}}
3131
SAMPLE_RESPONSE = [
3232
{"generated_text": "Hello, I'm a language model, and I'm here to help you with your English."}
3333
]
3434
JS_MODEL_ID = "huggingface-textgeneration1-gpt-neo-125m-fp16"
3535
ROLE_NAME = "SageMakerRole"
36+
MOCK_HF_MODEL_METADATA_JSON = {"mock_key": "mock_value"}
3637

3738

3839
@pytest.fixture
@@ -46,15 +47,23 @@ def happy_model_builder(sagemaker_session):
4647
)
4748

4849

50+
@patch("sagemaker.huggingface.llm_utils.urllib")
51+
@patch("sagemaker.huggingface.llm_utils.json")
4952
@pytest.mark.skipif(
5053
PYTHON_VERSION_IS_NOT_310,
5154
reason="The goal of these test are to test the serving components of our feature",
5255
)
5356
@pytest.mark.slow_test
54-
def test_happy_tgi_sagemaker_endpoint(happy_model_builder, gpu_instance_type):
57+
def test_happy_tgi_sagemaker_endpoint(
58+
mock_urllib, mock_json, happy_model_builder, gpu_instance_type
59+
):
5560
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
5661
caught_ex = None
5762

63+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
64+
mock_hf_model_metadata_url = Mock()
65+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
66+
5867
model = happy_model_builder.build()
5968

6069
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):

tests/integ/sagemaker/serve/test_serve_pt_happy.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import io
2020
import numpy as np
2121

22+
from unittest.mock import patch, Mock
2223
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
2324
from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator
2425
from sagemaker.serve.spec.inference_spec import InferenceSpec
@@ -37,6 +38,7 @@
3738
logger = logging.getLogger(__name__)
3839

3940
ROLE_NAME = "SageMakerRole"
41+
MOCK_HF_MODEL_METADATA_JSON = {"mock_key": "mock_value"}
4042

4143

4244
@pytest.fixture
@@ -181,6 +183,8 @@ def model_builder(request):
181183
# ), f"{caught_ex} was thrown when running pytorch squeezenet local container test"
182184

183185

186+
@patch("sagemaker.huggingface.llm_utils.urllib")
187+
@patch("sagemaker.huggingface.llm_utils.json")
184188
@pytest.mark.skipif(
185189
PYTHON_VERSION_IS_NOT_310, # or NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE,
186190
reason="The goal of these test are to test the serving components of our feature",
@@ -190,9 +194,17 @@ def model_builder(request):
190194
)
191195
@pytest.mark.slow_test
192196
def test_happy_pytorch_sagemaker_endpoint(
193-
sagemaker_session, model_builder, cpu_instance_type, test_image
197+
mock_urllib,
198+
mock_json,
199+
sagemaker_session,
200+
model_builder,
201+
cpu_instance_type,
202+
test_image,
194203
):
195204
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
205+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
206+
mock_hf_model_metadata_url = Mock()
207+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
196208
caught_ex = None
197209

198210
iam_client = sagemaker_session.boto_session.client("iam")

tests/integ/sagemaker/serve/test_serve_transformers.py

-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from __future__ import absolute_import
1414

1515
import pytest
16-
1716
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1817
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
1918

tests/unit/sagemaker/serve/builder/test_model_builder.py

+108-5
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
mock_s3_model_data_url = "sample s3 data url"
4343
mock_secret_key = "mock_secret_key"
4444
mock_instance_type = "mock instance type"
45+
MOCK_HF_MODEL_METADATA_JSON = {"mock_key": "mock_value"}
4546

4647
supported_model_server = {
4748
ModelServer.TORCHSERVE,
@@ -54,7 +55,15 @@
5455

5556
class TestModelBuilder(unittest.TestCase):
5657
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
57-
def test_validation_in_progress_mode_not_supported(self, mock_serveSettings):
58+
@patch("sagemaker.huggingface.llm_utils.urllib")
59+
@patch("sagemaker.huggingface.llm_utils.json")
60+
def test_validation_in_progress_mode_not_supported(
61+
self, mock_serveSettings, mock_urllib, mock_json
62+
):
63+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
64+
mock_hf_model_metadata_url = Mock()
65+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
66+
5867
builder = ModelBuilder()
5968
self.assertRaisesRegex(
6069
Exception,
@@ -66,7 +75,15 @@ def test_validation_in_progress_mode_not_supported(self, mock_serveSettings):
6675
)
6776

6877
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
69-
def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSettings):
78+
@patch("sagemaker.huggingface.llm_utils.urllib")
79+
@patch("sagemaker.huggingface.llm_utils.json")
80+
def test_validation_cannot_set_both_model_and_inference_spec(
81+
self, mock_serveSettings, mock_urllib, mock_json
82+
):
83+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
84+
mock_hf_model_metadata_url = Mock()
85+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
86+
7087
builder = ModelBuilder(inference_spec="some value", model=Mock(spec=object))
7188
self.assertRaisesRegex(
7289
Exception,
@@ -78,7 +95,15 @@ def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSet
7895
)
7996

8097
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
81-
def test_validation_unsupported_model_server_type(self, mock_serveSettings):
98+
@patch("sagemaker.huggingface.llm_utils.urllib")
99+
@patch("sagemaker.huggingface.llm_utils.json")
100+
def test_validation_unsupported_model_server_type(
101+
self, mock_serveSettings, mock_urllib, mock_json
102+
):
103+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
104+
mock_hf_model_metadata_url = Mock()
105+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
106+
82107
builder = ModelBuilder(model_server="invalid_model_server")
83108
self.assertRaisesRegex(
84109
Exception,
@@ -91,7 +116,15 @@ def test_validation_unsupported_model_server_type(self, mock_serveSettings):
91116
)
92117

93118
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
94-
def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings):
119+
@patch("sagemaker.huggingface.llm_utils.urllib")
120+
@patch("sagemaker.huggingface.llm_utils.json")
121+
def test_validation_model_server_not_set_with_image_uri(
122+
self, mock_serveSettings, mock_urllib, mock_json
123+
):
124+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
125+
mock_hf_model_metadata_url = Mock()
126+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
127+
95128
builder = ModelBuilder(image_uri="image_uri")
96129
self.assertRaisesRegex(
97130
Exception,
@@ -104,9 +137,15 @@ def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings
104137
)
105138

106139
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
140+
@patch("sagemaker.huggingface.llm_utils.urllib")
141+
@patch("sagemaker.huggingface.llm_utils.json")
107142
def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set(
108-
self, mock_serveSettings
143+
self, mock_serveSettings, mock_urllib, mock_json
109144
):
145+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
146+
mock_hf_model_metadata_url = Mock()
147+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
148+
110149
builder = ModelBuilder(inference_spec=None, model=None)
111150
self.assertRaisesRegex(
112151
Exception,
@@ -126,8 +165,12 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set
126165
@patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
127166
@patch("sagemaker.serve.builder.model_builder.Model")
128167
@patch("os.path.exists")
168+
@patch("sagemaker.huggingface.llm_utils.urllib")
169+
@patch("sagemaker.huggingface.llm_utils.json")
129170
def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
130171
self,
172+
mock_urllib,
173+
mock_json,
131174
mock_path_exists,
132175
mock_sdk_model,
133176
mock_sageMakerEndpointMode,
@@ -146,6 +189,10 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
146189
else None
147190
)
148191

192+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
193+
mock_hf_model_metadata_url = Mock()
194+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
195+
149196
mock_detect_fw_version.return_value = framework, version
150197

151198
mock_prepare_for_torchserve.side_effect = (
@@ -226,8 +273,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
226273
@patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
227274
@patch("sagemaker.serve.builder.model_builder.Model")
228275
@patch("os.path.exists")
276+
@patch("sagemaker.huggingface.llm_utils.urllib")
277+
@patch("sagemaker.huggingface.llm_utils.json")
229278
def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
230279
self,
280+
mock_urllib,
281+
mock_json,
231282
mock_path_exists,
232283
mock_sdk_model,
233284
mock_sageMakerEndpointMode,
@@ -246,6 +297,10 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
246297
else None
247298
)
248299

300+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
301+
mock_hf_model_metadata_url = Mock()
302+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
303+
249304
mock_detect_fw_version.return_value = framework, version
250305

251306
mock_prepare_for_torchserve.side_effect = (
@@ -326,8 +381,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
326381
@patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
327382
@patch("sagemaker.serve.builder.model_builder.Model")
328383
@patch("os.path.exists")
384+
@patch("sagemaker.huggingface.llm_utils.urllib")
385+
@patch("sagemaker.huggingface.llm_utils.json")
329386
def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
330387
self,
388+
mock_urllib,
389+
mock_json,
331390
mock_path_exists,
332391
mock_sdk_model,
333392
mock_sageMakerEndpointMode,
@@ -343,6 +402,10 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
343402
lambda model_path: mock_native_model if model_path == MODEL_PATH else None
344403
)
345404

405+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
406+
mock_hf_model_metadata_url = Mock()
407+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
408+
346409
mock_detect_fw_version.return_value = framework, version
347410

348411
mock_detect_container.side_effect = (
@@ -427,8 +490,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
427490
@patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
428491
@patch("sagemaker.serve.builder.model_builder.Model")
429492
@patch("os.path.exists")
493+
@patch("sagemaker.huggingface.llm_utils.urllib")
494+
@patch("sagemaker.huggingface.llm_utils.json")
430495
def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
431496
self,
497+
mock_urllib,
498+
mock_json,
432499
mock_path_exists,
433500
mock_sdk_model,
434501
mock_sageMakerEndpointMode,
@@ -447,6 +514,10 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
447514
else None
448515
)
449516

517+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
518+
mock_hf_model_metadata_url = Mock()
519+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
520+
450521
mock_detect_fw_version.return_value = framework, version
451522

452523
mock_prepare_for_torchserve.side_effect = (
@@ -530,8 +601,12 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
530601
@patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
531602
@patch("sagemaker.serve.builder.model_builder.Model")
532603
@patch("os.path.exists")
604+
@patch("sagemaker.huggingface.llm_utils.urllib")
605+
@patch("sagemaker.huggingface.llm_utils.json")
533606
def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
534607
self,
608+
mock_urllib,
609+
mock_json,
535610
mock_path_exists,
536611
mock_sdk_model,
537612
mock_sageMakerEndpointMode,
@@ -551,6 +626,10 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
551626
else None
552627
)
553628

629+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
630+
mock_hf_model_metadata_url = Mock()
631+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
632+
554633
mock_detect_fw_version.return_value = "xgboost", version
555634

556635
mock_prepare_for_torchserve.side_effect = (
@@ -635,8 +714,12 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
635714
@patch("sagemaker.serve.builder.model_builder.LocalContainerMode")
636715
@patch("sagemaker.serve.builder.model_builder.Model")
637716
@patch("os.path.exists")
717+
@patch("sagemaker.huggingface.llm_utils.urllib")
718+
@patch("sagemaker.huggingface.llm_utils.json")
638719
def test_build_happy_path_with_local_container_mode(
639720
self,
721+
mock_urllib,
722+
mock_json,
640723
mock_path_exists,
641724
mock_sdk_model,
642725
mock_localContainerMode,
@@ -651,6 +734,10 @@ def test_build_happy_path_with_local_container_mode(
651734
lambda model_path: mock_native_model if model_path == MODEL_PATH else None
652735
)
653736

737+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
738+
mock_hf_model_metadata_url = Mock()
739+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
740+
654741
mock_detect_container.side_effect = (
655742
lambda model, region, instance_type: mock_image_uri
656743
if model == mock_native_model
@@ -729,8 +816,12 @@ def test_build_happy_path_with_local_container_mode(
729816
@patch("sagemaker.serve.builder.model_builder.LocalContainerMode")
730817
@patch("sagemaker.serve.builder.model_builder.Model")
731818
@patch("os.path.exists")
819+
@patch("sagemaker.huggingface.llm_utils.urllib")
820+
@patch("sagemaker.huggingface.llm_utils.json")
732821
def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mode(
733822
self,
823+
mock_urllib,
824+
mock_json,
734825
mock_path_exists,
735826
mock_sdk_model,
736827
mock_localContainerMode,
@@ -747,6 +838,10 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo
747838
lambda model_path: mock_native_model if model_path == MODEL_PATH else None
748839
)
749840

841+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
842+
mock_hf_model_metadata_url = Mock()
843+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
844+
750845
mock_detect_fw_version.return_value = framework, version
751846

752847
mock_detect_container.side_effect = (
@@ -870,8 +965,12 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo
870965
@patch("sagemaker.serve.builder.model_builder.LocalContainerMode")
871966
@patch("sagemaker.serve.builder.model_builder.Model")
872967
@patch("os.path.exists")
968+
@patch("sagemaker.huggingface.llm_utils.urllib")
969+
@patch("sagemaker.huggingface.llm_utils.json")
873970
def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_container(
874971
self,
972+
mock_urllib,
973+
mock_json,
875974
mock_path_exists,
876975
mock_sdk_model,
877976
mock_localContainerMode,
@@ -885,6 +984,10 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co
885984
# setup mocks
886985
mock_detect_fw_version.return_value = framework, version
887986

987+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
988+
mock_hf_model_metadata_url = Mock()
989+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
990+
888991
mock_detect_container.side_effect = (
889992
lambda model, region, instance_type: mock_image_uri
890993
if model == mock_fw_model

0 commit comments

Comments
 (0)