Skip to content

Commit fc4096b

Browse files
committed
chore: address PR comments
1 parent 8c379c1 commit fc4096b

File tree

4 files changed

+297
-9
lines changed

4 files changed

+297
-9
lines changed

src/sagemaker/estimator.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -608,22 +608,24 @@ def __init__(
608608
else:
609609
raise ValueError(f"Bad value for instance type: '{instance_type}'")
610610

611-
self.volume_kms_key = (
612-
resolve_value_from_config(
613-
volume_kms_key,
614-
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH,
615-
sagemaker_session=self.sagemaker_session,
616-
)
617-
if (
618-
instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms)
619-
)
611+
use_config_condition = (
612+
(instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms))
620613
or instance_groups is not None
621614
and any(
622615
[
623616
instance_supports_kms(instance_group.instance_type)
624617
for instance_group in instance_groups
625618
]
626619
)
620+
)
621+
622+
self.volume_kms_key = (
623+
resolve_value_from_config(
624+
volume_kms_key,
625+
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH,
626+
sagemaker_session=self.sagemaker_session,
627+
)
628+
if use_config_condition
627629
else volume_kms_key
628630
)
629631

tests/unit/test_estimator.py

+47
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,53 @@ def test_estimator_initialization_with_sagemaker_config_injection_no_kms_support
497497
assert estimator.subnets == expected_subnets
498498

499499

