Skip to content

Commit 4b3f617

Browse files
authored
Merge branch 'master' into master
2 parents 38eed2f + 32dd631 commit 4b3f617

File tree

10 files changed

+88
-3
lines changed

10 files changed

+88
-3
lines changed

src/sagemaker/serve/builder/djl_builder.py

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(self):
8383
self.mode = None
8484
self.model_server = None
8585
self.image_uri = None
86+
self._is_custom_image_uri = False
8687
self.image_config = None
8788
self.vpc_config = None
8889
self._original_deploy = None

src/sagemaker/serve/builder/jumpstart_builder.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def __init__(self):
5757
self.mode = None
5858
self.model_server = None
5959
self.image_uri = None
60+
self._is_custom_image_uri = False
61+
self.vpc_config = None
6062
self._original_deploy = None
6163
self.secret_key = None
6264
self.js_model_config = None
@@ -94,7 +96,7 @@ def _is_jumpstart_model_id(self) -> bool:
9496

9597
def _create_pre_trained_js_model(self) -> Type[Model]:
9698
"""Placeholder docstring"""
97-
pysdk_model = JumpStartModel(self.model)
99+
pysdk_model = JumpStartModel(self.model, vpc_config=self.vpc_config)
98100
pysdk_model.sagemaker_session = self.sagemaker_session
99101

100102
self._original_deploy = pysdk_model.deploy

src/sagemaker/serve/builder/model_builder.py

+2
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,8 @@ def build(
607607

608608
self.serve_settings = self._get_serve_setting()
609609

610+
self._is_custom_image_uri = self.image_uri is None
611+
610612
if isinstance(self.model, str):
611613
if self._is_jumpstart_model_id():
612614
return self._build_for_jumpstart()

src/sagemaker/serve/builder/tgi_builder.py

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(self):
7676
self.mode = None
7777
self.model_server = None
7878
self.image_uri = None
79+
self._is_custom_image_uri = False
7980
self.image_config = None
8081
self.vpc_config = None
8182
self._original_deploy = None

src/sagemaker/serve/builder/transformers_builder.py

+4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def __init__(self):
5656
self.mode = None
5757
self.model_server = None
5858
self.image_uri = None
59+
self._is_custom_image_uri = False
60+
self.vpc_config = None
5961
self._original_deploy = None
6062
self.hf_model_config = None
6163
self._default_data_type = None
@@ -111,6 +113,7 @@ def _create_transformers_model(self) -> Type[Model]:
111113
py_version=self.py_version,
112114
transformers_version=base_hf_version,
113115
pytorch_version=self.pytorch_version,
116+
vpc_config=self.vpc_config,
114117
)
115118
elif "keras" in hf_model_md.get("tags") or "tensorflow" in hf_model_md.get("tags"):
116119
self.tensorflow_version = self._get_supported_version(
@@ -126,6 +129,7 @@ def _create_transformers_model(self) -> Type[Model]:
126129
py_version=self.py_version,
127130
transformers_version=base_hf_version,
128131
tensorflow_version=self.tensorflow_version,
132+
vpc_config=self.vpc_config,
129133
)
130134

131135
if self.mode == Mode.LOCAL_CONTAINER:

src/sagemaker/serve/utils/telemetry_logger.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from sagemaker import Session, exceptions
2121
from sagemaker.serve.mode.function_pointers import Mode
2222
from sagemaker.serve.utils.exceptions import ModelBuilderException
23-
from sagemaker.serve.utils.types import ModelServer
23+
from sagemaker.serve.utils.types import ModelServer, ImageUriOption
24+
from sagemaker.serve.validations.check_image_uri import is_1p_image_uri
2425
from sagemaker.user_agent import SDK_VERSION
2526

2627
logger = logging.getLogger(__name__)
@@ -62,11 +63,13 @@ def wrapper(self, *args, **kwargs):
6263
caught_ex = None
6364

6465
image_uri_tail = self.image_uri.split("/")[1]
66+
image_uri_option = _get_image_uri_option(self.image_uri, self._is_custom_image_uri)
6567
extra = (
6668
f"{func_name}"
6769
f"&x-modelServer={MODEL_SERVER_TO_CODE[str(self.model_server)]}"
6870
f"&x-imageTag={image_uri_tail}"
6971
f"&x-sdkVersion={SDK_VERSION}"
72+
f"&x-defaultImageUsage={image_uri_option}"
7073
)
7174

