Skip to content

feature: network isolation mode in training #791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 21, 2019
27 changes: 27 additions & 0 deletions doc/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,33 @@ Likewise, when you create ``Transformer`` from the ``Estimator`` using ``transfo
# Transform Job container instances will run in your VPC
mxnet_vpc_transformer.transform('s3://my-bucket/batch-transform-input')
Secure Training with Network Isolation (Internet-Free) Mode
-------------------------------------------------------------------------
You can enable network isolation mode when running training and inference on Amazon SageMaker.
For more information about Amazon SageMaker network isolation mode, see the `SageMaker documentation on network isolation or internet-free mode <https://docs.aws.amazon.com/sagemaker/latest/dg/mkt-algo-model-internet-free.html>`__.
To train a model in network isolation mode, set the optional parameter ``enable_network_isolation`` to ``True`` in any network isolation supported Framework Estimator.
.. code:: python
# set the enable_network_isolation parameter to True
sklearn_estimator = SKLearn('sklearn-train.py',
train_instance_type='ml.m4.xlarge',
framework_version='0.20.0',
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1},
enable_network_isolation=True)
# SageMaker Training Job will in the container without any inbound or outbound network calls during runtime
sklearn_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
'test': 's3://my-data-bucket/path/to/my/test/data'})
When this training job is created, the SageMaker Python SDK will upload the files in ``entry_point``, ``source_dir``, and ``dependencies`` to S3 as a compressed ``sourcedir.tar.gz`` file (``'s3://mybucket/sourcedir.tar.gz'``).
A new training job channel, named ``code``, will be added with that S3 URI. Before the training docker container is initialized, the ``sourcedir.tar.gz`` will be downloaded from S3 to the ML storage volume like any other offline input channel.
Once the training job begins, the training container will look at the offline input ``code`` channel to install dependencies and run the entry script. This isolates the training container, so no inbound or outbound network calls can be made.
FAQ
---
Expand Down
59 changes: 41 additions & 18 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def __init__(self, role, train_instance_count, train_instance_type,
self.metric_definitions = metric_definitions
self.model_uri = model_uri
self.model_channel_name = model_channel_name
self.code_uri = None
self.code_channel_name = 'code'

if self.train_instance_type in ('local', 'local_gpu'):
if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
Expand Down Expand Up @@ -773,9 +775,11 @@ class Framework(EstimatorBase):
LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled'
MPI_NUM_PROCESSES_PER_HOST = 'sagemaker_mpi_num_of_processes_per_host'
MPI_CUSTOM_MPI_OPTIONS = 'sagemaker_mpi_custom_mpi_options'
CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = '/opt/ml/input/data/code/sourcedir.tar.gz'

def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False,
container_log_level=logging.INFO, code_location=None, image_name=None, dependencies=None, **kwargs):
container_log_level=logging.INFO, code_location=None, image_name=None, dependencies=None,
enable_network_isolation=False, **kwargs):
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
Args:
Expand All @@ -784,6 +788,21 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
source_dir (str): Path (absolute or relative) to a directory with any other training
source code dependencies aside from the entry point file (default: None). Structure within this
directory are preserved when training on Amazon SageMaker.
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
For convenience, this accepts other types for keys and values, but ``str()`` will be called
to convert them before training.
enable_cloudwatch_metrics (bool): [DEPRECATED] Now there are cloudwatch metrics emitted by all SageMaker
training jobs. This will be ignored for now and removed in a further release.
container_log_level (int): Log level to use within the container (default: logging.INFO).
Valid values are defined in the Python logging module.
code_location (str): The S3 prefix URI where custom code will be uploaded (default: None).
The code file uploaded in S3 is 'code_location/source/sourcedir.tar.gz'.
If not specified, the default code location is s3://default_bucket/job-name/. And code file
uploaded to S3 is s3://default_bucket/job-name/source/sourcedir.tar.gz
image_name (str): An alternate image name to use instead of the official Sagemaker image
for the framework. This is useful to run one of the Sagemaker supported frameworks
with an image containing custom dependencies.
dependencies (list[str]): A list of paths to directories (absolute or relative) with
any additional libraries that will be exported to the container (default: []).
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
Expand All @@ -800,21 +819,11 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
>>> |------ common
>>> |------ virtual-env
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
For convenience, this accepts other types for keys and values, but ``str()`` will be called
to convert them before training.
enable_cloudwatch_metrics (bool): [DEPRECATED] Now there are cloudwatch metrics emitted by all SageMaker
training jobs. This will be ignored for now and removed in a further release.
container_log_level (int): Log level to use within the container (default: logging.INFO).
Valid values are defined in the Python logging module.
code_location (str): The S3 prefix URI where custom code will be uploaded (default: None).
The code file uploaded in S3 is 'code_location/source/sourcedir.tar.gz'.
If not specified, the default code location is s3://default_bucket/job-name/. And code file
uploaded to S3 is s3://default_bucket/job-name/source/sourcedir.tar.gz
image_name (str): An alternate image name to use instead of the official Sagemaker image
for the framework. This is useful to run one of the Sagemaker supported frameworks
with an image containing custom dependencies.
enable_network_isolation (bool): Specifies whether container will run in network isolation mode. Network
isolation mode restricts the container access to outside networks (such as the internet). The container
does not make any inbound or outbound network calls. If True, a channel named "code" will be created
for any user entry script for training. The user entry script, files in source_dir (if specified), and
dependencies will be uploaded in a tar to S3. Also known as internet-free mode (default: `False`).
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
"""
super(Framework, self).__init__(**kwargs)
Expand All @@ -830,9 +839,18 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
self.container_log_level = container_log_level
self.code_location = code_location
self.image_name = image_name
self._enable_network_isolation = enable_network_isolation

self._hyperparameters = hyperparameters or {}

def enable_network_isolation(self):
"""Return True if this Estimator can use network isolation to run.
Returns:
bool: Whether this Estimator can use network isolation or not.
"""
return self._enable_network_isolation

def _prepare_for_training(self, job_name=None):
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
Expand All @@ -858,6 +876,11 @@ def _prepare_for_training(self, job_name=None):

code_dir = 'file://' + self.source_dir
script = self.entry_point
elif self.enable_network_isolation() and self.entry_point:
self.uploaded_code = self._stage_user_code_in_s3()
code_dir = self.CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH
script = self.uploaded_code.script_name
self.code_uri = self.uploaded_code.s3_prefix
else:
self.uploaded_code = self._stage_user_code_in_s3()
code_dir = self.uploaded_code.s3_prefix
Expand All @@ -881,12 +904,12 @@ def _stage_user_code_in_s3(self):

if self.code_location is None and local_mode:
code_bucket = self.sagemaker_session.default_bucket()
code_s3_prefix = '{}/source'.format(self._current_job_name)
code_s3_prefix = '{}/{}'.format(self._current_job_name, 'source')
kms_key = None

elif self.code_location is None:
code_bucket, _ = parse_s3_url(self.output_path)
code_s3_prefix = '{}/source'.format(self._current_job_name)
code_s3_prefix = '{}/{}'.format(self._current_job_name, 'source')
kms_key = self.output_kms_key
else:
code_bucket, key_prefix = parse_s3_url(self.code_location)
Expand Down
44 changes: 27 additions & 17 deletions src/sagemaker/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,21 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
stop_condition = _Job._prepare_stop_condition(estimator.train_max_run)
vpc_config = estimator.get_vpc_config()

model_channel = _Job._prepare_model_channel(input_config, estimator.model_uri, estimator.model_channel_name,
validate_uri)
model_channel = _Job._prepare_channel(input_config, estimator.model_uri, estimator.model_channel_name,
validate_uri, content_type='application/x-sagemaker-model',
input_mode='File')
if model_channel:
input_config = [] if input_config is None else input_config
input_config.append(model_channel)

if estimator.enable_network_isolation():
code_channel = _Job._prepare_channel(input_config, estimator.code_uri, estimator.code_channel_name,
validate_uri)

if code_channel:
input_config = [] if input_config is None else input_config
input_config.append(code_channel)

return {'input_config': input_config,
'role': role,
'output_config': output_config,
Expand Down Expand Up @@ -110,16 +119,16 @@ def _convert_input_to_channel(channel_name, channel_s3_input):
return channel_config

@staticmethod
def _format_string_uri_input(uri_input, validate_uri=True):
def _format_string_uri_input(uri_input, validate_uri=True, content_type=None, input_mode=None):
if isinstance(uri_input, str) and validate_uri and uri_input.startswith('s3://'):
return s3_input(uri_input)
return s3_input(uri_input, content_type=content_type, input_mode=input_mode)
elif isinstance(uri_input, str) and validate_uri and uri_input.startswith('file://'):
return file_input(uri_input)
elif isinstance(uri_input, str) and validate_uri:
raise ValueError('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
'"file://"')
raise ValueError('URI input {} must be a valid S3 or FILE URI: must start with "s3://" or '
'"file://"'.format(uri_input))
elif isinstance(uri_input, str):
return s3_input(uri_input)
return s3_input(uri_input, content_type=content_type, input_mode=input_mode)
elif isinstance(uri_input, s3_input):
return uri_input
elif isinstance(uri_input, file_input):
Expand All @@ -128,21 +137,22 @@ def _format_string_uri_input(uri_input, validate_uri=True):
raise ValueError('Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(uri_input))

@staticmethod
def _prepare_model_channel(input_config, model_uri=None, model_channel_name=None, validate_uri=True):
if not model_uri:
def _prepare_channel(input_config, channel_uri=None, channel_name=None, validate_uri=True, content_type=None,
input_mode=None):
if not channel_uri:
return
elif not model_channel_name:
raise ValueError('Expected a pre-trained model channel name if a model URL is specified.')
elif not channel_name:
raise ValueError('Expected a channel name if a channel URI {} is specified'.format(channel_uri))

if input_config:
for channel in input_config:
if channel['ChannelName'] == model_channel_name:
raise ValueError('Duplicate channels not allowed.')
for existing_channel in input_config:
if existing_channel['ChannelName'] == channel_name:
raise ValueError('Duplicate channel {} not allowed.'.format(channel_name))

model_input = _Job._format_model_uri_input(model_uri, validate_uri)
model_channel = _Job._convert_input_to_channel(model_channel_name, model_input)
channel_input = _Job._format_string_uri_input(channel_uri, validate_uri, content_type, input_mode)
channel = _Job._convert_input_to_channel(channel_name, channel_input)

return model_channel
return channel

@staticmethod
def _format_model_uri_input(model_uri, validate_uri=True):
Expand Down
27 changes: 27 additions & 0 deletions tests/integ/test_sklearn_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,33 @@ def test_training_with_additional_hyperparameters(sagemaker_session, sklearn_ful
return sklearn.latest_training_job.name


@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.")
def test_training_with_network_isolation(sagemaker_session, sklearn_full_version):
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
script_path = os.path.join(DATA_DIR, 'sklearn_mnist', 'mnist.py')
data_path = os.path.join(DATA_DIR, 'sklearn_mnist')

sklearn = SKLearn(entry_point=script_path,
role='SageMakerRole',
train_instance_type="ml.c4.xlarge",
framework_version=sklearn_full_version,
py_version=PYTHON_VERSION,
sagemaker_session=sagemaker_session,
hyperparameters={'epochs': 1},
enable_network_isolation=True)

train_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
key_prefix='integ-test-data/sklearn_mnist/train')
test_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
key_prefix='integ-test-data/sklearn_mnist/test')
job_name = unique_name_from_base('test-sklearn-hp')

sklearn.fit({'train': train_input, 'test': test_input}, job_name=job_name)
assert sagemaker_session.sagemaker_client \
.describe_training_job(TrainingJobName=job_name)['EnableNetworkIsolation']
return sklearn.latest_training_job.name


@pytest.mark.canary_quick
@pytest.mark.regional_testing
@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.")
Expand Down
19 changes: 9 additions & 10 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,27 +175,26 @@ def test_unoptimized_gpu_family():


def test_tar_and_upload_dir_s3(sagemaker_session):
bucket = 'mybucker'
bucket = 'mybucket'
s3_key_prefix = 'something/source'
script = 'mnist.py'
directory = 's3://m'
result = fw_utils.tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, directory)

assert result == fw_utils.UploadedCode('s3://m', 'mnist.py')


@patch('sagemaker.utils')
def test_tar_and_upload_dir_s3_with_kms(utils, sagemaker_session):
bucket = 'mybucket'
s3_key_prefix = 'something/source'
script = 'mnist.py'
kms_key = 'kms-key'
result = fw_utils.tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, kms_key=kms_key)

result = fw_utils.tar_and_upload_dir(sagemaker_session,
'mybucker',
'something/source',
'mnist.py',
kms_key='kms-key')

assert result == fw_utils.UploadedCode('s3://mybucker/something/source/sourcedir.tar.gz',
'mnist.py')
assert result == fw_utils.UploadedCode('s3://{}/{}/sourcedir.tar.gz'.format(bucket, s3_key_prefix), script)

extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'kms-key'}
extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': kms_key}
obj = sagemaker_session.resource('s3').Object('', '')
obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)

Expand Down
Loading