diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 516043774b..dd29ed26fa 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,12 @@ CHANGELOG ========= +1.13.0 +====== + +* feature: Estimator: add input mode to training channels +* feature: Estimator: add model_uri and model_channel_name parameters + 1.12.0 ====== diff --git a/README.rst b/README.rst index 86688df414..a9cceaafc1 100644 --- a/README.rst +++ b/README.rst @@ -263,6 +263,66 @@ A few important notes: - Local Mode requires Docker Compose and `nvidia-docker2 `__ for ``local_gpu``. - Distributed training is not yet supported for ``local_gpu``. +Incremental Training +~~~~~~~~~~~~~~~~~~~~ + +Incremental training allows you to bring a pre-trained model into a SageMaker training job and use it as a starting point for a new model. +There are several situations where you might want to do this: + +- You want to perform additional training on a model to improve its fit on your data set. +- You want to import a pre-trained model and fit it to your data. +- You want to resume a training job that you previously stopped. + +To use incremental training with SageMaker algorithms, you need model artifacts compressed into a ``tar.gz`` file. These +artifacts are passed to a training job via an input channel configured with the pre-defined settings Amazon SageMaker algorithms require. + +To use model files with a SageMaker estimator, you can use the following parameters: + +* ``model_uri``: points to the location of a model tarball, either in S3 or locally. Specifying a local path only works in local mode. +* ``model_channel_name``: name of the channel SageMaker will use to download the tarball specified in ``model_uri``. Defaults to 'model'. + +This is converted into an input channel with the specifications mentioned above once you call ``fit()`` on the predictor. +In bring-your-own cases, ``model_channel_name`` can be overriden if you require to change the name of the channel while using +the same settings. + +If your bring-your-own case requires different settings, you can create your own ``s3_input`` object with the settings you require. + +Here's an example of how to use incremental training: + +.. code:: python + # Configure an estimator + estimator = sagemaker.estimator.Estimator(training_image, + role, + train_instance_count=1, + train_instance_type='ml.p2.xlarge', + train_volume_size=50, + train_max_run=360000, + input_mode='File', + output_path=s3_output_location) + + # Start a SageMaker training job and waits until completion. + estimator.fit('s3://my_bucket/my_training_data/') + + # Create a new estimator using the previous' model artifacts + incr_estimator = sagemaker.estimator.Estimator(training_image, + role, + train_instance_count=1, + train_instance_type='ml.p2.xlarge', + train_volume_size=50, + train_max_run=360000, + input_mode='File', + output_path=s3_output_location, + model_uri=estimator.model_data) + + # Start a SageMaker training job using the original model for incremental training + incr_estimator.fit('s3://my_bucket/my_training_data/') + +Currently, the following algorithms support incremental training: + +- Image Classification +- Object Detection +- Semantics Segmentation + MXNet SageMaker Estimators -------------------------- diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index b44f9d8834..dcb841261a 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -35,4 +35,4 @@ from sagemaker.session import s3_input # noqa: F401 from sagemaker.session import get_execution_role # noqa: F401 -__version__ = '1.12.0' +__version__ = '1.13.0' diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 5c41bcc66b..bbbb2beab6 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -70,17 +70,19 @@ def data_location(self, data_location): self._data_location = data_location @classmethod - def _prepare_init_params_from_job_description(cls, job_details): + def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): """Convert the job description to init params that can be handled by the class constructor Args: job_details: the returned job details from a describe_training_job API call. + model_channel_name (str): Name of the channel where pre-trained model data will be downloaded. Returns: dictionary: The transformed init_params """ - init_params = super(AmazonAlgorithmEstimatorBase, cls)._prepare_init_params_from_job_description(job_details) + init_params = super(AmazonAlgorithmEstimatorBase, cls)._prepare_init_params_from_job_description( + job_details, model_channel_name) # The hyperparam names may not be the same as the class attribute that holds them, # for instance: local_lloyd_init_method is called local_init_method. We need to map these diff --git a/src/sagemaker/chainer/estimator.py b/src/sagemaker/chainer/estimator.py index 47fa8abcd8..acb381ba45 100644 --- a/src/sagemaker/chainer/estimator.py +++ b/src/sagemaker/chainer/estimator.py @@ -134,17 +134,18 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override vpc_config=self.get_vpc_config(vpc_config_override)) @classmethod - def _prepare_init_params_from_job_description(cls, job_details): + def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): """Convert the job description to init params that can be handled by the class constructor Args: job_details: the returned job details from a describe_training_job API call. + model_channel_name (str): Name of the channel where pre-trained model data will be downloaded. Returns: dictionary: The transformed init_params """ - init_params = super(Chainer, cls)._prepare_init_params_from_job_description(job_details) + init_params = super(Chainer, cls)._prepare_init_params_from_job_description(job_details, model_channel_name) for argument in [Chainer._use_mpi, Chainer._num_processes, Chainer._process_slots_per_host, Chainer._additional_mpi_options]: diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 7325129a14..6e087abe8c 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -19,6 +19,7 @@ from abc import ABCMeta from abc import abstractmethod from six import with_metaclass +from six import string_types from sagemaker.analytics import TrainingJobAnalytics from sagemaker.fw_utils import (create_image_uri, tar_and_upload_dir, parse_s3_url, UploadedCode, @@ -49,7 +50,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): def __init__(self, role, train_instance_count, train_instance_type, train_volume_size=30, train_volume_kms_key=None, train_max_run=24 * 60 * 60, input_mode='File', output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None, - subnets=None, security_group_ids=None): + subnets=None, security_group_ids=None, model_uri=None, model_channel_name='model'): """Initialize an ``EstimatorBase`` instance. Args: @@ -69,6 +70,7 @@ def __init__(self, role, train_instance_count, train_instance_type, input_mode (str): The input mode that the algorithm supports (default: 'File'). Valid modes: 'File' - Amazon SageMaker copies the training dataset from the S3 location to a local directory. 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe. + This argument can be overriden on a per-channel basis using ``sagemaker.session.s3_input.input_mode``. output_path (str): S3 location for saving the trainig result (model artifacts and output files). If not specified, results are stored to a default bucket. If the bucket with the specific name does not exist, the estimator creates the bucket during the @@ -85,6 +87,16 @@ def __init__(self, role, train_instance_count, train_instance_type, subnets (list[str]): List of subnet ids. If not specified training job will be created without VPC config. security_group_ids (list[str]): List of security group ids. If not specified training job will be created without VPC config. + model_uri (str): URI where a pre-trained model is stored, either locally or in S3 (default: None). If + specified, the estimator will create a channel pointing to the model so the training job can download + it. This model can be a 'model.tar.gz' from a previous training job, or other artifacts coming from a + different source. + + In local mode, this should point to the path in which the model is located and not the file itself, as + local Docker containers will try to mount the URI as a volume. + + More information: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html#td-deserialization + model_channel_name (str): Name of the channel where 'model_uri' will be downloaded (default: 'model'). """ self.role = role self.train_instance_count = train_instance_count @@ -94,6 +106,8 @@ def __init__(self, role, train_instance_count, train_instance_type, self.train_max_run = train_max_run self.input_mode = input_mode self.tags = tags + self.model_uri = model_uri + self.model_channel_name = model_channel_name if self.train_instance_type in ('local', 'local_gpu'): if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1: @@ -209,7 +223,7 @@ def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_sessi raise NotImplementedError() @classmethod - def attach(cls, training_job_name, sagemaker_session=None): + def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='model'): """Attach to an existing training job. Create an Estimator bound to an existing training job, each subclass is responsible to implement @@ -225,6 +239,8 @@ def attach(cls, training_job_name, sagemaker_session=None): sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. + model_channel_name (str): Name of the channel where pre-trained model data will be downloaded (default: + 'model'). If no channel with the same name exists in the training job, this option will be ignored. Examples: >>> my_estimator.fit(wait=False) @@ -239,7 +255,7 @@ def attach(cls, training_job_name, sagemaker_session=None): sagemaker_session = sagemaker_session or Session() job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name) - init_params = cls._prepare_init_params_from_job_description(job_details) + init_params = cls._prepare_init_params_from_job_description(job_details, model_channel_name) estimator = cls(sagemaker_session=sagemaker_session, **init_params) estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session, @@ -294,11 +310,12 @@ def create_model(self, **kwargs): pass @classmethod - def _prepare_init_params_from_job_description(cls, job_details): + def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): """Convert the job description to init params that can be handled by the class constructor Args: job_details: the returned job details from a describe_training_job API call. + model_channel_name (str): Name of the channel where pre-trained model data will be downloaded. Returns: dictionary: The transformed init_params @@ -325,6 +342,13 @@ def _prepare_init_params_from_job_description(cls, job_details): if security_group_ids: init_params['security_group_ids'] = security_group_ids + if 'InputDataConfig' in job_details and model_channel_name: + for channel in job_details['InputDataConfig']: + if channel['ChannelName'] == model_channel_name: + init_params['model_channel_name'] = model_channel_name + init_params['model_uri'] = channel['DataSource']['S3DataSource']['S3Uri'] + break + return init_params def delete_endpoint(self): @@ -415,9 +439,10 @@ def start_new(cls, estimator, inputs): """ local_mode = estimator.sagemaker_session.local_mode + model_uri = estimator.model_uri # Allow file:// input only in local mode - if isinstance(inputs, str) and inputs.startswith('file://'): + if cls._is_local_channel(inputs) or cls._is_local_channel(model_uri): if not local_mode: raise ValueError('File URIs are supported in local mode only. Please use a S3 URI instead.') @@ -435,6 +460,10 @@ def start_new(cls, estimator, inputs): return cls(estimator.sagemaker_session, estimator._current_job_name) + @classmethod + def _is_local_channel(cls, input_uri): + return isinstance(input_uri, string_types) and input_uri.startswith('file://') + def wait(self, logs=True): if logs: self.sagemaker_session.logs_for_job(self.job_name, wait=True) @@ -451,7 +480,8 @@ class Estimator(EstimatorBase): def __init__(self, image_name, role, train_instance_count, train_instance_type, train_volume_size=30, train_volume_kms_key=None, train_max_run=24 * 60 * 60, input_mode='File', output_path=None, output_kms_key=None, base_job_name=None, - sagemaker_session=None, hyperparameters=None, tags=None, subnets=None, security_group_ids=None): + sagemaker_session=None, hyperparameters=None, tags=None, subnets=None, security_group_ids=None, + model_uri=None, model_channel_name='model'): """Initialize an ``Estimator`` instance. Args: @@ -474,6 +504,7 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type, * 'File' - Amazon SageMaker copies the training dataset from the S3 location to a local directory. * 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe. + This argument can be overriden on a per-channel basis using ``sagemaker.session.s3_input.input_mode``. output_path (str): S3 location for saving the trainig result (model artifacts and output files). If not specified, results are stored to a default bucket. If the bucket with the specific name does not exist, the estimator creates the bucket during the @@ -491,13 +522,24 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type, subnets (list[str]): List of subnet ids. If not specified training job will be created without VPC config. security_group_ids (list[str]): List of security group ids. If not specified training job will be created without VPC config. + model_uri (str): URI where a pre-trained model is stored, either locally or in S3 (default: None). If + specified, the estimator will create a channel pointing to the model so the training job can download + it. This model can be a 'model.tar.gz' from a previous training job, or other artifacts coming from a + different source. + + In local mode, this should point to the path in which the model is located and not the file itself, + as local Docker containers will try to mount the URI as a volume. + + More information: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html#td-deserialization + model_channel_name (str): Name of the channel where 'model_uri' will be downloaded (default: 'model'). """ self.image_name = image_name self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {} super(Estimator, self).__init__(role, train_instance_count, train_instance_type, train_volume_size, train_volume_kms_key, train_max_run, input_mode, output_path, output_kms_key, base_job_name, sagemaker_session, - tags, subnets, security_group_ids) + tags, subnets, security_group_ids, model_uri=model_uri, + model_channel_name=model_channel_name) def train_image(self): """ @@ -558,17 +600,18 @@ def predict_wrapper(endpoint, session): sagemaker_session=self.sagemaker_session, predictor_cls=predictor_cls, **kwargs) @classmethod - def _prepare_init_params_from_job_description(cls, job_details): + def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): """Convert the job description to init params that can be handled by the class constructor Args: job_details: the returned job details from a describe_training_job API call. + model_channel_name (str): Name of the channel where pre-trained model data will be downloaded Returns: dictionary: The transformed init_params """ - init_params = super(Estimator, cls)._prepare_init_params_from_job_description(job_details) + init_params = super(Estimator, cls)._prepare_init_params_from_job_description(job_details, model_channel_name) init_params['image_name'] = init_params.pop('image') return init_params @@ -695,17 +738,18 @@ def hyperparameters(self): return self._json_encode_hyperparameters(self._hyperparameters) @classmethod - def _prepare_init_params_from_job_description(cls, job_details): + def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): """Convert the job description to init params that can be handled by the class constructor Args: job_details: the returned job details from a describe_training_job API call. + model_channel_name (str): Name of the channel where pre-trained model data will be downloaded Returns: dictionary: The transformed init_params """ - init_params = super(Framework, cls)._prepare_init_params_from_job_description(job_details) + init_params = super(Framework, cls)._prepare_init_params_from_job_description(job_details, model_channel_name) init_params['entry_point'] = json.loads(init_params['hyperparameters'].get(SCRIPT_PARAM_NAME)) init_params['source_dir'] = json.loads(init_params['hyperparameters'].get(DIR_PARAM_NAME)) @@ -744,7 +788,7 @@ def train_image(self): self.train_instance_type, self.framework_version, py_version=self.py_version) @classmethod - def attach(cls, training_job_name, sagemaker_session=None): + def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='model'): """Attach to an existing training job. Create an Estimator bound to an existing training job, each subclass is responsible to implement @@ -760,6 +804,8 @@ def attach(cls, training_job_name, sagemaker_session=None): sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. + model_channel_name (str): Name of the channel where pre-trained model data will be downloaded (default: + 'model'). If no channel with the same name exists in the training job, this option will be ignored. Examples: >>> my_estimator.fit(wait=False) @@ -771,7 +817,7 @@ def attach(cls, training_job_name, sagemaker_session=None): Returns: Instance of the calling ``Estimator`` Class with the attached training job. """ - estimator = super(Framework, cls).attach(training_job_name, sagemaker_session) + estimator = super(Framework, cls).attach(training_job_name, sagemaker_session, model_channel_name) estimator.uploaded_code = UploadedCode(estimator.source_dir, estimator.entry_point) return estimator diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 4cfb0597e2..393010e289 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -62,6 +62,10 @@ def _load_config(inputs, estimator): stop_condition = _Job._prepare_stop_condition(estimator.train_max_run) vpc_config = estimator.get_vpc_config() + model_channel = _Job._prepare_model_channel(input_config, estimator.model_uri, estimator.model_channel_name) + if model_channel: + input_config.append(model_channel) + return {'input_config': input_config, 'role': role, 'output_config': output_config, @@ -81,7 +85,7 @@ def _format_inputs_to_input_config(inputs): input_dict['training'] = _Job._format_string_uri_input(inputs) elif isinstance(inputs, s3_input): input_dict['training'] = inputs - elif isinstance(input, file_input): + elif isinstance(inputs, file_input): input_dict['training'] = inputs elif isinstance(inputs, dict): for k, v in inputs.items(): @@ -92,13 +96,16 @@ def _format_inputs_to_input_config(inputs): raise ValueError( 'Cannot format input {}. Expecting one of str, dict or s3_input'.format(inputs)) - channels = [] - for channel_name, channel_s3_input in input_dict.items(): - channel_config = channel_s3_input.config.copy() - channel_config['ChannelName'] = channel_name - channels.append(channel_config) + channels = [_Job._convert_input_to_channel(name, input) for name, input in input_dict.items()] + return channels + @staticmethod + def _convert_input_to_channel(channel_name, channel_s3_input): + channel_config = channel_s3_input.config.copy() + channel_config['ChannelName'] = channel_name + return channel_config + @staticmethod def _format_string_uri_input(input): if isinstance(input, str): @@ -116,6 +123,36 @@ def _format_string_uri_input(input): else: raise ValueError('Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(input)) + @staticmethod + def _prepare_model_channel(input_config, model_uri=None, model_channel_name=None): + if not model_uri: + return + elif not model_channel_name: + raise ValueError('Expected a pre-trained model channel name if a model URL is specified.') + + for channel in input_config: + if channel['ChannelName'] == model_channel_name: + raise ValueError('Duplicate channels not allowed.') + + model_input = _Job._format_model_uri_input(model_uri) + model_channel = _Job._convert_input_to_channel(model_channel_name, model_input) + + return model_channel + + @staticmethod + def _format_model_uri_input(model_uri): + if isinstance(model_uri, string_types): + if model_uri.startswith('s3://'): + return s3_input(model_uri, input_mode='File', distribution='FullyReplicated', + content_type='application/x-sagemaker-model') + elif model_uri.startswith('file://'): + return file_input(model_uri) + else: + raise ValueError('Model URI must be a valid S3 or FILE URI: must start with "s3://" or ' + '"file://') + else: + raise ValueError('Cannot format model URI {}. Expecting str'.format(model_uri)) + @staticmethod def _format_record_set_list_input(inputs): # Deferred import due to circular dependency diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index 2097b97fb6..cc385e166a 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -100,17 +100,18 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override vpc_config=self.get_vpc_config(vpc_config_override)) @classmethod - def _prepare_init_params_from_job_description(cls, job_details): + def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): """Convert the job description to init params that can be handled by the class constructor Args: job_details: the returned job details from a describe_training_job API call. + model_channel_name (str): Name of the channel where pre-trained model data will be downloaded. Returns: dictionary: The transformed init_params """ - init_params = super(MXNet, cls)._prepare_init_params_from_job_description(job_details) + init_params = super(MXNet, cls)._prepare_init_params_from_job_description(job_details, model_channel_name) image_name = init_params.pop('image') framework, py_version, tag = framework_name_from_image(image_name) diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 05bffbd41d..6f2b6ce7c6 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -99,17 +99,18 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override vpc_config=self.get_vpc_config(vpc_config_override)) @classmethod - def _prepare_init_params_from_job_description(cls, job_details): + def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): """Convert the job description to init params that can be handled by the class constructor Args: job_details: the returned job details from a describe_training_job API call. + model_channel_name (str): Name of the channel where pre-trained model data will be downloaded. Returns: dictionary: The transformed init_params """ - init_params = super(PyTorch, cls)._prepare_init_params_from_job_description(job_details) + init_params = super(PyTorch, cls)._prepare_init_params_from_job_description(job_details, model_channel_name) image_name = init_params.pop('image') framework, py_version, tag = framework_name_from_image(image_name) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 66e3d8ec2f..c97450cb18 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1009,7 +1009,8 @@ class s3_input(object): """ def __init__(self, s3_data, distribution='FullyReplicated', compression=None, - content_type=None, record_wrapping=None, s3_data_type='S3Prefix'): + content_type=None, record_wrapping=None, s3_data_type='S3Prefix', + input_mode=None): """Create a definition for input data used by an SageMaker training job. See AWS documentation on the ``CreateTrainingJob`` API for more details on the parameters. @@ -1026,6 +1027,12 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None, be used to train. If 'ManifestFile', then ``s3_data`` defines a single s3 manifest file, listing each s3 object to train on. The Manifest file format is described in the SageMaker API documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html + input_mode (str): Optional override for this channel's input mode (default: None). By default, channels will + use the input mode defined on ``sagemaker.estimator.EstimatorBase.input_mode``, but they will ignore + that setting if this parameter is set. + * None - Amazon SageMaker will use the input mode specified in the ``Estimator``. + * 'File' - Amazon SageMaker copies the training dataset from the S3 location to a local directory. + * 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe. """ self.config = { 'DataSource': { @@ -1043,6 +1050,8 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None, self.config['ContentType'] = content_type if record_wrapping is not None: self.config['RecordWrapperType'] = record_wrapping + if input_mode is not None: + self.config['InputMode'] = input_mode def _deployment_entity_exists(describe_fn): diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 630b99d71e..ebaa1f9088 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -253,7 +253,7 @@ def fit_super(): fit_super() @classmethod - def _prepare_init_params_from_job_description(cls, job_details): + def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): """Convert the job description to init params that can be handled by the class constructor Args: @@ -263,7 +263,7 @@ def _prepare_init_params_from_job_description(cls, job_details): dictionary: The transformed init_params """ - init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description(job_details) + init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description(job_details, model_channel_name) # Move some of the tensorflow specific init params from hyperparameters into the main init params. for argument in ['checkpoint_path', 'training_steps', 'evaluation_steps']: diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 40eacfb6d3..73c6fd2241 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -50,6 +50,41 @@ } } +RETURNED_JOB_DESCRIPTION = { + 'AlgorithmSpecification': { + 'TrainingInputMode': 'File', + 'TrainingImage': '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-other-py2-cpu:1.0.4' + }, + 'HyperParameters': { + 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', + 'checkpoint_path': '"s3://other/1508872349"', + 'sagemaker_program': '"iris-dnn-classifier.py"', + 'sagemaker_enable_cloudwatch_metrics': 'false', + 'sagemaker_container_log_level': '"logging.INFO"', + 'sagemaker_job_name': '"neo"', + 'training_steps': '100', + }, + + 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', + 'ResourceConfig': { + 'VolumeSizeInGB': 30, + 'InstanceCount': 1, + 'InstanceType': 'ml.c4.xlarge' + }, + 'StoppingCondition': { + 'MaxRuntimeInSeconds': 24 * 60 * 60 + }, + 'TrainingJobName': 'neo', + 'TrainingJobStatus': 'Completed', + 'OutputDataConfig': { + 'KmsKeyId': '', + 'S3OutputPath': 's3://place/output/neo' + }, + 'TrainingJobOutput': { + 'S3TrainingJobOutput': 's3://here/output.tar.gz' + } +} + MODEL_CONTAINER_DEF = { 'Environment': { 'SAGEMAKER_PROGRAM': ENTRY_POINT, @@ -73,8 +108,9 @@ def create_model(self, role=None, model_server_workers=None): return DummyFrameworkModel(self.sagemaker_session, vpc_config=self.get_vpc_config()) @classmethod - def _prepare_init_params_from_job_description(cls, job_details): - init_params = super(DummyFramework, cls)._prepare_init_params_from_job_description(job_details) + def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): + init_params = super(DummyFramework, cls)._prepare_init_params_from_job_description( + job_details, model_channel_name) init_params.pop("image", None) return init_params @@ -133,6 +169,56 @@ def test_sagemaker_s3_uri_invalid(sagemaker_session): assert 'must be a valid S3 or FILE URI' in str(error) +def test_sagemaker_model_s3_uri_invalid(sagemaker_session): + with pytest.raises(ValueError) as error: + t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, + model_uri='thisdoesntstartwiths3either.tar.gz') + t.fit('s3://mydata') + assert 'must be a valid S3 or FILE URI' in str(error) + + +def test_sagemaker_model_file_uri_invalid(sagemaker_session): + with pytest.raises(ValueError) as error: + t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, + model_uri='file://notins3.tar.gz') + t.fit('s3://mydata') + assert 'File URIs are supported in local mode only' in str(error) + + +def test_sagemaker_model_default_channel_name(sagemaker_session): + f = DummyFramework(entry_point='my_script.py', role='DummyRole', train_instance_count=3, + train_instance_type='ml.m4.xlarge', sagemaker_session=sagemaker_session, + model_uri='s3://model-bucket/prefix/model.tar.gz') + _TrainingJob.start_new(f, {}) + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args['input_config'] == [{'ChannelName': 'model', + 'InputMode': 'File', + 'ContentType': 'application/x-sagemaker-model', + 'DataSource': { + 'S3DataSource': {'S3DataType': 'S3Prefix', + 'S3DataDistributionType': 'FullyReplicated', + 'S3Uri': 's3://model-bucket/prefix/model.tar.gz'}}}] + + +def test_sagemaker_model_custom_channel_name(sagemaker_session): + f = DummyFramework(entry_point='my_script.py', role='DummyRole', train_instance_count=3, + train_instance_type='ml.m4.xlarge', sagemaker_session=sagemaker_session, + model_uri='s3://model-bucket/prefix/model.tar.gz', model_channel_name='testModelChannel') + _TrainingJob.start_new(f, {}) + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args['input_config'] == [{'ChannelName': 'testModelChannel', + 'InputMode': 'File', + 'ContentType': 'application/x-sagemaker-model', + 'DataSource': { + 'S3DataSource': {'S3DataType': 'S3Prefix', + 'S3DataDistributionType': 'FullyReplicated', + 'S3Uri': 's3://model-bucket/prefix/model.tar.gz'}}}] + + @patch('time.strftime', return_value=TIMESTAMP) def test_custom_code_bucket(time, sagemaker_session): code_bucket = 'codebucket' @@ -282,38 +368,10 @@ def test_enable_cloudwatch_metrics(sagemaker_session): def test_attach_framework(sagemaker_session): - returned_job_description = { - 'AlgorithmSpecification': { - 'TrainingInputMode': 'File', - 'TrainingImage': '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-other-py2-cpu:1.0.4', - }, - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'checkpoint_path': '"s3://other/1508872349"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - }, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': { - 'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge', - }, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'OutputDataConfig': { - 'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo', - }, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}, - 'VpcConfig': { - 'Subnets': ['foo'], - 'SecurityGroupIds': ['bar'] - } + returned_job_description = RETURNED_JOB_DESCRIPTION.copy() + returned_job_description['VpcConfig'] = { + 'Subnets': ['foo'], + 'SecurityGroupIds': ['bar'] } sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=returned_job_description) @@ -335,41 +393,8 @@ def test_attach_framework(sagemaker_session): def test_attach_framework_with_tuning(sagemaker_session): - returned_job_description = { - 'AlgorithmSpecification': { - 'TrainingInputMode': 'File', - 'TrainingImage': '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-other-py2-cpu:1.0.4' - }, - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'checkpoint_path': '"s3://other/1508872349"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - '_tuning_objective_metric': 'Validation-accuracy', - }, - - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': { - 'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge' - }, - 'StoppingCondition': { - 'MaxRuntimeInSeconds': 24 * 60 * 60 - }, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'OutputDataConfig': { - 'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo' - }, - 'TrainingJobOutput': { - 'S3TrainingJobOutput': 's3://here/output.tar.gz' - } - } + returned_job_description = RETURNED_JOB_DESCRIPTION.copy() + returned_job_description['HyperParameters']['_tuning_objective_metric'] = 'Validation-accuracy' mock_describe_training_job = Mock(name='describe_training_job', return_value=returned_job_description) @@ -392,6 +417,28 @@ def test_attach_framework_with_tuning(sagemaker_session): assert framework_estimator.entry_point == 'iris-dnn-classifier.py' +def test_attach_framework_with_model_channel(sagemaker_session): + s3_uri = 's3://some/s3/path/model.tar.gz' + returned_job_description = RETURNED_JOB_DESCRIPTION.copy() + returned_job_description['InputDataConfig'] = [ + { + 'ChannelName': 'model', + 'InputMode': 'File', + 'DataSource': { + 'S3DataSource': { + 'S3Uri': s3_uri + } + } + } + ] + + sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', + return_value=returned_job_description) + + framework_estimator = DummyFramework.attach(training_job_name='neo', sagemaker_session=sagemaker_session) + assert framework_estimator.model_uri is s3_uri + + @patch('time.strftime', return_value=TIMESTAMP) def test_fit_verify_job_name(strftime, sagemaker_session): fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index 158e7229c4..c2b3a8aca5 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -31,6 +31,9 @@ IMAGE_NAME = 'fakeimage' JOB_NAME = 'fakejob' VOLUME_KMS_KEY = 'volkmskey' +CHANNEL_NAME = 'testChannel' +MODEL_URI = 's3://bucket/prefix/model.tar.gz' +LOCAL_MODEL_NAME = 'file://local/file.tar.gz' @pytest.fixture() @@ -63,6 +66,26 @@ def test_load_config(estimator): assert config['stop_condition']['MaxRuntimeInSeconds'] == MAX_RUNTIME +def test_load_config_with_model_channel(estimator): + inputs = s3_input(BUCKET_NAME) + + estimator.model_uri = MODEL_URI + estimator.model_channel_name = CHANNEL_NAME + + config = _Job._load_config(inputs, estimator) + + assert config['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] == BUCKET_NAME + assert config['input_config'][1]['DataSource']['S3DataSource']['S3Uri'] == MODEL_URI + assert config['input_config'][1]['ChannelName'] == CHANNEL_NAME + assert config['role'] == ROLE + assert config['output_config']['S3OutputPath'] == S3_OUTPUT_PATH + assert 'KmsKeyId' not in config['output_config'] + assert config['resource_config']['InstanceCount'] == INSTANCE_COUNT + assert config['resource_config']['InstanceType'] == INSTANCE_TYPE + assert config['resource_config']['VolumeSizeInGB'] == VOLUME_SIZE + assert config['stop_condition']['MaxRuntimeInSeconds'] == MAX_RUNTIME + + def test_format_inputs_to_input_config_string(): inputs = BUCKET_NAME @@ -107,6 +130,49 @@ def test_format_inputs_to_input_config_list(): assert channels[0]['DataSource']['S3DataSource']['S3DataType'] == records.s3_data_type +def test_prepare_model_channel(): + model_channel = _Job._prepare_model_channel([], MODEL_URI, CHANNEL_NAME) + + # The model channel should use all the defaults except InputMode + assert model_channel['DataSource']['S3DataSource']['S3Uri'] == MODEL_URI + assert model_channel['DataSource']['S3DataSource']['S3DataDistributionType'] == 'FullyReplicated' + assert model_channel['DataSource']['S3DataSource']['S3DataType'] == 'S3Prefix' + assert model_channel['InputMode'] == 'File' + assert model_channel['ChannelName'] == CHANNEL_NAME + assert 'CompressionType' not in model_channel + assert model_channel['ContentType'] == 'application/x-sagemaker-model' + assert 'RecordWrapperType' not in model_channel + + +def test_prepare_model_channel_duplicate(): + channels = [{ + 'ChannelName': CHANNEL_NAME, + 'DataSource': { + 'S3DataSource': { + 'S3DataDistributionType': 'FullyReplicated', + 'S3DataType': 'S3Prefix', + 'S3Uri': 's3://blah/blah' + } + } + }] + + with pytest.raises(ValueError) as error: + _Job._prepare_model_channel(channels, MODEL_URI, CHANNEL_NAME) + + assert 'Duplicate channels not allowed.' in str(error) + + +def test_prepare_model_channel_with_missing_name(): + with pytest.raises(ValueError) as ex: + _Job._prepare_model_channel([], model_uri=MODEL_URI, model_channel_name=None) + + assert 'Expected a pre-trained model channel name if a model URL is specified.' in str(ex) + + +def test_prepare_model_channel_with_missing_uri(): + assert _Job._prepare_model_channel([], model_uri=None, model_channel_name=None) is None + + def test_format_inputs_to_input_config_list_not_all_records(): records = RecordSet(s3_data=BUCKET_NAME, num_records=1, feature_dim=1) inputs = [records, 'mock'] @@ -263,6 +329,27 @@ def test_format_string_uri_input_exception(): _Job._format_string_uri_input(inputs) +def test_format_model_uri_input_string(): + model_uri = MODEL_URI + + model_uri_input = _Job._format_model_uri_input(model_uri) + + assert model_uri_input.config['DataSource']['S3DataSource']['S3Uri'] == model_uri + + +def test_format_model_uri_input_local_file(): + model_uri_input = _Job._format_model_uri_input(LOCAL_MODEL_NAME) + + assert model_uri_input.config['DataSource']['FileDataSource']['FileUri'] == LOCAL_MODEL_NAME + + +def test_format_model_uri_input_exception(): + model_uri = 1 + + with pytest.raises(ValueError): + _Job._format_model_uri_input(model_uri) + + def test_prepare_output_config(): kms_key_id = 'kms_key' diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 53fa414307..d47ced1b64 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -147,7 +147,8 @@ def test_s3_input_all_arguments(): content_type = 'text/csv' record_wrapping = 'RecordIO' s3_data_type = 'Manifestfile' - result = s3_input(s3_data=prefix, distribution=distribution, compression=compression, + input_mode = 'Pipe' + result = s3_input(s3_data=prefix, distribution=distribution, compression=compression, input_mode=input_mode, content_type=content_type, record_wrapping=record_wrapping, s3_data_type=s3_data_type) expected = \ {'DataSource': { @@ -159,7 +160,8 @@ def test_s3_input_all_arguments(): }, 'CompressionType': compression, 'ContentType': content_type, - 'RecordWrapperType': record_wrapping + 'RecordWrapperType': record_wrapping, + 'InputMode': input_mode } assert result.config == expected