7275
if self.model_server == ModelServer.DJL_SERVING or self.model_server == ModelServer.TGI:
@@ -201,3 +204,22 @@ def _get_region_or_default(session):
201204
return session.boto_session.region_name
202205
except Exception: # pylint: disable=W0703
203206
return "us-west-2"
207+
208+
209+
def _get_image_uri_option(image_uri: str, is_custom_image: bool) -> int:
210+
"""Detect whether default values are used for ModelBuilder
211+
212+
Args:
213+
image_uri (str): Image uri used by ModelBuilder.
214+
is_custom_image: (bool): Boolean indicating whether customer provides with custom image.
215+
Returns:
216+
bool: Integer code of image option types.
217+
"""
218+
219+
if not is_custom_image:
220+
return ImageUriOption.DEFAULT_IMAGE.value
221+
222+
if is_1p_image_uri(image_uri):
223+
return ImageUriOption.CUSTOM_1P_IMAGE.value
224+
225+
return ImageUriOption.CUSTOM_IMAGE.value

src/sagemaker/serve/utils/types.py

+12
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,15 @@ def __str__(self) -> str:
4343
INFERENTIA_1 = 3
4444
INFERENTIA_2 = 4
4545
GRAVITON = 5
46+
47+
48+
class ImageUriOption(Enum):
49+
"""Enum type for image uri options"""
50+
51+
def __str__(self) -> str:
52+
"""Convert enum to string"""
53+
return str(self.name)
54+
55+
CUSTOM_IMAGE = 1
56+
CUSTOM_1P_IMAGE = 2
57+
DEFAULT_IMAGE = 3

src/sagemaker/workflow/entities.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __str__(self):
8989
"""Override built-in String function for PipelineVariable"""
9090
raise TypeError(
9191
"Pipeline variables do not support __str__ operation. "
92-
"Please use `.to_string()` to convert it to string type in execution time"
92+
"Please use `.to_string()` to convert it to string type in execution time "
9393
"or use `.expr` to translate it to Json for display purpose in Python SDK."
9494
)
9595

tests/unit/sagemaker/serve/builder/test_transformers_builder.py

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import unittest
1717
from sagemaker.serve.builder.model_builder import ModelBuilder
1818
from sagemaker.serve.mode.function_pointers import Mode
19+
from tests.unit.sagemaker.serve.constants import MOCK_VPC_CONFIG
1920

2021
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor
2122

@@ -74,6 +75,7 @@ def test_build_deploy_for_transformers_local_container_and_remote_container(
7475
model=mock_model_id,
7576
schema_builder=mock_schema_builder,
7677
mode=Mode.LOCAL_CONTAINER,
78+
vpc_config=MOCK_VPC_CONFIG,
7779
)
7880

7981
builder._prepare_for_mode = MagicMock()
@@ -85,6 +87,7 @@ def test_build_deploy_for_transformers_local_container_and_remote_container(
8587
builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock()
8688
predictor = model.deploy(model_data_download_timeout=1800)
8789

90+
assert model.vpc_config == MOCK_VPC_CONFIG
8891
assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800"
8992
assert isinstance(predictor, TransformersLocalModePredictor)
9093

tests/unit/sagemaker/serve/utils/test_telemetry_logger.py

+38
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
_construct_url,
2121
)
2222
from sagemaker.serve.utils.exceptions import ModelBuilderException, LocalModelOutOfMemoryException
23+
from sagemaker.serve.utils.types import ImageUriOption
2324
from sagemaker.user_agent import SDK_VERSION
2425

