Skip to content

Commit 3456684

Browse files
committed
fix: intelligent defaults for volume size, JS Estimator image uri region, Predictor str method
1 parent 0f37695 commit 3456684

17 files changed

+484
-159
lines changed

src/sagemaker/base_predictor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
NumpySerializer,
4848
)
4949
from sagemaker.session import production_variant, Session
50-
from sagemaker.utils import name_from_base
50+
from sagemaker.utils import name_from_base, stringify_object
5151

5252
from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME
5353

@@ -75,6 +75,10 @@ def content_type(self) -> str:
7575
def accept(self) -> Tuple[str]:
7676
"""The content type(s) that are expected from the inference server."""
7777

78+
def __str__(self) -> str:
79+
"""Overriding str(*) method to make more human-readable."""
80+
return stringify_object(self)
81+
7882

7983
class Predictor(PredictorBase):
8084
"""Make prediction requests to an Amazon SageMaker endpoint."""

src/sagemaker/estimator.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
)
6464
from sagemaker.inputs import TrainingInput, FileSystemInput
6565
from sagemaker.instance_group import InstanceGroup
66+
from sagemaker.utils import instance_supports_kms
6667
from sagemaker.job import _Job
6768
from sagemaker.jumpstart.utils import (
6869
add_jumpstart_tags,
@@ -95,6 +96,7 @@
9596
)
9697
from sagemaker.workflow import is_pipeline_variable
9798
from sagemaker.workflow.entities import PipelineVariable
99+
from sagemaker.workflow.parameters import ParameterString
98100
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
99101

100102
logger = logging.getLogger(__name__)
@@ -599,10 +601,30 @@ def __init__(
599601
self.output_kms_key = resolve_value_from_config(
600602
output_kms_key, TRAINING_JOB_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
601603
)
602-
self.volume_kms_key = resolve_value_from_config(
603-
volume_kms_key,
604-
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH,
605-
sagemaker_session=self.sagemaker_session,
604+
if instance_type is None or isinstance(instance_type, str):
605+
instance_type_for_volume_kms = instance_type
606+
elif isinstance(instance_type, ParameterString):
607+
instance_type_for_volume_kms = instance_type.default_value
608+
else:
609+
raise ValueError(f"Bad value for instance type: '{instance_type}'")
610+
611+
self.volume_kms_key = (
612+
resolve_value_from_config(
613+
volume_kms_key,
614+
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH,
615+
sagemaker_session=self.sagemaker_session,
616+
)
617+
if (
618+
instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms)
619+
)
620+
or instance_groups is not None
621+
and all(
622+
[
623+
instance_supports_kms(instance_group.instance_type)
624+
for instance_group in instance_groups
625+
]
626+
)
627+
else volume_kms_key
606628
)
607629

608630
# VPC configurations

src/sagemaker/instance_types.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -119,22 +119,3 @@ def retrieve(
119119
tolerate_vulnerable_model,
120120
tolerate_deprecated_model,
121121
)
122-
123-
124-
def volume_size_supported(instance_type: str) -> bool:
125-
"""Returns True if SageMaker allows volume_size to be used for the instance type.
126-
127-
Raises:
128-
ValueError: If the instance type is improperly formatted.
129-
"""
130-
try:
131-
parts: List[str] = instance_type.split(".")
132-
if len(parts) != 3 or parts[0] != "ml":
133-
raise ValueError("Instance type must have 2 periods and start with 'ml'.")
134-
135-
# Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc) + g5
136-
# does not support attaching an EBS volume.
137-
family = parts[1]
138-
return "d" not in family and not family.startswith("g5")
139-
except Exception as e:
140-
raise ValueError(f"Failed to parse instance type '{instance_type}': {str(e)}")

src/sagemaker/jumpstart/artifacts/kwargs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515
from copy import deepcopy
1616
from typing import Optional
17-
from sagemaker.instance_types import volume_size_supported
17+
from sagemaker.utils import volume_size_supported
1818
from sagemaker.jumpstart.constants import (
1919
JUMPSTART_DEFAULT_REGION_NAME,
2020
)

src/sagemaker/jumpstart/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
is_valid_model_id,
3737
resolve_model_intelligent_default_field,
3838
)
39-
from sagemaker.jumpstart.utils import stringify_object
39+
from sagemaker.utils import stringify_object
4040
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
4141
from sagemaker.predictor import PredictorBase
4242

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
425425
"""Sets image uri in kwargs based on default or override, returns full kwargs."""
426426

427427
kwargs.image_uri = kwargs.image_uri or image_uris.retrieve(
428-
region=None,
428+
region=kwargs.region,
429429
framework=None,
430430
image_scope=JumpStartScriptScope.TRAINING,
431431
model_id=kwargs.model_id,

src/sagemaker/jumpstart/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
get_init_kwargs,
2929
)
3030
from sagemaker.jumpstart.utils import is_valid_model_id
31-
from sagemaker.jumpstart.utils import stringify_object
31+
from sagemaker.utils import stringify_object
3232
from sagemaker.model import Model
3333
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
3434
from sagemaker.predictor import PredictorBase

