-
Notifications
You must be signed in to change notification settings - Fork 1.2k
S3 Estimator and Image Classification #71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
88bd056
ac7b854
8b96f69
aea77a1
7167aec
3d985e7
24353e2
3d91eb7
2de775a
a919bce
5b9eec0
7f1389a
ddd0e68
39c6ba4
c61c7ef
2825073
85564ef
b43e652
c5ead9a
13cf73b
8e305fa
db548c2
9c9469f
8a4f3ea
8557394
80e0283
066e8b7
5754cba
7c80d16
e068276
baf9f6e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,5 +20,4 @@ examples/tensorflow/distributed_mnist/data | |
doc/_build | ||
**/.DS_Store | ||
venv/ | ||
*~ | ||
.pytest_cache/ | ||
*.rec |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,17 +11,16 @@ | |
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
import io | ||
import json | ||
import struct | ||
import sys | ||
|
||
import numpy as np | ||
from scipy.sparse import issparse | ||
|
||
from sagemaker.amazon.record_pb2 import Record | ||
|
||
|
||
class numpy_to_record_serializer(object): | ||
|
||
def __init__(self, content_type='application/x-recordio-protobuf'): | ||
self.content_type = content_type | ||
|
||
|
@@ -35,8 +34,18 @@ def __call__(self, array): | |
return buf | ||
|
||
|
||
class record_deserializer(object): | ||
class file_to_image_serializer(object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keep one naming convention. FileToImageSerializer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am using this because the other methods are also in this convention.. Refer numpy_to_recod_serializer. .. |
||
def __init__(self, content_type='application/x-image'): | ||
self.content_type = content_type | ||
|
||
def __call__(self, file): | ||
with open(file, 'rb') as f: | ||
payload = f.read() | ||
payload = bytearray(payload) | ||
return payload | ||
|
||
|
||
class record_deserializer(object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. RecordDeserializer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, I am maintaining this because of the other methods... refer |
||
def __init__(self, accept='application/x-recordio-protobuf'): | ||
self.accept = accept | ||
|
||
|
@@ -47,6 +56,14 @@ def __call__(self, stream, content_type): | |
stream.close() | ||
|
||
|
||
class response_deserializer(object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ResponseDeserializer |
||
def __init__(self, accept='application/json'): | ||
self.accept = accept | ||
|
||
def __call__(self, stream, content_type=None): | ||
return json.loads(stream) | ||
|
||
|
||
def _write_feature_tensor(resolved_type, record, vector): | ||
if resolved_type == "Int32": | ||
record.features["values"].int32_tensor.values.extend(vector) | ||
|
@@ -94,7 +111,7 @@ def write_numpy_to_dense_tensor(file, array, labels=None): | |
raise ValueError("Labels must be a Vector") | ||
if labels.shape[0] not in array.shape: | ||
raise ValueError("Label shape {} not compatible with array shape {}".format( | ||
labels.shape, array.shape)) | ||
labels.shape, array.shape)) | ||
resolved_label_type = _resolve_type(labels.dtype) | ||
resolved_type = _resolve_type(array.dtype) | ||
|
||
|
@@ -122,7 +139,7 @@ def write_spmatrix_to_sparse_tensor(file, array, labels=None): | |
raise ValueError("Labels must be a Vector") | ||
if labels.shape[0] not in array.shape: | ||
raise ValueError("Label shape {} not compatible with array shape {}".format( | ||
labels.shape, array.shape)) | ||
labels.shape, array.shape)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. align this as it was before. Also applies to the one above. |
||
resolved_label_type = _resolve_type(labels.dtype) | ||
resolved_type = _resolve_type(array.dtype) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why dont you make 32 the default value for mini_batch_size in the method signature?
def fit(self, s3set, mini_batch_size=32, distribution='ShardedByS3Key', **kwargs):
then you don't even have to do this whole thing. and you can just set it as
self.mini_batch_size = mini_batch_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two reasosn why: 1. Its a protocol used in the other alogrithms. 2. We want to make this a must supply parameter for user. If I assume a default and it fails because of memory error, it becomes a customer error, which is wrong.