2526
MOCK_SESSION = Mock()
@@ -71,6 +72,7 @@ def test_capture_telemetry_decorator_djl_success(self, mock_send_telemetry):
7172
mock_model_builder = ModelBuilderMock()
7273
mock_model_builder.serve_settings.telemetry_opt_out = False
7374
mock_model_builder.image_uri = MOCK_DJL_CONTAINER
75+
mock_model_builder._is_custom_image_uri = False
7476
mock_model_builder.model = MOCK_HUGGINGFACE_ID
7577
mock_model_builder.mode = Mode.LOCAL_CONTAINER
7678
mock_model_builder.model_server = ModelServer.DJL_SERVING
@@ -85,6 +87,37 @@ def test_capture_telemetry_decorator_djl_success(self, mock_send_telemetry):
8587
"&x-modelServer=4"
8688
"&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118"
8789
f"&x-sdkVersion={SDK_VERSION}"
90+
f"&x-defaultImageUsage={ImageUriOption.DEFAULT_IMAGE.value}"
91+
f"&x-modelName={MOCK_HUGGINGFACE_ID}"
92+
f"&x-endpointArn={MOCK_ENDPOINT_ARN}"
93+
f"&x-latency={latency}"
94+
)
95+
96+
mock_send_telemetry.assert_called_once_with(
97+
"1", 2, MOCK_SESSION, None, None, expected_extra_str
98+
)
99+
100+
@patch("sagemaker.serve.utils.telemetry_logger._send_telemetry")
101+
def test_capture_telemetry_decorator_djl_success_with_custom_image(self, mock_send_telemetry):
102+
mock_model_builder = ModelBuilderMock()
103+
mock_model_builder.serve_settings.telemetry_opt_out = False
104+
mock_model_builder.image_uri = MOCK_DJL_CONTAINER
105+
mock_model_builder._is_custom_image_uri = True
106+
mock_model_builder.model = MOCK_HUGGINGFACE_ID
107+
mock_model_builder.mode = Mode.LOCAL_CONTAINER
108+
mock_model_builder.model_server = ModelServer.DJL_SERVING
109+
mock_model_builder.sagemaker_session.endpoint_arn = MOCK_ENDPOINT_ARN
110+
111+
mock_model_builder.mock_deploy()
112+
113+
args = mock_send_telemetry.call_args.args
114+
latency = str(args[5]).split("latency=")[1]
115+
expected_extra_str = (
116+
f"{MOCK_FUNC_NAME}"
117+
"&x-modelServer=4"
118+
"&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118"
119+
f"&x-sdkVersion={SDK_VERSION}"
120+
f"&x-defaultImageUsage={ImageUriOption.CUSTOM_1P_IMAGE.value}"
88121
f"&x-modelName={MOCK_HUGGINGFACE_ID}"
89122
f"&x-endpointArn={MOCK_ENDPOINT_ARN}"
90123
f"&x-latency={latency}"
@@ -99,6 +132,7 @@ def test_capture_telemetry_decorator_tgi_success(self, mock_send_telemetry):
99132
mock_model_builder = ModelBuilderMock()
100133
mock_model_builder.serve_settings.telemetry_opt_out = False
101134
mock_model_builder.image_uri = MOCK_TGI_CONTAINER
135+
mock_model_builder._is_custom_image_uri = False
102136
mock_model_builder.model = MOCK_HUGGINGFACE_ID
103137
mock_model_builder.mode = Mode.LOCAL_CONTAINER
104138
mock_model_builder.model_server = ModelServer.TGI
@@ -113,6 +147,7 @@ def test_capture_telemetry_decorator_tgi_success(self, mock_send_telemetry):
113147
"&x-modelServer=6"
114148
"&x-imageTag=huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04"
115149
f"&x-sdkVersion={SDK_VERSION}"
150+
f"&x-defaultImageUsage={ImageUriOption.DEFAULT_IMAGE.value}"
116151
f"&x-modelName={MOCK_HUGGINGFACE_ID}"
117152
f"&x-endpointArn={MOCK_ENDPOINT_ARN}"
118153
f"&x-latency={latency}"
@@ -127,6 +162,7 @@ def test_capture_telemetry_decorator_no_call_when_disabled(self, mock_send_telem
127162
mock_model_builder = ModelBuilderMock()
128163
mock_model_builder.serve_settings.telemetry_opt_out = True
129164
mock_model_builder.image_uri = MOCK_DJL_CONTAINER
165+
mock_model_builder._is_custom_image_uri = False
130166
mock_model_builder.model = MOCK_HUGGINGFACE_ID
131167
mock_model_builder.model_server = ModelServer.DJL_SERVING
132168

@@ -139,6 +175,7 @@ def test_capture_telemetry_decorator_handle_exception_success(self, mock_send_te
139175
mock_model_builder = ModelBuilderMock()
140176
mock_model_builder.serve_settings.telemetry_opt_out = False
141177
mock_model_builder.image_uri = MOCK_DJL_CONTAINER
178+
mock_model_builder._is_custom_image_uri = False
142179
mock_model_builder.model = MOCK_HUGGINGFACE_ID
143180
mock_model_builder.mode = Mode.LOCAL_CONTAINER
144181
mock_model_builder.model_server = ModelServer.DJL_SERVING
@@ -158,6 +195,7 @@ def test_capture_telemetry_decorator_handle_exception_success(self, mock_send_te
158195
"&x-modelServer=4"
159196
"&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118"
160197
f"&x-sdkVersion={SDK_VERSION}"
198+
f"&x-defaultImageUsage={ImageUriOption.DEFAULT_IMAGE.value}"
161199
f"&x-modelName={MOCK_HUGGINGFACE_ID}"
162200
f"&x-endpointArn={MOCK_ENDPOINT_ARN}"
163201
f"&x-latency={latency}"

0 commit comments

Comments
 (0)