Skip to content

Commit 308f121

Browse files
icywang86ruilaurenyu
authored andcommitted
feature: enable sklearn for network isolation mode (#1051)
1 parent 7dbb149 commit 308f121

File tree

4 files changed

+48
-3
lines changed

4 files changed

+48
-3
lines changed

src/sagemaker/model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,10 @@ def _framework_env_vars(self):
810810
"""Placeholder docstring"""
811811
if self.uploaded_code:
812812
script_name = self.uploaded_code.script_name
813-
dir_name = self.uploaded_code.s3_prefix
813+
if self.enable_network_isolation():
814+
dir_name = "/opt/ml/model/code"
815+
else:
816+
dir_name = self.uploaded_code.s3_prefix
814817
else:
815818
script_name = self.entry_point
816819
dir_name = "file://" + self.source_dir

src/sagemaker/sklearn/model.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,13 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
135135
)
136136

137137
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
138-
self._upload_code(deploy_key_prefix)
138+
self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation())
139139
deploy_env = dict(self.env)
140140
deploy_env.update(self._framework_env_vars())
141141

142142
if self.model_server_workers:
143143
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
144-
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
144+
model_data_uri = (
145+
self.repacked_model_data if self.enable_network_isolation() else self.model_data
146+
)
147+
return sagemaker.container_def(deploy_image, model_data_uri, deploy_env)

tests/unit/test_model.py

+19
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,25 @@ def test_prepare_container_def(time, sagemaker_session):
154154
}
155155

156156

157+
@patch("shutil.rmtree", MagicMock())
158+
@patch("tarfile.open", MagicMock())
159+
@patch("os.listdir", MagicMock(return_value=["blah.py"]))
160+
@patch("time.strftime", return_value=TIMESTAMP)
161+
def test_prepare_container_def_with_network_isolation(time, sagemaker_session):
162+
model = DummyFrameworkModel(sagemaker_session, enable_network_isolation=True)
163+
assert model.prepare_container_def(INSTANCE_TYPE) == {
164+
"Environment": {
165+
"SAGEMAKER_PROGRAM": ENTRY_POINT,
166+
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
167+
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
168+
"SAGEMAKER_REGION": REGION,
169+
"SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false",
170+
},
171+
"Image": MODEL_IMAGE,
172+
"ModelDataUrl": MODEL_DATA,
173+
}
174+
175+
157176
@patch("shutil.rmtree", MagicMock())
158177
@patch("tarfile.open", MagicMock())
159178
@patch("os.path.exists", MagicMock(return_value=True))

tests/unit/test_sklearn.py

+20
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sagemaker.sklearn import defaults
2424
from sagemaker.sklearn import SKLearn
2525
from sagemaker.sklearn import SKLearnPredictor, SKLearnModel
26+
from sagemaker.fw_utils import UploadedCode
2627

2728
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
2829
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
@@ -168,6 +169,25 @@ def test_create_model(sagemaker_session):
168169
assert model_values["Image"] == default_image_uri
169170

170171

172+
@patch("sagemaker.model.FrameworkModel._upload_code")
173+
def test_create_model_with_network_isolation(upload, sagemaker_session):
174+
source_dir = "s3://mybucket/source"
175+
repacked_model_data = "s3://mybucket/prefix/model.tar.gz"
176+
177+
sklearn_model = SKLearnModel(
178+
model_data=source_dir,
179+
role=ROLE,
180+
sagemaker_session=sagemaker_session,
181+
entry_point=SCRIPT_PATH,
182+
enable_network_isolation=True,
183+
)
184+
sklearn_model.uploaded_code = UploadedCode(s3_prefix=repacked_model_data, script_name="script")
185+
sklearn_model.repacked_model_data = repacked_model_data
186+
model_values = sklearn_model.prepare_container_def(CPU)
187+
assert model_values["Environment"]["SAGEMAKER_SUBMIT_DIRECTORY"] == "/opt/ml/model/code"
188+
assert model_values["ModelDataUrl"] == repacked_model_data
189+
190+
171191
def test_create_model_from_estimator(sagemaker_session, sklearn_version):
172192
container_log_level = '"logging.INFO"'
173193
source_dir = "s3://mybucket/source"

0 commit comments

Comments
 (0)