Skip to content

fix: intelligent defaults for volume size, JS Estimator image uri region, Predictor str method #3870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,22 +608,24 @@ def __init__(
else:
raise ValueError(f"Bad value for instance type: '{instance_type}'")

self.volume_kms_key = (
resolve_value_from_config(
volume_kms_key,
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH,
sagemaker_session=self.sagemaker_session,
)
if (
instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms)
)
use_config_condition = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for making the change, could you add an inline comment here to explain what you are doing?

Most readers won't be familiar with instance_group and the details of which hardware supports attaching an EBS volume that can be encrypted and which do not.

Also nit: could you rename the variable use_volume_kms_config please?

(instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms))
or instance_groups is not None
and any(
[
instance_supports_kms(instance_group.instance_type)
for instance_group in instance_groups
]
)
)

self.volume_kms_key = (
resolve_value_from_config(
volume_kms_key,
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH,
sagemaker_session=self.sagemaker_session,
)
if use_config_condition
else volume_kms_key
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I find ternary hard to read in this context, could you do:

if <condition>:
  self.volume_kms_key = resolve_value_from_config(...)
else:
  self.volume_kms_key = volume_kms_key

More importantly, from a security/reasonable expection POV: should we let this code throw if any of the instance in the group does support KMS, but some do not?

If we silently remove it as you do here, the result is that there will be instances with EBS volumes out there that aren't encrypted with the default KMS key provided by admins. I don't think that's desirable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a valid concern with respect to instance_groups, lets involve the instance_groups feature team to weigh in on this.


Expand Down
3 changes: 2 additions & 1 deletion tests/unit/sagemaker/remote_function/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from mock import patch, Mock, ANY

from sagemaker.config import load_sagemaker_config
from sagemaker.session_settings import SessionSettings
from tests.unit import DATA_DIR
from sagemaker.remote_function.job import (
_JobSettings,
Expand Down Expand Up @@ -98,7 +99,7 @@ def mock_session():
session = Mock()
session.sagemaker_client.create_training_job.return_value = {"TrainingJobArn": TRAINING_JOB_ARN}
session.sagemaker_client.describe_training_job.return_value = COMPLETED_TRAINING_JOB

session.settings = SessionSettings()
session.default_bucket.return_value = BUCKET
session.expand_role.return_value = ROLE_ARN
session.boto_region_name = TEST_REGION
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/workflow/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mock import Mock, patch

from sagemaker import s3
from sagemaker.session_settings import SessionSettings
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.conditions import ConditionEquals
from sagemaker.workflow.execution_variables import ExecutionVariables
Expand Down Expand Up @@ -79,6 +80,7 @@ def test_pipeline_create_and_update_with_config_injection(sagemaker_session_mock
sagemaker_session_mock.sagemaker_client.describe_pipeline.return_value = {
"PipelineArn": "pipeline-arn"
}
sagemaker_session_mock.settings = SessionSettings()
pipeline = Pipeline(
name="MyPipeline",
parameters=[],
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,53 @@ def test_estimator_initialization_with_sagemaker_config_injection_no_kms_support
assert estimator.subnets == expected_subnets
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a test case where the instance group has types that support attaching a volume and types that do not support it.



def test_estimator_initialization_with_sagemaker_config_injection_partial_kms_support(
sagemaker_session,
):

sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_TRAINING_JOB

estimator = Estimator(
image_uri="some-image",
instance_groups=[
InstanceGroup("group1", "ml.p2.xlarge", 1),
InstanceGroup("group2", "ml.g5.2xlarge", 2),
],
sagemaker_session=sagemaker_session,
base_job_name="base_job_name",
)

expected_volume_kms_key_id = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][
"ResourceConfig"
]["VolumeKmsKeyId"]
expected_role_arn = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["RoleArn"]
expected_kms_key_id = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][
"OutputDataConfig"
]["KmsKeyId"]
expected_subnets = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"]["VpcConfig"][
"Subnets"
]
expected_security_groups = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][
"VpcConfig"
]["SecurityGroupIds"]
expected_enable_network_isolation = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"]["TrainingJob"][
"EnableNetworkIsolation"
]
expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_TRAINING_JOB["SageMaker"][
"TrainingJob"
]["EnableInterContainerTrafficEncryption"]
assert estimator.role == expected_role_arn
assert estimator.enable_network_isolation() == expected_enable_network_isolation
assert (
estimator.encrypt_inter_container_traffic
== expected_enable_inter_container_traffic_encryption
)
assert estimator.output_kms_key == expected_kms_key_id
assert estimator.volume_kms_key == expected_volume_kms_key_id
assert estimator.security_group_ids == expected_security_groups
assert estimator.subnets == expected_subnets


