diff --git a/doc/overview.rst b/doc/overview.rst index 5cd440d762..8a2f789252 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -559,6 +559,33 @@ Likewise, when you create ``Transformer`` from the ``Estimator`` using ``transfo # Transform Job container instances will run in your VPC mxnet_vpc_transformer.transform('s3://my-bucket/batch-transform-input') +Secure Training with Network Isolation (Internet-Free) Mode +------------------------------------------------------------------------- +You can enable network isolation mode when running training and inference on Amazon SageMaker. + +For more information about Amazon SageMaker network isolation mode, see the `SageMaker documentation on network isolation or internet-free mode `__. + +To train a model in network isolation mode, set the optional parameter ``enable_network_isolation`` to ``True`` in any network isolation supported Framework Estimator. + +.. code:: python + + # set the enable_network_isolation parameter to True + sklearn_estimator = SKLearn('sklearn-train.py', + train_instance_type='ml.m4.xlarge', + framework_version='0.20.0', + hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1}, + enable_network_isolation=True) + + # SageMaker Training Job will in the container without any inbound or outbound network calls during runtime + sklearn_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data', + 'test': 's3://my-data-bucket/path/to/my/test/data'}) + +When this training job is created, the SageMaker Python SDK will upload the files in ``entry_point``, ``source_dir``, and ``dependencies`` to S3 as a compressed ``sourcedir.tar.gz`` file (``'s3://mybucket/sourcedir.tar.gz'``). + +A new training job channel, named ``code``, will be added with that S3 URI. Before the training docker container is initialized, the ``sourcedir.tar.gz`` will be downloaded from S3 to the ML storage volume like any other offline input channel. + +Once the training job begins, the training container will look at the offline input ``code`` channel to install dependencies and run the entry script. This isolates the training container, so no inbound or outbound network calls can be made. + FAQ --- diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 54a6e9db86..2bc770f4c7 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -117,6 +117,8 @@ def __init__(self, role, train_instance_count, train_instance_type, self.metric_definitions = metric_definitions self.model_uri = model_uri self.model_channel_name = model_channel_name + self.code_uri = None + self.code_channel_name = 'code' if self.train_instance_type in ('local', 'local_gpu'): if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1: @@ -773,9 +775,11 @@ class Framework(EstimatorBase): LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled' MPI_NUM_PROCESSES_PER_HOST = 'sagemaker_mpi_num_of_processes_per_host' MPI_CUSTOM_MPI_OPTIONS = 'sagemaker_mpi_custom_mpi_options' + CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = '/opt/ml/input/data/code/sourcedir.tar.gz' def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False, - container_log_level=logging.INFO, code_location=None, image_name=None, dependencies=None, **kwargs): + container_log_level=logging.INFO, code_location=None, image_name=None, dependencies=None, + enable_network_isolation=False, **kwargs): """Base class initializer. Subclasses which override ``__init__`` should invoke ``super()`` Args: @@ -784,6 +788,21 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl source_dir (str): Path (absolute or relative) to a directory with any other training source code dependencies aside from the entry point file (default: None). Structure within this directory are preserved when training on Amazon SageMaker. + hyperparameters (dict): Hyperparameters that will be used for training (default: None). + The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. + For convenience, this accepts other types for keys and values, but ``str()`` will be called + to convert them before training. + enable_cloudwatch_metrics (bool): [DEPRECATED] Now there are cloudwatch metrics emitted by all SageMaker + training jobs. This will be ignored for now and removed in a further release. + container_log_level (int): Log level to use within the container (default: logging.INFO). + Valid values are defined in the Python logging module. + code_location (str): The S3 prefix URI where custom code will be uploaded (default: None). + The code file uploaded in S3 is 'code_location/source/sourcedir.tar.gz'. + If not specified, the default code location is s3://default_bucket/job-name/. And code file + uploaded to S3 is s3://default_bucket/job-name/source/sourcedir.tar.gz + image_name (str): An alternate image name to use instead of the official Sagemaker image + for the framework. This is useful to run one of the Sagemaker supported frameworks + with an image containing custom dependencies. dependencies (list[str]): A list of paths to directories (absolute or relative) with any additional libraries that will be exported to the container (default: []). The library folders will be copied to SageMaker in the same folder where the entrypoint is copied. @@ -800,21 +819,11 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl >>> |------ common >>> |------ virtual-env - hyperparameters (dict): Hyperparameters that will be used for training (default: None). - The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. - For convenience, this accepts other types for keys and values, but ``str()`` will be called - to convert them before training. - enable_cloudwatch_metrics (bool): [DEPRECATED] Now there are cloudwatch metrics emitted by all SageMaker - training jobs. This will be ignored for now and removed in a further release. - container_log_level (int): Log level to use within the container (default: logging.INFO). - Valid values are defined in the Python logging module. - code_location (str): The S3 prefix URI where custom code will be uploaded (default: None). - The code file uploaded in S3 is 'code_location/source/sourcedir.tar.gz'. - If not specified, the default code location is s3://default_bucket/job-name/. And code file - uploaded to S3 is s3://default_bucket/job-name/source/sourcedir.tar.gz - image_name (str): An alternate image name to use instead of the official Sagemaker image - for the framework. This is useful to run one of the Sagemaker supported frameworks - with an image containing custom dependencies. + enable_network_isolation (bool): Specifies whether container will run in network isolation mode. Network + isolation mode restricts the container access to outside networks (such as the internet). The container + does not make any inbound or outbound network calls. If True, a channel named "code" will be created + for any user entry script for training. The user entry script, files in source_dir (if specified), and + dependencies will be uploaded in a tar to S3. Also known as internet-free mode (default: `False`). **kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor. """ super(Framework, self).__init__(**kwargs) @@ -830,9 +839,18 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl self.container_log_level = container_log_level self.code_location = code_location self.image_name = image_name + self._enable_network_isolation = enable_network_isolation self._hyperparameters = hyperparameters or {} + def enable_network_isolation(self): + """Return True if this Estimator can use network isolation to run. + + Returns: + bool: Whether this Estimator can use network isolation or not. + """ + return self._enable_network_isolation + def _prepare_for_training(self, job_name=None): """Set hyperparameters needed for training. This method will also validate ``source_dir``. @@ -858,6 +876,11 @@ def _prepare_for_training(self, job_name=None): code_dir = 'file://' + self.source_dir script = self.entry_point + elif self.enable_network_isolation() and self.entry_point: + self.uploaded_code = self._stage_user_code_in_s3() + code_dir = self.CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH + script = self.uploaded_code.script_name + self.code_uri = self.uploaded_code.s3_prefix else: self.uploaded_code = self._stage_user_code_in_s3() code_dir = self.uploaded_code.s3_prefix @@ -881,12 +904,12 @@ def _stage_user_code_in_s3(self): if self.code_location is None and local_mode: code_bucket = self.sagemaker_session.default_bucket() - code_s3_prefix = '{}/source'.format(self._current_job_name) + code_s3_prefix = '{}/{}'.format(self._current_job_name, 'source') kms_key = None elif self.code_location is None: code_bucket, _ = parse_s3_url(self.output_path) - code_s3_prefix = '{}/source'.format(self._current_job_name) + code_s3_prefix = '{}/{}'.format(self._current_job_name, 'source') kms_key = self.output_kms_key else: code_bucket, key_prefix = parse_s3_url(self.code_location) diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 0e3c2f4034..bac1b78b3a 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -60,12 +60,21 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): stop_condition = _Job._prepare_stop_condition(estimator.train_max_run) vpc_config = estimator.get_vpc_config() - model_channel = _Job._prepare_model_channel(input_config, estimator.model_uri, estimator.model_channel_name, - validate_uri) + model_channel = _Job._prepare_channel(input_config, estimator.model_uri, estimator.model_channel_name, + validate_uri, content_type='application/x-sagemaker-model', + input_mode='File') if model_channel: input_config = [] if input_config is None else input_config input_config.append(model_channel) + if estimator.enable_network_isolation(): + code_channel = _Job._prepare_channel(input_config, estimator.code_uri, estimator.code_channel_name, + validate_uri) + + if code_channel: + input_config = [] if input_config is None else input_config + input_config.append(code_channel) + return {'input_config': input_config, 'role': role, 'output_config': output_config, @@ -110,16 +119,16 @@ def _convert_input_to_channel(channel_name, channel_s3_input): return channel_config @staticmethod - def _format_string_uri_input(uri_input, validate_uri=True): + def _format_string_uri_input(uri_input, validate_uri=True, content_type=None, input_mode=None): if isinstance(uri_input, str) and validate_uri and uri_input.startswith('s3://'): - return s3_input(uri_input) + return s3_input(uri_input, content_type=content_type, input_mode=input_mode) elif isinstance(uri_input, str) and validate_uri and uri_input.startswith('file://'): return file_input(uri_input) elif isinstance(uri_input, str) and validate_uri: - raise ValueError('Training input data must be a valid S3 or FILE URI: must start with "s3://" or ' - '"file://"') + raise ValueError('URI input {} must be a valid S3 or FILE URI: must start with "s3://" or ' + '"file://"'.format(uri_input)) elif isinstance(uri_input, str): - return s3_input(uri_input) + return s3_input(uri_input, content_type=content_type, input_mode=input_mode) elif isinstance(uri_input, s3_input): return uri_input elif isinstance(uri_input, file_input): @@ -128,21 +137,22 @@ def _format_string_uri_input(uri_input, validate_uri=True): raise ValueError('Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(uri_input)) @staticmethod - def _prepare_model_channel(input_config, model_uri=None, model_channel_name=None, validate_uri=True): - if not model_uri: + def _prepare_channel(input_config, channel_uri=None, channel_name=None, validate_uri=True, content_type=None, + input_mode=None): + if not channel_uri: return - elif not model_channel_name: - raise ValueError('Expected a pre-trained model channel name if a model URL is specified.') + elif not channel_name: + raise ValueError('Expected a channel name if a channel URI {} is specified'.format(channel_uri)) if input_config: - for channel in input_config: - if channel['ChannelName'] == model_channel_name: - raise ValueError('Duplicate channels not allowed.') + for existing_channel in input_config: + if existing_channel['ChannelName'] == channel_name: + raise ValueError('Duplicate channel {} not allowed.'.format(channel_name)) - model_input = _Job._format_model_uri_input(model_uri, validate_uri) - model_channel = _Job._convert_input_to_channel(model_channel_name, model_input) + channel_input = _Job._format_string_uri_input(channel_uri, validate_uri, content_type, input_mode) + channel = _Job._convert_input_to_channel(channel_name, channel_input) - return model_channel + return channel @staticmethod def _format_model_uri_input(model_uri, validate_uri=True): diff --git a/tests/integ/test_sklearn_train.py b/tests/integ/test_sklearn_train.py index 7a32239bee..c332dcc73a 100644 --- a/tests/integ/test_sklearn_train.py +++ b/tests/integ/test_sklearn_train.py @@ -55,6 +55,33 @@ def test_training_with_additional_hyperparameters(sagemaker_session, sklearn_ful return sklearn.latest_training_job.name +@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.") +def test_training_with_network_isolation(sagemaker_session, sklearn_full_version): + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + script_path = os.path.join(DATA_DIR, 'sklearn_mnist', 'mnist.py') + data_path = os.path.join(DATA_DIR, 'sklearn_mnist') + + sklearn = SKLearn(entry_point=script_path, + role='SageMakerRole', + train_instance_type="ml.c4.xlarge", + framework_version=sklearn_full_version, + py_version=PYTHON_VERSION, + sagemaker_session=sagemaker_session, + hyperparameters={'epochs': 1}, + enable_network_isolation=True) + + train_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), + key_prefix='integ-test-data/sklearn_mnist/train') + test_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), + key_prefix='integ-test-data/sklearn_mnist/test') + job_name = unique_name_from_base('test-sklearn-hp') + + sklearn.fit({'train': train_input, 'test': test_input}, job_name=job_name) + assert sagemaker_session.sagemaker_client \ + .describe_training_job(TrainingJobName=job_name)['EnableNetworkIsolation'] + return sklearn.latest_training_job.name + + @pytest.mark.canary_quick @pytest.mark.regional_testing @pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.") diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index ed5e97b631..0a7e82058a 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -175,27 +175,26 @@ def test_unoptimized_gpu_family(): def test_tar_and_upload_dir_s3(sagemaker_session): - bucket = 'mybucker' + bucket = 'mybucket' s3_key_prefix = 'something/source' script = 'mnist.py' directory = 's3://m' result = fw_utils.tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, directory) + assert result == fw_utils.UploadedCode('s3://m', 'mnist.py') @patch('sagemaker.utils') def test_tar_and_upload_dir_s3_with_kms(utils, sagemaker_session): + bucket = 'mybucket' + s3_key_prefix = 'something/source' + script = 'mnist.py' + kms_key = 'kms-key' + result = fw_utils.tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, kms_key=kms_key) - result = fw_utils.tar_and_upload_dir(sagemaker_session, - 'mybucker', - 'something/source', - 'mnist.py', - kms_key='kms-key') - - assert result == fw_utils.UploadedCode('s3://mybucker/something/source/sourcedir.tar.gz', - 'mnist.py') + assert result == fw_utils.UploadedCode('s3://{}/{}/sourcedir.tar.gz'.format(bucket, s3_key_prefix), script) - extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'kms-key'} + extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': kms_key} obj = sagemaker_session.resource('s3').Object('', '') obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args) diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index fa2cace4b4..a3669a2524 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -13,11 +13,13 @@ from __future__ import absolute_import import pytest +import os from mock import Mock from sagemaker.amazon.amazon_estimator import RecordSet -from sagemaker.estimator import Estimator +from sagemaker.estimator import Estimator, Framework from sagemaker.job import _Job +from sagemaker.model import FrameworkModel from sagemaker.session import s3_input BUCKET_NAME = 's3://mybucket/train' @@ -28,12 +30,29 @@ VOLUME_SIZE = 1 MAX_RUNTIME = 1 ROLE = 'DummyRole' +REGION = 'us-west-2' IMAGE_NAME = 'fakeimage' +SCRIPT_NAME = 'script.py' JOB_NAME = 'fakejob' VOLUME_KMS_KEY = 'volkmskey' -CHANNEL_NAME = 'testChannel' +MODEL_CHANNEL_NAME = 'testModelChannel' MODEL_URI = 's3://bucket/prefix/model.tar.gz' LOCAL_MODEL_NAME = 'file://local/file.tar.gz' +CODE_CHANNEL_NAME = 'testCodeChannel' +CODE_URI = 's3://bucket/prefix/code.py' +DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') +SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_NAME) +MODEL_CONTAINER_DEF = { + 'Environment': { + 'SAGEMAKER_PROGRAM': SCRIPT_NAME, + 'SAGEMAKER_SUBMIT_DIRECTORY': 's3://mybucket/mi-2017-10-10-14-14-15/sourcedir.tar.gz', + 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', + 'SAGEMAKER_REGION': REGION, + 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false' + }, + 'Image': IMAGE_NAME, + 'ModelDataUrl': MODEL_URI, +} @pytest.fixture() @@ -51,6 +70,39 @@ def sagemaker_session(): return mock_session +class DummyFramework(Framework): + __framework_name__ = 'dummy' + + def train_image(self): + return IMAGE_NAME + + def create_model(self, role=None, model_server_workers=None): + return DummyFrameworkModel(self.sagemaker_session, vpc_config=self.get_vpc_config()) + + @classmethod + def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): + init_params = super(DummyFramework, cls)._prepare_init_params_from_job_description( + job_details, model_channel_name) + init_params.pop("image", None) + return init_params + + +class DummyFrameworkModel(FrameworkModel): + def __init__(self, sagemaker_session, **kwargs): + super(DummyFrameworkModel, self).__init__(MODEL_URI, IMAGE_NAME, INSTANCE_TYPE, ROLE, SCRIPT_NAME, + sagemaker_session=sagemaker_session, **kwargs) + + def prepare_container_def(self, instance_type, accelerator_type=None): + return MODEL_CONTAINER_DEF + + +@pytest.fixture() +def framework(sagemaker_session): + return DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, + output_path=S3_OUTPUT_PATH, train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE) + + def test_load_config(estimator): inputs = s3_input(BUCKET_NAME) @@ -70,13 +122,13 @@ def test_load_config_with_model_channel(estimator): inputs = s3_input(BUCKET_NAME) estimator.model_uri = MODEL_URI - estimator.model_channel_name = CHANNEL_NAME + estimator.model_channel_name = MODEL_CHANNEL_NAME config = _Job._load_config(inputs, estimator) assert config['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] == BUCKET_NAME assert config['input_config'][1]['DataSource']['S3DataSource']['S3Uri'] == MODEL_URI - assert config['input_config'][1]['ChannelName'] == CHANNEL_NAME + assert config['input_config'][1]['ChannelName'] == MODEL_CHANNEL_NAME assert config['role'] == ROLE assert config['output_config']['S3OutputPath'] == S3_OUTPUT_PATH assert 'KmsKeyId' not in config['output_config'] @@ -88,12 +140,12 @@ def test_load_config_with_model_channel(estimator): def test_load_config_with_model_channel_no_inputs(estimator): estimator.model_uri = MODEL_URI - estimator.model_channel_name = CHANNEL_NAME + estimator.model_channel_name = MODEL_CHANNEL_NAME config = _Job._load_config(inputs=None, estimator=estimator) assert config['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] == MODEL_URI - assert config['input_config'][0]['ChannelName'] == CHANNEL_NAME + assert config['input_config'][0]['ChannelName'] == MODEL_CHANNEL_NAME assert config['role'] == ROLE assert config['output_config']['S3OutputPath'] == S3_OUTPUT_PATH assert 'KmsKeyId' not in config['output_config'] @@ -103,6 +155,43 @@ def test_load_config_with_model_channel_no_inputs(estimator): assert config['stop_condition']['MaxRuntimeInSeconds'] == MAX_RUNTIME +def test_load_config_with_code_channel(framework): + inputs = s3_input(BUCKET_NAME) + + framework.model_uri = MODEL_URI + framework.model_channel_name = MODEL_CHANNEL_NAME + framework.code_uri = CODE_URI + framework._enable_network_isolation = True + config = _Job._load_config(inputs, framework) + + assert len(config['input_config']) == 3 + assert config['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] == BUCKET_NAME + assert config['input_config'][2]['DataSource']['S3DataSource']['S3Uri'] == CODE_URI + assert config['input_config'][2]['ChannelName'] == framework.code_channel_name + assert config['role'] == ROLE + assert config['output_config']['S3OutputPath'] == S3_OUTPUT_PATH + assert 'KmsKeyId' not in config['output_config'] + assert config['resource_config']['InstanceCount'] == INSTANCE_COUNT + assert config['resource_config']['InstanceType'] == INSTANCE_TYPE + + +def test_load_config_with_code_channel_no_code_uri(framework): + inputs = s3_input(BUCKET_NAME) + + framework.model_uri = MODEL_URI + framework.model_channel_name = MODEL_CHANNEL_NAME + framework._enable_network_isolation = True + config = _Job._load_config(inputs, framework) + + assert len(config['input_config']) == 2 + assert config['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] == BUCKET_NAME + assert config['role'] == ROLE + assert config['output_config']['S3OutputPath'] == S3_OUTPUT_PATH + assert 'KmsKeyId' not in config['output_config'] + assert config['resource_config']['InstanceCount'] == INSTANCE_COUNT + assert config['resource_config']['InstanceType'] == INSTANCE_TYPE + + def test_format_inputs_none(): channels = _Job._format_inputs_to_input_config(inputs=None) @@ -153,23 +242,28 @@ def test_format_inputs_to_input_config_list(): assert channels[0]['DataSource']['S3DataSource']['S3DataType'] == records.s3_data_type -def test_prepare_model_channel(): - model_channel = _Job._prepare_model_channel([], MODEL_URI, CHANNEL_NAME) +@pytest.mark.parametrize('channel_uri, channel_name, content_type, input_mode', + [[MODEL_URI, MODEL_CHANNEL_NAME, 'application/x-sagemaker-model', 'File'], + [CODE_URI, CODE_CHANNEL_NAME, None, None]]) +def test_prepare_channel(channel_uri, channel_name, content_type, input_mode): + channel = _Job._prepare_channel([], channel_uri, channel_name, content_type=content_type, input_mode=input_mode) + + assert channel['DataSource']['S3DataSource']['S3Uri'] == channel_uri + assert channel['DataSource']['S3DataSource']['S3DataDistributionType'] == 'FullyReplicated' + assert channel['DataSource']['S3DataSource']['S3DataType'] == 'S3Prefix' + assert channel['ChannelName'] == channel_name + assert 'CompressionType' not in channel + assert 'RecordWrapperType' not in channel - # The model channel should use all the defaults except InputMode - assert model_channel['DataSource']['S3DataSource']['S3Uri'] == MODEL_URI - assert model_channel['DataSource']['S3DataSource']['S3DataDistributionType'] == 'FullyReplicated' - assert model_channel['DataSource']['S3DataSource']['S3DataType'] == 'S3Prefix' - assert model_channel['InputMode'] == 'File' - assert model_channel['ChannelName'] == CHANNEL_NAME - assert 'CompressionType' not in model_channel - assert model_channel['ContentType'] == 'application/x-sagemaker-model' - assert 'RecordWrapperType' not in model_channel + # The model channel should use all the defaults except InputMode and ContentType + if channel_name == MODEL_CHANNEL_NAME: + assert channel['ContentType'] == 'application/x-sagemaker-model' + assert channel['InputMode'] == 'File' -def test_prepare_model_channel_duplicate(): +def test_prepare_channel_duplicate(): channels = [{ - 'ChannelName': CHANNEL_NAME, + 'ChannelName': MODEL_CHANNEL_NAME, 'DataSource': { 'S3DataSource': { 'S3DataDistributionType': 'FullyReplicated', @@ -180,20 +274,20 @@ def test_prepare_model_channel_duplicate(): }] with pytest.raises(ValueError) as error: - _Job._prepare_model_channel(channels, MODEL_URI, CHANNEL_NAME) + _Job._prepare_channel(channels, MODEL_URI, MODEL_CHANNEL_NAME) - assert 'Duplicate channels not allowed.' in str(error) + assert 'Duplicate channel {} not allowed.'.format(MODEL_CHANNEL_NAME) in str(error) -def test_prepare_model_channel_with_missing_name(): +def test_prepare_channel_with_missing_name(): with pytest.raises(ValueError) as ex: - _Job._prepare_model_channel([], model_uri=MODEL_URI, model_channel_name=None) + _Job._prepare_channel([], channel_uri=MODEL_URI, channel_name=None) - assert 'Expected a pre-trained model channel name if a model URL is specified.' in str(ex) + assert 'Expected a channel name if a channel URI {} is specified'.format(MODEL_URI) in str(ex) -def test_prepare_model_channel_with_missing_uri(): - assert _Job._prepare_model_channel([], model_uri=None, model_channel_name=None) is None +def test_prepare_channel_with_missing_uri(): + assert _Job._prepare_channel([], channel_uri=None, channel_name=None) is None def test_format_inputs_to_input_config_list_not_all_records():