-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: intelligent defaults for volume size, JS Estimator image uri region, Predictor str method #3870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: intelligent defaults for volume size, JS Estimator image uri region, Predictor str method #3870
Changes from all commits
3456684
d42cad6
8c379c1
fc4096b
55e93d9
fe4092a
60e1184
da1ad6b
71714b6
b0b6edb
22d27ee
84969d9
84e88c0
b6e5f8c
7709a87
c19fc9d
9f0533f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,6 +63,7 @@ | |
) | ||
from sagemaker.inputs import TrainingInput, FileSystemInput | ||
from sagemaker.instance_group import InstanceGroup | ||
from sagemaker.utils import instance_supports_kms | ||
from sagemaker.job import _Job | ||
from sagemaker.jumpstart.utils import ( | ||
add_jumpstart_tags, | ||
|
@@ -95,6 +96,7 @@ | |
) | ||
from sagemaker.workflow import is_pipeline_variable | ||
from sagemaker.workflow.entities import PipelineVariable | ||
from sagemaker.workflow.parameters import ParameterString | ||
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -599,10 +601,33 @@ def __init__( | |
self.output_kms_key = resolve_value_from_config( | ||
output_kms_key, TRAINING_JOB_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session | ||
) | ||
self.volume_kms_key = resolve_value_from_config( | ||
volume_kms_key, | ||
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, | ||
sagemaker_session=self.sagemaker_session, | ||
if instance_type is None or isinstance(instance_type, str): | ||
instance_type_for_volume_kms = instance_type | ||
elif isinstance(instance_type, ParameterString): | ||
instance_type_for_volume_kms = instance_type.default_value | ||
else: | ||
raise ValueError(f"Bad value for instance type: '{instance_type}'") | ||
Comment on lines
+608
to
+609
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would this be backwards incompatible? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm pretty sure. The unit tests make it seem like it's compatible. There shouldn't be a different python type for the instance type. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Today if a user passes in a different python type, they wont get an error right? (My understanding is that the "Optional[Union[str, PipelineVariable]]" is not strictly enforced by Python). Let me know if im mistaken If Im not wrong about the above, this would be changing behavior of the Estimator class (and I havent seen type assertions like these in other places anyway). Seems safer to not change this behavior There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I can skip throwing an exception. However, if a variable that isn't a string or There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It turns out this is not backward compatible, see the issue #3993 Have we made the fix to skip PipelineVariable type? |
||
|
||
# KMS can only be attached to supported instances | ||
use_volume_kms_config = ( | ||
(instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms)) | ||
or instance_groups is not None | ||
and any( | ||
[ | ||
instance_supports_kms(instance_group.instance_type) | ||
for instance_group in instance_groups | ||
] | ||
) | ||
) | ||
|
||
self.volume_kms_key = ( | ||
resolve_value_from_config( | ||
volume_kms_key, | ||
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, | ||
sagemaker_session=self.sagemaker_session, | ||
) | ||
if use_volume_kms_config | ||
else volume_kms_key | ||
Comment on lines
+623
to
+630
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What are the alternate options to skipping volume kms key default configs that the admin set? If an admin set volume kms keys in the config, is it reasonable for them to expect that it will only be plugged in for only some instance types? Or would they prefer that their users only use instance types that support kms keys? I think its worth looping in Prateek to discuss whether this is is the right approach There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there is a volume KMS key, it will be applied whenever an instance supports it. If an instance does not support it, it is not applied. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's fine to have this conditional logic as we dont want to limit users to use only those instances which support volume kms key. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And I suppose the whole purpose of the Config is to improve user experience. If the customer wants to enforce that KMS keys be used always, the right place for them to strictly enforce that would be through IAM Policies, SCPs, etc. There is no guarantee that populating the Config means those values will be used no matter what. |
||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I find ternary hard to read in this context, could you do: if <condition>:
self.volume_kms_key = resolve_value_from_config(...)
else:
self.volume_kms_key = volume_kms_key More importantly, from a security/reasonable expection POV: should we let this code throw if any of the instance in the group does support KMS, but some do not? If we silently remove it as you do here, the result is that there will be instances with EBS volumes out there that aren't encrypted with the default KMS key provided by admins. I don't think that's desirable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a valid concern with respect to instance_groups, lets involve the instance_groups feature team to weigh in on this. |
||
|
||
# VPC configurations | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,11 +25,11 @@ | |
import tarfile | ||
import tempfile | ||
import time | ||
from typing import Any, List, Optional | ||
import json | ||
import abc | ||
import uuid | ||
from datetime import datetime | ||
from typing import List, Optional | ||
|
||
from importlib import import_module | ||
import botocore | ||
|
@@ -45,7 +45,6 @@ | |
from sagemaker.session_settings import SessionSettings | ||
from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string | ||
|
||
|
||
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$" | ||
MAX_BUCKET_PATHS_COUNT = 5 | ||
S3_PREFIX = "s3://" | ||
|
@@ -1070,6 +1069,7 @@ def resolve_value_from_config( | |
Returns: | ||
The value that should be used by the caller | ||
""" | ||
|
||
config_value = ( | ||
get_sagemaker_config_value(sagemaker_session, config_path) if config_path else None | ||
) | ||
|
@@ -1097,6 +1097,7 @@ def get_sagemaker_config_value(sagemaker_session, key): | |
""" | ||
if not sagemaker_session: | ||
return None | ||
|
||
if sagemaker_session.sagemaker_config: | ||
validate_sagemaker_config(sagemaker_session.sagemaker_config) | ||
config_value = get_config_value(key, sagemaker_session.sagemaker_config) | ||
|
@@ -1362,3 +1363,47 @@ def update_nested_dictionary_with_values_from_config( | |
) | ||
|
||
return inferred_config_dict | ||
|
||
|
||
def stringify_object(obj: Any) -> str: | ||
"""Returns string representation of object, returning only non-None fields.""" | ||
non_none_atts = {key: value for key, value in obj.__dict__.items() if value is not None} | ||
return f"{type(obj).__name__}: {str(non_none_atts)}" | ||
|
||
|
||
def volume_size_supported(instance_type: str) -> bool: | ||
"""Returns True if SageMaker allows volume_size to be used for the instance type. | ||
|
||
Raises: | ||
ValueError: If the instance type is improperly formatted. | ||
""" | ||
|
||
try: | ||
|
||
# local mode does not support volume size | ||
if instance_type.startswith("local"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will need to add another check |
||
return False | ||
|
||
parts: List[str] = instance_type.split(".") | ||
|
||
if len(parts) == 3 and parts[0] == "ml": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can use some regex magic to check for instancetype input - https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/image_uris.py#LL266C10-L266C10 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussed offline that the other utility has different specifications and would be large blast radius to merge utilities in this PR. |
||
parts = parts[1:] | ||
|
||
if len(parts) != 2: | ||
raise ValueError(f"Failed to parse instance type '{instance_type}'") | ||
|
||
# Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc) + g5 | ||
# does not support attaching an EBS volume. | ||
family = parts[0] | ||
return "d" not in family and not family.startswith("g5") | ||
except Exception as e: | ||
raise ValueError(f"Failed to parse instance type '{instance_type}': {str(e)}") | ||
|
||
|
||
def instance_supports_kms(instance_type: str) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How are we sure this will stay up to date? If an instance type starts supporting kms keys one day, it seems like we would be breaking the expectations of users and admins Could we add some context here also about where the "source of truth" is for this info? If it will need to be updated manually, it should at least be easy for someone to check whether this is up to date There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rules for supporting KMS are equal to whether the instance supports volume size. Unfortunately, there's no other source of truth, so this will have to be updated manually. However, this function is already used in other parts of the code (JumpStart). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll defer to @mufaddal-rohawala 's judgement on this, but I would say that this function needs to be explicit and clear about (1) why is kms support dependent on volume size and (2) how to check whether the answers given by this function are up to date. Those things would help with maintainability, especially if this is something that needs to be updated manually |
||
"""Returns True if SageMaker allows KMS keys to be attached to the instance. | ||
|
||
Raises: | ||
ValueError: If the instance type is improperly formatted. | ||
""" | ||
return volume_size_supported(instance_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks to me like a good improvement, but please double check with SDK team if that's considered backward incompatible please. The alternative to be to write a different serialization function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree this would be a good improvement, I dont see any negative implications of this.