src/sagemaker/jumpstart/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -530,12 +530,6 @@ def resolve_estimator_intelligent_default_field(
530530
return field_val
531531

532532

533-
def stringify_object(obj: Any) -> str:
534-
"""Returns string representation of object, returning only non-None fields."""
535-
non_none_atts = {key: value for key, value in obj.__dict__.items() if value is not None}
536-
return f"{type(obj).__name__}: {str(non_none_atts)}"
537-
538-
539533
def is_valid_model_id(
540534
model_id: Optional[str],
541535
region: Optional[str] = None,

src/sagemaker/pipeline.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sagemaker.transformer import Transformer
3535
from sagemaker.workflow.entities import PipelineVariable
3636
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
37+
from sagemaker.utils import instance_supports_kms
3738

3839

3940
class PipelineModel(object):
@@ -235,8 +236,12 @@ def deploy(
235236
container_startup_health_check_timeout=container_startup_health_check_timeout,
236237
)
237238
self.endpoint_name = endpoint_name or self.name
238-
kms_key = resolve_value_from_config(
239-
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
239+
kms_key = (
240+
resolve_value_from_config(
241+
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
242+
)
243+
if instance_supports_kms(instance_type)
244+
else kms_key
240245
)
241246

242247
data_capture_config_dict = None

src/sagemaker/session.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import botocore.config
3030
from botocore.exceptions import ClientError
3131
import six
32+
from sagemaker.utils import instance_supports_kms
3233

3334
import sagemaker.logs
3435
from sagemaker import vpc_utils
@@ -820,6 +821,12 @@ def train( # noqa: C901
820821
inferred_resource_config = update_nested_dictionary_with_values_from_config(
821822
resource_config, TRAINING_JOB_RESOURCE_CONFIG_PATH, sagemaker_session=self
822823
)
824+
if (
825+
"InstanceType" in inferred_resource_config
826+
and not instance_supports_kms(inferred_resource_config["InstanceType"])
827+
and "VolumeKmsKeyId" in inferred_resource_config
828+
):
829+
del inferred_resource_config["VolumeKmsKeyId"]
823830
train_request = self._get_train_request(
824831
input_mode=input_mode,
825832
input_config=input_config,
@@ -3756,8 +3763,12 @@ def create_endpoint_config(
37563763
)
37573764
if tags is not None:
37583765
request["Tags"] = tags
3759-
kms_key = resolve_value_from_config(
3760-
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self
3766+
kms_key = (
3767+
resolve_value_from_config(
3768+
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self
3769+
)
3770+
if instance_supports_kms(instance_type)
3771+
else kms_key
37613772
)
37623773
if kms_key is not None:
37633774
request["KmsKeyId"] = kms_key
@@ -3850,7 +3861,16 @@ def create_endpoint_config_from_existing(
38503861

38513862
if new_kms_key is not None or existing_endpoint_config_desc.get("KmsKeyId") is not None:
38523863
request["KmsKeyId"] = new_kms_key or existing_endpoint_config_desc.get("KmsKeyId")
3853-
if KMS_KEY_ID not in request:
3864+
3865+
supports_kms = all(
3866+
[
3867+
instance_supports_kms(production_variant["InstanceType"])
3868+
for production_variant in production_variants
3869+
if "InstanceType" in production_variant
3870+
]
3871+
)
3872+
3873+
if KMS_KEY_ID not in request and supports_kms:
38543874
kms_key_from_config = resolve_value_from_config(
38553875
config_path=ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self
38563876
)
@@ -4471,15 +4491,28 @@ def endpoint_from_production_variants(
44714491
Returns:
44724492
str: The name of the created ``Endpoint``.
44734493
"""
4494+
4495+
supports_kms = all(
4496+
[
4497+
instance_supports_kms(production_variant["InstanceType"])
4498+
for production_variant in production_variants
4499+
if "InstanceType" in production_variant
4500+
]
4501+
)
4502+
44744503
update_list_of_dicts_with_values_from_config(
44754504
production_variants,
44764505
ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH,
44774506
required_key_paths=["CoreDumpConfig.DestinationS3Uri"],
44784507
sagemaker_session=self,
44794508
)
44804509
config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants}
4481-
kms_key = resolve_value_from_config(
4482-
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self
4510+
kms_key = (
4511+
resolve_value_from_config(
4512+
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self
4513+
)
4514+
if supports_kms
4515+
else kms_key
44834516
)
44844517
tags = _append_project_tags(tags)
44854518
tags = self._append_sagemaker_config_tags(

src/sagemaker/session_settings.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
class SessionSettings(object):
1919
"""Optional container class for settings to apply to a SageMaker session."""
2020

21-
def __init__(self, encrypt_repacked_artifacts=True, local_download_dir=None) -> None:
21+
def __init__(
22+
self,
23+
encrypt_repacked_artifacts=True,
24+
local_download_dir=None,
25+
ignore_intelligent_defaults=False,
26+
) -> None:
2227
"""Initialize the ``SessionSettings`` of a SageMaker ``Session``.
2328
2429
Args:
@@ -27,9 +32,12 @@ def __init__(self, encrypt_repacked_artifacts=True, local_download_dir=None) ->
2732
is not provided (Default: True).
2833
local_download_dir (str): Optional. A path specifying the local directory
2934
for downloading artifacts. (Default: None).
35+
ignore_intelligent_defaults (bool): Optional. Flag to indicate whether to ignore
36+
intelligent default settings. (Default: False).
3037
"""
3138
self._encrypt_repacked_artifacts = encrypt_repacked_artifacts
3239
self._local_download_dir = local_download_dir
40+
self._ignore_intelligent_defaults = ignore_intelligent_defaults
3341

3442
@property
3543
def encrypt_repacked_artifacts(self) -> bool:
@@ -40,3 +48,8 @@ def encrypt_repacked_artifacts(self) -> bool:
4048
def local_download_dir(self) -> str:
4149
"""Return path specifying the local directory for downloading artifacts."""
4250
return self._local_download_dir
51+
52+
@property
53+
def ignore_intelligent_defaults(self) -> bool:
54+
"""Return boolean for whether intelligent defaults should be ignored."""
55+
return self._ignore_intelligent_defaults

src/sagemaker/utils.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
import tarfile
2626
import tempfile
2727
import time
28+
from typing import Any, List, Optional
2829
import json
2930
import abc
3031
import uuid
3132
from datetime import datetime
32-
from typing import List, Optional
3333

3434
from importlib import import_module
3535
import botocore
@@ -41,7 +41,6 @@
4141
from sagemaker.session_settings import SessionSettings
4242
from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string
4343

44-
4544
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
4645
MAX_BUCKET_PATHS_COUNT = 5
4746
S3_PREFIX = "s3://"
@@ -1066,6 +1065,15 @@ def resolve_value_from_config(
10661065
Returns:
10671066
The value that should be used by the caller
10681067
"""
1068+
1069+
if (
1070+
sagemaker_session is not None
1071+
and sagemaker_session.settings is not None
1072+
and sagemaker_session.settings.ignore_intelligent_defaults
1073+
):
1074+
logger.info("Ignoring intelligent defaults. Returning direct input.")
1075+
return direct_input
1076+
10691077
config_value = (
10701078
get_sagemaker_config_value(sagemaker_session, config_path) if config_path else None
10711079
)
@@ -1408,3 +1416,47 @@ def update_nested_dictionary_with_values_from_config(
14081416
"combined value that will be used = {}\n".format(inferred_config_dict),
14091417
)
14101418
return inferred_config_dict
1419+
1420+
1421+
def stringify_object(obj: Any) -> str:
1422+
"""Returns string representation of object, returning only non-None fields."""
1423+
non_none_atts = {key: value for key, value in obj.__dict__.items() if value is not None}
1424+
return f"{type(obj).__name__}: {str(non_none_atts)}"
1425+
1426+
1427+
def volume_size_supported(instance_type: str) -> bool:
1428+
"""Returns True if SageMaker allows volume_size to be used for the instance type.
1429+
1430+
Raises:
1431+
ValueError: If the instance type is improperly formatted.
1432+
"""
1433+
1434+
try:
1435+
1436+
# local mode does not support volume size
1437+
if instance_type.startswith("local"):
1438+
return False
1439+
1440+
parts: List[str] = instance_type.split(".")
1441+
1442+
if len(parts) == 3 and parts[0] == "ml":
1443+
parts = parts[1:]
1444+
1445+
if len(parts) != 2:
1446+
raise ValueError(f"Failed to parse instance type '{instance_type}'")
1447+
1448+
# Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc) + g5
1449+
# does not support attaching an EBS volume.
1450+
family = parts[0]
1451+
return "d" not in family and not family.startswith("g5")
1452+
except Exception as e:
1453+
raise ValueError(f"Failed to parse instance type '{instance_type}': {str(e)}")
1454+
1455+
1456+
def instance_supports_kms(instance_type: str) -> bool:
1457+
"""Returns True if SageMaker allows KMS keys to be attached to the instance.
1458+
1459+
Raises:
1460+
ValueError: If the instance type is improperly formatted.
1461+
"""
1462+
return volume_size_supported(instance_type)

0 commit comments

Comments
 (0)