500+
def test_estimator_initialization_with_sagemaker_config_injection_partial_kms_support(
501+
sagemaker_session,
502+
):
503+
504+
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_TRAINING_JOB
505+
506+
estimator = Estimator(
507+
image_uri="some-image",
508+
instance_groups=[
509+
InstanceGroup("group1", "ml.p2.xlarge", 1),
510+
InstanceGroup("group2", "ml.g5.2xlarge", 2),
511+
],
512+
sagemaker_session=sagemaker_session,
513+
base_job_name="base_job_name",
514+
)
515+
516+
expected_volume_kms_key_id = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][
517+
"ResourceConfig"
518+
]["VolumeKmsKeyId"]
519+
expected_role_arn = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["RoleArn"]
520+
expected_kms_key_id = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][
521+
"OutputDataConfig"
522+
]["KmsKeyId"]
523+
expected_subnets = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["VpcConfig"][
524+
"Subnets"
525+
]
526+
expected_security_groups = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][
527+
"VpcConfig"
528+
]["SecurityGroupIds"]
529+
expected_enable_network_isolation = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][
530+
"EnableNetworkIsolation"
531+
]
532+
expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"][
533+
"TrainingJob"
534+
]["EnableInterContainerTrafficEncryption"]
535+
assert estimator.role == expected_role_arn
536+
assert estimator.enable_network_isolation() == expected_enable_network_isolation
537+
assert (
538+
estimator.encrypt_inter_container_traffic
539+
== expected_enable_inter_container_traffic_encryption
540+
)
541+
assert estimator.output_kms_key == expected_kms_key_id
542+
assert estimator.volume_kms_key == expected_volume_kms_key_id
543+
assert estimator.security_group_ids == expected_security_groups
544+
assert estimator.subnets == expected_subnets
545+
546+
500547
def test_framework_with_heterogeneous_cluster(sagemaker_session):
501548
f = DummyFramework(
502549
entry_point=SCRIPT_PATH,

tests/unit/test_session.py

+236
Original file line numberDiff line numberDiff line change
@@ -2786,6 +2786,121 @@ def test_create_endpoint_config_from_existing_with_sagemaker_config_injection(
27862786
)
27872787

27882788

2789+
def test_create_endpoint_config_from_existing_with_sagemaker_config_injection_partial_kms_support(
2790+
sagemaker_session,
2791+
):
2792+
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG
2793+
2794+
pvs = [
2795+
sagemaker.production_variant("A", "ml.g5.2xlarge"),
2796+
sagemaker.production_variant("B", "ml.p2.xlarge"),
2797+
sagemaker.production_variant("C", "ml.p2.xlarge"),
2798+
]
2799+
# Add DestinationS3Uri to only one production variant
2800+
pvs[0]["CoreDumpConfig"] = {"DestinationS3Uri": "s3://test"}
2801+
existing_endpoint_arn = "arn:aws:sagemaker:us-west-2:123412341234:endpoint-config/foo"
2802+
existing_endpoint_name = "foo"
2803+
new_endpoint_name = "new-foo"
2804+
sagemaker_session.sagemaker_client.describe_endpoint_config.return_value = {
2805+
"ProductionVariants": [sagemaker.production_variant("A", "ml.m4.xlarge")],
2806+
"EndpointConfigArn": existing_endpoint_arn,
2807+
"AsyncInferenceConfig": {},
2808+
}
2809+
sagemaker_session.sagemaker_client.list_tags.return_value = {"Tags": []}
2810+
2811+
sagemaker_session.create_endpoint_config_from_existing(
2812+
existing_endpoint_name, new_endpoint_name, new_production_variants=pvs
2813+
)
2814+
2815+
expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][
2816+
"EndpointConfig"
2817+
]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"]
2818+
expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][
2819+
"AsyncInferenceConfig"
2820+
]["OutputConfig"]["KmsKeyId"]
2821+
expected_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][
2822+
"KmsKeyId"
2823+
]
2824+
expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"]
2825+
2826+
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
2827+
EndpointConfigName=new_endpoint_name,
2828+
ProductionVariants=[
2829+
{
2830+
"CoreDumpConfig": {
2831+
"KmsKeyId": expected_production_variant_0_kms_key_id,
2832+
"DestinationS3Uri": pvs[0]["CoreDumpConfig"]["DestinationS3Uri"],
2833+
},
2834+
**sagemaker.production_variant("A", "ml.g5.2xlarge"),
2835+
},
2836+
{
2837+
# Merge shouldn't happen because input for this index doesn't have DestinationS3Uri
2838+
**sagemaker.production_variant("B", "ml.p2.xlarge"),
2839+
},
2840+
sagemaker.production_variant("C", "ml.p2.xlarge"),
2841+
],
2842+
KmsKeyId=expected_kms_key_id, # from config
2843+
Tags=expected_tags, # from config
2844+
AsyncInferenceConfig={"OutputConfig": {"KmsKeyId": expected_inference_kms_key_id}},
2845+
)
2846+
2847+
2848+
def test_create_endpoint_config_from_existing_with_sagemaker_config_injection_no_kms_support(
2849+
sagemaker_session,
2850+
):
2851+
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG
2852+
2853+
pvs = [
2854+
sagemaker.production_variant("A", "ml.g5.2xlarge"),
2855+
sagemaker.production_variant("B", "ml.g5.xlarge"),
2856+
sagemaker.production_variant("C", "ml.g5.xlarge"),
2857+
]
2858+
# Add DestinationS3Uri to only one production variant
2859+
pvs[0]["CoreDumpConfig"] = {"DestinationS3Uri": "s3://test"}
2860+
existing_endpoint_arn = "arn:aws:sagemaker:us-west-2:123412341234:endpoint-config/foo"
2861+
existing_endpoint_name = "foo"
2862+
new_endpoint_name = "new-foo"
2863+
sagemaker_session.sagemaker_client.describe_endpoint_config.return_value = {
2864+
"ProductionVariants": [sagemaker.production_variant("A", "ml.m4.xlarge")],
2865+
"EndpointConfigArn": existing_endpoint_arn,
2866+
"AsyncInferenceConfig": {},
2867+
}
2868+
sagemaker_session.sagemaker_client.list_tags.return_value = {"Tags": []}
2869+
2870+
sagemaker_session.create_endpoint_config_from_existing(
2871+
existing_endpoint_name, new_endpoint_name, new_production_variants=pvs
2872+
)
2873+
2874+
expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][
2875+
"EndpointConfig"
2876+
]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"]
2877+
expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][
2878+
"AsyncInferenceConfig"
2879+
]["OutputConfig"]["KmsKeyId"]
2880+
2881+
expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"]
2882+
2883+
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
2884+
EndpointConfigName=new_endpoint_name,
2885+
ProductionVariants=[
2886+
{
2887+
"CoreDumpConfig": {
2888+
"KmsKeyId": expected_production_variant_0_kms_key_id,
2889+
"DestinationS3Uri": pvs[0]["CoreDumpConfig"]["DestinationS3Uri"],
2890+
},
2891+
**sagemaker.production_variant("A", "ml.g5.2xlarge"),
2892+
},
2893+
{
2894+
# Merge shouldn't happen because input for this index doesn't have DestinationS3Uri
2895+
**sagemaker.production_variant("B", "ml.g5.xlarge"),
2896+
},
2897+
sagemaker.production_variant("C", "ml.g5.xlarge"),
2898+
],
2899+
Tags=expected_tags, # from config
2900+
AsyncInferenceConfig={"OutputConfig": {"KmsKeyId": expected_inference_kms_key_id}},
2901+
)
2902+
2903+
27892904
def test_endpoint_from_production_variants_with_sagemaker_config_injection(
27902905
sagemaker_session,
27912906
):
@@ -2848,6 +2963,127 @@ def test_endpoint_from_production_variants_with_sagemaker_config_injection(
28482963
)
28492964

