Skip to content

Commit b688ee6

Browse files
evakravisagemaker-bot
authored andcommitted
fix: minor jumpstart dev ex improvements (aws#4279)
1 parent aa99b5b commit b688ee6

27 files changed

+1043
-294
lines changed

src/sagemaker/chainer/model.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,11 @@ def register(
265265
)
266266

267267
def prepare_container_def(
268-
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
268+
self,
269+
instance_type=None,
270+
accelerator_type=None,
271+
serverless_inference_config=None,
272+
accept_eula=None,
269273
):
270274
"""Return a container definition with framework configuration set in model environment.
271275
@@ -278,6 +282,11 @@ def prepare_container_def(
278282
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
279283
Specifies configuration related to serverless endpoint. Instance type is
280284
not provided in serverless inference. So this is used to find image URIs.
285+
accept_eula (bool): For models that require a Model Access Config, specify True or
286+
False to indicate whether model terms of use have been accepted.
287+
The `accept_eula` value must be explicitly defined as `True` in order to
288+
accept the end-user license agreement (EULA) that some
289+
models require. (Default: None).
281290
282291
Returns:
283292
dict[str, str]: A container definition object usable with the
@@ -307,7 +316,12 @@ def prepare_container_def(
307316
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string(
308317
self.model_server_workers
309318
)
310-
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
319+
return sagemaker.container_def(
320+
deploy_image,
321+
self.model_data,
322+
deploy_env,
323+
accept_eula=accept_eula,
324+
)
311325

312326
def serving_image_uri(
313327
self, region_name, instance_type, accelerator_type=None, serverless_inference_config=None

src/sagemaker/djl_inference/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,7 @@ def prepare_container_def(
733733
instance_type=None,
734734
accelerator_type=None,
735735
serverless_inference_config=None,
736+
accept_eula=None,
736737
): # pylint: disable=unused-argument
737738
"""A container definition with framework configuration set in model environment variables.
738739

src/sagemaker/huggingface/model.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ def prepare_container_def(
465465
accelerator_type=None,
466466
serverless_inference_config=None,
467467
inference_tool=None,
468+
accept_eula=None,
468469
):
469470
"""A container definition with framework configuration set in model environment variables.
470471
@@ -479,6 +480,11 @@ def prepare_container_def(
479480
not provided in serverless inference. So this is used to find image URIs.
480481
inference_tool (str): the tool that will be used to aid in the inference.
481482
Valid values: "neuron, neuronx, None" (default: None).
483+
accept_eula (bool): For models that require a Model Access Config, specify True or
484+
False to indicate whether model terms of use have been accepted.
485+
The `accept_eula` value must be explicitly defined as `True` in order to
486+
accept the end-user license agreement (EULA) that some
487+
models require. (Default: None).
482488
483489
Returns:
484490
dict[str, str]: A container definition object usable with the
@@ -510,7 +516,10 @@ def prepare_container_def(
510516
self.model_server_workers
511517
)
512518
return sagemaker.container_def(
513-
deploy_image, self.repacked_model_data or self.model_data, deploy_env
519+
deploy_image,
520+
self.repacked_model_data or self.model_data,
521+
deploy_env,
522+
accept_eula=accept_eula,
514523
)
515524

516525
def serving_image_uri(

src/sagemaker/jumpstart/constants.py

+16
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""This module stores constants related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
1515
import logging
16+
import os
1617
from typing import Dict, Set, Type
1718
import boto3
1819
from sagemaker.base_deserializers import BaseDeserializer, JSONDeserializer
@@ -33,6 +34,8 @@
3334
from sagemaker.session import Session
3435

3536

37+
ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING = "DISABLE_JUMPSTART_LOGGING"
38+
3639
JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set(
3740
[
3841
JumpStartLaunchedRegionInfo(
@@ -209,6 +212,19 @@
209212

210213
JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart")
211214

215+
# disable logging if env var is set
216+
JUMPSTART_LOGGER.addHandler(
217+
type(
218+
"",
219+
(logging.StreamHandler,),
220+
{
221+
"emit": lambda self, *args, **kwargs: logging.StreamHandler.emit(self, *args, **kwargs)
222+
if not os.environ.get(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING)
223+
else None
224+
},
225+
)()
226+
)
227+
212228
try:
213229
DEFAULT_JUMPSTART_SAGEMAKER_SESSION = Session(
214230
boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)

src/sagemaker/jumpstart/exceptions.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@
3939
"Note that models may have different input/output signatures after a major version upgrade."
4040
)
4141

42+
_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION = (
43+
"We recommend that you specify a more recent "
44+
"model version or choose a different model. To access the latest models "
45+
"and model versions, be sure to upgrade to the latest version of the SageMaker Python SDK."
46+
)
47+
4248

4349
def get_wildcard_model_version_msg(
4450
model_id: str, wildcard_model_version: str, full_model_version: str
@@ -115,16 +121,16 @@ def __init__(
115121
self.message = (
116122
f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore
117123
"has at least 1 vulnerable dependency in the inference script. "
118-
"Please try targeting a higher version of the model or using a "
119-
"different model. List of vulnerabilities: "
124+
f"{_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION} "
125+
"List of vulnerabilities: "
120126
f"{', '.join(vulnerabilities)}" # type: ignore
121127
)
122128
elif scope == JumpStartScriptScope.TRAINING:
123129
self.message = (
124130
f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore
125131
"has at least 1 vulnerable dependency in the training script. "
126-
"Please try targeting a higher version of the model or using a "
127-
"different model. List of vulnerabilities: "
132+
f"{_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION} "
133+
"List of vulnerabilities: "
128134
f"{', '.join(vulnerabilities)}" # type: ignore
129135
)
130136
else:
@@ -159,8 +165,7 @@ def __init__(
159165
raise RuntimeError("Must specify `model_id` and `version` arguments.")
160166
self.message = (
161167
f"Version '{version}' of JumpStart model '{model_id}' is deprecated. "
162-
"Please try targeting a higher version of the model or using a "
163-
"different model."
168+
f"{_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION}"
164169
)
165170

166171
super().__init__(self.message)

src/sagemaker/jumpstart/filters.py

-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ class SpecialSupportedFilterKeys(str, Enum):
4545

4646
TASK = "task"
4747
FRAMEWORK = "framework"
48-
SUPPORTED_MODEL = "supported_model"
4948

5049

5150
FILTER_OPERATOR_STRING_MAPPINGS = {
@@ -74,7 +73,6 @@ class SpecialSupportedFilterKeys(str, Enum):
7473
[
7574
SpecialSupportedFilterKeys.TASK,
7675
SpecialSupportedFilterKeys.FRAMEWORK,
77-
SpecialSupportedFilterKeys.SUPPORTED_MODEL,
7876
]
7977
)
8078

0 commit comments

Comments
 (0)