Skip to content

fix: minor jumpstart dev ex improvements #4279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 5, 2023
Merged
18 changes: 16 additions & 2 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,11 @@ def register(
)

def prepare_container_def(
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
self,
instance_type=None,
accelerator_type=None,
serverless_inference_config=None,
accept_eula=None,
):
"""Return a container definition with framework configuration set in model environment.

Expand All @@ -278,6 +282,11 @@ def prepare_container_def(
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
Specifies configuration related to serverless endpoint. Instance type is
not provided in serverless inference. So this is used to find image URIs.
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).

Returns:
dict[str, str]: A container definition object usable with the
Expand Down Expand Up @@ -307,7 +316,12 @@ def prepare_container_def(
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string(
self.model_server_workers
)
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
return sagemaker.container_def(
deploy_image,
self.model_data,
deploy_env,
accept_eula=accept_eula,
)

def serving_image_uri(
self, region_name, instance_type, accelerator_type=None, serverless_inference_config=None
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/djl_inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ def prepare_container_def(
instance_type=None,
accelerator_type=None,
serverless_inference_config=None,
accept_eula=None,
): # pylint: disable=unused-argument
"""A container definition with framework configuration set in model environment variables.

Expand Down
11 changes: 10 additions & 1 deletion src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ def prepare_container_def(
accelerator_type=None,
serverless_inference_config=None,
inference_tool=None,
accept_eula=None,
):
"""A container definition with framework configuration set in model environment variables.

Expand All @@ -479,6 +480,11 @@ def prepare_container_def(
not provided in serverless inference. So this is used to find image URIs.
inference_tool (str): the tool that will be used to aid in the inference.
Valid values: "neuron, neuronx, None" (default: None).
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).

Returns:
dict[str, str]: A container definition object usable with the
Expand Down Expand Up @@ -510,7 +516,10 @@ def prepare_container_def(
self.model_server_workers
)
return sagemaker.container_def(
deploy_image, self.repacked_model_data or self.model_data, deploy_env
deploy_image,
self.repacked_model_data or self.model_data,
deploy_env,
accept_eula=accept_eula,
)

def serving_image_uri(
Expand Down
17 changes: 11 additions & 6 deletions src/sagemaker/jumpstart/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@
"Note that models may have different input/output signatures after a major version upgrade."
)

_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION = (
"We recommend that you specify a more recent "
"model version or choose a different model. To access the latest models "
"and model versions, be sure to upgrade to the latest version of the SageMaker Python SDK."
)


def get_wildcard_model_version_msg(
model_id: str, wildcard_model_version: str, full_model_version: str
Expand Down Expand Up @@ -115,16 +121,16 @@ def __init__(
self.message = (
f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore
"has at least 1 vulnerable dependency in the inference script. "
"Please try targeting a higher version of the model or using a "
"different model. List of vulnerabilities: "
f"{_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION} "
"List of vulnerabilities: "
f"{', '.join(vulnerabilities)}" # type: ignore
)
elif scope == JumpStartScriptScope.TRAINING:
self.message = (
f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore
"has at least 1 vulnerable dependency in the training script. "
"Please try targeting a higher version of the model or using a "
"different model. List of vulnerabilities: "
f"{_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION} "
"List of vulnerabilities: "
f"{', '.join(vulnerabilities)}" # type: ignore
)
else:
Expand Down Expand Up @@ -159,8 +165,7 @@ def __init__(
raise RuntimeError("Must specify `model_id` and `version` arguments.")
self.message = (
f"Version '{version}' of JumpStart model '{model_id}' is deprecated. "
"Please try targeting a higher version of the model or using a "
"different model."
f"{_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION}"
)

super().__init__(self.message)
74 changes: 0 additions & 74 deletions src/sagemaker/jumpstart/payload_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,80 +115,6 @@ def _construct_payload(
return payload_to_use


def _extract_generated_text_from_response(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Could you specify why you are removing this in the PR description?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, will do

response: dict,
model_id: str,
model_version: str,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
accept_type: Optional[str] = None,
) -> str:
"""Returns generated text extracted from full response payload.

