Skip to content

Commit f1d34ad

Browse files
authored
feature: network isolation mode in training (aws#791)
* feature: network isolation mode in training * feature: network isolation mode in tar support training * change: documentation and check describe training job network isolation * doc update * doc update, remove inference section * sourcedir * type error fix
1 parent e2561d1 commit f1d34ad

File tree

6 files changed

+251
-71
lines changed

6 files changed

+251
-71
lines changed

doc/overview.rst

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,33 @@ Likewise, when you create ``Transformer`` from the ``Estimator`` using ``transfo
559559
# Transform Job container instances will run in your VPC
560560
mxnet_vpc_transformer.transform('s3://my-bucket/batch-transform-input')
561561
562+
Secure Training with Network Isolation (Internet-Free) Mode
563+
-------------------------------------------------------------------------
564+
You can enable network isolation mode when running training and inference on Amazon SageMaker.
565+
566+
For more information about Amazon SageMaker network isolation mode, see the `SageMaker documentation on network isolation or internet-free mode <https://docs.aws.amazon.com/sagemaker/latest/dg/mkt-algo-model-internet-free.html>`__.
567+
568+
To train a model in network isolation mode, set the optional parameter ``enable_network_isolation`` to ``True`` in any network isolation supported Framework Estimator.
569+
570+
.. code:: python
571+
572+
# set the enable_network_isolation parameter to True
573+
sklearn_estimator = SKLearn('sklearn-train.py',
574+
train_instance_type='ml.m4.xlarge',
575+
framework_version='0.20.0',
576+
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1},
577+
enable_network_isolation=True)
578+
579+
# SageMaker Training Job will in the container without any inbound or outbound network calls during runtime
580+
sklearn_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
581+
'test': 's3://my-data-bucket/path/to/my/test/data'})
582+
583+
When this training job is created, the SageMaker Python SDK will upload the files in ``entry_point``, ``source_dir``, and ``dependencies`` to S3 as a compressed ``sourcedir.tar.gz`` file (``'s3://mybucket/sourcedir.tar.gz'``).
584+
585+
A new training job channel, named ``code``, will be added with that S3 URI. Before the training docker container is initialized, the ``sourcedir.tar.gz`` will be downloaded from S3 to the ML storage volume like any other offline input channel.
586+
587+
Once the training job begins, the training container will look at the offline input ``code`` channel to install dependencies and run the entry script. This isolates the training container, so no inbound or outbound network calls can be made.
588+
562589
563590
FAQ
564591
---

