diff --git a/CHANGELOG.rst b/CHANGELOG.rst index fd5d3a5b17..d5700733bc 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -5,6 +5,7 @@ CHANGELOG 1.16.2.dev ========== +* feature: Add support for AugmentedManifestFile and ShuffleConfig * bug-fix: add version bound for requests module to avoid version conflicts between docker-compose and docker-py * bug-fix: Remove unnecessary dependency tensorflow * doc-fix: Change ``distribution`` to ``distributions`` diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index f0fb48ce8d..badd69d6a2 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1222,7 +1222,7 @@ class s3_input(object): def __init__(self, s3_data, distribution='FullyReplicated', compression=None, content_type=None, record_wrapping=None, s3_data_type='S3Prefix', - input_mode=None): + input_mode=None, attribute_names=None, shuffle_config=None): """Create a definition for input data used by an SageMaker training job. See AWS documentation on the ``CreateTrainingJob`` API for more details on the parameters. @@ -1234,17 +1234,23 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None, compression (str): Valid values: 'Gzip', None (default: None). This is used only in Pipe input mode. content_type (str): MIME type of the input data (default: None). record_wrapping (str): Valid values: 'RecordIO' (default: None). - s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile'. If 'S3Prefix', ``s3_data`` defines - a prefix of s3 objects to train on. All objects with s3 keys beginning with ``s3_data`` will - be used to train. If 'ManifestFile', then ``s3_data`` defines a single s3 manifest file, listing - each s3 object to train on. The Manifest file format is described in the SageMaker API documentation: - https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html + s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile', 'AugmentedManifestFile'. If 'S3Prefix', + ``s3_data`` defines a prefix of s3 objects to train on. All objects with s3 keys beginning with + ``s3_data`` will be used to train. If 'ManifestFile' or 'AugmentedManifestFile', then ``s3_data`` + defines a single s3 manifest file or augmented manifest file (respectively), listing the s3 data to + train on. Both the ManifestFile and AugmentedManifestFile formats are described in the SageMaker API + documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html input_mode (str): Optional override for this channel's input mode (default: None). By default, channels will use the input mode defined on ``sagemaker.estimator.EstimatorBase.input_mode``, but they will ignore that setting if this parameter is set. * None - Amazon SageMaker will use the input mode specified in the ``Estimator``. * 'File' - Amazon SageMaker copies the training dataset from the S3 location to a local directory. * 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe. + attribute_names (list[str]): A list of one or more attribute names to use that are found in a specified + AugmentedManifestFile. + shuffle_config (ShuffleConfig): If specified this configuration enables shuffling on this channel. See the + SageMaker API documentation for more info: + https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html """ self.config = { 'DataSource': { @@ -1264,6 +1270,24 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None, self.config['RecordWrapperType'] = record_wrapping if input_mode is not None: self.config['InputMode'] = input_mode + if attribute_names is not None: + self.config['DataSource']['S3DataSource']['AttributeNames'] = attribute_names + if shuffle_config is not None: + self.config['ShuffleConfig'] = {'Seed': shuffle_config.seed} + + +class ShuffleConfig(object): + """ + Used to configure channel shuffling using a seed. See SageMaker + documentation for more detail: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html + """ + def __init__(self, seed): + """ + Create a ShuffleConfig. + Args: + seed (long): the long value used to seed the shuffled sequence. + """ + self.seed = seed class ModelContainer(object): diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 1916f9ca7c..c33cbc4ece 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -24,7 +24,7 @@ from sagemaker.estimator import Estimator, EstimatorBase, Framework, _TrainingJob from sagemaker.model import FrameworkModel from sagemaker.predictor import RealTimePredictor -from sagemaker.session import s3_input +from sagemaker.session import s3_input, ShuffleConfig from sagemaker.transformer import Transformer MODEL_DATA = "s3://bucket/model.tar.gz" @@ -277,6 +277,30 @@ def test_invalid_custom_code_bucket(sagemaker_session): assert "Expecting 's3' scheme" in str(error) +def test_augmented_manifest(sagemaker_session): + fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True) + fw.fit(inputs=s3_input('s3://mybucket/train_manifest', s3_data_type='AugmentedManifestFile', + attribute_names=['foo', 'bar'])) + + _, _, train_kwargs = sagemaker_session.train.mock_calls[0] + s3_data_source = train_kwargs['input_config'][0]['DataSource']['S3DataSource'] + assert s3_data_source['S3Uri'] == 's3://mybucket/train_manifest' + assert s3_data_source['S3DataType'] == 'AugmentedManifestFile' + assert s3_data_source['AttributeNames'] == ['foo', 'bar'] + + +def test_shuffle_config(sagemaker_session): + fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True) + fw.fit(inputs=s3_input('s3://mybucket/train_manifest', shuffle_config=ShuffleConfig(100))) + _, _, train_kwargs = sagemaker_session.train.mock_calls[0] + channel = train_kwargs['input_config'][0] + assert channel['ShuffleConfig']['Seed'] == 100 + + BASE_HP = { 'sagemaker_program': json.dumps(SCRIPT_NAME), 'sagemaker_submit_directory': json.dumps('s3://mybucket/{}/source/sourcedir.tar.gz'.format(JOB_NAME)),