def test_framework_with_heterogeneous_cluster(sagemaker_session):
f = DummyFramework(
entry_point=SCRIPT_PATH,
Expand Down
236 changes: 236 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2870,6 +2870,121 @@ def test_create_endpoint_config_from_existing_with_sagemaker_config_injection(
)


def test_create_endpoint_config_from_existing_with_sagemaker_config_injection_partial_kms_support(
sagemaker_session,
):
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG

pvs = [
sagemaker.production_variant("A", "ml.g5.2xlarge"),
sagemaker.production_variant("B", "ml.p2.xlarge"),
sagemaker.production_variant("C", "ml.p2.xlarge"),
]
# Add DestinationS3Uri to only one production variant
pvs[0]["CoreDumpConfig"] = {"DestinationS3Uri": "s3://test"}
existing_endpoint_arn = "arn:aws:sagemaker:us-west-2:123412341234:endpoint-config/foo"
existing_endpoint_name = "foo"
new_endpoint_name = "new-foo"
sagemaker_session.sagemaker_client.describe_endpoint_config.return_value = {
"ProductionVariants": [sagemaker.production_variant("A", "ml.m4.xlarge")],
"EndpointConfigArn": existing_endpoint_arn,
"AsyncInferenceConfig": {},
}
sagemaker_session.sagemaker_client.list_tags.return_value = {"Tags": []}

sagemaker_session.create_endpoint_config_from_existing(
existing_endpoint_name, new_endpoint_name, new_production_variants=pvs
)

expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][
"EndpointConfig"
]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"]
expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][
"AsyncInferenceConfig"
]["OutputConfig"]["KmsKeyId"]
expected_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional: how about renaming this variable expected_volume_kms_key_id for better readability?

"KmsKeyId"
]
expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"]

sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName=new_endpoint_name,
ProductionVariants=[
{
"CoreDumpConfig": {
"KmsKeyId": expected_production_variant_0_kms_key_id,
"DestinationS3Uri": pvs[0]["CoreDumpConfig"]["DestinationS3Uri"],
},
**sagemaker.production_variant("A", "ml.g5.2xlarge"),
},
{
# Merge shouldn't happen because input for this index doesn't have DestinationS3Uri
**sagemaker.production_variant("B", "ml.p2.xlarge"),
},
sagemaker.production_variant("C", "ml.p2.xlarge"),
],
KmsKeyId=expected_kms_key_id, # from config
Tags=expected_tags, # from config
AsyncInferenceConfig={"OutputConfig": {"KmsKeyId": expected_inference_kms_key_id}},
)


def test_create_endpoint_config_from_existing_with_sagemaker_config_injection_no_kms_support(
sagemaker_session,
):
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG

pvs = [
sagemaker.production_variant("A", "ml.g5.2xlarge"),
sagemaker.production_variant("B", "ml.g5.xlarge"),
sagemaker.production_variant("C", "ml.g5.xlarge"),
]
# Add DestinationS3Uri to only one production variant
pvs[0]["CoreDumpConfig"] = {"DestinationS3Uri": "s3://test"}
existing_endpoint_arn = "arn:aws:sagemaker:us-west-2:123412341234:endpoint-config/foo"
existing_endpoint_name = "foo"
new_endpoint_name = "new-foo"
sagemaker_session.sagemaker_client.describe_endpoint_config.return_value = {
"ProductionVariants": [sagemaker.production_variant("A", "ml.m4.xlarge")],
"EndpointConfigArn": existing_endpoint_arn,
"AsyncInferenceConfig": {},
}
sagemaker_session.sagemaker_client.list_tags.return_value = {"Tags": []}

sagemaker_session.create_endpoint_config_from_existing(
existing_endpoint_name, new_endpoint_name, new_production_variants=pvs
)

expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][
"EndpointConfig"
]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"]
expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][
"AsyncInferenceConfig"
]["OutputConfig"]["KmsKeyId"]

expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"]

sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName=new_endpoint_name,
ProductionVariants=[
{
"CoreDumpConfig": {
"KmsKeyId": expected_production_variant_0_kms_key_id,
"DestinationS3Uri": pvs[0]["CoreDumpConfig"]["DestinationS3Uri"],
},
**sagemaker.production_variant("A", "ml.g5.2xlarge"),
},
{
# Merge shouldn't happen because input for this index doesn't have DestinationS3Uri
**sagemaker.production_variant("B", "ml.g5.xlarge"),
},
sagemaker.production_variant("C", "ml.g5.xlarge"),
],
Tags=expected_tags, # from config
AsyncInferenceConfig={"OutputConfig": {"KmsKeyId": expected_inference_kms_key_id}},
)


def test_endpoint_from_production_variants_with_sagemaker_config_injection(
sagemaker_session,
):
Expand Down Expand Up @@ -2932,6 +3047,127 @@ def test_endpoint_from_production_variants_with_sagemaker_config_injection(
)


def test_endpoint_from_production_variants_with_sagemaker_config_injection_partial_kms_support(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding these tests. Optional but I would prefer one or more comments for each test that make it clear what is being verified. Based on the title alone I can imagine it being difficult for someone reading through later to find whats different in the contents of one test vs another.

Also, do we need to verify all the configs for these kms_support tests? I wonder if we could get rid of all the asserts that dont relate to the thing being tested?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

sagemaker_session,
):
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG

