Skip to content

Commit 03bc337

Browse files
authored
feature: support inter container traffic encryption for processing jobs (#1431)
1 parent eda0029 commit 03bc337

File tree

6 files changed

+165
-1
lines changed

6 files changed

+165
-1
lines changed

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def create_monitoring_schedule(
304304
network_config_dict = None
305305
if self.network_config is not None:
306306
network_config_dict = self.network_config._to_request_dict()
307+
self._validate_network_config(network_config_dict)
307308

308309
self.sagemaker_session.create_monitoring_schedule(
309310
monitoring_schedule_name=self.monitoring_schedule_name,
@@ -453,6 +454,7 @@ def update_monitoring_schedule(
453454
network_config_dict = None
454455
if self.network_config is not None:
455456
network_config_dict = self.network_config._to_request_dict()
457+
self._validate_network_config(network_config_dict)
456458

457459
self.sagemaker_session.update_monitoring_schedule(
458460
monitoring_schedule_name=self.monitoring_schedule_name,
@@ -961,6 +963,29 @@ def _wait_for_schedule_changes_to_apply(self):
961963
if schedule_desc["MonitoringScheduleStatus"] != "Pending":
962964
break
963965

966+
def _validate_network_config(self, network_config_dict):
967+
"""Validates that EnableInterContainerTrafficEncryption is not set in the provided
968+
NetworkConfig request dictionary.
969+
970+
Args:
971+
network_config_dict (dict): NetworkConfig request dictionary.
972+
Contains parameters from :class:`~sagemaker.network.NetworkConfig` object
973+
that configures network isolation, encryption of
974+
inter-container traffic, security group IDs, and subnets.
975+
976+
"""
977+
if "EnableInterContainerTrafficEncryption" in network_config_dict:
978+
message = (
979+
"EnableInterContainerTrafficEncryption is not supported in Model Monitor. "
980+
"Please ensure that encrypt_inter_container_traffic=None "
981+
"when creating your NetworkConfig object. "
982+
"Current encrypt_inter_container_traffic value: {}".format(
983+
self.network_config.encrypt_inter_container_traffic
984+
)
985+
)
986+
_LOGGER.info(message)
987+
raise ValueError(message)
988+
964989

965990
class DefaultModelMonitor(ModelMonitor):
966991
"""Sets up Amazon SageMaker Monitoring Schedules and baseline suggestions. Use this class when
@@ -1272,6 +1297,7 @@ def create_monitoring_schedule(
12721297
network_config_dict = None
12731298
if self.network_config is not None:
12741299
network_config_dict = self.network_config._to_request_dict()
1300+
super(DefaultModelMonitor, self)._validate_network_config(network_config_dict)
12751301

12761302
self.sagemaker_session.create_monitoring_schedule(
12771303
monitoring_schedule_name=self.monitoring_schedule_name,
@@ -1429,6 +1455,7 @@ def update_monitoring_schedule(
14291455
network_config_dict = None
14301456
if self.network_config is not None:
14311457
network_config_dict = self.network_config._to_request_dict()
1458+
super(DefaultModelMonitor, self)._validate_network_config(network_config_dict)
14321459

14331460
if role is not None:
14341461
self.role = role

src/sagemaker/network.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@ class NetworkConfig(object):
2020
"""Accepts network configuration parameters and provides a method to turn these parameters
2121
into a dictionary."""
2222

23-
def __init__(self, enable_network_isolation=False, security_group_ids=None, subnets=None):
23+
def __init__(
24+
self,
25+
enable_network_isolation=False,
26+
security_group_ids=None,
27+
subnets=None,
28+
encrypt_inter_container_traffic=None,
29+
):
2430
"""Initialize a ``NetworkConfig`` instance. NetworkConfig accepts network configuration
2531
parameters and provides a method to turn these parameters into a dictionary.
2632
@@ -29,15 +35,23 @@ def __init__(self, enable_network_isolation=False, security_group_ids=None, subn
2935
network isolation.
3036
security_group_ids ([str]): A list of strings representing security group IDs.
3137
subnets ([str]): A list of strings representing subnets.
38+
encrypt_inter_container_traffic (bool): Boolean that determines whether to
39+
encrypt inter-container traffic. Default value is None.
3240
"""
3341
self.enable_network_isolation = enable_network_isolation
3442
self.security_group_ids = security_group_ids
3543
self.subnets = subnets
44+
self.encrypt_inter_container_traffic = encrypt_inter_container_traffic
3645

3746
def _to_request_dict(self):
3847
"""Generates a request dictionary using the parameters provided to the class."""
3948
network_config_request = {"EnableNetworkIsolation": self.enable_network_isolation}
4049

50+
if self.encrypt_inter_container_traffic is not None:
51+
network_config_request[
52+
"EnableInterContainerTrafficEncryption"
53+
] = self.encrypt_inter_container_traffic
54+
4155
if self.security_group_ids is not None or self.subnets is not None:
4256
network_config_request["VpcConfig"] = {}
4357

tests/integ/test_processing.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ProcessingJob,
2828
)
2929
from sagemaker.sklearn.processing import SKLearnProcessor
30+
from sagemaker.network import NetworkConfig
3031
from tests.integ import DATA_DIR
3132
from tests.integ.kms_utils import get_or_create_kms_key
3233

@@ -643,3 +644,33 @@ def test_processor_with_custom_bucket(
643644
assert ROLE in job_description["RoleArn"]
644645

645646
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}
647+
648+
649+
def test_sklearn_with_network_config(sagemaker_session, sklearn_full_version, cpu_instance_type):
650+
script_path = os.path.join(DATA_DIR, "dummy_script.py")
651+
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
652+
653+
sklearn_processor = SKLearnProcessor(
654+
framework_version=sklearn_full_version,
655+
role=ROLE,
656+
instance_type=cpu_instance_type,
657+
instance_count=1,
658+
command=["python3"],
659+
sagemaker_session=sagemaker_session,
660+
base_job_name="test-sklearn-with-network-config",
661+
network_config=NetworkConfig(
662+
enable_network_isolation=True, encrypt_inter_container_traffic=True
663+
),
664+
)
665+
666+
sklearn_processor.run(
667+
code=script_path,
668+
inputs=[ProcessingInput(source=input_file_path, destination="/opt/ml/processing/inputs/")],
669+
wait=False,
670+
logs=False,
671+
)
672+
673+
job_description = sklearn_processor.latest_job.describe()
674+
network_config = job_description["NetworkConfig"]
675+
assert network_config["EnableInterContainerTrafficEncryption"]
676+
assert network_config["EnableNetworkIsolation"]

tests/unit/sagemaker/monitor/test_model_monitoring.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
# from sagemaker.model_monitor import ModelMonitor
2121
from sagemaker.model_monitor import DefaultModelMonitor
22+
from sagemaker.model_monitor import ModelMonitor
23+
from sagemaker.model_monitor import MonitoringOutput
2224

2325
# from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor, ScriptProcessor
2426
# from sagemaker.sklearn.processing import SKLearnProcessor
@@ -55,6 +57,11 @@
5557

5658
CUSTOM_IMAGE_URI = "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri"
5759

60+
INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG = (
61+
"EnableInterContainerTrafficEncryption is not supported in Model Monitor. Please ensure that "
62+
)
63+
"encrypt_inter_container_traffic=None when creating your NetworkConfig object."
64+
5865

5966
# TODO-reinvent-2019: Continue to flesh these out.
6067
@pytest.fixture()
@@ -128,3 +135,39 @@ def test_default_model_monitor_suggest_baseline(sagemaker_session):
128135
# processor().run.assert_called_once(
129136
#
130137
# )
138+
139+
140+
def test_default_model_monitor_with_invalid_network_config(sagemaker_session):
141+
invalid_network_config = NetworkConfig(encrypt_inter_container_traffic=False)
142+
my_default_monitor = DefaultModelMonitor(
143+
role=ROLE, sagemaker_session=sagemaker_session, network_config=invalid_network_config
144+
)
145+
with pytest.raises(ValueError) as exception:
146+
my_default_monitor.create_monitoring_schedule(endpoint_input="test_endpoint")
147+
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)
148+
149+
with pytest.raises(ValueError) as exception:
150+
my_default_monitor.update_monitoring_schedule()
151+
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)
152+
153+
154+
def test_model_monitor_with_invalid_network_config(sagemaker_session):
155+
invalid_network_config = NetworkConfig(encrypt_inter_container_traffic=False)
156+
my_model_monitor = ModelMonitor(
157+
role=ROLE,
158+
image_uri=CUSTOM_IMAGE_URI,
159+
sagemaker_session=sagemaker_session,
160+
network_config=invalid_network_config,
161+
)
162+
with pytest.raises(ValueError) as exception:
163+
my_model_monitor.create_monitoring_schedule(
164+
endpoint_input="test_endpoint",
165+
output=MonitoringOutput(
166+
source="/opt/ml/processing/output", destination="/opt/ml/processing/output"
167+
),
168+
)
169+
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)
170+
171+
with pytest.raises(ValueError) as exception:
172+
my_model_monitor.update_monitoring_schedule()
173+
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)

tests/unit/test_processing.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def test_sklearn_with_all_parameters(exists_mock, isfile_mock, ecr_prefix, sagem
141141
subnets=["my_subnet_id"],
142142
security_group_ids=["my_security_group_id"],
143143
enable_network_isolation=True,
144+
encrypt_inter_container_traffic=True,
144145
),
145146
sagemaker_session=sagemaker_session,
146147
)
@@ -330,6 +331,7 @@ def test_script_processor_with_all_parameters(exists_mock, isfile_mock, sagemake
330331
subnets=["my_subnet_id"],
331332
security_group_ids=["my_security_group_id"],
332333
enable_network_isolation=True,
334+
encrypt_inter_container_traffic=True,
333335
),
334336
sagemaker_session=sagemaker_session,
335337
)
@@ -386,6 +388,49 @@ def test_processor_with_required_parameters(sagemaker_session):
386388
sagemaker_session.process.assert_called_with(**expected_args)
387389

388390

391+
def test_processor_with_missing_network_config_parameters(sagemaker_session):
392+
processor = Processor(
393+
role=ROLE,
394+
image_uri=CUSTOM_IMAGE_URI,
395+
instance_count=1,
396+
instance_type="ml.m4.xlarge",
397+
sagemaker_session=sagemaker_session,
398+
network_config=NetworkConfig(enable_network_isolation=True),
399+
)
400+
401+
processor.run()
402+
403+
expected_args = _get_expected_args(processor._current_job_name)
404+
del expected_args["app_specification"]["ContainerEntrypoint"]
405+
expected_args["inputs"] = []
406+
expected_args["network_config"] = {"EnableNetworkIsolation": True}
407+
408+
sagemaker_session.process.assert_called_with(**expected_args)
409+
410+
411+
def test_processor_with_encryption_parameter_in_network_config(sagemaker_session):
412+
processor = Processor(
413+
role=ROLE,
414+
image_uri=CUSTOM_IMAGE_URI,
415+
instance_count=1,
416+
instance_type="ml.m4.xlarge",
417+
sagemaker_session=sagemaker_session,
418+
network_config=NetworkConfig(encrypt_inter_container_traffic=False),
419+
)
420+
421+
processor.run()
422+
423+
expected_args = _get_expected_args(processor._current_job_name)
424+
del expected_args["app_specification"]["ContainerEntrypoint"]
425+
expected_args["inputs"] = []
426+
expected_args["network_config"] = {
427+
"EnableNetworkIsolation": False,
428+
"EnableInterContainerTrafficEncryption": False,
429+
}
430+
431+
sagemaker_session.process.assert_called_with(**expected_args)
432+
433+
389434
def test_processor_with_all_parameters(sagemaker_session):
390435
processor = Processor(
391436
role=ROLE,
@@ -405,6 +450,7 @@ def test_processor_with_all_parameters(sagemaker_session):
405450
subnets=["my_subnet_id"],
406451
security_group_ids=["my_security_group_id"],
407452
enable_network_isolation=True,
453+
encrypt_inter_container_traffic=True,
408454
),
409455
)
410456

@@ -580,6 +626,7 @@ def _get_expected_args_all_parameters(job_name):
580626
"environment": {"my_env_variable": "my_env_variable_value"},
581627
"network_config": {
582628
"EnableNetworkIsolation": True,
629+
"EnableInterContainerTrafficEncryption": True,
583630
"VpcConfig": {
584631
"SecurityGroupIds": ["my_security_group_id"],
585632
"Subnets": ["my_subnet_id"],

tests/unit/test_session.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def test_process(boto_session):
131131
},
132132
"environment": {"my_env_variable": 20},
133133
"network_config": {
134+
"EnableInterContainerTrafficEncryption": True,
134135
"EnableNetworkIsolation": True,
135136
"VpcConfig": {
136137
"SecurityGroupIds": ["my_security_group_id"],
@@ -219,6 +220,7 @@ def test_process(boto_session):
219220
},
220221
"Environment": {"my_env_variable": 20},
221222
"NetworkConfig": {
223+
"EnableInterContainerTrafficEncryption": True,
222224
"EnableNetworkIsolation": True,
223225
"VpcConfig": {
224226
"SecurityGroupIds": ["my_security_group_id"],

0 commit comments

Comments
 (0)