Skip to content

Add AugmentedManifestFile & ShuffleConfig support #528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 9, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
CHANGELOG
=========

1.16.2
======
* feature: Add support for AugmentedManifestFile and ShuffleConfig

1.16.1.post1
============

Expand Down
36 changes: 30 additions & 6 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,7 +1222,7 @@ class s3_input(object):

def __init__(self, s3_data, distribution='FullyReplicated', compression=None,
content_type=None, record_wrapping=None, s3_data_type='S3Prefix',
input_mode=None):
input_mode=None, attribute_names=None, shuffle_config=None):
"""Create a definition for input data used by an SageMaker training job.

See AWS documentation on the ``CreateTrainingJob`` API for more details on the parameters.
Expand All @@ -1234,17 +1234,23 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None,
compression (str): Valid values: 'Gzip', None (default: None). This is used only in Pipe input mode.
content_type (str): MIME type of the input data (default: None).
record_wrapping (str): Valid values: 'RecordIO' (default: None).
s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile'. If 'S3Prefix', ``s3_data`` defines
a prefix of s3 objects to train on. All objects with s3 keys beginning with ``s3_data`` will
be used to train. If 'ManifestFile', then ``s3_data`` defines a single s3 manifest file, listing
each s3 object to train on. The Manifest file format is described in the SageMaker API documentation:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html
s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile', 'AugmentedManifestFile'. If 'S3Prefix',
``s3_data`` defines a prefix of s3 objects to train on. All objects with s3 keys beginning with
``s3_data`` will be used to train. If 'ManifestFile' or 'AugmentedManifestFile', then ``s3_data``
defines a single s3 manifest file or augmented manifest file (respectively), listing the s3 data to
train on. Both the ManifestFile and AugmentedManifestFile formats are described in the SageMaker API
documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html
input_mode (str): Optional override for this channel's input mode (default: None). By default, channels will
use the input mode defined on ``sagemaker.estimator.EstimatorBase.input_mode``, but they will ignore
that setting if this parameter is set.
* None - Amazon SageMaker will use the input mode specified in the ``Estimator``.
* 'File' - Amazon SageMaker copies the training dataset from the S3 location to a local directory.
* 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe.
attribute_names (list[str]): A list of one or more attribute names to use that are found in a specified
AugmentedManifestFile.
shuffle_config (ShuffleConfig): If specified this configuration enables shuffling on this channel. See the
SageMaker API documentation for more info:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
"""
self.config = {
'DataSource': {
Expand All @@ -1264,6 +1270,24 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None,
self.config['RecordWrapperType'] = record_wrapping
if input_mode is not None:
self.config['InputMode'] = input_mode
if attribute_names is not None:
self.config['DataSource']['S3DataSource']['AttributeNames'] = attribute_names
if shuffle_config is not None:
self.config['ShuffleConfig'] = {'Seed': shuffle_config.seed}


class ShuffleConfig(object):
"""
Used to configure channel shuffling using a seed. See SageMaker
documentation for more detail: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
"""
def __init__(self, seed):
"""
Create a ShuffleConfig.
Args:
seed (long): the long value used to seed the shuffled sequence.
"""
self.seed = seed


class ModelContainer(object):
Expand Down
26 changes: 25 additions & 1 deletion tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sagemaker.estimator import Estimator, EstimatorBase, Framework, _TrainingJob
from sagemaker.model import FrameworkModel
from sagemaker.predictor import RealTimePredictor
from sagemaker.session import s3_input
from sagemaker.session import s3_input, ShuffleConfig
from sagemaker.transformer import Transformer

MODEL_DATA = "s3://bucket/model.tar.gz"
Expand Down Expand Up @@ -277,6 +277,30 @@ def test_invalid_custom_code_bucket(sagemaker_session):
assert "Expecting 's3' scheme" in str(error)


def test_augmented_manifest(sagemaker_session):
fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
enable_cloudwatch_metrics=True)
fw.fit(inputs=s3_input('s3://mybucket/train_manifest', s3_data_type='AugmentedManifestFile',
attribute_names=['foo', 'bar']))

_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
s3_data_source = train_kwargs['input_config'][0]['DataSource']['S3DataSource']
assert s3_data_source['S3Uri'] == 's3://mybucket/train_manifest'
assert s3_data_source['S3DataType'] == 'AugmentedManifestFile'
assert s3_data_source['AttributeNames'] == ['foo', 'bar']


def test_shuffle_config(sagemaker_session):
fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
enable_cloudwatch_metrics=True)
fw.fit(inputs=s3_input('s3://mybucket/train_manifest', shuffle_config=ShuffleConfig(100)))
_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
channel = train_kwargs['input_config'][0]
assert channel['ShuffleConfig']['Seed'] == 100


BASE_HP = {
'sagemaker_program': json.dumps(SCRIPT_NAME),
'sagemaker_submit_directory': json.dumps('s3://mybucket/{}/source/sourcedir.tar.gz'.format(JOB_NAME)),
Expand Down