Args:
response (dict): Dictionary-valued response from which to extract
generated text.
model_id (str): JumpStart model ID of the JumpStart model from which to extract
generated text.
model_version (str): Version of the JumpStart model for which to extract generated
text.
region (Optional[str]): Region for which to extract generated
text. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities. (Default: False).
tolerate_deprecated_model (bool): True if deprecated versions of model
specifications should be tolerated (exception not raised). If False, raises
an exception if the version of the model is deprecated. (Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
accept_type (Optional[str]): The accept type to optionally specify for the response.
(Default: None).

Returns:
str: extracted generated text from the endpoint response payload.

Raises:
ValueError: If the model is invalid, the model does not support generated text extraction,
or if the response is malformed.
"""

if not isinstance(response, dict):
raise ValueError(f"Response must be dictionary. Instead, got: {type(response)}")

payloads: Optional[Dict[str, JumpStartSerializablePayload]] = _retrieve_example_payloads(
model_id=model_id,
model_version=model_version,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)
if payloads is None or len(payloads) == 0:
raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.")

for payload in payloads.values():
if accept_type is None or payload.accept == accept_type:
generated_text_response_key: Optional[str] = payload.generated_text_response_key
if generated_text_response_key is None:
raise ValueError(
f"Model ID '{model_id}' does not support generated text extraction."
)

generated_text_response_key_split = generated_text_response_key.split(".")
try:
return _extract_field_from_json(response, generated_text_response_key_split)
except KeyError:
raise ValueError(f"Response is malformed: {response}")

raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.")


class PayloadSerializer:
"""Utility class for serializing payloads associated with JumpStart models.

Expand Down
2 changes: 0 additions & 2 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,6 @@ class JumpStartSerializablePayload(JumpStartDataHolderType):
"content_type",
"accept",
"body",
"generated_text_response_key",
"prompt_key",
]

Expand Down Expand Up @@ -373,7 +372,6 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
self.content_type = json_obj["content_type"]
self.body = json_obj["body"]
accept = json_obj.get("accept")
self.generated_text_response_key = json_obj.get("generated_text_response_key")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: same as above

