-
Notifications
You must be signed in to change notification settings - Fork 1.2k
S3 Estimator and Image Classification #71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
88bd056
ac7b854
8b96f69
aea77a1
7167aec
3d985e7
24353e2
3d91eb7
2de775a
a919bce
5b9eec0
7f1389a
ddd0e68
39c6ba4
c61c7ef
2825073
85564ef
b43e652
c5ead9a
13cf73b
8e305fa
db548c2
9c9469f
8a4f3ea
8557394
80e0283
066e8b7
5754cba
7c80d16
e068276
baf9f6e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,4 +20,4 @@ examples/tensorflow/distributed_mnist/data | |
doc/_build | ||
**/.DS_Store | ||
venv/ | ||
*~ | ||
*.rec |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -110,6 +110,7 @@ def fit(self, records, mini_batch_size=None, **kwargs): | |
records (:class:`~RecordSet`): The records to train this ``Estimator`` on | ||
mini_batch_size (int or None): The size of each mini-batch to use when training. If None, a | ||
default value will be used. | ||
distribution (s3 distribution type): S3 Distribution. | ||
""" | ||
self.feature_dim = records.feature_dim | ||
self.mini_batch_size = mini_batch_size | ||
|
@@ -152,9 +153,98 @@ def record_set(self, train, labels=None, channel="train"): | |
return RecordSet(manifest_s3_file, num_records=train.shape[0], feature_dim=train.shape[1], channel=channel) | ||
|
||
|
||
class AmazonS3AlgorithmEstimatorBase(EstimatorBase): | ||
"""Base class for Amazon first-party Estimator implementations. This class isn't | ||
intended to be instantiated directly. This is difference from the base class | ||
because this class handles S3 data""" | ||
|
||
mini_batch_size = hp('mini_batch_size', (validation.isint, validation.gt(0))) | ||
|
||
def __init__(self, role, train_instance_count, train_instance_type, algorithm, **kwargs): | ||
"""Initialize an AmazonAlgorithmEstimatorBase. | ||
|
||
Args: | ||
algortihm (str): Use one of the supported algorithms | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is the typo? I don't see. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. algortihm |
||
""" | ||
super(AmazonS3AlgorithmEstimatorBase, self).__init__(role, train_instance_count, train_instance_type, | ||
**kwargs) | ||
self.algorithm = algorithm | ||
|
||
def train_image(self): | ||
return registry(self.sagemaker_session.boto_region_name, algorithm=self.algorithm) + "/" + type(self).repo | ||
|
||
def hyperparameters(self): | ||
return hp.serialize_all(self) | ||
|
||
def fit(self, s3set, mini_batch_size=None, distribution='ShardedByS3Key', **kwargs): | ||
"""Fit this Estimator on serialized Record objects, stored in S3. | ||
|
||
``records`` should be a list of instances of :class:`~RecordSet`. This defines a collection of | ||
s3 data files to train this ``Estimator`` on. | ||
|
||
More information on the Amazon Record format is available at: | ||
https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html | ||
|
||
See :meth:`~AmazonS3AlgorithmEstimatorBase.s3_record_set` to construct a ``RecordSet`` object | ||
from :class:`~numpy.ndarray` arrays. | ||
|
||
Args: | ||
s3set (list): This is a list of :class:`~S3Set` items The list of records to train | ||
this ``Estimator`` will depend on each algorithm and type of input data. | ||
distribution (str): The s3 distribution of data. | ||
mini_batch_size (int or None): The size of each mini-batch to use when training. If None, a | ||
default value will be used. | ||
""" | ||
default_mini_batch_size = 32 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why dont you make 32 the default value for mini_batch_size in the method signature?
then you don't even have to do this whole thing. and you can just set it as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two reasosn why: 1. Its a protocol used in the other alogrithms. 2. We want to make this a must supply parameter for user. If I assume a default and it fails because of memory error, it becomes a customer error, which is wrong. |
||
self.mini_batch_size = mini_batch_size or default_mini_batch_size | ||
data = {} | ||
for item in s3set: | ||
data[item.channel] = s3_input(item.s3_location, distribution=item.distribution, | ||
content_type=item.content_type, | ||
s3_data_type=item.s3_data_type) | ||
super(AmazonS3AlgorithmEstimatorBase, self).fit(data, **kwargs) | ||
|
||
def s3_record_set(self, s3_loc, content_type, channel="train"): | ||
"""Build a :class:`~RecordSet` from a S3 location with data in it. | ||
|
||
Args: | ||
s3_loc (str): A s3 bucket where data is located | ||
channel (str): The SageMaker TrainingJob channel this RecordSet should be assigned to. | ||
content_type (str): Content type of the data. | ||
Returns: | ||
RecordSet: A RecordSet referencing the encoded, uploading training and label data. | ||
""" | ||
return S3Set(s3_loc, content_type=content_type, channel=channel) | ||
|
||
|
||
class S3Set (object): | ||
def __init__(self, s3_location, content_type = None, s3_data_type='S3Prefix', distribution = 'FullyReplicated', channel='train'): | ||
"""A collection of Amazon :class:~`Record` objects serialized and stored in S3. | ||
|
||
Args: | ||
s3_location (str): The S3 location of the training data | ||
distribution (str): The s3 distribution of data. | ||
content_type (str): Mandatory content type of the data. | ||
s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile'. If 'S3Prefix', ``s3_data`` defines | ||
a prefix of s3 objects to train on. All objects with s3 keys beginning with ``s3_data`` will | ||
be used to train. If 'ManifestFile', then ``s3_data`` defines a single s3 manifest file, listing | ||
each s3 object to train on. | ||
channel (str): The SageMaker Training Job channel this RecordSet should be bound to | ||
""" | ||
self.s3_location = s3_location | ||
self.distribution = distribution | ||
self.s3_data_type = s3_data_type | ||
self.channel = channel | ||
self.content_type = content_type | ||
|
||
def __repr__(self): | ||
"""Return an unambiguous representation of this S3Set""" | ||
return str((S3Set, self.__dict__)) | ||
|
||
|
||
class RecordSet(object): | ||
|
||
def __init__(self, s3_data, num_records, feature_dim, s3_data_type='ManifestFile', channel='train'): | ||
def __init__(self, s3_data, num_records = None, feature_dim = None, s3_data_type='ManifestFile', channel='train'): | ||
"""A collection of Amazon :class:~`Record` objects serialized and stored in S3. | ||
|
||
Args: | ||
|
@@ -235,6 +325,13 @@ def registry(region_name, algorithm=None): | |
"us-west-2": "174872318107", | ||
"eu-west-1": "438346466558" | ||
}[region_name] | ||
elif algorithm in ['image_classification']: | ||
account_id = { | ||
"us-east-1": "811284229777", | ||
"us-east-2": "825641698319", | ||
"us-west-2": "433757028032", | ||
"eu-west-1": "685385470294" | ||
}[region_name] | ||
elif algorithm in ["lda"]: | ||
account_id = { | ||
"us-east-1": "766337827248", | ||
|
@@ -244,4 +341,4 @@ def registry(region_name, algorithm=None): | |
}[region_name] | ||
else: | ||
raise ValueError("Algorithm class:{} doesn't have mapping to account_id with images".format(algorithm)) | ||
return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name) | ||
return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
|
||
import numpy as np | ||
from scipy.sparse import issparse | ||
|
||
import json | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please maintain the import order: 1.- python built in libraries There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this still needs to be fixed. import json should go before the numpy import. Also, please maintain the alphabetical order of the imports when you change it. import io |
||
from sagemaker.amazon.record_pb2 import Record | ||
|
||
|
||
|
@@ -35,6 +35,17 @@ def __call__(self, array): | |
return buf | ||
|
||
|
||
class file_to_image_serializer(object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keep one naming convention. FileToImageSerializer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am using this because the other methods are also in this convention.. Refer numpy_to_recod_serializer. .. |
||
|
||
def __init__(self, content_type='application/x-image'): | ||
self.content_type = content_type | ||
|
||
def __call__(self, file): | ||
with open(file, 'rb') as f: | ||
payload = f.read() | ||
payload = bytearray(payload) | ||
return payload | ||
|
||
class record_deserializer(object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. RecordDeserializer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, I am maintaining this because of the other methods... refer |
||
|
||
def __init__(self, accept='application/x-recordio-protobuf'): | ||
|
@@ -47,6 +58,15 @@ def __call__(self, stream, content_type): | |
stream.close() | ||
|
||
|
||
class response_deserializer(object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ResponseDeserializer |
||
|
||
def __init__(self, accept='application/json'): | ||
self.accept = accept | ||
|
||
def __call__(self, stream, content_type=None): | ||
return json.loads(stream) | ||
|
||
|
||
def _write_feature_tensor(resolved_type, record, vector): | ||
if resolved_type == "Int32": | ||
record.features["values"].int32_tensor.values.extend(vector) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please look at #54 we intentionally removed these type validations: isint() isbool() etc. In favor of declaring a specific type for the hp.
So this should be
hp('mini_batch_size', validation.gt(0), data_type=int)
This applies to every hp declaration in this PR.