-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: minor jumpstart dev ex improvements #4279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
2a8d685
cc99a57
b9a8ef1
6673e99
86822d2
ec0c76d
3528e48
b1789fb
0bcf25e
20b06f2
15b2bd5
bca5955
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -341,7 +341,6 @@ class JumpStartSerializablePayload(JumpStartDataHolderType): | |
"content_type", | ||
"accept", | ||
"body", | ||
"generated_text_response_key", | ||
JGuinegagne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"prompt_key", | ||
] | ||
|
||
|
@@ -373,7 +372,6 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: | |
self.content_type = json_obj["content_type"] | ||
self.body = json_obj["body"] | ||
accept = json_obj.get("accept") | ||
self.generated_text_response_key = json_obj.get("generated_text_response_key") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: same as above |
||
self.prompt_key = json_obj.get("prompt_key") | ||
if accept: | ||
self.accept = accept | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,7 +109,8 @@ def get_jumpstart_gated_content_bucket( | |
accessors.JumpStartModelsAccessor.set_jumpstart_gated_content_bucket(gated_bucket_to_return) | ||
|
||
if gated_bucket_to_return != old_gated_content_bucket: | ||
accessors.JumpStartModelsAccessor.reset_cache() | ||
if old_gated_content_bucket is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: This seems to be copied from get_jumpstart_content_bucket(). Can we pass in these bucket names in and merge these two functions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for now, let's keep as is. i agree, there is code duplication that should be removed. btw, this change is for ensuring the wildcard logging happens a single time. basically, what is happening is that when a gated model is initialized, the cache got reset, causing the wildcard warning to get emitted twice. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see, that makes a lot of sense. Nice catch! |
||
accessors.JumpStartModelsAccessor.reset_cache() | ||
for info_log in info_logs: | ||
constants.JUMPSTART_LOGGER.info(info_log) | ||
|
||
|
@@ -153,7 +154,8 @@ def get_jumpstart_content_bucket( | |
accessors.JumpStartModelsAccessor.set_jumpstart_content_bucket(bucket_to_return) | ||
|
||
if bucket_to_return != old_content_bucket: | ||
accessors.JumpStartModelsAccessor.reset_cache() | ||
if old_content_bucket is not None: | ||
accessors.JumpStartModelsAccessor.reset_cache() | ||
for info_log in info_logs: | ||
constants.JUMPSTART_LOGGER.info(info_log) | ||
return bucket_to_return | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -559,6 +559,7 @@ def create( | |
accelerator_type: Optional[str] = None, | ||
serverless_inference_config: Optional[ServerlessInferenceConfig] = None, | ||
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, | ||
accept_eula: Optional[bool] = None, | ||
): | ||
"""Create a SageMaker Model Entity | ||
|
||
|
@@ -582,6 +583,11 @@ def create( | |
For more information about tags, see | ||
`boto3 documentation <https://boto3.amazonaws.com/v1/documentation/\ | ||
api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags>`_ | ||
accept_eula (bool): For models that require a Model Access Config, specify True or | ||
False to indicate whether model terms of use have been accepted. | ||
The `accept_eula` value must be explicitly defined as `True` in order to | ||
accept the end-user license agreement (EULA) that some | ||
models require. (Default: None). | ||
|
||
Returns: | ||
None or pipeline step arguments in case the Model instance is built with | ||
|
@@ -593,6 +599,7 @@ def create( | |
accelerator_type=accelerator_type, | ||
tags=tags, | ||
serverless_inference_config=serverless_inference_config, | ||
accept_eula=accept_eula, | ||
) | ||
|
||
def _init_sagemaker_session_if_does_not_exist(self, instance_type=None): | ||
|
@@ -613,6 +620,7 @@ def prepare_container_def( | |
instance_type=None, | ||
accelerator_type=None, | ||
serverless_inference_config=None, | ||
accept_eula=None, | ||
): # pylint: disable=unused-argument | ||
"""Return a dict created by ``sagemaker.container_def()``. | ||
|
||
|
@@ -630,6 +638,11 @@ def prepare_container_def( | |
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): | ||
Specifies configuration related to serverless endpoint. Instance type is | ||
not provided in serverless inference. So this is used to find image URIs. | ||
accept_eula (bool): For models that require a Model Access Config, specify True or | ||
False to indicate whether model terms of use have been accepted. | ||
The `accept_eula` value must be explicitly defined as `True` in order to | ||
accept the end-user license agreement (EULA) that some | ||
models require. (Default: None). | ||
|
||
Returns: | ||
dict: A container definition object usable with the CreateModel API. | ||
|
@@ -647,7 +660,9 @@ def prepare_container_def( | |
self.repacked_model_data or self.model_data, | ||
deploy_env, | ||
image_config=self.image_config, | ||
accept_eula=getattr(self, "accept_eula", None), | ||
accept_eula=accept_eula | ||
if accept_eula is not None | ||
else getattr(self, "accept_eula", None), | ||
) | ||
|
||
def is_repack(self) -> bool: | ||
|
@@ -789,6 +804,7 @@ def _create_sagemaker_model( | |
accelerator_type=None, | ||
tags=None, | ||
serverless_inference_config=None, | ||
accept_eula=None, | ||
): | ||
"""Create a SageMaker Model Entity | ||
|
||
|
@@ -808,6 +824,11 @@ def _create_sagemaker_model( | |
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): | ||
Specifies configuration related to serverless endpoint. Instance type is | ||
not provided in serverless inference. So this is used to find image URIs. | ||
accept_eula (bool): For models that require a Model Access Config, specify True or | ||
False to indicate whether model terms of use have been accepted. | ||
The `accept_eula` value must be explicitly defined as `True` in order to | ||
accept the end-user license agreement (EULA) that some | ||
models require. (Default: None). | ||
""" | ||
if self.model_package_arn is not None or self.algorithm_arn is not None: | ||
model_package = ModelPackage( | ||
|
@@ -838,6 +859,7 @@ def _create_sagemaker_model( | |
instance_type, | ||
accelerator_type=accelerator_type, | ||
serverless_inference_config=serverless_inference_config, | ||
accept_eula=accept_eula, | ||
) | ||
|
||
if not isinstance(self.sagemaker_session, PipelineSession): | ||
|
@@ -1459,7 +1481,12 @@ def deploy( | |
"serverless_inference_config needs to be a ServerlessInferenceConfig object" | ||
) | ||
|
||
if instance_type and instance_type.startswith("ml.inf") and not self._is_compiled_model: | ||
if ( | ||
getattr(self, "model_id", None) in {"", None} | ||
and instance_type | ||
and instance_type.startswith("ml.inf") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Consider a const There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you create a separate utility for this + unit test it please? |
||
and not self._is_compiled_model | ||
): | ||
LOGGER.warning( | ||
"Your model is not compiled. Please compile your model before using Inferentia." | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -121,7 +121,11 @@ def __init__( | |
) | ||
|
||
def prepare_container_def( | ||
self, instance_type=None, accelerator_type=None, serverless_inference_config=None | ||
self, | ||
instance_type=None, | ||
accelerator_type=None, | ||
serverless_inference_config=None, | ||
accept_eula=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another plus to adhering to |
||
): | ||
"""Return a container definition set. | ||
|
||
|
@@ -149,6 +153,7 @@ def prepare_container_def( | |
env=environment, | ||
model_data_url=self.model_data_prefix, | ||
container_mode=self.container_mode, | ||
accept_eula=accept_eula, | ||
) | ||
|
||
def deploy( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Could you specify why you are removing this in the PR description?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, will do