self.prompt_key = json_obj.get("prompt_key")
if accept:
self.accept = accept
Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def get_jumpstart_gated_content_bucket(
accessors.JumpStartModelsAccessor.set_jumpstart_gated_content_bucket(gated_bucket_to_return)

if gated_bucket_to_return != old_gated_content_bucket:
accessors.JumpStartModelsAccessor.reset_cache()
if old_gated_content_bucket is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This seems to be copied from get_jumpstart_content_bucket(). Can we pass in these bucket names in and merge these two functions?

Copy link
Member Author

@evakravi evakravi Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now, let's keep as is. i agree, there is code duplication that should be removed.

btw, this change is for ensuring the wildcard logging happens a single time. basically, what is happening is that when a gated model is initialized, the cache got reset, causing the wildcard warning to get emitted twice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, that makes a lot of sense. Nice catch!

accessors.JumpStartModelsAccessor.reset_cache()
for info_log in info_logs:
constants.JUMPSTART_LOGGER.info(info_log)

Expand Down Expand Up @@ -153,7 +154,8 @@ def get_jumpstart_content_bucket(
accessors.JumpStartModelsAccessor.set_jumpstart_content_bucket(bucket_to_return)

if bucket_to_return != old_content_bucket:
accessors.JumpStartModelsAccessor.reset_cache()
if old_content_bucket is not None:
accessors.JumpStartModelsAccessor.reset_cache()
for info_log in info_logs:
constants.JUMPSTART_LOGGER.info(info_log)
return bucket_to_return
Expand Down
31 changes: 29 additions & 2 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def create(
accelerator_type: Optional[str] = None,
serverless_inference_config: Optional[ServerlessInferenceConfig] = None,
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
accept_eula: Optional[bool] = None,
):
"""Create a SageMaker Model Entity

Expand All @@ -582,6 +583,11 @@ def create(
For more information about tags, see
`boto3 documentation <https://boto3.amazonaws.com/v1/documentation/\
api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags>`_
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).

Returns:
None or pipeline step arguments in case the Model instance is built with
Expand All @@ -593,6 +599,7 @@ def create(
accelerator_type=accelerator_type,
tags=tags,
serverless_inference_config=serverless_inference_config,
accept_eula=accept_eula,
)

def _init_sagemaker_session_if_does_not_exist(self, instance_type=None):
Expand All @@ -613,6 +620,7 @@ def prepare_container_def(
instance_type=None,
accelerator_type=None,
serverless_inference_config=None,
accept_eula=None,
): # pylint: disable=unused-argument
"""Return a dict created by ``sagemaker.container_def()``.

Expand All @@ -630,6 +638,11 @@ def prepare_container_def(
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
Specifies configuration related to serverless endpoint. Instance type is
not provided in serverless inference. So this is used to find image URIs.
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).

Returns:
dict: A container definition object usable with the CreateModel API.
Expand All @@ -647,7 +660,9 @@ def prepare_container_def(
self.repacked_model_data or self.model_data,
deploy_env,
image_config=self.image_config,
accept_eula=getattr(self, "accept_eula", None),
accept_eula=accept_eula
if accept_eula is not None
else getattr(self, "accept_eula", None),
)

def is_repack(self) -> bool:
Expand Down Expand Up @@ -789,6 +804,7 @@ def _create_sagemaker_model(
accelerator_type=None,
tags=None,
serverless_inference_config=None,
accept_eula=None,
):
"""Create a SageMaker Model Entity

Expand All @@ -808,6 +824,11 @@ def _create_sagemaker_model(
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
Specifies configuration related to serverless endpoint. Instance type is
not provided in serverless inference. So this is used to find image URIs.
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).
"""
if self.model_package_arn is not None or self.algorithm_arn is not None:
model_package = ModelPackage(
Expand Down Expand Up @@ -838,6 +859,7 @@ def _create_sagemaker_model(
instance_type,
accelerator_type=accelerator_type,
serverless_inference_config=serverless_inference_config,
accept_eula=accept_eula,
)

if not isinstance(self.sagemaker_session, PipelineSession):
Expand Down Expand Up @@ -1459,7 +1481,12 @@ def deploy(
"serverless_inference_config needs to be a ServerlessInferenceConfig object"
)

if instance_type and instance_type.startswith("ml.inf") and not self._is_compiled_model:
if (
getattr(self, "model_id", None) in {"", None}
and instance_type
and instance_type.startswith("ml.inf")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Consider a const

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you create a separate utility for this + unit test it please?

and not self._is_compiled_model
):
LOGGER.warning(
"Your model is not compiled. Please compile your model before using Inferentia."
)
Expand Down
7 changes: 6 additions & 1 deletion src/sagemaker/multidatamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,11 @@ def __init__(
)

def prepare_container_def(
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
self,
instance_type=None,
accelerator_type=None,
serverless_inference_config=None,
accept_eula=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another plus to adhering to getattr is now you won't need to make every modification here

):
"""Return a container definition set.

Expand Down Expand Up @@ -149,6 +153,7 @@ def prepare_container_def(
env=environment,
model_data_url=self.model_data_prefix,
container_mode=self.container_mode,
accept_eula=accept_eula,
)

def deploy(
Expand Down
16 changes: 14 additions & 2 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,11 @@ def register(
)

def prepare_container_def(
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
self,
instance_type=None,
accelerator_type=None,
serverless_inference_config=None,
accept_eula=None,
):
"""Return a container definition with framework configuration.

Expand All @@ -282,6 +286,11 @@ def prepare_container_def(
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
Specifies configuration related to serverless endpoint. Instance type is
not provided in serverless inference. So this is used to find image URIs.
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).

Returns:
dict[str, str]: A container definition object usable with the
Expand Down Expand Up @@ -312,7 +321,10 @@ def prepare_container_def(
self.model_server_workers
)
return sagemaker.container_def(
deploy_image, self.repacked_model_data or self.model_data, deploy_env
deploy_image,
self.repacked_model_data or self.model_data,
deploy_env,
accept_eula=accept_eula,
)

def serving_image_uri(
Expand Down
Loading