src/sagemaker/estimator.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def __init__(self, role, train_instance_count, train_instance_type,
117117
self.metric_definitions = metric_definitions
118118
self.model_uri = model_uri
119119
self.model_channel_name = model_channel_name
120+
self.code_uri = None
121+
self.code_channel_name = 'code'
120122

121123
if self.train_instance_type in ('local', 'local_gpu'):
122124
if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
@@ -773,9 +775,11 @@ class Framework(EstimatorBase):
773775
LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled'
774776
MPI_NUM_PROCESSES_PER_HOST = 'sagemaker_mpi_num_of_processes_per_host'
775777
MPI_CUSTOM_MPI_OPTIONS = 'sagemaker_mpi_custom_mpi_options'
778+
CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = '/opt/ml/input/data/code/sourcedir.tar.gz'
776779

777780
def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False,
778-
container_log_level=logging.INFO, code_location=None, image_name=None, dependencies=None, **kwargs):
781+
container_log_level=logging.INFO, code_location=None, image_name=None, dependencies=None,
782+
enable_network_isolation=False, **kwargs):
779783
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
780784
781785
Args:
@@ -784,6 +788,21 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
784788
source_dir (str): Path (absolute or relative) to a directory with any other training
785789
source code dependencies aside from the entry point file (default: None). Structure within this
786790
directory are preserved when training on Amazon SageMaker.
791+
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
792+
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
793+
For convenience, this accepts other types for keys and values, but ``str()`` will be called
794+
to convert them before training.
795+
enable_cloudwatch_metrics (bool): [DEPRECATED] Now there are cloudwatch metrics emitted by all SageMaker
796+
training jobs. This will be ignored for now and removed in a further release.
797+
container_log_level (int): Log level to use within the container (default: logging.INFO).
798+
Valid values are defined in the Python logging module.
799+
code_location (str): The S3 prefix URI where custom code will be uploaded (default: None).
800+
The code file uploaded in S3 is 'code_location/source/sourcedir.tar.gz'.
801+
If not specified, the default code location is s3://default_bucket/job-name/. And code file
802+
uploaded to S3 is s3://default_bucket/job-name/source/sourcedir.tar.gz
803+
image_name (str): An alternate image name to use instead of the official Sagemaker image
804+
for the framework. This is useful to run one of the Sagemaker supported frameworks
805+
with an image containing custom dependencies.
787806
dependencies (list[str]): A list of paths to directories (absolute or relative) with
788807
any additional libraries that will be exported to the container (default: []).
789808
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
@@ -800,21 +819,11 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
800819
>>> |------ common
801820
>>> |------ virtual-env
802821
803-
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
804-
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
805-
For convenience, this accepts other types for keys and values, but ``str()`` will be called
806-
to convert them before training.
807-
enable_cloudwatch_metrics (bool): [DEPRECATED] Now there are cloudwatch metrics emitted by all SageMaker
808-
training jobs. This will be ignored for now and removed in a further release.
809-
container_log_level (int): Log level to use within the container (default: logging.INFO).
810-
Valid values are defined in the Python logging module.
811-
code_location (str): The S3 prefix URI where custom code will be uploaded (default: None).
812-
The code file uploaded in S3 is 'code_location/source/sourcedir.tar.gz'.
813-
If not specified, the default code location is s3://default_bucket/job-name/. And code file
814-
uploaded to S3 is s3://default_bucket/job-name/source/sourcedir.tar.gz
815-
image_name (str): An alternate image name to use instead of the official Sagemaker image
816-
for the framework. This is useful to run one of the Sagemaker supported frameworks
817-
with an image containing custom dependencies.
822+
enable_network_isolation (bool): Specifies whether container will run in network isolation mode. Network
823+
isolation mode restricts the container access to outside networks (such as the internet). The container
824+
does not make any inbound or outbound network calls. If True, a channel named "code" will be created
825+
for any user entry script for training. The user entry script, files in source_dir (if specified), and
826+
dependencies will be uploaded in a tar to S3. Also known as internet-free mode (default: `False`).
818827
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
819828
"""
820829
super(Framework, self).__init__(**kwargs)
@@ -830,9 +839,18 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
830839
self.container_log_level = container_log_level
831840
self.code_location = code_location
832841
self.image_name = image_name
842+
self._enable_network_isolation = enable_network_isolation
833843

834844
self._hyperparameters = hyperparameters or {}
835845

846+
def enable_network_isolation(self):
847+
"""Return True if this Estimator can use network isolation to run.
848+
849+
Returns:
850+
bool: Whether this Estimator can use network isolation or not.
851+
"""
852+
return self._enable_network_isolation
853+
836854
def _prepare_for_training(self, job_name=None):
837855
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
838856
@@ -858,6 +876,11 @@ def _prepare_for_training(self, job_name=None):
858876

859877
code_dir = 'file://' + self.source_dir
860878
script = self.entry_point
879+
elif self.enable_network_isolation() and self.entry_point:
880+
self.uploaded_code = self._stage_user_code_in_s3()
881+
code_dir = self.CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH
882+
script = self.uploaded_code.script_name
883+
self.code_uri = self.uploaded_code.s3_prefix
861884
else:
862885
self.uploaded_code = self._stage_user_code_in_s3()
863886
code_dir = self.uploaded_code.s3_prefix
@@ -881,12 +904,12 @@ def _stage_user_code_in_s3(self):
881904

882905
if self.code_location is None and local_mode:
883906
code_bucket = self.sagemaker_session.default_bucket()
884-
code_s3_prefix = '{}/source'.format(self._current_job_name)
907+
code_s3_prefix = '{}/{}'.format(self._current_job_name, 'source')
885908
kms_key = None
886909

887910
elif self.code_location is None:
888911
code_bucket, _ = parse_s3_url(self.output_path)
889-
code_s3_prefix = '{}/source'.format(self._current_job_name)
912+
code_s3_prefix = '{}/{}'.format(self._current_job_name, 'source')
890913
kms_key = self.output_kms_key
891914
else:
892915
code_bucket, key_prefix = parse_s3_url(self.code_location)

src/sagemaker/job.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,21 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
6060
stop_condition = _Job._prepare_stop_condition(estimator.train_max_run)
6161
vpc_config = estimator.get_vpc_config()
6262

63-
model_channel = _Job._prepare_model_channel(input_config, estimator.model_uri, estimator.model_channel_name,
64-
validate_uri)
63+
model_channel = _Job._prepare_channel(input_config, estimator.model_uri, estimator.model_channel_name,
64+
validate_uri, content_type='application/x-sagemaker-model',
65+
input_mode='File')
6566
if model_channel:
6667
input_config = [] if input_config is None else input_config
6768
input_config.append(model_channel)
6869

70+
if estimator.enable_network_isolation():
71+
code_channel = _Job._prepare_channel(input_config, estimator.code_uri, estimator.code_channel_name,
72+
validate_uri)
73+
74+
if code_channel:
75+
input_config = [] if input_config is None else input_config
76+
input_config.append(code_channel)
77+
6978
return {'input_config': input_config,
7079
'role': role,
7180
'output_config': output_config,
@@ -110,16 +119,16 @@ def _convert_input_to_channel(channel_name, channel_s3_input):
110119
return channel_config
111120

112121
@staticmethod
113-
def _format_string_uri_input(uri_input, validate_uri=True):
122+
def _format_string_uri_input(uri_input, validate_uri=True, content_type=None, input_mode=None):
114123
if isinstance(uri_input, str) and validate_uri and uri_input.startswith('s3://'):
115-
return s3_input(uri_input)
124+
return s3_input(uri_input, content_type=content_type, input_mode=input_mode)
116125
elif isinstance(uri_input, str) and validate_uri and uri_input.startswith('file://'):
117126
return file_input(uri_input)
118127
elif isinstance(uri_input, str) and validate_uri:
119-
raise ValueError('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
120-
'"file://"')
128+
raise ValueError('URI input {} must be a valid S3 or FILE URI: must start with "s3://" or '
129+
'"file://"'.format(uri_input))
121130
elif isinstance(uri_input, str):
122-
return s3_input(uri_input)
131+
return s3_input(uri_input, content_type=content_type, input_mode=input_mode)
123132
elif isinstance(uri_input, s3_input):
124133
return uri_input
125134
elif isinstance(uri_input, file_input):
@@ -128,21 +137,22 @@ def _format_string_uri_input(uri_input, validate_uri=True):
128137
raise ValueError('Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(uri_input))
129138

130139
@staticmethod
131-
def _prepare_model_channel(input_config, model_uri=None, model_channel_name=None, validate_uri=True):
132-
if not model_uri:
140+
def _prepare_channel(input_config, channel_uri=None, channel_name=None, validate_uri=True, content_type=None,
141+
input_mode=None):
142+
if not channel_uri:
133143
return
134-
elif not model_channel_name:
135-
raise ValueError('Expected a pre-trained model channel name if a model URL is specified.')
144+
elif not channel_name:
145+
raise ValueError('Expected a channel name if a channel URI {} is specified'.format(channel_uri))
136146

137147
if input_config:
138-
for channel in input_config:
139-
if channel['ChannelName'] == model_channel_name:
140-
raise ValueError('Duplicate channels not allowed.')
148+
for existing_channel in input_config:
149+
if existing_channel['ChannelName'] == channel_name:
150+
raise ValueError('Duplicate channel {} not allowed.'.format(channel_name))
141151

142-
model_input = _Job._format_model_uri_input(model_uri, validate_uri)
143-
model_channel = _Job._convert_input_to_channel(model_channel_name, model_input)
152+
channel_input = _Job._format_string_uri_input(channel_uri, validate_uri, content_type, input_mode)
153+
channel = _Job._convert_input_to_channel(channel_name, channel_input)
144154

145-
return model_channel
155+
return channel
146156

147157
@staticmethod
148158
def _format_model_uri_input(model_uri, validate_uri=True):

tests/integ/test_sklearn_train.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,33 @@ def test_training_with_additional_hyperparameters(sagemaker_session, sklearn_ful
5555
return sklearn.latest_training_job.name
5656

5757

58+
@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.")
59+
def test_training_with_network_isolation(sagemaker_session, sklearn_full_version):
60+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
61+
script_path = os.path.join(DATA_DIR, 'sklearn_mnist', 'mnist.py')
62+
data_path = os.path.join(DATA_DIR, 'sklearn_mnist')
63+
64+
sklearn = SKLearn(entry_point=script_path,
65+
role='SageMakerRole',
66+
train_instance_type="ml.c4.xlarge",
67+
framework_version=sklearn_full_version,
68+
py_version=PYTHON_VERSION,
69+
sagemaker_session=sagemaker_session,
70+
hyperparameters={'epochs': 1},
71+
enable_network_isolation=True)
72+
73+
train_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
74+
key_prefix='integ-test-data/sklearn_mnist/train')
75+
test_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
76+
key_prefix='integ-test-data/sklearn_mnist/test')
77+
job_name = unique_name_from_base('test-sklearn-hp')
78+
79+
sklearn.fit({'train': train_input, 'test': test_input}, job_name=job_name)
80+
assert sagemaker_session.sagemaker_client \
81+
.describe_training_job(TrainingJobName=job_name)['EnableNetworkIsolation']
82+
return sklearn.latest_training_job.name
83+
84+
5885
@pytest.mark.canary_quick
5986
@pytest.mark.regional_testing
6087
@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.")

tests/unit/test_fw_utils.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -175,27 +175,26 @@ def test_unoptimized_gpu_family():
175175

176176

177177
def test_tar_and_upload_dir_s3(sagemaker_session):
178-
bucket = 'mybucker'
178+
bucket = 'mybucket'
179179
s3_key_prefix = 'something/source'
180180
script = 'mnist.py'
181181
directory = 's3://m'
182182
result = fw_utils.tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, directory)
183+
183184
assert result == fw_utils.UploadedCode('s3://m', 'mnist.py')
184185

185186

186187
@patch('sagemaker.utils')
187188
def test_tar_and_upload_dir_s3_with_kms(utils, sagemaker_session):
189+
bucket = 'mybucket'
190+
s3_key_prefix = 'something/source'
191+
script = 'mnist.py'
192+
kms_key = 'kms-key'
193+
result = fw_utils.tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, kms_key=kms_key)
188194

189-
result = fw_utils.tar_and_upload_dir(sagemaker_session,
190-
'mybucker',
191-
'something/source',
192-
'mnist.py',
193-
kms_key='kms-key')
194-
195-
assert result == fw_utils.UploadedCode('s3://mybucker/something/source/sourcedir.tar.gz',
196-
'mnist.py')
195+
assert result == fw_utils.UploadedCode('s3://{}/{}/sourcedir.tar.gz'.format(bucket, s3_key_prefix), script)
197196

198-
extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'kms-key'}
197+
extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': kms_key}
199198
obj = sagemaker_session.resource('s3').Object('', '')
200199
obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)
201200

0 commit comments

Comments
 (0)