Skip to content

Commit ce87772

Browse files
jiapinwakrishna1995
authored andcommitted
Initial commit
1 parent 655589b commit ce87772

File tree

7 files changed

+65
-19
lines changed

7 files changed

+65
-19
lines changed

src/sagemaker/serve/builder/djl_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(self):
8383
self.mode = None
8484
self.model_server = None
8585
self.image_uri = None
86+
self._is_custom_image_uri = False
8687
self.image_config = None
8788
self.vpc_config = None
8889
self._original_deploy = None

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def __init__(self):
5757
self.mode = None
5858
self.model_server = None
5959
self.image_uri = None
60+
self._is_custom_image_uri = False
61+
self.vpc_config = None
6062
self._original_deploy = None
6163
self.secret_key = None
6264
self.js_model_config = None
@@ -94,7 +96,7 @@ def _is_jumpstart_model_id(self) -> bool:
9496

9597
def _create_pre_trained_js_model(self) -> Type[Model]:
9698
"""Placeholder docstring"""
97-
pysdk_model = JumpStartModel(self.model)
99+
pysdk_model = JumpStartModel(self.model, vpc_config=self.vpc_config)
98100
pysdk_model.sagemaker_session = self.sagemaker_session
99101

100102
self._original_deploy = pysdk_model.deploy

src/sagemaker/serve/builder/model_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def _auto_detect_container(self):
294294
"Skipping auto detection as the image uri is provided %s",
295295
self.image_uri,
296296
)
297+
self._is_custom_image_uri = True
297298
return
298299

299300
if self.model:
@@ -605,6 +606,8 @@ def build(
605606

606607
self.serve_settings = self._get_serve_setting()
607608

609+
self._is_custom_image_uri = self.image_uri is None
610+
608611
if isinstance(self.model, str):
609612
if self._is_jumpstart_model_id():
610613
return self._build_for_jumpstart()

src/sagemaker/serve/builder/tgi_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(self):
7676
self.mode = None
7777
self.model_server = None
7878
self.image_uri = None
79+
self._is_custom_image_uri = False
7980
self.image_config = None
8081
self.vpc_config = None
8182
self._original_deploy = None

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def __init__(self):
5656
self.mode = None
5757
self.model_server = None
5858
self.image_uri = None
59+
self._is_custom_image_uri = False
60+
self.vpc_config = None
5961
self._original_deploy = None
6062
self.hf_model_config = None
6163
self._default_data_type = None
@@ -111,6 +113,7 @@ def _create_transformers_model(self) -> Type[Model]:
111113
py_version=self.py_version,
112114
transformers_version=base_hf_version,
113115
pytorch_version=self.pytorch_version,
116+
vpc_config=self.vpc_config,
114117
)
115118
elif "keras" in hf_model_md.get("tags") or "tensorflow" in hf_model_md.get("tags"):
116119
self.tensorflow_version = self._get_supported_version(
@@ -126,6 +129,7 @@ def _create_transformers_model(self) -> Type[Model]:
126129
py_version=self.py_version,
127130
transformers_version=base_hf_version,
128131
tensorflow_version=self.tensorflow_version,
132+
vpc_config=self.vpc_config,
129133
)
130134

131135
if self.mode == Mode.LOCAL_CONTAINER:

src/sagemaker/serve/utils/telemetry_logger.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515
import logging
16+
from typing import Optional
1617
from time import perf_counter
1718

1819
import requests
1920

2021
from sagemaker import Session, exceptions
2122
from sagemaker.serve.mode.function_pointers import Mode
2223
from sagemaker.serve.utils.exceptions import ModelBuilderException
23-
from sagemaker.serve.utils.types import ModelServer
24+
from sagemaker.serve.utils.types import ModelServer, ImageUriOption
25+
from sagemaker.serve.validations.check_image_uri import is_1p_image_uri
2426
from sagemaker.user_agent import SDK_VERSION
2527

2628
logger = logging.getLogger(__name__)
@@ -75,6 +77,8 @@ def wrapper(self, *args, **kwargs):
7577
if self.sagemaker_session and self.sagemaker_session.endpoint_arn:
7678
extra += f"&x-endpointArn={self.sagemaker_session.endpoint_arn}"
7779

80+
extra += f"&x-defaultImageUsage={_get_image_uri_option(self.image_uri, self._is_custom_image_uri)}"
81+
7882
start_timer = perf_counter()
7983
try:
8084
response = func(self, *args, **kwargs)
@@ -91,10 +95,10 @@ def wrapper(self, *args, **kwargs):
9195
extra,
9296
)
9397
except (
94-
ModelBuilderException,
95-
exceptions.CapacityError,
96-
exceptions.UnexpectedStatusException,
97-
exceptions.AsyncInferenceError,
98+
ModelBuilderException,
99+
exceptions.CapacityError,
100+
exceptions.UnexpectedStatusException,
101+
exceptions.AsyncInferenceError,
98102
) as e:
99103
stop_timer = perf_counter()
100104
elapsed = stop_timer - start_timer
@@ -122,12 +126,12 @@ def wrapper(self, *args, **kwargs):
122126

123127

124128
def _send_telemetry(
125-
status: str,
126-
mode: int,
127-
session: Session,
128-
failure_reason: str = None,
129-
failure_type: str = None,
130-
extra_info: str = None,
129+
status: str,
130+
mode: int,
131+
session: Session,
132+
failure_reason: str = None,
133+
failure_type: str = None,
134+
extra_info: str = None,
131135
) -> None:
132136
"""Make GET request to an empty object in S3 bucket"""
133137
try:
@@ -149,13 +153,13 @@ def _send_telemetry(
149153

150154

151155
def _construct_url(
152-
accountId: str,
153-
mode: str,
154-
status: str,
155-
failure_reason: str,
156-
failure_type: str,
157-
extra_info: str,
158-
region: str,
156+
accountId: str,
157+
mode: str,
158+
status: str,
159+
failure_reason: str,
160+
failure_type: str,
161+
extra_info: str,
162+
region: str,
159163
) -> str:
160164
"""Placeholder docstring"""
161165

@@ -201,3 +205,22 @@ def _get_region_or_default(session):
201205
return session.boto_session.region_name
202206
except Exception: # pylint: disable=W0703
203207
return "us-west-2"
208+
209+
210+
def _get_image_uri_option(image_uri: str, is_custom_image: bool) -> int:
211+
"""Detect whether default values are used for ModelBuilder
212+
213+
Args:
214+
image_uri (str): Image uri used by ModelBuilder.
215+
is_custom_image: (bool): Boolean indicating whether customer provides with custom image.
216+
Returns:
217+
bool: Integer code of image option types.
218+
"""
219+
220+
if not is_custom_image:
221+
return ImageUriOption.DEFAULT_IMAGE.value
222+
223+
if is_1p_image_uri(image_uri):
224+
return ImageUriOption.CUSTOM_1P_IMAGE.value
225+
226+
return ImageUriOption.CUSTOM_IMAGE.value

src/sagemaker/serve/utils/types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,15 @@ def __str__(self) -> str:
4343
INFERENTIA_1 = 3
4444
INFERENTIA_2 = 4
4545
GRAVITON = 5
46+
47+
48+
class ImageUriOption(Enum):
49+
"Enum type for image uri options"
50+
51+
def __str__(self) -> str:
52+
"""Convert enum to string"""
53+
return str(self.name)
54+
55+
CUSTOM_IMAGE = 1
56+
CUSTOM_1P_IMAGE = 2
57+
DEFAULT_IMAGE = 3

0 commit comments

Comments
 (0)