diff --git a/.gitignore b/.gitignore index 02fbd0283e..cb6d0d2664 100644 --- a/.gitignore +++ b/.gitignore @@ -20,5 +20,4 @@ examples/tensorflow/distributed_mnist/data doc/_build **/.DS_Store venv/ -*~ -.pytest_cache/ +*.rec \ No newline at end of file diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index 93a62c2a72..e84cb12201 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -17,6 +17,8 @@ from sagemaker.amazon.pca import PCA, PCAModel, PCAPredictor from sagemaker.amazon.lda import LDA, LDAModel, LDAPredictor from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerModel, LinearLearnerPredictor +from sagemaker.amazon.image_classification import ImageClassification, ImageClassificationModel +from sagemaker.amazon.image_classification import ImageClassificationPredictor from sagemaker.amazon.factorization_machines import FactorizationMachines, FactorizationMachinesModel from sagemaker.amazon.factorization_machines import FactorizationMachinesPredictor from sagemaker.amazon.ntm import NTM, NTMModel, NTMPredictor @@ -34,5 +36,6 @@ LinearLearnerModel, LinearLearnerPredictor, LDA, LDAModel, LDAPredictor, FactorizationMachines, FactorizationMachinesModel, FactorizationMachinesPredictor, + ImageClassification, ImageClassificationModel, ImageClassificationPredictor, Model, NTM, NTMModel, NTMPredictor, RealTimePredictor, Session, container_def, s3_input, production_variant, get_execution_role] diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 23896276e7..9dd04d2340 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -110,6 +110,7 @@ def fit(self, records, mini_batch_size=None, **kwargs): records (:class:`~RecordSet`): The records to train this ``Estimator`` on mini_batch_size (int or None): The size of each mini-batch to use when training. If None, a default value will be used. + distribution (s3 distribution type): S3 Distribution. """ self.feature_dim = records.feature_dim self.mini_batch_size = mini_batch_size @@ -152,9 +153,98 @@ def record_set(self, train, labels=None, channel="train"): return RecordSet(manifest_s3_file, num_records=train.shape[0], feature_dim=train.shape[1], channel=channel) -class RecordSet(object): +class AmazonS3AlgorithmEstimatorBase(EstimatorBase): + """Base class for Amazon first-party Estimator implementations. This class isn't + intended to be instantiated directly. This is difference from the base class + because this class handles S3 data""" + + mini_batch_size = hp('mini_batch_size', (validation, validation.gt(0))) + + def __init__(self, role, train_instance_count, train_instance_type, algorithm, **kwargs): + """Initialize an AmazonAlgorithmEstimatorBase. + + Args: + algorithm (str): Use one of the supported algorithms + """ + super(AmazonS3AlgorithmEstimatorBase, self).__init__(role, train_instance_count, train_instance_type, + **kwargs) + self.algorithm = algorithm + + def train_image(self): + return registry(self.sagemaker_session.boto_region_name, algorithm=self.algorithm) + "/" + type(self).repo + + def hyperparameters(self): + return hp.serialize_all(self) + + def fit(self, s3set, mini_batch_size=None, distribution='ShardedByS3Key', **kwargs): + """Fit this Estimator on serialized Record objects, stored in S3. - def __init__(self, s3_data, num_records, feature_dim, s3_data_type='ManifestFile', channel='train'): + ``records`` should be a list of instances of :class:`~RecordSet`. This defines a collection of + s3 data files to train this ``Estimator`` on. + + More information on the Amazon Record format is available at: + https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html + + See :meth:`~AmazonS3AlgorithmEstimatorBase.s3_record_set` to construct a ``RecordSet`` object + from :class:`~numpy.ndarray` arrays. + + Args: + s3set (list): This is a list of :class:`~S3Set` items The list of records to train + this ``Estimator`` will depend on each algorithm and type of input data. + distribution (str): The s3 distribution of data. + mini_batch_size (int or None): The size of each mini-batch to use when training. If None, a + default value will be used. + """ + default_mini_batch_size = 32 + self.mini_batch_size = mini_batch_size or default_mini_batch_size + data = {} + for item in s3set: + data[item.channel] = s3_input(item.s3_location, distribution=item.distribution, + content_type=item.content_type, + s3_data_type=item.s3_data_type) + super(AmazonS3AlgorithmEstimatorBase, self).fit(data, **kwargs) + + def s3_record_set(self, s3_loc, content_type, channel="train"): + """Build a :class:`~RecordSet` from a S3 location with data in it. + + Args: + s3_loc (str): A s3 bucket where data is located + channel (str): The SageMaker TrainingJob channel this RecordSet should be assigned to. + content_type (str): Content type of the data. + Returns: + RecordSet: A RecordSet referencing the encoded, uploading training and label data. + """ + return S3Set(s3_loc, content_type=content_type, channel=channel) + + +class S3Set(object): + def __init__(self, s3_location, content_type=None, s3_data_type='S3Prefix', distribution='FullyReplicated', + channel='train'): + """A collection of Amazon :class:~`Record` objects serialized and stored in S3. + + Args: + s3_location (str): The S3 location of the training data + distribution (str): The s3 distribution of data. + content_type (str): Mandatory content type of the data. + s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile'. If 'S3Prefix', ``s3_data`` defines + a prefix of s3 objects to train on. All objects with s3 keys beginning with ``s3_data`` will + be used to train. If 'ManifestFile', then ``s3_data`` defines a single s3 manifest file, listing + each s3 object to train on. + channel (str): The SageMaker Training Job channel this RecordSet should be bound to + """ + self.s3_location = s3_location + self.distribution = distribution + self.s3_data_type = s3_data_type + self.channel = channel + self.content_type = content_type + + def __repr__(self): + """Return an unambiguous representation of this S3Set""" + return str((S3Set, self.__dict__)) + + +class RecordSet(object): + def __init__(self, s3_data, num_records=None, feature_dim=None, s3_data_type='ManifestFile', channel='train'): """A collection of Amazon :class:~`Record` objects serialized and stored in S3. Args: @@ -235,6 +325,13 @@ def registry(region_name, algorithm=None): "us-west-2": "174872318107", "eu-west-1": "438346466558" }[region_name] + elif algorithm in ['image_classification']: + account_id = { + "us-east-1": "811284229777", + "us-east-2": "825641698319", + "us-west-2": "433757028032", + "eu-west-1": "685385470294" + }[region_name] elif algorithm in ["lda"]: account_id = { "us-east-1": "766337827248", diff --git a/src/sagemaker/amazon/common.py b/src/sagemaker/amazon/common.py index 6b5dc0c68a..d9e35ee4ff 100644 --- a/src/sagemaker/amazon/common.py +++ b/src/sagemaker/amazon/common.py @@ -11,17 +11,16 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import io +import json import struct import sys import numpy as np from scipy.sparse import issparse - from sagemaker.amazon.record_pb2 import Record class numpy_to_record_serializer(object): - def __init__(self, content_type='application/x-recordio-protobuf'): self.content_type = content_type @@ -35,8 +34,18 @@ def __call__(self, array): return buf -class record_deserializer(object): +class file_to_image_serializer(object): + def __init__(self, content_type='application/x-image'): + self.content_type = content_type + + def __call__(self, file): + with open(file, 'rb') as f: + payload = f.read() + payload = bytearray(payload) + return payload + +class record_deserializer(object): def __init__(self, accept='application/x-recordio-protobuf'): self.accept = accept @@ -47,6 +56,14 @@ def __call__(self, stream, content_type): stream.close() +class response_deserializer(object): + def __init__(self, accept='application/json'): + self.accept = accept + + def __call__(self, stream, content_type=None): + return json.loads(stream) + + def _write_feature_tensor(resolved_type, record, vector): if resolved_type == "Int32": record.features["values"].int32_tensor.values.extend(vector) @@ -94,7 +111,7 @@ def write_numpy_to_dense_tensor(file, array, labels=None): raise ValueError("Labels must be a Vector") if labels.shape[0] not in array.shape: raise ValueError("Label shape {} not compatible with array shape {}".format( - labels.shape, array.shape)) + labels.shape, array.shape)) resolved_label_type = _resolve_type(labels.dtype) resolved_type = _resolve_type(array.dtype) @@ -122,7 +139,7 @@ def write_spmatrix_to_sparse_tensor(file, array, labels=None): raise ValueError("Labels must be a Vector") if labels.shape[0] not in array.shape: raise ValueError("Label shape {} not compatible with array shape {}".format( - labels.shape, array.shape)) + labels.shape, array.shape)) resolved_label_type = _resolve_type(labels.dtype) resolved_type = _resolve_type(array.dtype) diff --git a/src/sagemaker/amazon/image_classification.py b/src/sagemaker/amazon/image_classification.py new file mode 100644 index 0000000000..0535bba11c --- /dev/null +++ b/src/sagemaker/amazon/image_classification.py @@ -0,0 +1,231 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from sagemaker.amazon.amazon_estimator import AmazonS3AlgorithmEstimatorBase, registry +from sagemaker.amazon.common import file_to_image_serializer, response_deserializer +from sagemaker.amazon.validation import gt, isin, ge, le +from sagemaker.amazon.hyperparameter import Hyperparameter as hp +from sagemaker.model import Model +from sagemaker.predictor import RealTimePredictor +from sagemaker.session import Session + + +class ImageClassification(AmazonS3AlgorithmEstimatorBase): + repo = 'image-classification:latest' + + num_classes = hp('num_classes', (gt(1)), 'num_classes should be an integer greater-than 1', int) + num_training_samples = hp('num_training_samples', (gt(1)), + 'num_training_samples should be an integer greater-than 1', int) + use_pretrained_model = hp('use_pretrained_model', (isin(0, 1),), + 'use_pretrained_model should be in the set, [0,1]', int) + checkpoint_frequency = hp('checkpoint_frequency', (ge(1),), + 'checkpoint_frequency should be an integer greater-than 1', int) + num_layers = hp('num_layers', (isin(18, 34, 50, 101, 152, 200, 20, 32, 44, 56, 110),), + 'num_layers should be in the set [18, 34, 50, 101, 152, 200, 20, 32, 44, 56, 110]', int) + resize = hp('resize', (gt(1)), 'resize should be an integer greater-than 1', int) + epochs = hp('epochs', (ge(1)), 'epochs should be an integer greater-than 1', int) + learning_rate = hp('learning_rate', (gt(0)), 'learning_rate should be a floating point greater than 0', float) + lr_scheduler_factor = hp('lr_scheduler_factor', (gt(0)), + 'lr_schedule_factor should be a floating point greater than 0', float) + lr_scheduler_step = hp('lr_scheduler_step', (), 'lr_scheduler_step should be a string input.', str) + optimizer = hp('optimizer', (isin('sgd', 'adam', 'rmsprop', 'nag')), + 'Should be one optimizer among the list sgd, adam, rmsprop, or nag.', str) + momentum = hp('momentum', (ge(0), le(1)), 'momentum is expected in the range 0, 1', float) + weight_decay = hp('weight_decay', (ge(0), le(1)), 'weight_decay in range 0 , 1 ', float) + beta_1 = hp('beta_1', (ge(0), le(1)), 'beta_1 should be in range 0, 1', float) + beta_2 = hp('beta_2', (ge(0), le(1)), 'beta_2 should be in the range 0, 1', float) + eps = hp('eps', (gt(0), le(1)), 'eps should be in the range 0, 1', float) + gamma = hp('gamma', (ge(0), le(1)), 'gamma should be in the range 0, 1', float) + mini_batch_size = hp('mini_batch_size', (gt(0)), 'mini_batch_size should be an integer greater than 0', int) + image_shape = hp('image_shape', (), 'image_shape is expected to be a string', str) + augmentation_type = hp('augmentation_type', (isin('crop', 'crop_color', 'crop_color_transform')), + 'augmentation type must be from one option offered', str) + top_k = hp('top_k', (ge(1)), 'top_k should be greater than or equal to 1', int) + kv_store = hp('kv_store', (isin('dist_sync', 'dist_async')), 'Can be dist_sync or dist_async', str) + + def __init__(self, role, train_instance_count, train_instance_type, num_classes, num_training_samples, resize=None, + lr_scheduler_step=None, use_pretrained_model=0, checkpoint_frequency=1, num_layers=18, + epochs=30, learning_rate=0.1, + lr_schedule_factor=0.1, optimizer='sgd', momentum=0., weight_decay=0.0001, beta_1=0.9, + beta_2=0.999, eps=1e-8, gamma=0.9, mini_batch_size=32, image_shape='3,224,224', + augmentation_type=None, top_k=None, kv_store=None, **kwargs): + """ + An Image classification algorithm :class:`~sagemaker.amazon.AmazonAlgorithmEstimatorBase`. + + This Estimator may be fit via calls to + :meth:`~sagemaker.amazon.amazon_estimator.AmazonS3AlgorithmEstimatorBase.fit` + + After this Estimator is fit, model data is stored in S3. The model may be deployed to an Amazon SageMaker + Endpoint by invoking :meth:`~sagemaker.amazon.estimator.EstimatorBase.deploy`. As well as deploying an Endpoint, + ``deploy`` returns a :class:`~sagemaker.amazon.kmeans.ImageClassificationPredictor` object that can be used to + label assignment, using the trained model hosted in the SageMaker Endpoint. + + ImageClassification Estimators can be configured by setting hyperparameters. The available hyperparameters for + ImageClassification are documented below. For further information on the AWS ImageClassification algorithm, + please consult AWS technical documentation: + https://docs.aws.amazon.com/sagemaker/latest/dg/IC-Hyperparameter.html + + Args: + role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and + APIs that create Amazon SageMaker endpoints use this role to access + training data and model artifacts. After the endpoint is created, + the inference code might use the IAM role, if accessing AWS resource. + train_instance_count (int): Number of Amazon EC2 instances to use for training. + train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. + num_classes (int): Number of output classes. This parameter defines the dimensions of the network output + and is typically set to the number of classes in the dataset. + num_training_samples (int): Number of training examples in the input dataset. If there is a + mismatch between this value and the number of samples in the training + set, then the behavior of the lr_scheduler_step parameter is undefined + and distributed training accuracy might be affected. + use_pretrained_model (int): Flag to indicate whether to use pre-trained model for training. + If set to `1`, then the pretrained model with the corresponding number + of layers is loaded and used for training. Only the top FC layer are + reinitialized with random weights. Otherwise, the network is trained from scratch. + Default value: 0 + checkpoint_frequency (int): Period to store model parameters (in number of epochs). Default value: 1 + num_layers (int): Number of layers for the network. For data with large image size (for example, 224x224 - + like ImageNet), we suggest selecting the number of layers from the set [18, 34, 50, 101, + 152, 200]. For data with small image size (for example, 28x28 - like CFAR), we suggest + selecting the number of layers from the set [20, 32, 44, 56, 110]. The number of layers + in each set is based on the ResNet paper. For transfer learning, the number of layers + defines the architecture of base network and hence can only be selected from the set + [18, 34, 50, 101, 152, 200]. Default value: 152 + resize (int): Resize the image before using it for training. The images are resized so that the shortest + side is of this parameter. If the parameter is not set, then the training data is used as such + without resizing. + Note: This option is available only for inputs specified as application/x-image content-type + in training and validation channels. + epochs (int): Number of training epochs. Default value: 30 + learning_rate (float): Initial learning rate. Float. Range in [0, 1]. Default value: 0.1 + lr_scheduler_factor (flaot): The ratio to reduce learning rate used in conjunction with the + `lr_scheduler_step` parameter, defined as `lr_new=lr_old * lr_scheduler_factor`. + Valid values: Float. Range in [0, 1]. Default value: 0.1 + lr_scheduler_step (str): The epochs at which to reduce the learning rate. As explained in the + ``lr_scheduler_factor`` parameter, the learning rate is reduced by + ``lr_scheduler_factor`` at these epochs. For example, if the value is set + to "10, 20", then the learning rate is reduced by ``lr_scheduler_factor`` after 10th + epoch and again by ``lr_scheduler_factor`` after 20th epoch. The epochs are delimited + by ",". + optimizer (str): The optimizer types. For more details of the parameters for the optimizers, please refer to + MXNet's API. Valid values: One of sgd, adam, rmsprop, or nag. Default value: `sgd`. + momentum (float): The momentum for sgd and nag, ignored for other optimizers. Valid values: Float. Range in + [0, 1]. Default value: 0 + weight_decay (float): The coefficient weight decay for sgd and nag, ignored for other optimizers. + Range in [0, 1]. Default value: 0.0001 + beta_1 (float): The beta1 for adam, in other words, exponential decay rate for the first moment estimates. + Range in [0, 1]. Default value: 0.9 + beta_2 (float): The beta2 for adam, in other words, exponential decay rate for the second moment estimates. + Range in [0, 1]. Default value: 0.999 + eps (float): The epsilon for adam and rmsprop. It is usually set to a small value to avoid division by 0. + Range in [0, 1]. Default value: 1e-8 + gamma (float): The gamma for rmsprop. A decay factor of moving average of the squared gradient. + Range in [0, 1]. Default value: 0.9 + mini_batch_size (int): The batch size for training. In a single-machine multi-GPU setting, each GPU handles + mini_batch_size/num_gpu training samples. For the multi-machine training in + dist_sync mode, the actual batch size is mini_batch_size*number of machines. + See MXNet docs for more details. Default value: 32 + image_shape (str): The input image dimensions, which is the same size as the input layer of the network. \ + The format is defined as 'num_channels, height, width'. The image dimension can take on + any value as the network can handle varied dimensions of the input. However, there may + be memory constraints if a larger image dimension is used. Typical image dimensions for + image classification are '3, 224, 224'. This is similar to the ImageNet dataset. + Default value: ‘3, 224, 224’ + augmentation_type (str): Data augmentation type. The input images can be augmented in multiple ways as + specified below. + 'crop' - Randomly crop the image and flip the image horizontally + 'crop_color' - In addition to ‘crop’, three random values in the range [-36, 36], + [-50, 50], and [-50, 50] + are added to the corresponding Hue-Saturation-Lightness channels resptly. + 'crop_color_transform': In addition to crop_color, random transformations, including + rotation, shear, and aspect ratio variations are applied to the image. + The maximum angle of rotation is 10 degrees, the maximum shear ratio is 0.1, + and the maximum aspect changing ratio is 0.25. + top_k (int): Report the top-k accuracy during training. This parameter has to be greater than 1, + since the top-1 training accuracy is the same as the regular training accuracy that has + already been reported. + kv_store (str): Weight update synchronization mode during distributed training. The weight updates can be + updated either synchronously or asynchronously across machines. Synchronous updates + typically provide better accuracy than asynchronous updates but can be slower. + See distributed training in MXNet for more details. This parameter is not applicable + to single machine training. + 'dist_sync' - The gradients are synchronized after every batch with all the workers. + With dist_sync, + batch-size now means the batch size used on each machine. So if there are n + machines and we use + batch size b, then dist_sync behaves like local with batch size n*b + 'dist_async'- Performs asynchronous updates. The weights are updated whenever gradients + are received from any machine and the weight updates are atomic. However, the + order is not guaranteed. + **kwargs: base class keyword argument values. + """ + super(ImageClassification, self).__init__(role, train_instance_count, train_instance_type, + algorithm='image_classification', **kwargs) + self.num_classes = num_classes + self.num_training_samples = num_training_samples + self.resize = resize + self.lr_scheduler_step = lr_scheduler_step + self.use_pretrained_model = use_pretrained_model + self.checkpoint_frequency = checkpoint_frequency + self.num_layers = num_layers + self.epochs = epochs + self.learning_rate = learning_rate + self.lr_schedule_factor = lr_schedule_factor + self.optimizer = optimizer + self.momentum = momentum + self.weight_decay = weight_decay + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.eps = eps + self.gamma = gamma + self.mini_batch_size = mini_batch_size + self.image_shape = image_shape + self.augmentation_type = augmentation_type + self.top_k = top_k + self.kv_store = kv_store + + def create_model(self): + """Return a :class:`~sagemaker.amazon.image_classification.ImageClassification` referencing the latest + s3 model data produced by this Estimator.""" + return ImageClassificationModel(self.model_data, self.role, self.sagemaker_session) + + def hyperparameters(self): + """Return the SageMaker hyperparameters for training this ImageClassification Estimator""" + hp = dict() + hp.update(super(ImageClassification, self).hyperparameters()) + return hp + + +class ImageClassificationPredictor(RealTimePredictor): + """Assigns input vectors to their closest cluster in a ImageClassification model. + + The implementation of :meth:`~sagemaker.predictor.RealTimePredictor.predict` in this + `RealTimePredictor` requires a `x-image` as input.""" + + def __init__(self, endpoint, sagemaker_session=None): + super(ImageClassificationPredictor, self).__init__(endpoint, sagemaker_session, + serializer=file_to_image_serializer(), + deserializer=response_deserializer(), + content_type='application/x-image') + + +class ImageClassificationModel(Model): + """Reference ImageClassification s3 model data. Calling :meth:`~sagemaker.model.Model.deploy` creates an Endpoint and + return a Predictor to performs classification assignment.""" + + def __init__(self, model_data, role, sagemaker_session=None): + sagemaker_session = sagemaker_session or Session() + image = registry(sagemaker_session.boto_session.region_name, + 'image_classification') + "/" + ImageClassification.repo + super(ImageClassificationModel, self).__init__(model_data, image, role, + predictor_cls=ImageClassificationPredictor, + sagemaker_session=sagemaker_session) diff --git a/src/sagemaker/amazon/validation.py b/src/sagemaker/amazon/validation.py index 7c7fa4f2a0..ed9c722291 100644 --- a/src/sagemaker/amazon/validation.py +++ b/src/sagemaker/amazon/validation.py @@ -15,34 +15,40 @@ def gt(minimum): def validate(value): return value > minimum + return validate def ge(minimum): def validate(value): return value >= minimum + return validate def lt(maximum): def validate(value): return value < maximum + return validate def le(maximum): def validate(value): return value <= maximum + return validate def isin(*expected): def validate(value): return value in expected + return validate def istype(expected): def validate(value): return isinstance(value, expected) + return validate diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 2ec9669c20..ff78a9fa62 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -13,3 +13,5 @@ CONTENT_TYPE_JSON = 'application/json' CONTENT_TYPE_CSV = 'text/csv' CONTENT_TYPE_OCTET_STREAM = 'application/octet-stream' +CONTENT_TYPE_IMAGES = 'application/x-image' +CONTENT_TYPE_RECORDIO = 'application/x-recordio' diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index c672703315..10109cf62c 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -309,7 +309,6 @@ def start_new(cls, estimator, inputs): Returns: sagemaker.estimator.Framework: Constructed object that captures all information about the started job. """ - input_config = _TrainingJob._format_inputs_to_input_config(inputs) role = estimator.sagemaker_session.expand_role(estimator.role) output_config = _TrainingJob._prepare_output_config(estimator.output_path, estimator.output_kms_key) diff --git a/tests/integ/test_image_classification.py b/tests/integ/test_image_classification.py new file mode 100644 index 0000000000..b71f8b74a3 --- /dev/null +++ b/tests/integ/test_image_classification.py @@ -0,0 +1,65 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import boto3 +import os + +import sagemaker +from sagemaker import ImageClassification, ImageClassificationModel +from sagemaker.utils import name_from_base +from tests.integ import REGION +from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name +from six.moves.urllib.request import urlretrieve + + +def download(url): + filename = url.split("/")[-1] + if not os.path.exists(filename): + urlretrieve(url, filename) + + +def upload_to_s3(channel, file, bucket): + s3 = boto3.resource('s3') + data = open(file, "rb") + key = channel + '/' + file + s3.Bucket(bucket).put_object(Key=key, Body=data) + + +def test_image_classification(): + with timeout(minutes=45): + sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION)) + + # caltech-256 + download('http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec') + upload_to_s3('train', 'caltech-256-60-train.rec', sagemaker_session.default_bucket()) + download('http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec') + upload_to_s3('validation', 'caltech-256-60-val.rec', sagemaker_session.default_bucket()) + ic = ImageClassification(role='SageMakerRole', train_instance_count=1, + train_instance_type='ml.p3.2xlarge', num_layers=18, + num_classes=257, num_training_samples=15420, epochs=1, image_shape='3,32,32', + sagemaker_session=sagemaker_session, base_job_name='test-ic') + + ic.epochs = 1 + data_location = 's3://' + sagemaker_session.default_bucket() + s3set = list() + s3set.append(ic.s3_record_set(data_location + '/validation/', channel='validation', + content_type='application/x-recordio')) + s3set.append(ic.s3_record_set(data_location + '/train/', channel='train', + content_type='application/x-recordio')) + ic.fit(s3set) + + endpoint_name = name_from_base('ic') + with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20): + model = ImageClassificationModel(ic.model_data, role='SageMakerRole', sagemaker_session=sagemaker_session) + predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=endpoint_name) + assert predictor is not None diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index caf18b01e7..6058cf030c 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -16,6 +16,7 @@ # Use PCA as a test implementation of AmazonAlgorithmEstimator from sagemaker.amazon.pca import PCA +from sagemaker.amazon.image_classification import ImageClassification from sagemaker.amazon.amazon_estimator import upload_numpy_to_s3_shards, _build_shards, registry COMMON_ARGS = {'role': 'myrole', 'train_instance_count': 1, 'train_instance_type': 'ml.c4.xlarge'} @@ -63,6 +64,13 @@ def test_init(sagemaker_session): assert pca.num_components == 55 +def test_s3_init(sagemaker_session): + ic = ImageClassification(epochs=12, num_classes=2, num_training_samples=2, + sagemaker_session=sagemaker_session, **COMMON_ARGS) + assert ic.epochs == 12 + assert ic.num_classes == 2 + + def test_init_all_pca_hyperparameters(sagemaker_session): pca = PCA(num_components=55, algorithm_mode='randomized', subtract_mean=True, extra_components=33, sagemaker_session=sagemaker_session, @@ -72,6 +80,16 @@ def test_init_all_pca_hyperparameters(sagemaker_session): assert pca.extra_components == 33 +def test_init_all_ic_hyperparameters(sagemaker_session): + ic = ImageClassification( + num_classes=257, num_training_samples=15420, epochs=1, + image_shape='3,32,32', sagemaker_session=sagemaker_session, + **COMMON_ARGS) + assert ic.num_classes == 257 + assert ic.num_training_samples == 15420 + assert ic.image_shape == '3,32,32' + + def test_init_estimator_args(sagemaker_session): pca = PCA(num_components=1, train_max_run=1234, sagemaker_session=sagemaker_session, data_location='s3://some-bucket/some-key/', **COMMON_ARGS) @@ -82,6 +100,16 @@ def test_init_estimator_args(sagemaker_session): assert pca.data_location == 's3://some-bucket/some-key/' +def test_init_s3estimator_args(sagemaker_session): + ic = ImageClassification( + num_classes=257, num_training_samples=15420, epochs=1, + image_shape='3,32,32', sagemaker_session=sagemaker_session, + **COMMON_ARGS) + assert ic.train_instance_type == COMMON_ARGS['train_instance_type'] + assert ic.train_instance_count == COMMON_ARGS['train_instance_count'] + assert ic.role == COMMON_ARGS['role'] + + def test_data_location_validation(sagemaker_session): pca = PCA(num_components=2, sagemaker_session=sagemaker_session, **COMMON_ARGS) with pytest.raises(ValueError): @@ -106,9 +134,22 @@ def test_pca_hyperparameters(sagemaker_session): algorithm_mode='randomized') +def test_ic_hyperparameters(sagemaker_session): + ic = ImageClassification( + num_classes=257, num_training_samples=15420, epochs=1, + image_shape='3,32,32', sagemaker_session=sagemaker_session, + **COMMON_ARGS) + assert isinstance(ic.hyperparameters(), dict) + + def test_image(sagemaker_session): pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS) assert pca.train_image() == registry('us-west-2') + '/pca:1' + ic = ImageClassification( + num_classes=257, num_training_samples=15420, epochs=1, + image_shape='3,32,32', sagemaker_session=sagemaker_session, + **COMMON_ARGS) + assert ic.train_image() == registry('us-west-2', 'image_classification') + '/image-classification:latest' @patch('time.strftime', return_value=TIMESTAMP)