sagemaker_session.sagemaker_client.describe_endpoint = Mock(
return_value={"EndpointStatus": "InService"}
)
pvs = [
sagemaker.production_variant("A", "ml.g5.xlarge"),
sagemaker.production_variant("B", "ml.p2.xlarge"),
sagemaker.production_variant("C", "ml.p2.xlarge"),
]
# Add DestinationS3Uri to only one production variant
pvs[0]["CoreDumpConfig"] = {"DestinationS3Uri": "s3://test"}
sagemaker_session.endpoint_from_production_variants(
"some-endpoint",
pvs,
data_capture_config_dict={},
async_inference_config_dict=AsyncInferenceConfig()._to_request_dict(),
)
expected_data_capture_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][
"EndpointConfig"
]["DataCaptureConfig"]["KmsKeyId"]
expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][
"AsyncInferenceConfig"
]["OutputConfig"]["KmsKeyId"]
expected_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][
"KmsKeyId"
]
expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"]

expected_async_inference_config_dict = AsyncInferenceConfig()._to_request_dict()
expected_async_inference_config_dict["OutputConfig"]["KmsKeyId"] = expected_inference_kms_key_id
expected_pvs = [
sagemaker.production_variant("A", "ml.g5.xlarge"),
sagemaker.production_variant("B", "ml.p2.xlarge"),
sagemaker.production_variant("C", "ml.p2.xlarge"),
]
# Add DestinationS3Uri, KmsKeyId to only one production variant
expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment please

"EndpointConfig"
]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"]
expected_pvs[0]["CoreDumpConfig"] = {
"DestinationS3Uri": "s3://test",
"KmsKeyId": expected_production_variant_0_kms_key_id,
}
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName="some-endpoint",
ProductionVariants=expected_pvs,
Tags=expected_tags, # from config
KmsKeyId=expected_kms_key_id, # from config
AsyncInferenceConfig=expected_async_inference_config_dict,
DataCaptureConfig={"KmsKeyId": expected_data_capture_kms_key_id},
)
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(
EndpointConfigName="some-endpoint",
EndpointName="some-endpoint",
Tags=expected_tags, # from config
)


def test_endpoint_from_production_variants_with_sagemaker_config_injection_no_kms_support(
sagemaker_session,
):
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_CONFIG

sagemaker_session.sagemaker_client.describe_endpoint = Mock(
return_value={"EndpointStatus": "InService"}
)
pvs = [
sagemaker.production_variant("A", "ml.g5.xlarge"),
sagemaker.production_variant("B", "ml.g5.xlarge"),
sagemaker.production_variant("C", "ml.g5.xlarge"),
]
# Add DestinationS3Uri to only one production variant
pvs[0]["CoreDumpConfig"] = {"DestinationS3Uri": "s3://test"}
sagemaker_session.endpoint_from_production_variants(
"some-endpoint",
pvs,
data_capture_config_dict={},
async_inference_config_dict=AsyncInferenceConfig()._to_request_dict(),
)
expected_data_capture_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][
"EndpointConfig"
]["DataCaptureConfig"]["KmsKeyId"]
expected_inference_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"][
"AsyncInferenceConfig"
]["OutputConfig"]["KmsKeyId"]

expected_tags = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"]["EndpointConfig"]["Tags"]

expected_async_inference_config_dict = AsyncInferenceConfig()._to_request_dict()
expected_async_inference_config_dict["OutputConfig"]["KmsKeyId"] = expected_inference_kms_key_id
expected_pvs = [
sagemaker.production_variant("A", "ml.g5.xlarge"),
sagemaker.production_variant("B", "ml.g5.xlarge"),
sagemaker.production_variant("C", "ml.g5.xlarge"),
]
# Add DestinationS3Uri, KmsKeyId to only one production variant
expected_production_variant_0_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_CONFIG["SageMaker"][
"EndpointConfig"
]["ProductionVariants"][0]["CoreDumpConfig"]["KmsKeyId"]
expected_pvs[0]["CoreDumpConfig"] = {
"DestinationS3Uri": "s3://test",
"KmsKeyId": expected_production_variant_0_kms_key_id,
}
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName="some-endpoint",
ProductionVariants=expected_pvs,
Tags=expected_tags, # from config
AsyncInferenceConfig=expected_async_inference_config_dict,
DataCaptureConfig={"KmsKeyId": expected_data_capture_kms_key_id},
)
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(
EndpointConfigName="some-endpoint",
EndpointName="some-endpoint",
Tags=expected_tags, # from config
)


def test_create_endpoint_config_with_tags(sagemaker_session):
tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}]

Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pytest
from mock import MagicMock, Mock, patch, PropertyMock
from sagemaker.session_settings import SessionSettings

from sagemaker.transformer import _TransformJob, Transformer
from sagemaker.workflow.pipeline_context import PipelineSession, _PipelineConfig
Expand Down Expand Up @@ -111,6 +112,8 @@ def transformer(sagemaker_session):
def test_transform_with_sagemaker_config_injection(start_new_job, sagemaker_session):
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_TRANSFORM_JOB

sagemaker_session.settings = SessionSettings()

transformer = Transformer(
MODEL_NAME,
INSTANCE_COUNT,
Expand Down