-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: minor jumpstart dev ex improvements #4279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
2a8d685
cc99a57
b9a8ef1
6673e99
86822d2
ec0c76d
3528e48
b1789fb
0bcf25e
20b06f2
15b2bd5
bca5955
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -115,80 +115,6 @@ def _construct_payload( | |
return payload_to_use | ||
|
||
|
||
def _extract_generated_text_from_response( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Could you specify why you are removing this in the PR description? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, will do |
||
response: dict, | ||
model_id: str, | ||
model_version: str, | ||
region: Optional[str] = None, | ||
tolerate_vulnerable_model: bool = False, | ||
tolerate_deprecated_model: bool = False, | ||
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
accept_type: Optional[str] = None, | ||
) -> str: | ||
"""Returns generated text extracted from full response payload. | ||
|
||
Args: | ||
response (dict): Dictionary-valued response from which to extract | ||
generated text. | ||
model_id (str): JumpStart model ID of the JumpStart model from which to extract | ||
generated text. | ||
model_version (str): Version of the JumpStart model for which to extract generated | ||
text. | ||
region (Optional[str]): Region for which to extract generated | ||
text. (Default: None). | ||
tolerate_vulnerable_model (bool): True if vulnerable versions of model | ||
specifications should be tolerated (exception not raised). If False, raises an | ||
exception if the script used by this version of the model has dependencies with known | ||
security vulnerabilities. (Default: False). | ||
tolerate_deprecated_model (bool): True if deprecated versions of model | ||
specifications should be tolerated (exception not raised). If False, raises | ||
an exception if the version of the model is deprecated. (Default: False). | ||
sagemaker_session (sagemaker.session.Session): A SageMaker Session | ||
object, used for SageMaker interactions. If not | ||
specified, one is created using the default AWS configuration | ||
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). | ||
accept_type (Optional[str]): The accept type to optionally specify for the response. | ||
(Default: None). | ||
|
||
Returns: | ||
str: extracted generated text from the endpoint response payload. | ||
|
||
Raises: | ||
ValueError: If the model is invalid, the model does not support generated text extraction, | ||
or if the response is malformed. | ||
""" | ||
|
||
if not isinstance(response, dict): | ||
raise ValueError(f"Response must be dictionary. Instead, got: {type(response)}") | ||
|
||
payloads: Optional[Dict[str, JumpStartSerializablePayload]] = _retrieve_example_payloads( | ||
model_id=model_id, | ||
model_version=model_version, | ||
region=region, | ||
tolerate_vulnerable_model=tolerate_vulnerable_model, | ||
tolerate_deprecated_model=tolerate_deprecated_model, | ||
sagemaker_session=sagemaker_session, | ||
) | ||
if payloads is None or len(payloads) == 0: | ||
raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.") | ||
|
||
for payload in payloads.values(): | ||
if accept_type is None or payload.accept == accept_type: | ||
generated_text_response_key: Optional[str] = payload.generated_text_response_key | ||
if generated_text_response_key is None: | ||
raise ValueError( | ||
f"Model ID '{model_id}' does not support generated text extraction." | ||
) | ||
|
||
generated_text_response_key_split = generated_text_response_key.split(".") | ||
try: | ||
return _extract_field_from_json(response, generated_text_response_key_split) | ||
except KeyError: | ||
raise ValueError(f"Response is malformed: {response}") | ||
|
||
raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.") | ||
|
||
|
||
class PayloadSerializer: | ||
"""Utility class for serializing payloads associated with JumpStart models. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -339,7 +339,6 @@ class JumpStartSerializablePayload(JumpStartDataHolderType): | |
"content_type", | ||
"accept", | ||
"body", | ||
"generated_text_response_key", | ||
JGuinegagne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"prompt_key", | ||
] | ||
|
||
|
@@ -371,7 +370,6 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: | |
self.content_type = json_obj["content_type"] | ||
self.body = json_obj["body"] | ||
accept = json_obj.get("accept") | ||
self.generated_text_response_key = json_obj.get("generated_text_response_key") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: same as above |
||
self.prompt_key = json_obj.get("prompt_key") | ||
if accept: | ||
self.accept = accept | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,7 +109,8 @@ def get_jumpstart_gated_content_bucket( | |
accessors.JumpStartModelsAccessor.set_jumpstart_gated_content_bucket(gated_bucket_to_return) | ||
|
||
if gated_bucket_to_return != old_gated_content_bucket: | ||
accessors.JumpStartModelsAccessor.reset_cache() | ||
if old_gated_content_bucket is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: This seems to be copied from get_jumpstart_content_bucket(). Can we pass in these bucket names in and merge these two functions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for now, let's keep as is. i agree, there is code duplication that should be removed. btw, this change is for ensuring the wildcard logging happens a single time. basically, what is happening is that when a gated model is initialized, the cache got reset, causing the wildcard warning to get emitted twice. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see, that makes a lot of sense. Nice catch! |
||
accessors.JumpStartModelsAccessor.reset_cache() | ||
for info_log in info_logs: | ||
constants.JUMPSTART_LOGGER.info(info_log) | ||
|
||
|
@@ -153,7 +154,8 @@ def get_jumpstart_content_bucket( | |
accessors.JumpStartModelsAccessor.set_jumpstart_content_bucket(bucket_to_return) | ||
|
||
if bucket_to_return != old_content_bucket: | ||
accessors.JumpStartModelsAccessor.reset_cache() | ||
if old_content_bucket is not None: | ||
accessors.JumpStartModelsAccessor.reset_cache() | ||
for info_log in info_logs: | ||
constants.JUMPSTART_LOGGER.info(info_log) | ||
return bucket_to_return | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -547,6 +547,7 @@ def create( | |
accelerator_type: Optional[str] = None, | ||
serverless_inference_config: Optional[ServerlessInferenceConfig] = None, | ||
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, | ||
accept_eula: Optional[bool] = None, | ||
): | ||
"""Create a SageMaker Model Entity | ||
|
||
|
@@ -570,6 +571,11 @@ def create( | |
For more information about tags, see | ||
`boto3 documentation <https://boto3.amazonaws.com/v1/documentation/\ | ||
api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags>`_ | ||
accept_eula (bool): For models that require a Model Access Config, specify True or | ||
False to indicate whether model terms of use have been accepted. | ||
The `accept_eula` value must be explicitly defined as `True` in order to | ||
accept the end-user license agreement (EULA) that some | ||
models require. (Default: None). | ||
|
||
Returns: | ||
None or pipeline step arguments in case the Model instance is built with | ||
|
@@ -581,6 +587,7 @@ def create( | |
accelerator_type=accelerator_type, | ||
tags=tags, | ||
serverless_inference_config=serverless_inference_config, | ||
accept_eula=accept_eula, | ||
) | ||
|
||
def _init_sagemaker_session_if_does_not_exist(self, instance_type=None): | ||
|
@@ -601,6 +608,7 @@ def prepare_container_def( | |
instance_type=None, | ||
accelerator_type=None, | ||
serverless_inference_config=None, | ||
accept_eula=None, | ||
): # pylint: disable=unused-argument | ||
"""Return a dict created by ``sagemaker.container_def()``. | ||
|
||
|
@@ -618,6 +626,11 @@ def prepare_container_def( | |
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): | ||
Specifies configuration related to serverless endpoint. Instance type is | ||
not provided in serverless inference. So this is used to find image URIs. | ||
accept_eula (bool): For models that require a Model Access Config, specify True or | ||
False to indicate whether model terms of use have been accepted. | ||
The `accept_eula` value must be explicitly defined as `True` in order to | ||
accept the end-user license agreement (EULA) that some | ||
models require. (Default: None). | ||
|
||
Returns: | ||
dict: A container definition object usable with the CreateModel API. | ||
|
@@ -635,7 +648,7 @@ def prepare_container_def( | |
self.repacked_model_data or self.model_data, | ||
deploy_env, | ||
image_config=self.image_config, | ||
accept_eula=getattr(self, "accept_eula", None), | ||
accept_eula=accept_eula or getattr(self, "accept_eula", None), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question: Why not adhere the changes to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. imo we should pass it in and modify There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is possible to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question: what happens if Could you add a unit test for this case please? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, totally agree @evakravi. Doesn't have to be this PR but can we put a backlog item to refactor |
||
) | ||
|
||
def is_repack(self) -> bool: | ||
|
@@ -777,6 +790,7 @@ def _create_sagemaker_model( | |
accelerator_type=None, | ||
tags=None, | ||
serverless_inference_config=None, | ||
accept_eula=None, | ||
): | ||
"""Create a SageMaker Model Entity | ||
|
||
|
@@ -796,6 +810,11 @@ def _create_sagemaker_model( | |
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): | ||
Specifies configuration related to serverless endpoint. Instance type is | ||
not provided in serverless inference. So this is used to find image URIs. | ||
accept_eula (bool): For models that require a Model Access Config, specify True or | ||
False to indicate whether model terms of use have been accepted. | ||
The `accept_eula` value must be explicitly defined as `True` in order to | ||
accept the end-user license agreement (EULA) that some | ||
models require. (Default: None). | ||
""" | ||
|
||
if self.model_package_arn is not None or self.algorithm_arn is not None: | ||
|
@@ -827,6 +846,7 @@ def _create_sagemaker_model( | |
instance_type, | ||
accelerator_type=accelerator_type, | ||
serverless_inference_config=serverless_inference_config, | ||
accept_eula=accept_eula, | ||
) | ||
|
||
if not isinstance(self.sagemaker_session, PipelineSession): | ||
|
@@ -1434,7 +1454,12 @@ def deploy( | |
"serverless_inference_config needs to be a ServerlessInferenceConfig object" | ||
) | ||
|
||
if instance_type and instance_type.startswith("ml.inf") and not self._is_compiled_model: | ||
if ( | ||
getattr(self, "model_id", None) in {"", None} | ||
and instance_type | ||
and instance_type.startswith("ml.inf") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Consider a const There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you create a separate utility for this + unit test it please? |
||
and not self._is_compiled_model | ||
): | ||
LOGGER.warning( | ||
"Your model is not compiled. Please compile your model before using Inferentia." | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -121,7 +121,11 @@ def __init__( | |
) | ||
|
||
def prepare_container_def( | ||
self, instance_type=None, accelerator_type=None, serverless_inference_config=None | ||
self, | ||
instance_type=None, | ||
accelerator_type=None, | ||
serverless_inference_config=None, | ||
accept_eula=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another plus to adhering to |
||
): | ||
"""Return a container definition set. | ||
|
||
|
@@ -149,6 +153,7 @@ def prepare_container_def( | |
env=environment, | ||
model_data_url=self.model_data_prefix, | ||
container_mode=self.container_mode, | ||
accept_eula=accept_eula, | ||
) | ||
|
||
def deploy( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you check with @judyheflin please? I think they generally don't like using
Please
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommendation:
"Version '{version}' of JumpStart model '{model_id}' is deprecated. We recommend that you specify a more recent model version or choose a different model. To access the latest models and model versions, be sure to upgrade to the latest version of the SageMaker Python SDK."