Skip to content

Add Local Mode Batch Inference support. #414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Oct 11, 2018
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ CHANGELOG
=========

* enhancement: Local Mode: add training environment variables for AWS region and job name
* feature: Add support for Local Mode Batch Inference
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feature: Local Mode: add support for Batch Transform


1.11.0
======
Expand Down
21 changes: 21 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,27 @@ Here is an end-to-end example:
predictor.delete_endpoint()


If you don't want to deploy your model locally, you can also choose to perform a Local Transform Job. This is useful
if you want to test your container before creating a Sagemaker Transform Job. Note that the performance will not match
Transform Jobs hosted on SageMaker but it is still a useful tool to ensure you have everything right, or if you are
not dealing with huge amounts of data.

Here is an end-to-end example:

.. code:: python

from sagemaker.mxnet import MXNet

mxnet_estimator = MXNet('train.py',
train_instance_type='local',
train_instance_count=1)

mxnet_estimator.fit('file:///tmp/my_training_data')
transformer = mxnet_estimator.transformer(1, 'local', assemble_with='Line', max_payload=1)
transformer.transform('s3://my/transform/data, content_type='text/csv', split_type='Line')
transformer.wait()


For detailed examples of running Docker in local mode, see:

- `TensorFlow local mode example notebook <https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/tensorflow_distributed_mnist/tensorflow_local_mode_mnist.ipynb>`__.
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/amazon/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def write_spmatrix_to_sparse_tensor(file, array, labels=None):
def read_records(file):
"""Eagerly read a collection of amazon Record protobuf objects from file."""
records = []
for record_data in _read_recordio(file):
for record_data in read_recordio(file):
record = Record()
record.ParseFromString(record_data)
records.append(record)
Expand Down Expand Up @@ -183,7 +183,7 @@ def _write_recordio(f, data):
f.write(padding[pad])


def _read_recordio(f):
def read_recordio(f):
while(True):
try:
read_kmagic, = struct.unpack('I', f.read(4))
Expand Down
249 changes: 249 additions & 0 deletions src/sagemaker/local/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
import sys
import tempfile
from six.moves.urllib.parse import urlparse

from sagemaker.amazon.common import read_recordio
from sagemaker.local.utils import download_folder
from sagemaker.utils import get_config_value


class DataSourceFactory(object):

@staticmethod
def get_instance(data_source, sagemaker_session):
parsed_uri = urlparse(data_source)
if parsed_uri.scheme == 'file':
return LocalFileDataSource(parsed_uri.path)
elif parsed_uri.scheme == 's3':
return S3DataSource(parsed_uri.netloc, parsed_uri.path, sagemaker_session)


class DataSource(object):

def get_file_list(self):
pass

def get_root_dir(self):
pass


class LocalFileDataSource(DataSource):
"""
Represents a data source within the local filesystem.
"""

def __init__(self, root_path):
self.root_path = os.path.abspath(root_path)
if not os.path.exists(self.root_path):
raise RuntimeError('Invalid data source: %s Does not exist.' % self.root_path)

def get_file_list(self):
"""Retrieve the list of absolute paths to all the files in this data source.

Returns:
List[string] List of absolute paths.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a colon after "[string]"

"""
if os.path.isdir(self.root_path):
files = [os.path.join(self.root_path, f) for f in os.listdir(self.root_path)
if os.path.isfile(os.path.join(self.root_path, f))]
else:
files = [self.root_path]

return files
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could just put the return in each branch of the if statement

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bump (especially because you do it in get_root_dir)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remove the return files line now :)


def get_root_dir(self):
"""Retrieve the absolute path to the root directory of this data source.

Returns:
string: absolute path to the root directory of this data source.
"""
if os.path.isdir(self.root_path):
return self.root_path
else:
return os.path.dirname(self.root_path)


class S3DataSource(DataSource):
"""Defines a data source given by a bucket and s3 prefix. The contents will be downloaded
and then processed as local data.
"""

def __init__(self, bucket, prefix, sagemaker_session):
"""Create an S3DataSource instance

Args:
bucket (str): s3 bucket name
prefix (str): s3 prefix path to the data
sagemaker_session (sagemaker.Session): a sagemaker_session with the desired settings to talk to s3

"""

# Create a temporary dir to store the S3 contents
root_dir = get_config_value('local.container_root', sagemaker_session.config)
if root_dir:
root_dir = os.path.abspath(root_dir)

working_dir = tempfile.mkdtemp(dir=root_dir)
download_folder(bucket, prefix, working_dir, sagemaker_session)
self.files = LocalFileDataSource(working_dir)

def get_file_list(self):
return self.files.get_file_list()

def get_root_dir(self):
return self.files.get_root_dir()


class SplitterFactory(object):

@staticmethod
def get_instance(split_type):
if split_type is None:
return NoneSplitter()
elif split_type == 'Line':
return LineSplitter()
elif split_type == 'RecordIO':
return RecordIOSplitter()
else:
raise ValueError('Invalid Split Type: %s' % split_type)


class Splitter(object):

def split(self, file):
pass


class NoneSplitter(Splitter):
"""Does not split records, essentially reads the whole file.
"""

def split(self, file):
with open(file, 'r') as f:
yield f.read()


class LineSplitter(Splitter):
"""Split records by new line.

"""

def split(self, file):
with open(file, 'r') as f:
for line in f:
yield line


class RecordIOSplitter(Splitter):
"""Split using Amazon Recordio.

Not useful for string content.

"""
def split(self, file):
with open(file, 'rb') as f:
for record in read_recordio(f):
yield record


class BatchStrategyFactory(object):

@staticmethod
def get_instance(strategy, splitter):
if strategy == 'SingleRecord':
return SingleRecordStrategy(splitter)
elif strategy == 'MultiRecord':
return MultiRecordStrategy(splitter)
else:
raise ValueError('Invalid Batch Strategy: %s - Valid Strategies: "SingleRecord", "MultiRecord"')


class BatchStrategy(object):

def pad(self, file, size):
pass


class MultiRecordStrategy(BatchStrategy):
"""Feed multiple records at a time for batch inference.

Will group up as many records as possible within the payload specified.

"""
def __init__(self, splitter):
self.splitter = splitter

def pad(self, file, size=6):
buffer = ''
for element in self.splitter.split(file):
if _payload_size_within_limit(buffer + element, size):
buffer += element
else:
tmp = buffer
buffer = element
yield tmp
if _validate_payload_size(buffer, size):
yield buffer


class SingleRecordStrategy(BatchStrategy):
"""Feed a single record at a time for batch inference.

If a single record does not fit within the payload specified it will throw a Runtime error.
"""
def __init__(self, splitter):
self.splitter = splitter

def pad(self, file, size=6):
for element in self.splitter.split(file):
if _validate_payload_size(element, size):
yield element


def _payload_size_within_limit(payload, size):
"""

Args:
payload:
size:

Returns:

"""
size_in_bytes = size * 1024 * 1024
if size == 0:
return True
else:
return sys.getsizeof(payload) < size_in_bytes


def _validate_payload_size(payload, size):
"""Check if a payload is within the size in MB threshold. Raise an exception otherwise.

Args:
payload: data that will be checked
size (int): max size in MB

Returns (bool): True if within bounds. if size=0 it will always return True
Raises:
RuntimeError: If the payload is larger a runtime error is thrown.
"""

if not _payload_size_within_limit(payload, size):
raise RuntimeError('Record is larger than %sMB. Please increase your max_payload' % size)
return True
Loading