Skip to content

Commit 5fe413e

Browse files
awsbmillareBrent Millareahsan-z-khanshreyapanditnavinsoni
authored andcommitted
feature: network isolation mode for xgboost (aws#2626)
* feature: network isolation mode for xgboost * Add integ tests for xgboost net iso * fix import * fix job_name name * fix black-check * Revert local test setup changes Co-authored-by: Brent Millare <[email protected]> Co-authored-by: Ahsan Khan <[email protected]> Co-authored-by: Shreya Pandit <[email protected]> Co-authored-by: Navin Soni <[email protected]> Co-authored-by: Shreya Pandit <[email protected]>
1 parent f6bb190 commit 5fe413e

File tree

4 files changed

+110
-2
lines changed

4 files changed

+110
-2
lines changed

src/sagemaker/xgboost/model.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,16 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
145145
)
146146

147147
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
148-
self._upload_code(deploy_key_prefix)
148+
self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation())
149149
deploy_env = dict(self.env)
150150
deploy_env.update(self._framework_env_vars())
151151

152152
if self.model_server_workers:
153153
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
154-
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
154+
model_data = (
155+
self.repacked_model_data if self.enable_network_isolation() else self.model_data
156+
)
157+
return sagemaker.container_def(deploy_image, model_data, deploy_env)
155158

156159
def serving_image_uri(self, region_name, instance_type):
157160
"""Create a URI for the serving image.

tests/data/xgboost_abalone/abalone.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import argparse
2+
import os
3+
4+
from sagemaker_xgboost_container.data_utils import get_dmatrix
5+
6+
import xgboost as xgb
7+
8+
model_filename = "xgboost-model"
9+
10+
if __name__ == "__main__":
11+
parser = argparse.ArgumentParser()
12+
13+
# Sagemaker specific arguments. Defaults are set in the environment variables.
14+
parser.add_argument(
15+
"--model_dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
16+
)
17+
parser.add_argument(
18+
"--train",
19+
type=str,
20+
default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/abalone"),
21+
)
22+
23+
args, _ = parser.parse_known_args()
24+
25+
dtrain = get_dmatrix(args.train, "libsvm")
26+
27+
params = {
28+
"max_depth": 5,
29+
"eta": 0.2,
30+
"gamma": 4,
31+
"min_child_weight": 6,
32+
"subsample": 0.7,
33+
"verbosity": 2,
34+
"objective": "reg:squarederror",
35+
"tree_method": "auto",
36+
"predictor": "auto",
37+
}
38+
39+
booster = xgb.train(params=params, dtrain=dtrain, num_boost_round=50)
40+
booster.save_model(args.model_dir + "/" + model_filename)
41+
42+
43+
def model_fn(model_dir):
44+
"""Deserialize and return fitted model.
45+
46+
Note that this should have the same name as the serialized model in the _xgb_train method
47+
"""
48+
booster = xgb.Booster()
49+
booster.load_model(os.path.join(model_dir, model_filename))
50+
return booster

tests/integ/test_xgboost.py

+34
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import os
1616
import pytest
17+
from sagemaker.utils import unique_name_from_base
18+
from sagemaker.xgboost import XGBoost
1719
from sagemaker.xgboost.processing import XGBoostProcessor
1820
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
1921
from tests.integ.timeout import timeout
@@ -48,3 +50,35 @@ def test_framework_processing_job_with_deps(
4850
inputs=[],
4951
wait=True,
5052
)
53+
54+
55+
def test_training_with_network_isolation(
56+
sagemaker_session,
57+
xgboost_latest_version,
58+
xgboost_latest_py_version,
59+
cpu_instance_type,
60+
):
61+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
62+
base_job_name = "test-network-isolation-xgboost"
63+
64+
xgboost = XGBoost(
65+
entry_point=os.path.join(DATA_DIR, "xgboost_abalone", "abalone.py"),
66+
role=ROLE,
67+
instance_type=cpu_instance_type,
68+
instance_count=1,
69+
framework_version=xgboost_latest_version,
70+
py_version=xgboost_latest_py_version,
71+
base_job_name=base_job_name,
72+
sagemaker_session=sagemaker_session,
73+
enable_network_isolation=True,
74+
)
75+
76+
train_input = xgboost.sagemaker_session.upload_data(
77+
path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"),
78+
key_prefix="integ-test-data/xgboost_abalone/abalone",
79+
)
80+
job_name = unique_name_from_base(base_job_name)
81+
xgboost.fit(inputs={"train": train_input}, job_name=job_name)
82+
assert sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=job_name)[
83+
"EnableNetworkIsolation"
84+
]

tests/unit/test_xgboost.py

+21
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from packaging.version import Version
2323

2424

25+
from sagemaker.fw_utils import UploadedCode
2526
from sagemaker.xgboost import XGBoost, XGBoostModel, XGBoostPredictor
2627

2728

@@ -180,6 +181,26 @@ def test_create_model(sagemaker_session, xgboost_framework_version):
180181
assert model_values["Image"] == default_image_uri
181182

182183

184+
@patch("sagemaker.model.FrameworkModel._upload_code")
185+
def test_create_model_with_network_isolation(upload, sagemaker_session, xgboost_framework_version):
186+
source_dir = "s3://mybucket/source"
187+
repacked_model_data = "s3://mybucket/prefix/model.tar.gz"
188+
189+
xgboost_model = XGBoostModel(
190+
model_data=source_dir,
191+
role=ROLE,
192+
sagemaker_session=sagemaker_session,
193+
entry_point=SCRIPT_PATH,
194+
framework_version=xgboost_framework_version,
195+
enable_network_isolation=True,
196+
)
197+
xgboost_model.uploaded_code = UploadedCode(s3_prefix=repacked_model_data, script_name="script")
198+
xgboost_model.repacked_model_data = repacked_model_data
199+
model_values = xgboost_model.prepare_container_def(CPU)
200+
assert model_values["Environment"]["SAGEMAKER_SUBMIT_DIRECTORY"] == "/opt/ml/model/code"
201+
assert model_values["ModelDataUrl"] == repacked_model_data
202+
203+
183204
@patch("sagemaker.estimator.name_from_base")
184205
def test_create_model_from_estimator(name_from_base, sagemaker_session, xgboost_framework_version):
185206
container_log_level = '"logging.INFO"'

0 commit comments

Comments
 (0)