diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index bd7bf9fe99..fae081d8b6 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -304,6 +304,7 @@ def create_monitoring_schedule( network_config_dict = None if self.network_config is not None: network_config_dict = self.network_config._to_request_dict() + self._validate_network_config(network_config_dict) self.sagemaker_session.create_monitoring_schedule( monitoring_schedule_name=self.monitoring_schedule_name, @@ -453,6 +454,7 @@ def update_monitoring_schedule( network_config_dict = None if self.network_config is not None: network_config_dict = self.network_config._to_request_dict() + self._validate_network_config(network_config_dict) self.sagemaker_session.update_monitoring_schedule( monitoring_schedule_name=self.monitoring_schedule_name, @@ -961,6 +963,29 @@ def _wait_for_schedule_changes_to_apply(self): if schedule_desc["MonitoringScheduleStatus"] != "Pending": break + def _validate_network_config(self, network_config_dict): + """Validates that EnableInterContainerTrafficEncryption is not set in the provided + NetworkConfig request dictionary. + + Args: + network_config_dict (dict): NetworkConfig request dictionary. + Contains parameters from :class:`~sagemaker.network.NetworkConfig` object + that configures network isolation, encryption of + inter-container traffic, security group IDs, and subnets. + + """ + if "EnableInterContainerTrafficEncryption" in network_config_dict: + message = ( + "EnableInterContainerTrafficEncryption is not supported in Model Monitor. " + "Please ensure that encrypt_inter_container_traffic=None " + "when creating your NetworkConfig object. " + "Current encrypt_inter_container_traffic value: {}".format( + self.network_config.encrypt_inter_container_traffic + ) + ) + _LOGGER.info(message) + raise ValueError(message) + class DefaultModelMonitor(ModelMonitor): """Sets up Amazon SageMaker Monitoring Schedules and baseline suggestions. Use this class when @@ -1272,6 +1297,7 @@ def create_monitoring_schedule( network_config_dict = None if self.network_config is not None: network_config_dict = self.network_config._to_request_dict() + super(DefaultModelMonitor, self)._validate_network_config(network_config_dict) self.sagemaker_session.create_monitoring_schedule( monitoring_schedule_name=self.monitoring_schedule_name, @@ -1429,6 +1455,7 @@ def update_monitoring_schedule( network_config_dict = None if self.network_config is not None: network_config_dict = self.network_config._to_request_dict() + super(DefaultModelMonitor, self)._validate_network_config(network_config_dict) if role is not None: self.role = role diff --git a/src/sagemaker/network.py b/src/sagemaker/network.py index 569f87fe6a..e82acd703e 100644 --- a/src/sagemaker/network.py +++ b/src/sagemaker/network.py @@ -20,7 +20,13 @@ class NetworkConfig(object): """Accepts network configuration parameters and provides a method to turn these parameters into a dictionary.""" - def __init__(self, enable_network_isolation=False, security_group_ids=None, subnets=None): + def __init__( + self, + enable_network_isolation=False, + security_group_ids=None, + subnets=None, + encrypt_inter_container_traffic=None, + ): """Initialize a ``NetworkConfig`` instance. NetworkConfig accepts network configuration parameters and provides a method to turn these parameters into a dictionary. @@ -29,15 +35,23 @@ def __init__(self, enable_network_isolation=False, security_group_ids=None, subn network isolation. security_group_ids ([str]): A list of strings representing security group IDs. subnets ([str]): A list of strings representing subnets. + encrypt_inter_container_traffic (bool): Boolean that determines whether to + encrypt inter-container traffic. Default value is None. """ self.enable_network_isolation = enable_network_isolation self.security_group_ids = security_group_ids self.subnets = subnets + self.encrypt_inter_container_traffic = encrypt_inter_container_traffic def _to_request_dict(self): """Generates a request dictionary using the parameters provided to the class.""" network_config_request = {"EnableNetworkIsolation": self.enable_network_isolation} + if self.encrypt_inter_container_traffic is not None: + network_config_request[ + "EnableInterContainerTrafficEncryption" + ] = self.encrypt_inter_container_traffic + if self.security_group_ids is not None or self.subnets is not None: network_config_request["VpcConfig"] = {} diff --git a/tests/integ/test_processing.py b/tests/integ/test_processing.py index 9773b870ec..1456f2a0f7 100644 --- a/tests/integ/test_processing.py +++ b/tests/integ/test_processing.py @@ -27,6 +27,7 @@ ProcessingJob, ) from sagemaker.sklearn.processing import SKLearnProcessor +from sagemaker.network import NetworkConfig from tests.integ import DATA_DIR from tests.integ.kms_utils import get_or_create_kms_key @@ -643,3 +644,33 @@ def test_processor_with_custom_bucket( assert ROLE in job_description["RoleArn"] assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600} + + +def test_sklearn_with_network_config(sagemaker_session, sklearn_full_version, cpu_instance_type): + script_path = os.path.join(DATA_DIR, "dummy_script.py") + input_file_path = os.path.join(DATA_DIR, "dummy_input.txt") + + sklearn_processor = SKLearnProcessor( + framework_version=sklearn_full_version, + role=ROLE, + instance_type=cpu_instance_type, + instance_count=1, + command=["python3"], + sagemaker_session=sagemaker_session, + base_job_name="test-sklearn-with-network-config", + network_config=NetworkConfig( + enable_network_isolation=True, encrypt_inter_container_traffic=True + ), + ) + + sklearn_processor.run( + code=script_path, + inputs=[ProcessingInput(source=input_file_path, destination="/opt/ml/processing/inputs/")], + wait=False, + logs=False, + ) + + job_description = sklearn_processor.latest_job.describe() + network_config = job_description["NetworkConfig"] + assert network_config["EnableInterContainerTrafficEncryption"] + assert network_config["EnableNetworkIsolation"] diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index d064600ca9..4450b0b7e8 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -19,6 +19,8 @@ # from sagemaker.model_monitor import ModelMonitor from sagemaker.model_monitor import DefaultModelMonitor +from sagemaker.model_monitor import ModelMonitor +from sagemaker.model_monitor import MonitoringOutput # from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor, ScriptProcessor # from sagemaker.sklearn.processing import SKLearnProcessor @@ -55,6 +57,11 @@ CUSTOM_IMAGE_URI = "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri" +INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG = ( + "EnableInterContainerTrafficEncryption is not supported in Model Monitor. Please ensure that " +) +"encrypt_inter_container_traffic=None when creating your NetworkConfig object." + # TODO-reinvent-2019: Continue to flesh these out. @pytest.fixture() @@ -128,3 +135,39 @@ def test_default_model_monitor_suggest_baseline(sagemaker_session): # processor().run.assert_called_once( # # ) + + +def test_default_model_monitor_with_invalid_network_config(sagemaker_session): + invalid_network_config = NetworkConfig(encrypt_inter_container_traffic=False) + my_default_monitor = DefaultModelMonitor( + role=ROLE, sagemaker_session=sagemaker_session, network_config=invalid_network_config + ) + with pytest.raises(ValueError) as exception: + my_default_monitor.create_monitoring_schedule(endpoint_input="test_endpoint") + assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value) + + with pytest.raises(ValueError) as exception: + my_default_monitor.update_monitoring_schedule() + assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value) + + +def test_model_monitor_with_invalid_network_config(sagemaker_session): + invalid_network_config = NetworkConfig(encrypt_inter_container_traffic=False) + my_model_monitor = ModelMonitor( + role=ROLE, + image_uri=CUSTOM_IMAGE_URI, + sagemaker_session=sagemaker_session, + network_config=invalid_network_config, + ) + with pytest.raises(ValueError) as exception: + my_model_monitor.create_monitoring_schedule( + endpoint_input="test_endpoint", + output=MonitoringOutput( + source="/opt/ml/processing/output", destination="/opt/ml/processing/output" + ), + ) + assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value) + + with pytest.raises(ValueError) as exception: + my_model_monitor.update_monitoring_schedule() + assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value) diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 672bf85f86..860b664df9 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -141,6 +141,7 @@ def test_sklearn_with_all_parameters(exists_mock, isfile_mock, ecr_prefix, sagem subnets=["my_subnet_id"], security_group_ids=["my_security_group_id"], enable_network_isolation=True, + encrypt_inter_container_traffic=True, ), sagemaker_session=sagemaker_session, ) @@ -330,6 +331,7 @@ def test_script_processor_with_all_parameters(exists_mock, isfile_mock, sagemake subnets=["my_subnet_id"], security_group_ids=["my_security_group_id"], enable_network_isolation=True, + encrypt_inter_container_traffic=True, ), sagemaker_session=sagemaker_session, ) @@ -386,6 +388,49 @@ def test_processor_with_required_parameters(sagemaker_session): sagemaker_session.process.assert_called_with(**expected_args) +def test_processor_with_missing_network_config_parameters(sagemaker_session): + processor = Processor( + role=ROLE, + image_uri=CUSTOM_IMAGE_URI, + instance_count=1, + instance_type="ml.m4.xlarge", + sagemaker_session=sagemaker_session, + network_config=NetworkConfig(enable_network_isolation=True), + ) + + processor.run() + + expected_args = _get_expected_args(processor._current_job_name) + del expected_args["app_specification"]["ContainerEntrypoint"] + expected_args["inputs"] = [] + expected_args["network_config"] = {"EnableNetworkIsolation": True} + + sagemaker_session.process.assert_called_with(**expected_args) + + +def test_processor_with_encryption_parameter_in_network_config(sagemaker_session): + processor = Processor( + role=ROLE, + image_uri=CUSTOM_IMAGE_URI, + instance_count=1, + instance_type="ml.m4.xlarge", + sagemaker_session=sagemaker_session, + network_config=NetworkConfig(encrypt_inter_container_traffic=False), + ) + + processor.run() + + expected_args = _get_expected_args(processor._current_job_name) + del expected_args["app_specification"]["ContainerEntrypoint"] + expected_args["inputs"] = [] + expected_args["network_config"] = { + "EnableNetworkIsolation": False, + "EnableInterContainerTrafficEncryption": False, + } + + sagemaker_session.process.assert_called_with(**expected_args) + + def test_processor_with_all_parameters(sagemaker_session): processor = Processor( role=ROLE, @@ -405,6 +450,7 @@ def test_processor_with_all_parameters(sagemaker_session): subnets=["my_subnet_id"], security_group_ids=["my_security_group_id"], enable_network_isolation=True, + encrypt_inter_container_traffic=True, ), ) @@ -580,6 +626,7 @@ def _get_expected_args_all_parameters(job_name): "environment": {"my_env_variable": "my_env_variable_value"}, "network_config": { "EnableNetworkIsolation": True, + "EnableInterContainerTrafficEncryption": True, "VpcConfig": { "SecurityGroupIds": ["my_security_group_id"], "Subnets": ["my_subnet_id"], diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8f2292e3b2..67fffc118c 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -131,6 +131,7 @@ def test_process(boto_session): }, "environment": {"my_env_variable": 20}, "network_config": { + "EnableInterContainerTrafficEncryption": True, "EnableNetworkIsolation": True, "VpcConfig": { "SecurityGroupIds": ["my_security_group_id"], @@ -219,6 +220,7 @@ def test_process(boto_session): }, "Environment": {"my_env_variable": 20}, "NetworkConfig": { + "EnableInterContainerTrafficEncryption": True, "EnableNetworkIsolation": True, "VpcConfig": { "SecurityGroupIds": ["my_security_group_id"],