-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: intelligent defaults for volume size, JS Estimator image uri region, Predictor str method #3870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: intelligent defaults for volume size, JS Estimator image uri region, Predictor str method #3870
Changes from 3 commits
3456684
d42cad6
8c379c1
fc4096b
55e93d9
fe4092a
60e1184
da1ad6b
71714b6
b0b6edb
22d27ee
84969d9
84e88c0
b6e5f8c
7709a87
c19fc9d
9f0533f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -608,22 +608,24 @@ def __init__( | |
else: | ||
raise ValueError(f"Bad value for instance type: '{instance_type}'") | ||
|
||
self.volume_kms_key = ( | ||
resolve_value_from_config( | ||
volume_kms_key, | ||
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, | ||
sagemaker_session=self.sagemaker_session, | ||
) | ||
if ( | ||
instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms) | ||
) | ||
use_config_condition = ( | ||
(instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms)) | ||
or instance_groups is not None | ||
and any( | ||
[ | ||
instance_supports_kms(instance_group.instance_type) | ||
for instance_group in instance_groups | ||
] | ||
) | ||
) | ||
|
||
self.volume_kms_key = ( | ||
resolve_value_from_config( | ||
volume_kms_key, | ||
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, | ||
sagemaker_session=self.sagemaker_session, | ||
) | ||
if use_config_condition | ||
else volume_kms_key | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I find ternary hard to read in this context, could you do: if <condition>:
self.volume_kms_key = resolve_value_from_config(...)
else:
self.volume_kms_key = volume_kms_key More importantly, from a security/reasonable expection POV: should we let this code throw if any of the instance in the group does support KMS, but some do not? If we silently remove it as you do here, the result is that there will be instances with EBS volumes out there that aren't encrypted with the default KMS key provided by admins. I don't think that's desirable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a valid concern with respect to instance_groups, lets involve the instance_groups feature team to weigh in on this. |
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -504,6 +504,53 @@ def test_estimator_initialization_with_sagemaker_config_injection_no_kms_support | |
assert estimator.subnets == expected_subnets | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a test case where the instance group has types that support attaching a volume and types that do not support it. |
||
|
||
|
||
def test_estimator_initialization_with_sagemaker_config_injection_partial_kms_support( | ||
sagemaker_session, | ||
): | ||
|
||
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_TRAINING_JOB | ||
|
||
estimator = Estimator( | ||
image_uri="some-image", | ||
instance_groups=[ | ||
InstanceGroup("group1", "ml.p2.xlarge", 1), | ||
InstanceGroup("group2", "ml.g5.2xlarge", 2), | ||
], | ||
sagemaker_session=sagemaker_session, | ||
base_job_name="base_job_name", | ||
) | ||
|
||
expected_volume_kms_key_id = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ | ||
"ResourceConfig" | ||
]["VolumeKmsKeyId"] | ||
expected_role_arn = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["RoleArn"] | ||
expected_kms_key_id = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ | ||
"OutputDataConfig" | ||
]["KmsKeyId"] | ||
expected_subnets = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["VpcConfig"][ | ||
"Subnets" | ||
] | ||
expected_security_groups = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ | ||
"VpcConfig" | ||
]["SecurityGroupIds"] | ||
expected_enable_network_isolation = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][ | ||
"EnableNetworkIsolation" | ||
] | ||
expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"][ | ||
"TrainingJob" | ||
]["EnableInterContainerTrafficEncryption"] | ||
assert estimator.role == expected_role_arn | ||
assert estimator.enable_network_isolation() == expected_enable_network_isolation | ||
assert ( | ||
estimator.encrypt_inter_container_traffic | ||
== expected_enable_inter_container_traffic_encryption | ||
) | ||
assert estimator.output_kms_key == expected_kms_key_id | ||
assert estimator.volume_kms_key == expected_volume_kms_key_id | ||
assert estimator.security_group_ids == expected_security_groups | ||
assert estimator.subnets == expected_subnets | ||
|
||
|
||
def test_framework_with_heterogeneous_cluster(sagemaker_session): | ||
f = DummyFramework( | ||
entry_point=SCRIPT_PATH, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2870,6 +2870,121 @@ def test_create_endpoint_config_from_existing_with_sagemaker_config_injection( | |
) | ||
|
||
|
||
def test_create_endpoint_config_from_existing_with_sagemaker_config_injection_partial_kms_support( | ||
sagemaker_session, | ||
): | ||
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG | ||
|
||
pvs = [ | ||
sagemaker.production_variant("A", "ml.g5.2xlarge"), | ||
sagemaker.production_variant("B", "ml.p2.xlarge"), | ||
sagemaker.production_variant("C", "ml.p2.xlarge"), | ||
] | ||
# Add DestinationS3Uri to only one production variant | ||
pvs[0]["CoreDumpConfig"] = {"DestinationS3Uri": "s3://test"} | ||
existing_endpoint_arn = "arn:aws:sagemaker:us-west-2:123412341234:endpoint-config/foo" | ||
existing_endpoint_name = "foo" | ||
new_endpoint_name = "new-foo" | ||
sagemaker_session.sagemaker_client.describe_endpoint_config.return_value = { | ||
"ProductionVariants": [sagemaker.production_variant("A", "ml.m4.xlarge")], | ||
"EndpointConfigArn": existing_endpoint_arn, | ||
"AsyncInferenceConfig": {}, | ||
} | ||
sagemaker_session.sagemaker_client.list_tags.return_value = {"Tags": []} | ||
|
||
sagemaker_session.create_endpoint_config_from_existing( | ||
existing_endpoint_name, new_endpoint_name, new_production_variants=pvs | ||
) | ||
|
||
expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ | ||
"EndpointConfig" | ||
]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"] | ||
expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][ | ||
"AsyncInferenceConfig" | ||
]["OutputConfig"]["KmsKeyId"] | ||
expected_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. optional: how about renaming this variable |
||
"KmsKeyId" | ||
] | ||
expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"] | ||
|
||
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( | ||
EndpointConfigName=new_endpoint_name, | ||
ProductionVariants=[ | ||
{ | ||
"CoreDumpConfig": { | ||
"KmsKeyId": expected_production_variant_0_kms_key_id, | ||
"DestinationS3Uri": pvs[0]["CoreDumpConfig"]["DestinationS3Uri"], | ||
}, | ||
**sagemaker.production_variant("A", "ml.g5.2xlarge"), | ||
}, | ||
{ | ||
# Merge shouldn't happen because input for this index doesn't have DestinationS3Uri | ||
**sagemaker.production_variant("B", "ml.p2.xlarge"), | ||
}, | ||
sagemaker.production_variant("C", "ml.p2.xlarge"), | ||
], | ||
KmsKeyId=expected_kms_key_id, # from config | ||
Tags=expected_tags, # from config | ||
AsyncInferenceConfig={"OutputConfig": {"KmsKeyId": expected_inference_kms_key_id}}, | ||
) | ||
|
||
|
||
def test_create_endpoint_config_from_existing_with_sagemaker_config_injection_no_kms_support( | ||
sagemaker_session, | ||
): | ||
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG | ||
|
||
pvs = [ | ||
sagemaker.production_variant("A", "ml.g5.2xlarge"), | ||
sagemaker.production_variant("B", "ml.g5.xlarge"), | ||
sagemaker.production_variant("C", "ml.g5.xlarge"), | ||
] | ||
# Add DestinationS3Uri to only one production variant | ||
pvs[0]["CoreDumpConfig"] = {"DestinationS3Uri": "s3://test"} | ||
existing_endpoint_arn = "arn:aws:sagemaker:us-west-2:123412341234:endpoint-config/foo" | ||
existing_endpoint_name = "foo" | ||
new_endpoint_name = "new-foo" | ||
sagemaker_session.sagemaker_client.describe_endpoint_config.return_value = { | ||
"ProductionVariants": [sagemaker.production_variant("A", "ml.m4.xlarge")], | ||
"EndpointConfigArn": existing_endpoint_arn, | ||
"AsyncInferenceConfig": {}, | ||
} | ||
sagemaker_session.sagemaker_client.list_tags.return_value = {"Tags": []} | ||
|
||
sagemaker_session.create_endpoint_config_from_existing( | ||
existing_endpoint_name, new_endpoint_name, new_production_variants=pvs | ||
) | ||
|
||
expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ | ||
"EndpointConfig" | ||
]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"] | ||
expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][ | ||
"AsyncInferenceConfig" | ||
]["OutputConfig"]["KmsKeyId"] | ||
|
||
expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"] | ||
|
||
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( | ||
EndpointConfigName=new_endpoint_name, | ||
ProductionVariants=[ | ||
{ | ||
"CoreDumpConfig": { | ||
"KmsKeyId": expected_production_variant_0_kms_key_id, | ||
"DestinationS3Uri": pvs[0]["CoreDumpConfig"]["DestinationS3Uri"], | ||
}, | ||
**sagemaker.production_variant("A", "ml.g5.2xlarge"), | ||
}, | ||
{ | ||
# Merge shouldn't happen because input for this index doesn't have DestinationS3Uri | ||
**sagemaker.production_variant("B", "ml.g5.xlarge"), | ||
}, | ||
sagemaker.production_variant("C", "ml.g5.xlarge"), | ||
], | ||
Tags=expected_tags, # from config | ||
AsyncInferenceConfig={"OutputConfig": {"KmsKeyId": expected_inference_kms_key_id}}, | ||
) | ||
|
||
|
||
def test_endpoint_from_production_variants_with_sagemaker_config_injection( | ||
sagemaker_session, | ||
): | ||
|
@@ -2932,6 +3047,127 @@ def test_endpoint_from_production_variants_with_sagemaker_config_injection( | |
) | ||
|
||
|
||
def test_endpoint_from_production_variants_with_sagemaker_config_injection_partial_kms_support( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for adding these tests. Optional but I would prefer one or more comments for each test that make it clear what is being verified. Based on the title alone I can imagine it being difficult for someone reading through later to find whats different in the contents of one test vs another. Also, do we need to verify all the configs for these kms_support tests? I wonder if we could get rid of all the asserts that dont relate to the thing being tested? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
sagemaker_session, | ||
): | ||
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG | ||
|
||
sagemaker_session.sagemaker_client.describe_endpoint = Mock( | ||
return_value={"EndpointStatus": "InService"} | ||
) | ||
pvs = [ | ||
sagemaker.production_variant("A", "ml.g5.xlarge"), | ||
sagemaker.production_variant("B", "ml.p2.xlarge"), | ||
sagemaker.production_variant("C", "ml.p2.xlarge"), | ||
] | ||
# Add DestinationS3Uri to only one production variant | ||
pvs[0]["CoreDumpConfig"] = {"DestinationS3Uri": "s3://test"} | ||
sagemaker_session.endpoint_from_production_variants( | ||
"some-endpoint", | ||
pvs, | ||
data_capture_config_dict={}, | ||
async_inference_config_dict=AsyncInferenceConfig()._to_request_dict(), | ||
) | ||
expected_data_capture_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ | ||
"EndpointConfig" | ||
]["DataCaptureConfig"]["KmsKeyId"] | ||
expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][ | ||
"AsyncInferenceConfig" | ||
]["OutputConfig"]["KmsKeyId"] | ||
expected_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][ | ||
"KmsKeyId" | ||
] | ||
expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"] | ||
|
||
expected_async_inference_config_dict = AsyncInferenceConfig()._to_request_dict() | ||
expected_async_inference_config_dict["OutputConfig"]["KmsKeyId"] = expected_inference_kms_key_id | ||
expected_pvs = [ | ||
sagemaker.production_variant("A", "ml.g5.xlarge"), | ||
sagemaker.production_variant("B", "ml.p2.xlarge"), | ||
sagemaker.production_variant("C", "ml.p2.xlarge"), | ||
] | ||
# Add DestinationS3Uri, KmsKeyId to only one production variant | ||
expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment please |
||
"EndpointConfig" | ||
]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"] | ||
expected_pvs[0]["CoreDumpConfig"] = { | ||
"DestinationS3Uri": "s3://test", | ||
"KmsKeyId": expected_production_variant_0_kms_key_id, | ||
} | ||
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( | ||
EndpointConfigName="some-endpoint", | ||
ProductionVariants=expected_pvs, | ||
Tags=expected_tags, # from config | ||
KmsKeyId=expected_kms_key_id, # from config | ||
AsyncInferenceConfig=expected_async_inference_config_dict, | ||
DataCaptureConfig={"KmsKeyId": expected_data_capture_kms_key_id}, | ||
) | ||
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with( | ||
EndpointConfigName="some-endpoint", | ||
EndpointName="some-endpoint", | ||
Tags=expected_tags, # from config | ||
) | ||
|
||
|
||
def test_endpoint_from_production_variants_with_sagemaker_config_injection_no_kms_support( | ||
sagemaker_session, | ||
): | ||
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG | ||
|
||
sagemaker_session.sagemaker_client.describe_endpoint = Mock( | ||
return_value={"EndpointStatus": "InService"} | ||
) | ||
pvs = [ | ||
sagemaker.production_variant("A", "ml.g5.xlarge"), | ||
sagemaker.production_variant("B", "ml.g5.xlarge"), | ||
sagemaker.production_variant("C", "ml.g5.xlarge"), | ||
] | ||
# Add DestinationS3Uri to only one production variant | ||
pvs[0]["CoreDumpConfig"] = {"DestinationS3Uri": "s3://test"} | ||
sagemaker_session.endpoint_from_production_variants( | ||
"some-endpoint", | ||
pvs, | ||
data_capture_config_dict={}, | ||
async_inference_config_dict=AsyncInferenceConfig()._to_request_dict(), | ||
) | ||
expected_data_capture_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ | ||
"EndpointConfig" | ||
]["DataCaptureConfig"]["KmsKeyId"] | ||
expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][ | ||
"AsyncInferenceConfig" | ||
]["OutputConfig"]["KmsKeyId"] | ||
|
||
expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"] | ||
|
||
expected_async_inference_config_dict = AsyncInferenceConfig()._to_request_dict() | ||
expected_async_inference_config_dict["OutputConfig"]["KmsKeyId"] = expected_inference_kms_key_id | ||
expected_pvs = [ | ||
sagemaker.production_variant("A", "ml.g5.xlarge"), | ||
sagemaker.production_variant("B", "ml.g5.xlarge"), | ||
sagemaker.production_variant("C", "ml.g5.xlarge"), | ||
] | ||
# Add DestinationS3Uri, KmsKeyId to only one production variant | ||
expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][ | ||
"EndpointConfig" | ||
]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"] | ||
expected_pvs[0]["CoreDumpConfig"] = { | ||
"DestinationS3Uri": "s3://test", | ||
"KmsKeyId": expected_production_variant_0_kms_key_id, | ||
} | ||
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( | ||
EndpointConfigName="some-endpoint", | ||
ProductionVariants=expected_pvs, | ||
Tags=expected_tags, # from config | ||
AsyncInferenceConfig=expected_async_inference_config_dict, | ||
DataCaptureConfig={"KmsKeyId": expected_data_capture_kms_key_id}, | ||
) | ||
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with( | ||
EndpointConfigName="some-endpoint", | ||
EndpointName="some-endpoint", | ||
Tags=expected_tags, # from config | ||
) | ||
|
||
|
||
def test_create_endpoint_config_with_tags(sagemaker_session): | ||
tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for making the change, could you add an inline comment here to explain what you are doing?
Most readers won't be familiar with
instance_group
and the details of which hardware supports attaching an EBS volume that can be encrypted and which do not.Also nit: could you rename the variable
use_volume_kms_config
please?