28502965

2966+
def test_endpoint_from_production_variants_with_sagemaker_config_injection_partial_kms_support(
2967+
sagemaker_session,
2968+
):
2969+
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG
2970+
2971+
sagemaker_session.sagemaker_client.describe_endpoint = Mock(
2972+
return_value={"EndpointStatus": "InService"}
2973+
)
2974+
pvs = [
2975+
sagemaker.production_variant("A", "ml.g5.xlarge"),
2976+
sagemaker.production_variant("B", "ml.p2.xlarge"),
2977+
sagemaker.production_variant("C", "ml.p2.xlarge"),
2978+
]
2979+
# Add DestinationS3Uri to only one production variant
2980+
pvs[0]["CoreDumpConfig"] = {"DestinationS3Uri": "s3://test"}
2981+
sagemaker_session.endpoint_from_production_variants(
2982+
"some-endpoint",
2983+
pvs,
2984+
data_capture_config_dict={},
2985+
async_inference_config_dict=AsyncInferenceConfig()._to_request_dict(),
2986+
)
2987+
expected_data_capture_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][
2988+
"EndpointConfig"
2989+
]["DataCaptureConfig"]["KmsKeyId"]
2990+
expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][
2991+
"AsyncInferenceConfig"
2992+
]["OutputConfig"]["KmsKeyId"]
2993+
expected_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][
2994+
"KmsKeyId"
2995+
]
2996+
expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"]
2997+
2998+
expected_async_inference_config_dict = AsyncInferenceConfig()._to_request_dict()
2999+
expected_async_inference_config_dict["OutputConfig"]["KmsKeyId"] = expected_inference_kms_key_id
3000+
expected_pvs = [
3001+
sagemaker.production_variant("A", "ml.g5.xlarge"),
3002+
sagemaker.production_variant("B", "ml.p2.xlarge"),
3003+
sagemaker.production_variant("C", "ml.p2.xlarge"),
3004+
]
3005+
# Add DestinationS3Uri, KmsKeyId to only one production variant
3006+
expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][
3007+
"EndpointConfig"
3008+
]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"]
3009+
expected_pvs[0]["CoreDumpConfig"] = {
3010+
"DestinationS3Uri": "s3://test",
3011+
"KmsKeyId": expected_production_variant_0_kms_key_id,
3012+
}
3013+
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
3014+
EndpointConfigName="some-endpoint",
3015+
ProductionVariants=expected_pvs,
3016+
Tags=expected_tags, # from config
3017+
KmsKeyId=expected_kms_key_id, # from config
3018+
AsyncInferenceConfig=expected_async_inference_config_dict,
3019+
DataCaptureConfig={"KmsKeyId": expected_data_capture_kms_key_id},
3020+
)
3021+
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(
3022+
EndpointConfigName="some-endpoint",
3023+
EndpointName="some-endpoint",
3024+
Tags=expected_tags, # from config
3025+
)
3026+
3027+
3028+
def test_endpoint_from_production_variants_with_sagemaker_config_injection_no_kms_support(
3029+
sagemaker_session,
3030+
):
3031+
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG
3032+
3033+
sagemaker_session.sagemaker_client.describe_endpoint = Mock(
3034+
return_value={"EndpointStatus": "InService"}
3035+
)
3036+
pvs = [
3037+
sagemaker.production_variant("A", "ml.g5.xlarge"),
3038+
sagemaker.production_variant("B", "ml.g5.xlarge"),
3039+
sagemaker.production_variant("C", "ml.g5.xlarge"),
3040+
]
3041+
# Add DestinationS3Uri to only one production variant
3042+
pvs[0]["CoreDumpConfig"] = {"DestinationS3Uri": "s3://test"}
3043+
sagemaker_session.endpoint_from_production_variants(
3044+
"some-endpoint",
3045+
pvs,
3046+
data_capture_config_dict={},
3047+
async_inference_config_dict=AsyncInferenceConfig()._to_request_dict(),
3048+
)
3049+
expected_data_capture_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][
3050+
"EndpointConfig"
3051+
]["DataCaptureConfig"]["KmsKeyId"]
3052+
expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][
3053+
"AsyncInferenceConfig"
3054+
]["OutputConfig"]["KmsKeyId"]
3055+
3056+
expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"]
3057+
3058+
expected_async_inference_config_dict = AsyncInferenceConfig()._to_request_dict()
3059+
expected_async_inference_config_dict["OutputConfig"]["KmsKeyId"] = expected_inference_kms_key_id
3060+
expected_pvs = [
3061+
sagemaker.production_variant("A", "ml.g5.xlarge"),
3062+
sagemaker.production_variant("B", "ml.g5.xlarge"),
3063+
sagemaker.production_variant("C", "ml.g5.xlarge"),
3064+
]
3065+
# Add DestinationS3Uri, KmsKeyId to only one production variant
3066+
expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][
3067+
"EndpointConfig"
3068+
]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"]
3069+
expected_pvs[0]["CoreDumpConfig"] = {
3070+
"DestinationS3Uri": "s3://test",
3071+
"KmsKeyId": expected_production_variant_0_kms_key_id,
3072+
}
3073+
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
3074+
EndpointConfigName="some-endpoint",
3075+
ProductionVariants=expected_pvs,
3076+
Tags=expected_tags, # from config
3077+
AsyncInferenceConfig=expected_async_inference_config_dict,
3078+
DataCaptureConfig={"KmsKeyId": expected_data_capture_kms_key_id},
3079+
)
3080+
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(
3081+
EndpointConfigName="some-endpoint",
3082+
EndpointName="some-endpoint",
3083+
Tags=expected_tags, # from config
3084+
)
3085+
3086+
28513087
def test_create_endpoint_config_with_tags(sagemaker_session):
28523088
tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}]
28533089

tests/unit/test_transformer.py

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import pytest
1616
from mock import MagicMock, Mock, patch, PropertyMock
17+
from sagemaker.session_settings import SessionSettings
1718

1819
from sagemaker.transformer import _TransformJob, Transformer
1920
from sagemaker.workflow.pipeline_context import PipelineSession, _PipelineConfig
@@ -106,6 +107,8 @@ def transformer(sagemaker_session):
106107
def test_transform_with_sagemaker_config_injection(start_new_job, sagemaker_session):
107108
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_TRANSFORM_JOB
108109

110+
sagemaker_session.settings = SessionSettings()
111+
109112
transformer = Transformer(
110113
MODEL_NAME,
111114
INSTANCE_COUNT,

0 commit comments

Comments
 (0)