Skip to content

feature: support inter-container traffic encryption for processing jobs #1431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/sagemaker/model_monitor/model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def create_monitoring_schedule(
network_config_dict = None
if self.network_config is not None:
network_config_dict = self.network_config._to_request_dict()
self._validate_network_config(network_config_dict)

self.sagemaker_session.create_monitoring_schedule(
monitoring_schedule_name=self.monitoring_schedule_name,
Expand Down Expand Up @@ -453,6 +454,7 @@ def update_monitoring_schedule(
network_config_dict = None
if self.network_config is not None:
network_config_dict = self.network_config._to_request_dict()
self._validate_network_config(network_config_dict)

self.sagemaker_session.update_monitoring_schedule(
monitoring_schedule_name=self.monitoring_schedule_name,
Expand Down Expand Up @@ -961,6 +963,29 @@ def _wait_for_schedule_changes_to_apply(self):
if schedule_desc["MonitoringScheduleStatus"] != "Pending":
break

def _validate_network_config(self, network_config_dict):
"""Validates that EnableInterContainerTrafficEncryption is not set in the provided
NetworkConfig request dictionary.

Args:
network_config_dict (dict): NetworkConfig request dictionary.
Contains parameters from :class:`~sagemaker.network.NetworkConfig` object
that configures network isolation, encryption of
inter-container traffic, security group IDs, and subnets.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring is missing the "Args" section

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

"""
if "EnableInterContainerTrafficEncryption" in network_config_dict:
message = (
"EnableInterContainerTrafficEncryption is not supported in Model Monitor. "
"Please ensure that encrypt_inter_container_traffic=None "
"when creating your NetworkConfig object. "
"Current encrypt_inter_container_traffic value: {}".format(
self.network_config.encrypt_inter_container_traffic
)
)
_LOGGER.info(message)
raise ValueError(message)


class DefaultModelMonitor(ModelMonitor):
"""Sets up Amazon SageMaker Monitoring Schedules and baseline suggestions. Use this class when
Expand Down Expand Up @@ -1272,6 +1297,7 @@ def create_monitoring_schedule(
network_config_dict = None
if self.network_config is not None:
network_config_dict = self.network_config._to_request_dict()
super(DefaultModelMonitor, self)._validate_network_config(network_config_dict)

self.sagemaker_session.create_monitoring_schedule(
monitoring_schedule_name=self.monitoring_schedule_name,
Expand Down Expand Up @@ -1429,6 +1455,7 @@ def update_monitoring_schedule(
network_config_dict = None
if self.network_config is not None:
network_config_dict = self.network_config._to_request_dict()
super(DefaultModelMonitor, self)._validate_network_config(network_config_dict)

if role is not None:
self.role = role
Expand Down
16 changes: 15 additions & 1 deletion src/sagemaker/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ class NetworkConfig(object):
"""Accepts network configuration parameters and provides a method to turn these parameters
into a dictionary."""

def __init__(self, enable_network_isolation=False, security_group_ids=None, subnets=None):
def __init__(
self,
enable_network_isolation=False,
security_group_ids=None,
subnets=None,
encrypt_inter_container_traffic=None,
):
"""Initialize a ``NetworkConfig`` instance. NetworkConfig accepts network configuration
parameters and provides a method to turn these parameters into a dictionary.

Expand All @@ -29,15 +35,23 @@ def __init__(self, enable_network_isolation=False, security_group_ids=None, subn
network isolation.
security_group_ids ([str]): A list of strings representing security group IDs.
subnets ([str]): A list of strings representing subnets.
encrypt_inter_container_traffic (bool): Boolean that determines whether to
encrypt inter-container traffic. Default value is None.
"""
self.enable_network_isolation = enable_network_isolation
self.security_group_ids = security_group_ids
self.subnets = subnets
self.encrypt_inter_container_traffic = encrypt_inter_container_traffic

def _to_request_dict(self):
"""Generates a request dictionary using the parameters provided to the class."""
network_config_request = {"EnableNetworkIsolation": self.enable_network_isolation}

if self.encrypt_inter_container_traffic is not None:
network_config_request[
"EnableInterContainerTrafficEncryption"
] = self.encrypt_inter_container_traffic

if self.security_group_ids is not None or self.subnets is not None:
network_config_request["VpcConfig"] = {}

Expand Down
31 changes: 31 additions & 0 deletions tests/integ/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ProcessingJob,
)
from sagemaker.sklearn.processing import SKLearnProcessor
from sagemaker.network import NetworkConfig
from tests.integ import DATA_DIR
from tests.integ.kms_utils import get_or_create_kms_key

Expand Down Expand Up @@ -643,3 +644,33 @@ def test_processor_with_custom_bucket(
assert ROLE in job_description["RoleArn"]

assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}


def test_sklearn_with_network_config(sagemaker_session, sklearn_full_version, cpu_instance_type):
script_path = os.path.join(DATA_DIR, "dummy_script.py")
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")

sklearn_processor = SKLearnProcessor(
framework_version=sklearn_full_version,
role=ROLE,
instance_type=cpu_instance_type,
instance_count=1,
command=["python3"],
sagemaker_session=sagemaker_session,
base_job_name="test-sklearn-with-network-config",
network_config=NetworkConfig(
enable_network_isolation=True, encrypt_inter_container_traffic=True
),
)

sklearn_processor.run(
code=script_path,
inputs=[ProcessingInput(source=input_file_path, destination="/opt/ml/processing/inputs/")],
wait=False,
logs=False,
)

job_description = sklearn_processor.latest_job.describe()
network_config = job_description["NetworkConfig"]
assert network_config["EnableInterContainerTrafficEncryption"]
assert network_config["EnableNetworkIsolation"]
43 changes: 43 additions & 0 deletions tests/unit/sagemaker/monitor/test_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

# from sagemaker.model_monitor import ModelMonitor
from sagemaker.model_monitor import DefaultModelMonitor
from sagemaker.model_monitor import ModelMonitor
from sagemaker.model_monitor import MonitoringOutput

# from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor, ScriptProcessor
# from sagemaker.sklearn.processing import SKLearnProcessor
Expand Down Expand Up @@ -55,6 +57,11 @@

CUSTOM_IMAGE_URI = "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri"

INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG = (
"EnableInterContainerTrafficEncryption is not supported in Model Monitor. Please ensure that "
)
"encrypt_inter_container_traffic=None when creating your NetworkConfig object."


# TODO-reinvent-2019: Continue to flesh these out.
@pytest.fixture()
Expand Down Expand Up @@ -128,3 +135,39 @@ def test_default_model_monitor_suggest_baseline(sagemaker_session):
# processor().run.assert_called_once(
#
# )


def test_default_model_monitor_with_invalid_network_config(sagemaker_session):
invalid_network_config = NetworkConfig(encrypt_inter_container_traffic=False)
my_default_monitor = DefaultModelMonitor(
role=ROLE, sagemaker_session=sagemaker_session, network_config=invalid_network_config
)
with pytest.raises(ValueError) as exception:
my_default_monitor.create_monitoring_schedule(endpoint_input="test_endpoint")
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)

with pytest.raises(ValueError) as exception:
my_default_monitor.update_monitoring_schedule()
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)


def test_model_monitor_with_invalid_network_config(sagemaker_session):
invalid_network_config = NetworkConfig(encrypt_inter_container_traffic=False)
my_model_monitor = ModelMonitor(
role=ROLE,
image_uri=CUSTOM_IMAGE_URI,
sagemaker_session=sagemaker_session,
network_config=invalid_network_config,
)
with pytest.raises(ValueError) as exception:
my_model_monitor.create_monitoring_schedule(
endpoint_input="test_endpoint",
output=MonitoringOutput(
source="/opt/ml/processing/output", destination="/opt/ml/processing/output"
),
)
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)

with pytest.raises(ValueError) as exception:
my_model_monitor.update_monitoring_schedule()
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)
47 changes: 47 additions & 0 deletions tests/unit/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def test_sklearn_with_all_parameters(exists_mock, isfile_mock, ecr_prefix, sagem
subnets=["my_subnet_id"],
security_group_ids=["my_security_group_id"],
enable_network_isolation=True,
encrypt_inter_container_traffic=True,
),
sagemaker_session=sagemaker_session,
)
Expand Down Expand Up @@ -330,6 +331,7 @@ def test_script_processor_with_all_parameters(exists_mock, isfile_mock, sagemake
subnets=["my_subnet_id"],
security_group_ids=["my_security_group_id"],
enable_network_isolation=True,
encrypt_inter_container_traffic=True,
),
sagemaker_session=sagemaker_session,
)
Expand Down Expand Up @@ -386,6 +388,49 @@ def test_processor_with_required_parameters(sagemaker_session):
sagemaker_session.process.assert_called_with(**expected_args)


def test_processor_with_missing_network_config_parameters(sagemaker_session):
processor = Processor(
role=ROLE,
image_uri=CUSTOM_IMAGE_URI,
instance_count=1,
instance_type="ml.m4.xlarge",
sagemaker_session=sagemaker_session,
network_config=NetworkConfig(enable_network_isolation=True),
)

processor.run()

expected_args = _get_expected_args(processor._current_job_name)
del expected_args["app_specification"]["ContainerEntrypoint"]
expected_args["inputs"] = []
expected_args["network_config"] = {"EnableNetworkIsolation": True}

sagemaker_session.process.assert_called_with(**expected_args)


def test_processor_with_encryption_parameter_in_network_config(sagemaker_session):
processor = Processor(
role=ROLE,
image_uri=CUSTOM_IMAGE_URI,
instance_count=1,
instance_type="ml.m4.xlarge",
sagemaker_session=sagemaker_session,
network_config=NetworkConfig(encrypt_inter_container_traffic=False),
)

processor.run()

expected_args = _get_expected_args(processor._current_job_name)
del expected_args["app_specification"]["ContainerEntrypoint"]
expected_args["inputs"] = []
expected_args["network_config"] = {
"EnableNetworkIsolation": False,
"EnableInterContainerTrafficEncryption": False,
}

sagemaker_session.process.assert_called_with(**expected_args)


def test_processor_with_all_parameters(sagemaker_session):
processor = Processor(
role=ROLE,
Expand All @@ -405,6 +450,7 @@ def test_processor_with_all_parameters(sagemaker_session):
subnets=["my_subnet_id"],
security_group_ids=["my_security_group_id"],
enable_network_isolation=True,
encrypt_inter_container_traffic=True,
),
)

Expand Down Expand Up @@ -580,6 +626,7 @@ def _get_expected_args_all_parameters(job_name):
"environment": {"my_env_variable": "my_env_variable_value"},
"network_config": {
"EnableNetworkIsolation": True,
"EnableInterContainerTrafficEncryption": True,
"VpcConfig": {
"SecurityGroupIds": ["my_security_group_id"],
"Subnets": ["my_subnet_id"],
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def test_process(boto_session):
},
"environment": {"my_env_variable": 20},
"network_config": {
"EnableInterContainerTrafficEncryption": True,
"EnableNetworkIsolation": True,
"VpcConfig": {
"SecurityGroupIds": ["my_security_group_id"],
Expand Down Expand Up @@ -219,6 +220,7 @@ def test_process(boto_session):
},
"Environment": {"my_env_variable": 20},
"NetworkConfig": {
"EnableInterContainerTrafficEncryption": True,
"EnableNetworkIsolation": True,
"VpcConfig": {
"SecurityGroupIds": ["my_security_group_id"],
Expand Down