Skip to content

fix: minor jumpstart dev ex improvements #4279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 5, 2023
Merged
18 changes: 16 additions & 2 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,11 @@ def register(
)

def prepare_container_def(
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
self,
instance_type=None,
accelerator_type=None,
serverless_inference_config=None,
accept_eula=None,
):
"""Return a container definition with framework configuration set in model environment.

Expand All @@ -278,6 +282,11 @@ def prepare_container_def(
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
Specifies configuration related to serverless endpoint. Instance type is
not provided in serverless inference. So this is used to find image URIs.
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).

Returns:
dict[str, str]: A container definition object usable with the
Expand Down Expand Up @@ -307,7 +316,12 @@ def prepare_container_def(
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string(
self.model_server_workers
)
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
return sagemaker.container_def(
deploy_image,
self.model_data,
deploy_env,
accept_eula=accept_eula,
)

def serving_image_uri(
self, region_name, instance_type, accelerator_type=None, serverless_inference_config=None
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/djl_inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ def prepare_container_def(
instance_type=None,
accelerator_type=None,
serverless_inference_config=None,
accept_eula=None,
): # pylint: disable=unused-argument
"""A container definition with framework configuration set in model environment variables.

Expand Down
11 changes: 10 additions & 1 deletion src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ def prepare_container_def(
accelerator_type=None,
serverless_inference_config=None,
inference_tool=None,
accept_eula=None,
):
"""A container definition with framework configuration set in model environment variables.

Expand All @@ -479,6 +480,11 @@ def prepare_container_def(
not provided in serverless inference. So this is used to find image URIs.
inference_tool (str): the tool that will be used to aid in the inference.
Valid values: "neuron, neuronx, None" (default: None).
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).

Returns:
dict[str, str]: A container definition object usable with the
Expand Down Expand Up @@ -510,7 +516,10 @@ def prepare_container_def(
self.model_server_workers
)
return sagemaker.container_def(
deploy_image, self.repacked_model_data or self.model_data, deploy_env
deploy_image,
self.repacked_model_data or self.model_data,
deploy_env,
accept_eula=accept_eula,
)

def serving_image_uri(
Expand Down
16 changes: 16 additions & 0 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""This module stores constants related to SageMaker JumpStart."""
from __future__ import absolute_import
import logging
import os
from typing import Dict, Set, Type
import boto3
from sagemaker.base_deserializers import BaseDeserializer, JSONDeserializer
Expand All @@ -33,6 +34,8 @@
from sagemaker.session import Session


ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING = "DISABLE_JUMPSTART_LOGGING"

JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set(
[
JumpStartLaunchedRegionInfo(
Expand Down Expand Up @@ -209,6 +212,19 @@

JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart")

# disable logging if env var is set
JUMPSTART_LOGGER.addHandler(
type(
"",
(logging.StreamHandler,),
{
"emit": lambda self, *args, **kwargs: logging.StreamHandler.emit(self, *args, **kwargs)
if not os.environ.get(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING)
else None
},
)()
)

try:
DEFAULT_JUMPSTART_SAGEMAKER_SESSION = Session(
boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)
Expand Down
17 changes: 11 additions & 6 deletions src/sagemaker/jumpstart/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@
"Note that models may have different input/output signatures after a major version upgrade."
)

_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION = (
"We recommend that you specify a more recent "
"model version or choose a different model. To access the latest models "
"and model versions, be sure to upgrade to the latest version of the SageMaker Python SDK."
)


def get_wildcard_model_version_msg(
model_id: str, wildcard_model_version: str, full_model_version: str
Expand Down Expand Up @@ -115,16 +121,16 @@ def __init__(
self.message = (
f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore
"has at least 1 vulnerable dependency in the inference script. "
"Please try targeting a higher version of the model or using a "
"different model. List of vulnerabilities: "
f"{_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION} "
"List of vulnerabilities: "
f"{', '.join(vulnerabilities)}" # type: ignore
)
elif scope == JumpStartScriptScope.TRAINING:
self.message = (
f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore
"has at least 1 vulnerable dependency in the training script. "
"Please try targeting a higher version of the model or using a "
"different model. List of vulnerabilities: "
f"{_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION} "
"List of vulnerabilities: "
f"{', '.join(vulnerabilities)}" # type: ignore
)
else:
Expand Down Expand Up @@ -159,8 +165,7 @@ def __init__(
raise RuntimeError("Must specify `model_id` and `version` arguments.")
self.message = (
f"Version '{version}' of JumpStart model '{model_id}' is deprecated. "
"Please try targeting a higher version of the model or using a "
"different model."
f"{_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION}"
)

super().__init__(self.message)
2 changes: 0 additions & 2 deletions src/sagemaker/jumpstart/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class SpecialSupportedFilterKeys(str, Enum):

TASK = "task"
FRAMEWORK = "framework"
SUPPORTED_MODEL = "supported_model"


FILTER_OPERATOR_STRING_MAPPINGS = {
Expand Down Expand Up @@ -74,7 +73,6 @@ class SpecialSupportedFilterKeys(str, Enum):
[
SpecialSupportedFilterKeys.TASK,
SpecialSupportedFilterKeys.FRAMEWORK,
SpecialSupportedFilterKeys.SUPPORTED_MODEL,
]
)

Expand Down
Loading