From d7e5fe73caa83308ccdd3d4b342dad7e0c18e7fb Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Thu, 19 Sep 2019 15:55:25 -0700 Subject: [PATCH] feature: enable sklearn for network isolation mode --- src/sagemaker/model.py | 5 ++++- src/sagemaker/sklearn/model.py | 7 +++++-- tests/unit/test_model.py | 19 +++++++++++++++++++ tests/unit/test_sklearn.py | 20 ++++++++++++++++++++ 4 files changed, 48 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index d2faf6fa30..15f5de0d91 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -810,7 +810,10 @@ def _framework_env_vars(self): """Placeholder docstring""" if self.uploaded_code: script_name = self.uploaded_code.script_name - dir_name = self.uploaded_code.s3_prefix + if self.enable_network_isolation(): + dir_name = "/opt/ml/model/code" + else: + dir_name = self.uploaded_code.s3_prefix else: script_name = self.entry_point dir_name = "file://" + self.source_dir diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 5d02044302..d18be45770 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -135,10 +135,13 @@ def prepare_container_def(self, instance_type, accelerator_type=None): ) deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) - self._upload_code(deploy_key_prefix) + self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation()) deploy_env = dict(self.env) deploy_env.update(self._framework_env_vars()) if self.model_server_workers: deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) - return sagemaker.container_def(deploy_image, self.model_data, deploy_env) + model_data_uri = ( + self.repacked_model_data if self.enable_network_isolation() else self.model_data + ) + return sagemaker.container_def(deploy_image, model_data_uri, deploy_env) diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index b76b3cf922..5f9fb868fa 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -154,6 +154,25 @@ def test_prepare_container_def(time, sagemaker_session): } +@patch("shutil.rmtree", MagicMock()) +@patch("tarfile.open", MagicMock()) +@patch("os.listdir", MagicMock(return_value=["blah.py"])) +@patch("time.strftime", return_value=TIMESTAMP) +def test_prepare_container_def_with_network_isolation(time, sagemaker_session): + model = DummyFrameworkModel(sagemaker_session, enable_network_isolation=True) + assert model.prepare_container_def(INSTANCE_TYPE) == { + "Environment": { + "SAGEMAKER_PROGRAM": ENTRY_POINT, + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": REGION, + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + }, + "Image": MODEL_IMAGE, + "ModelDataUrl": MODEL_DATA, + } + + @patch("shutil.rmtree", MagicMock()) @patch("tarfile.open", MagicMock()) @patch("os.path.exists", MagicMock(return_value=True)) diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index dd98718884..6acdf40592 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -23,6 +23,7 @@ from sagemaker.sklearn import defaults from sagemaker.sklearn import SKLearn from sagemaker.sklearn import SKLearnPredictor, SKLearnModel +from sagemaker.fw_utils import UploadedCode DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") @@ -168,6 +169,25 @@ def test_create_model(sagemaker_session): assert model_values["Image"] == default_image_uri +@patch("sagemaker.model.FrameworkModel._upload_code") +def test_create_model_with_network_isolation(upload, sagemaker_session): + source_dir = "s3://mybucket/source" + repacked_model_data = "s3://mybucket/prefix/model.tar.gz" + + sklearn_model = SKLearnModel( + model_data=source_dir, + role=ROLE, + sagemaker_session=sagemaker_session, + entry_point=SCRIPT_PATH, + enable_network_isolation=True, + ) + sklearn_model.uploaded_code = UploadedCode(s3_prefix=repacked_model_data, script_name="script") + sklearn_model.repacked_model_data = repacked_model_data + model_values = sklearn_model.prepare_container_def(CPU) + assert model_values["Environment"]["SAGEMAKER_SUBMIT_DIRECTORY"] == "/opt/ml/model/code" + assert model_values["ModelDataUrl"] == repacked_model_data + + def test_create_model_from_estimator(sagemaker_session, sklearn_version): container_log_level = '"logging.INFO"' source_dir = "s3://mybucket/source"