|
| 1 | +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +import pytest |
| 14 | +from mock import Mock |
| 15 | +from sagemaker.tensorflow import TensorFlow |
| 16 | + |
| 17 | + |
| 18 | +SCRIPT = 'resnet_cifar_10.py' |
| 19 | +TIMESTAMP = '2017-11-06-14:14:15.673' |
| 20 | +TIME = 1510006209.073025 |
| 21 | +BUCKET_NAME = 'mybucket' |
| 22 | +INSTANCE_COUNT = 1 |
| 23 | +INSTANCE_TYPE_GPU = 'ml.p2.xlarge' |
| 24 | +INSTANCE_TYPE_CPU = 'ml.m4.xlarge' |
| 25 | +CPU_IMAGE_NAME = 'sagemaker-tensorflow-py2-cpu' |
| 26 | +GPU_IMAGE_NAME = 'sagemaker-tensorflow-py2-gpu' |
| 27 | +REGION = 'us-west-2' |
| 28 | +IMAGE_URI_FORMAT_STRING = "520713654638.dkr.ecr.{}.amazonaws.com/{}:{}-{}-{}" |
| 29 | +REGION = 'us-west-2' |
| 30 | +ROLE = 'SagemakerRole' |
| 31 | +SOURCE_DIR = 's3://fefergerger' |
| 32 | + |
| 33 | + |
| 34 | +@pytest.fixture() |
| 35 | +def sagemaker_session(): |
| 36 | + boto_mock = Mock(name='boto_session', region_name=REGION) |
| 37 | + ims = Mock(name='sagemaker_session', boto_session=boto_mock) |
| 38 | + ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) |
| 39 | + ims.expand_role = Mock(name="expand_role", return_value=ROLE) |
| 40 | + ims.sagemaker_client.describe_training_job = Mock(return_value={'ModelArtifacts': |
| 41 | + {'S3ModelArtifacts': 's3://m/m.tar.gz'}}) |
| 42 | + return ims |
| 43 | + |
| 44 | + |
| 45 | +# Test that we pass all necessary fields from estimator to the session when we call deploy |
| 46 | +def test_deploy(sagemaker_session, tf_version): |
| 47 | + estimator = TensorFlow(entry_point=SCRIPT, source_dir=SOURCE_DIR, role=ROLE, |
| 48 | + framework_version=tf_version, |
| 49 | + train_instance_count=2, train_instance_type=INSTANCE_TYPE_CPU, |
| 50 | + sagemaker_session=sagemaker_session, |
| 51 | + base_job_name='test-cifar') |
| 52 | + |
| 53 | + estimator.fit('s3://mybucket/train') |
| 54 | + print('job succeeded: {}'.format(estimator.latest_training_job.name)) |
| 55 | + |
| 56 | + estimator.deploy(initial_instance_count=1, instance_type=INSTANCE_TYPE_CPU) |
| 57 | + image = IMAGE_URI_FORMAT_STRING.format(REGION, GPU_IMAGE_NAME, tf_version, 'cpu', 'py2') |
| 58 | + sagemaker_session.create_model.assert_called_with( |
| 59 | + estimator._current_job_name, |
| 60 | + ROLE, |
| 61 | + {'Environment': |
| 62 | + {'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', |
| 63 | + 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', |
| 64 | + 'SAGEMAKER_SUBMIT_DIRECTORY': SOURCE_DIR, |
| 65 | + 'SAGEMAKER_REGION': REGION, |
| 66 | + 'SAGEMAKER_PROGRAM': SCRIPT}, |
| 67 | + 'Image': image, |
| 68 | + 'ModelDataUrl': 's3://m/m.tar.gz'}) |
0 commit comments