-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: intelligent defaults for volume size, JS Estimator image uri region, Predictor str method #3870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: intelligent defaults for volume size, JS Estimator image uri region, Predictor str method #3870
Changes from 2 commits
3456684
d42cad6
8c379c1
fc4096b
55e93d9
fe4092a
60e1184
da1ad6b
71714b6
b0b6edb
22d27ee
84969d9
84e88c0
b6e5f8c
7709a87
c19fc9d
9f0533f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,6 +63,7 @@ | |
) | ||
from sagemaker.inputs import TrainingInput, FileSystemInput | ||
from sagemaker.instance_group import InstanceGroup | ||
from sagemaker.utils import instance_supports_kms | ||
from sagemaker.job import _Job | ||
from sagemaker.jumpstart.utils import ( | ||
add_jumpstart_tags, | ||
|
@@ -95,6 +96,7 @@ | |
) | ||
from sagemaker.workflow import is_pipeline_variable | ||
from sagemaker.workflow.entities import PipelineVariable | ||
from sagemaker.workflow.parameters import ParameterString | ||
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -599,10 +601,30 @@ def __init__( | |
self.output_kms_key = resolve_value_from_config( | ||
output_kms_key, TRAINING_JOB_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session | ||
) | ||
self.volume_kms_key = resolve_value_from_config( | ||
volume_kms_key, | ||
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, | ||
sagemaker_session=self.sagemaker_session, | ||
if instance_type is None or isinstance(instance_type, str): | ||
instance_type_for_volume_kms = instance_type | ||
elif isinstance(instance_type, ParameterString): | ||
instance_type_for_volume_kms = instance_type.default_value | ||
else: | ||
raise ValueError(f"Bad value for instance type: '{instance_type}'") | ||
Comment on lines
+608
to
+609
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would this be backwards incompatible? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm pretty sure. The unit tests make it seem like it's compatible. There shouldn't be a different python type for the instance type. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Today if a user passes in a different python type, they wont get an error right? (My understanding is that the "Optional[Union[str, PipelineVariable]]" is not strictly enforced by Python). Let me know if im mistaken If Im not wrong about the above, this would be changing behavior of the Estimator class (and I havent seen type assertions like these in other places anyway). Seems safer to not change this behavior There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I can skip throwing an exception. However, if a variable that isn't a string or There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It turns out this is not backward compatible, see the issue #3993 Have we made the fix to skip PipelineVariable type? |
||
|
||
self.volume_kms_key = ( | ||
resolve_value_from_config( | ||
volume_kms_key, | ||
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, | ||
sagemaker_session=self.sagemaker_session, | ||
) | ||
if ( | ||
instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms) | ||
) | ||
or instance_groups is not None | ||
and all( | ||
[ | ||
instance_supports_kms(instance_group.instance_type) | ||
for instance_group in instance_groups | ||
] | ||
) | ||
else volume_kms_key | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I find ternary hard to read in this context, could you do: if <condition>:
self.volume_kms_key = resolve_value_from_config(...)
else:
self.volume_kms_key = volume_kms_key More importantly, from a security/reasonable expection POV: should we let this code throw if any of the instance in the group does support KMS, but some do not? If we silently remove it as you do here, the result is that there will be instances with EBS volumes out there that aren't encrypted with the default KMS key provided by admins. I don't think that's desirable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a valid concern with respect to instance_groups, lets involve the instance_groups feature team to weigh in on this. |
||
|
||
# VPC configurations | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
import botocore.config | ||
from botocore.exceptions import ClientError | ||
import six | ||
from sagemaker.utils import instance_supports_kms | ||
|
||
import sagemaker.logs | ||
from sagemaker import vpc_utils | ||
|
@@ -820,6 +821,12 @@ def train( # noqa: C901 | |
inferred_resource_config = update_nested_dictionary_with_values_from_config( | ||
resource_config, TRAINING_JOB_RESOURCE_CONFIG_PATH, sagemaker_session=self | ||
) | ||
if ( | ||
"InstanceType" in inferred_resource_config | ||
and not instance_supports_kms(inferred_resource_config["InstanceType"]) | ||
and "VolumeKmsKeyId" in inferred_resource_config | ||
): | ||
del inferred_resource_config["VolumeKmsKeyId"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the user explicitly passed in a volume kms id, this would also skip it (if the instance type does not support kms). Are we sure we want to change the user input in that way? Shouldnt the SDK respect their wishes, even if it will cause the API to fail? Especially for something security related like kms keys There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These instances don't have any notion of KMS, so there's no security problem, AFAIK. If any instance though supports KMS, we always apply the KMS key. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The admin sets the KMS config with the expectation it is used whenever KMS is supported. But if someone uses a instance type like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me add some more context here. The only thing you can encrypt with a custom KMS key is an EBS volume. The issue here is that you cannot attach an EBS volume to these types of instances. See documentation. As a result, passing any KMS key with such instance is nonsensical, the only admissible value is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
AFAIK, intelligent defaults will take the user supplied value (None) over the ones set in Inteliigent defaults. @rubanh correct me if Im wrong. But with the flexibility we want to provide with PySDK, the user need not figure out that the instance doesn't support kms volume and explicitly pass None for such a case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So actually @JGuinegagne is right about this case @mufaddal-rohawala , if the user supplies None, the Defaults Config understands that as "the user didnt supply a value" (because it doesnt know if the None was a explicitly or implicitly provided) And then to followup on my original comment, thanks for linking to documentation. So lets break this down into separate questions.
|
||
train_request = self._get_train_request( | ||
input_mode=input_mode, | ||
input_config=input_config, | ||
|
@@ -3756,8 +3763,12 @@ def create_endpoint_config( | |
) | ||
if tags is not None: | ||
request["Tags"] = tags | ||
kms_key = resolve_value_from_config( | ||
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self | ||
kms_key = ( | ||
resolve_value_from_config( | ||
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self | ||
) | ||
if instance_supports_kms(instance_type) | ||
else kms_key | ||
) | ||
if kms_key is not None: | ||
request["KmsKeyId"] = kms_key | ||
|
@@ -3850,7 +3861,16 @@ def create_endpoint_config_from_existing( | |
|
||
if new_kms_key is not None or existing_endpoint_config_desc.get("KmsKeyId") is not None: | ||
request["KmsKeyId"] = new_kms_key or existing_endpoint_config_desc.get("KmsKeyId") | ||
if KMS_KEY_ID not in request: | ||
|
||
supports_kms = all( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment here please, I would vote for letting it fail in that case. |
||
[ | ||
instance_supports_kms(production_variant["InstanceType"]) | ||
for production_variant in production_variants | ||
if "InstanceType" in production_variant | ||
] | ||
) | ||
|
||
if KMS_KEY_ID not in request and supports_kms: | ||
kms_key_from_config = resolve_value_from_config( | ||
config_path=ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self | ||
) | ||
|
@@ -4471,15 +4491,28 @@ def endpoint_from_production_variants( | |
Returns: | ||
str: The name of the created ``Endpoint``. | ||
""" | ||
|
||
supports_kms = all( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment here please |
||
[ | ||
instance_supports_kms(production_variant["InstanceType"]) | ||
for production_variant in production_variants | ||
if "InstanceType" in production_variant | ||
] | ||
) | ||
|
||
update_list_of_dicts_with_values_from_config( | ||
production_variants, | ||
ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, | ||
required_key_paths=["CoreDumpConfig.DestinationS3Uri"], | ||
sagemaker_session=self, | ||
) | ||
config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants} | ||
kms_key = resolve_value_from_config( | ||
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self | ||
kms_key = ( | ||
resolve_value_from_config( | ||
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self | ||
) | ||
if supports_kms | ||
else kms_key | ||
) | ||
tags = _append_project_tags(tags) | ||
tags = self._append_sagemaker_config_tags( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,11 +25,11 @@ | |
import tarfile | ||
import tempfile | ||
import time | ||
from typing import Any, List, Optional | ||
import json | ||
import abc | ||
import uuid | ||
from datetime import datetime | ||
from typing import List, Optional | ||
|
||
from importlib import import_module | ||
import botocore | ||
|
@@ -41,7 +41,6 @@ | |
from sagemaker.session_settings import SessionSettings | ||
from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string | ||
|
||
|
||
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$" | ||
MAX_BUCKET_PATHS_COUNT = 5 | ||
S3_PREFIX = "s3://" | ||
|
@@ -1066,6 +1065,15 @@ def resolve_value_from_config( | |
Returns: | ||
The value that should be used by the caller | ||
""" | ||
|
||
if ( | ||
sagemaker_session is not None | ||
and sagemaker_session.settings is not None | ||
and sagemaker_session.settings.ignore_intelligent_defaults | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change to Also, have we explored whether passing an empty config to the Session when initializing it is enough to fulfill the usecase here without adding a whole new setting for it? I would lean against adding a new setting unless its necessary. Also, if we do still want to make a setting like this, maybe instead of changing the utils it would be better to skip the config initialization inside Session and LocalSession instead when the setting is off and set an empty config? That way if the user checks session.sagemaker_config, its empty, which might better match and set expectations around values not being plugged in? Or are these Session settings something that we expect users will keep editing during a lifecycle of a session and not just before session initialization? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should leave the config settings as is. This is just a flag that tells SM to ignore the config when it comes time to using it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So Im wondering why we dont want to do something like the following (ignore at initialization rather than when its time to use it). Seems like this could limit potential for confusion (because the lifecycle of the config matches the lifecycle of the Session today). But curious if there's usecases you have in mind that make this approach unideal.
|
||
): | ||
logger.info("Ignoring intelligent defaults. Returning direct input.") | ||
rubanh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return direct_input | ||
rubanh marked this conversation as resolved.
Show resolved
Hide resolved
rubanh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
config_value = ( | ||
get_sagemaker_config_value(sagemaker_session, config_path) if config_path else None | ||
) | ||
|
@@ -1408,3 +1416,47 @@ def update_nested_dictionary_with_values_from_config( | |
"combined value that will be used = {}\n".format(inferred_config_dict), | ||
) | ||
return inferred_config_dict | ||
|
||
|
||
def stringify_object(obj: Any) -> str: | ||
"""Returns string representation of object, returning only non-None fields.""" | ||
non_none_atts = {key: value for key, value in obj.__dict__.items() if value is not None} | ||
return f"{type(obj).__name__}: {str(non_none_atts)}" | ||
|
||
|
||
def volume_size_supported(instance_type: str) -> bool: | ||
"""Returns True if SageMaker allows volume_size to be used for the instance type. | ||
|
||
Raises: | ||
ValueError: If the instance type is improperly formatted. | ||
""" | ||
|
||
try: | ||
|
||
# local mode does not support volume size | ||
if instance_type.startswith("local"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will need to add another check |
||
return False | ||
|
||
parts: List[str] = instance_type.split(".") | ||
|
||
if len(parts) == 3 and parts[0] == "ml": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can use some regex magic to check for instancetype input - https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/image_uris.py#LL266C10-L266C10 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussed offline that the other utility has different specifications and would be large blast radius to merge utilities in this PR. |
||
parts = parts[1:] | ||
|
||
if len(parts) != 2: | ||
raise ValueError(f"Failed to parse instance type '{instance_type}'") | ||
|
||
# Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc) + g5 | ||
# does not support attaching an EBS volume. | ||
family = parts[0] | ||
return "d" not in family and not family.startswith("g5") | ||
except Exception as e: | ||
raise ValueError(f"Failed to parse instance type '{instance_type}': {str(e)}") | ||
|
||
|
||
def instance_supports_kms(instance_type: str) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How are we sure this will stay up to date? If an instance type starts supporting kms keys one day, it seems like we would be breaking the expectations of users and admins Could we add some context here also about where the "source of truth" is for this info? If it will need to be updated manually, it should at least be easy for someone to check whether this is up to date There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rules for supporting KMS are equal to whether the instance supports volume size. Unfortunately, there's no other source of truth, so this will have to be updated manually. However, this function is already used in other parts of the code (JumpStart). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll defer to @mufaddal-rohawala 's judgement on this, but I would say that this function needs to be explicit and clear about (1) why is kms support dependent on volume size and (2) how to check whether the answers given by this function are up to date. Those things would help with maintainability, especially if this is something that needs to be updated manually |
||
"""Returns True if SageMaker allows KMS keys to be attached to the instance. | ||
|
||
Raises: | ||
ValueError: If the instance type is improperly formatted. | ||
""" | ||
return volume_size_supported(instance_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks to me like a good improvement, but please double check with SDK team if that's considered backward incompatible please. The alternative to be to write a different serialization function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree this would be a good improvement, I dont see any negative implications of this.