Skip to content

Commit 924983d

Browse files
committed
Fix test builds
1 parent 617dc82 commit 924983d

File tree

6 files changed

+135
-15
lines changed

6 files changed

+135
-15
lines changed

src/sagemaker/serve/builder/model_builder.py

-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
ModelServer.TORCHSERVE,
6363
ModelServer.TRITON,
6464
ModelServer.DJL_SERVING,
65-
ModelServer.MMS,
6665
}
6766

6867

tests/integ/sagemaker/serve/test_serve_js_happy.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import pytest
16+
from unittest.mock import patch, Mock
1617
from sagemaker.serve.builder.model_builder import ModelBuilder
1718
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1819
from tests.integ.sagemaker.serve.constants import (
@@ -26,13 +27,13 @@
2627

2728
logger = logging.getLogger(__name__)
2829

29-
3030
SAMPLE_PROMPT = {"inputs": "Hello, I'm a language model,", "parameters": {}}
3131
SAMPLE_RESPONSE = [
3232
{"generated_text": "Hello, I'm a language model, and I'm here to help you with your English."}
3333
]
3434
JS_MODEL_ID = "huggingface-textgeneration1-gpt-neo-125m-fp16"
3535
ROLE_NAME = "SageMakerRole"
36+
MOCK_HF_MODEL_METADATA_JSON = {"mock_key": "mock_value"}
3637

3738

3839
@pytest.fixture
@@ -46,15 +47,21 @@ def happy_model_builder(sagemaker_session):
4647
)
4748

4849

50+
@patch("sagemaker.huggingface.llm_utils.urllib")
51+
@patch("sagemaker.huggingface.llm_utils.json")
4952
@pytest.mark.skipif(
5053
PYTHON_VERSION_IS_NOT_310,
5154
reason="The goal of these test are to test the serving components of our feature",
5255
)
5356
@pytest.mark.slow_test
54-
def test_happy_tgi_sagemaker_endpoint(happy_model_builder, gpu_instance_type):
57+
def test_happy_tgi_sagemaker_endpoint(mock_urllib, mock_json, happy_model_builder, gpu_instance_type):
5558
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
5659
caught_ex = None
5760

61+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
62+
mock_hf_model_metadata_url = Mock()
63+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
64+
5865
model = happy_model_builder.build()
5966

6067
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):

tests/integ/sagemaker/serve/test_serve_pt_happy.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import io
2020
import numpy as np
2121

22+
from unittest.mock import patch, Mock
2223
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
2324
from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator
2425
from sagemaker.serve.spec.inference_spec import InferenceSpec
@@ -37,6 +38,7 @@
3738
logger = logging.getLogger(__name__)
3839

3940
ROLE_NAME = "SageMakerRole"
41+
MOCK_HF_MODEL_METADATA_JSON = {"mock_key": "mock_value"}
4042

4143

4244
@pytest.fixture
@@ -181,6 +183,8 @@ def model_builder(request):
181183
# ), f"{caught_ex} was thrown when running pytorch squeezenet local container test"
182184

183185

186+
@patch("sagemaker.huggingface.llm_utils.urllib")
187+
@patch("sagemaker.huggingface.llm_utils.json")
184188
@pytest.mark.skipif(
185189
PYTHON_VERSION_IS_NOT_310, # or NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE,
186190
reason="The goal of these test are to test the serving components of our feature",
@@ -190,9 +194,12 @@ def model_builder(request):
190194
)
191195
@pytest.mark.slow_test
192196
def test_happy_pytorch_sagemaker_endpoint(
193-
sagemaker_session, model_builder, cpu_instance_type, test_image
197+
mock_urllib, mock_json, sagemaker_session, model_builder, cpu_instance_type, test_image,
194198
):
195199
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
200+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
201+
mock_hf_model_metadata_url = Mock()
202+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
196203
caught_ex = None
197204

198205
iam_client = sagemaker_session.boto_session.client("iam")

tests/integ/sagemaker/serve/test_serve_transformers.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import pytest
16-
16+
from unittest.mock import patch, Mock
1717
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1818
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
1919

@@ -29,6 +29,8 @@
2929

3030
logger = logging.getLogger(__name__)
3131

32+
MOCK_HF_MODEL_METADATA_JSON = {"mock_key": "mock_value"}
33+
3234
sample_input = {
3335
"inputs": "The man worked as a [MASK].",
3436
}
@@ -85,16 +87,20 @@ def model_builder_model_schema_builder():
8587
def model_builder(request):
8688
return request.getfixturevalue(request.param)
8789

88-
90+
@patch("sagemaker.huggingface.llm_utils.urllib")
91+
@patch("sagemaker.huggingface.llm_utils.json")
8992
@pytest.mark.skipif(
9093
PYTHON_VERSION_IS_NOT_310,
9194
reason="Testing feature",
9295
)
9396
@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True)
9497
def test_pytorch_transformers_sagemaker_endpoint(
95-
sagemaker_session, model_builder, gpu_instance_type, input
98+
mock_urllib, mock_json, sagemaker_session, model_builder, gpu_instance_type, input
9699
):
97100
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
101+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
102+
mock_hf_model_metadata_url = Mock()
103+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
98104
caught_ex = None
99105

100106
iam_client = sagemaker_session.boto_session.client("iam")

0 commit comments

Comments
 (0)