diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index a1d3933180..3d518aa299 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -448,21 +448,27 @@ def _upload_code(self, key_prefix, repack=False): local_code = utils.get_config_value('local.local_code', self.sagemaker_session.config) if self.sagemaker_session.local_mode and local_code: self.uploaded_code = None - else: - if not repack: - bucket = self.bucket or self.sagemaker_session.default_bucket() - self.uploaded_code = fw_utils.tar_and_upload_dir(session=self.sagemaker_session.boto_session, - bucket=bucket, - s3_key_prefix=key_prefix, - script=self.entry_point, - directory=self.source_dir, - dependencies=self.dependencies) + elif not repack: + bucket = self.bucket or self.sagemaker_session.default_bucket() + self.uploaded_code = fw_utils.tar_and_upload_dir(session=self.sagemaker_session.boto_session, + bucket=bucket, + s3_key_prefix=key_prefix, + script=self.entry_point, + directory=self.source_dir, + dependencies=self.dependencies) if repack: - self.repacked_model_data = utils.repack_model(inference_script=self.entry_point, - source_directory=self.source_dir, - model_uri=self.model_data, - sagemaker_session=self.sagemaker_session) + bucket = self.bucket or self.sagemaker_session.default_bucket() + repacked_model_data = 's3://' + os.path.join(bucket, key_prefix, 'model.tar.gz') + + utils.repack_model(inference_script=self.entry_point, + source_directory=self.source_dir, + dependencies=self.dependencies, + model_uri=self.model_data, + repacked_model_uri=repacked_model_data, + sagemaker_session=self.sagemaker_session) + + self.repacked_model_data = repacked_model_data self.uploaded_code = UploadedCode(s3_prefix=self.repacked_model_data, script_name=os.path.basename(self.entry_point)) diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 40eea52bf5..29f5040c5d 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -92,21 +92,21 @@ def prepare_container_def(self, instance_type, accelerator_type=None): Returns: dict[str, str]: A container definition object usable with the CreateModel API. """ - mms_version = parse_version(self.framework_version) >= parse_version(self._LOWEST_MMS_VERSION) + is_mms_version = parse_version(self.framework_version) >= parse_version(self._LOWEST_MMS_VERSION) deploy_image = self.image if not deploy_image: region_name = self.sagemaker_session.boto_session.region_name framework_name = self.__framework_name__ - if mms_version: + if is_mms_version: framework_name += '-serving' deploy_image = create_image_uri(region_name, framework_name, instance_type, self.framework_version, self.py_version, accelerator_type=accelerator_type) deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) - self._upload_code(deploy_key_prefix, mms_version) + self._upload_code(deploy_key_prefix, is_mms_version) deploy_env = dict(self.env) deploy_env.update(self._framework_env_vars()) diff --git a/src/sagemaker/tensorflow/serving.py b/src/sagemaker/tensorflow/serving.py index a680f2df30..7a37318d10 100644 --- a/src/sagemaker/tensorflow/serving.py +++ b/src/sagemaker/tensorflow/serving.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import logging +import os import sagemaker from sagemaker.content_types import CONTENT_TYPE_JSON @@ -128,10 +129,17 @@ def prepare_container_def(self, instance_type, accelerator_type=None): env = self._get_container_env() if self.entry_point: - model_data = sagemaker.utils.repack_model(self.entry_point, - self.source_dir, - self.model_data, - self.sagemaker_session) + key_prefix = sagemaker.fw_utils.model_code_key_prefix(self.key_prefix, self.name, image) + + bucket = self.bucket or self.sagemaker_session.default_bucket() + model_data = 's3://' + os.path.join(bucket, key_prefix, 'model.tar.gz') + + sagemaker.utils.repack_model(self.entry_point, + self.source_dir, + self.dependencies, + self.model_data, + model_data, + self.sagemaker_session) else: model_data = self.model_data diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index d20f3194e0..9d1d139cb3 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -29,8 +29,6 @@ import six -import sagemaker - ECR_URI_PATTERN = r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$' @@ -300,7 +298,12 @@ def _tmpdir(suffix='', prefix='tmp'): shutil.rmtree(tmp) -def repack_model(inference_script, source_directory, model_uri, sagemaker_session): +def repack_model(inference_script, + source_directory, + dependencies, + model_uri, + repacked_model_uri, + sagemaker_session): """Unpack model tarball and creates a new model tarball with the provided code script. This function does the following: @@ -311,60 +314,91 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio Args: inference_script (str): path or basename of the inference script that will be packed into the model source_directory (str): path including all the files that will be packed into the model + dependencies (list[str]): A list of paths to directories (absolute or relative) with + any additional libraries that will be exported to the container (default: []). + The library folders will be copied to SageMaker in the same folder where the entrypoint is copied. + Example: + + The following call + >>> Estimator(entry_point='train.py', dependencies=['my/libs/common', 'virtual-env']) + results in the following inside the container: + + >>> $ ls + + >>> opt/ml/code + >>> |------ train.py + >>> |------ common + >>> |------ virtual-env + + repacked_model_uri (str): path or file system location where the new model will be saved model_uri (str): S3 or file system location of the original model tar sagemaker_session (:class:`sagemaker.session.Session`): a sagemaker session to interact with S3. Returns: str: path to the new packed model """ - new_model_name = 'model-%s.tar.gz' % sagemaker.utils.sagemaker_short_timestamp() + dependencies = dependencies or [] with _tmpdir() as tmp: - tmp_model_dir = os.path.join(tmp, 'model') - os.mkdir(tmp_model_dir) + model_dir = _extract_model(model_uri, sagemaker_session, tmp) - model_from_s3 = model_uri.lower().startswith('s3://') - if model_from_s3: - local_model_path = os.path.join(tmp, 'tar_file') - download_file_from_url(model_uri, local_model_path, sagemaker_session) + _create_or_update_code_dir(model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp) - new_model_path = os.path.join(tmp, new_model_name) - else: - local_model_path = model_uri.replace('file://', '') - new_model_path = os.path.join(os.path.dirname(local_model_path), new_model_name) + tmp_model_path = os.path.join(tmp, 'temp-model.tar.gz') + with tarfile.open(tmp_model_path, mode='w:gz') as t: + t.add(model_dir, arcname=os.path.sep) - with tarfile.open(name=local_model_path, mode='r:gz') as t: - t.extractall(path=tmp_model_dir) + _save_model(repacked_model_uri, tmp_model_path, sagemaker_session) - code_dir = os.path.join(tmp_model_dir, 'code') - if os.path.exists(code_dir): - shutil.rmtree(code_dir, ignore_errors=True) - if source_directory and source_directory.lower().startswith('s3://'): - local_code_path = os.path.join(tmp, 'local_code.tar.gz') - download_file_from_url(source_directory, local_code_path, sagemaker_session) +def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session): + if repacked_model_uri.lower().startswith('s3://'): + url = parse.urlparse(repacked_model_uri) + bucket, key = url.netloc, url.path.lstrip('/') + new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri)) - with tarfile.open(name=local_code_path, mode='r:gz') as t: - t.extractall(path=code_dir) + sagemaker_session.boto_session.resource('s3').Object(bucket, new_key).upload_file( + tmp_model_path) + else: + shutil.move(tmp_model_path, repacked_model_uri.replace('file://', '')) - elif source_directory: - shutil.copytree(source_directory, code_dir) - else: - os.mkdir(code_dir) - shutil.copy2(inference_script, code_dir) - with tarfile.open(new_model_path, mode='w:gz') as t: - t.add(tmp_model_dir, arcname=os.path.sep) +def _create_or_update_code_dir(model_dir, inference_script, source_directory, + dependencies, sagemaker_session, tmp): + code_dir = os.path.join(model_dir, 'code') + if os.path.exists(code_dir): + shutil.rmtree(code_dir, ignore_errors=True) + if source_directory and source_directory.lower().startswith('s3://'): + local_code_path = os.path.join(tmp, 'local_code.tar.gz') + download_file_from_url(source_directory, local_code_path, sagemaker_session) + + with tarfile.open(name=local_code_path, mode='r:gz') as t: + t.extractall(path=code_dir) - if model_from_s3: - url = parse.urlparse(model_uri) - bucket, key = url.netloc, url.path.lstrip('/') - new_key = key.replace(os.path.basename(key), new_model_name) + elif source_directory: + shutil.copytree(source_directory, code_dir) + else: + os.mkdir(code_dir) + shutil.copy2(inference_script, code_dir) - sagemaker_session.boto_session.resource('s3').Object(bucket, new_key).upload_file(new_model_path) - return 's3://%s/%s' % (bucket, new_key) + for dependency in dependencies: + if os.path.isdir(dependency): + shutil.copytree(dependency, code_dir) else: - return 'file://%s' % new_model_path + shutil.copy2(dependency, code_dir) + + +def _extract_model(model_uri, sagemaker_session, tmp): + tmp_model_dir = os.path.join(tmp, 'model') + os.mkdir(tmp_model_dir) + if model_uri.lower().startswith('s3://'): + local_model_path = os.path.join(tmp, 'tar_file') + download_file_from_url(model_uri, local_model_path, sagemaker_session) + else: + local_model_path = model_uri.replace('file://', '') + with tarfile.open(name=local_model_path, mode='r:gz') as t: + t.extractall(path=tmp_model_dir) + return tmp_model_dir def download_file_from_url(url, dst, sagemaker_session): diff --git a/tests/data/tfs/tfs-test-entrypoint-and-dependencies/dependency.py b/tests/data/tfs/tfs-test-entrypoint-and-dependencies/dependency.py new file mode 100644 index 0000000000..c60b935b80 --- /dev/null +++ b/tests/data/tfs/tfs-test-entrypoint-and-dependencies/dependency.py @@ -0,0 +1,12 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. diff --git a/tests/data/tfs/tfs-test-entrypoint-and-dependencies/inference.py b/tests/data/tfs/tfs-test-entrypoint-and-dependencies/inference.py new file mode 100644 index 0000000000..2fe2eb3327 --- /dev/null +++ b/tests/data/tfs/tfs-test-entrypoint-and-dependencies/inference.py @@ -0,0 +1,27 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import json + +import dependency + +def input_handler(data, context): + data = json.loads(data.read().decode('utf-8')) + new_values = [x + 1 for x in data['instances']] + dumps = json.dumps({'instances': new_values}) + return dumps + + +def output_handler(data, context): + response_content_type = context.accept_header + prediction = data.content + return prediction, response_content_type diff --git a/tests/data/tfs/tfs-test-model-with-inference/code/inference.py b/tests/data/tfs/tfs-test-model-with-inference/code/inference.py index 507d0c44f3..2f691fea1d 100644 --- a/tests/data/tfs/tfs-test-model-with-inference/code/inference.py +++ b/tests/data/tfs/tfs-test-model-with-inference/code/inference.py @@ -1,4 +1,4 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -12,7 +12,6 @@ # language governing permissions and limitations under the License. import json - def input_handler(data, context): data = json.loads(data.read().decode('utf-8')) new_values = [x + 1 for x in data['instances']] diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index 05e0725d5c..ab43b1368c 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -84,8 +84,8 @@ def tfs_predictor_with_model_and_entry_point_same_tar(instance_type, @pytest.fixture(scope='module') -def tfs_predictor_with_model_and_entry_point_separated(instance_type, - sagemaker_session, tf_full_version): +def tfs_predictor_with_model_and_entry_point_and_dependencies(instance_type, + sagemaker_session, tf_full_version): endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') model_data = sagemaker_session.upload_data( @@ -96,10 +96,14 @@ def tfs_predictor_with_model_and_entry_point_separated(instance_type, with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): entry_point = os.path.join(tests.integ.DATA_DIR, - 'tfs/tfs-test-model-with-inference/code/inference.py') + 'tfs/tfs-test-entrypoint-and-dependencies/inference.py') + dependencies = [os.path.join(tests.integ.DATA_DIR, + 'tfs/tfs-test-entrypoint-and-dependencies/dependency.py')] + model = Model(entry_point=entry_point, model_data=model_data, role='SageMakerRole', + dependencies=dependencies, framework_version=tf_full_version, sagemaker_session=sagemaker_session) predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name) @@ -152,12 +156,12 @@ def test_predict_with_entry_point(tfs_predictor_with_model_and_entry_point_same_ assert expected_result == result -def test_predict_with_model_and_entry_point_separated( - tfs_predictor_with_model_and_entry_point_separated): +def test_predict_with_model_and_entry_point_and_dependencies_separated( + tfs_predictor_with_model_and_entry_point_and_dependencies): input_data = {'instances': [1.0, 2.0, 5.0]} expected_result = {'predictions': [4.0, 4.5, 6.0]} - result = tfs_predictor_with_model_and_entry_point_separated.predict(input_data) + result = tfs_predictor_with_model_and_entry_point_and_dependencies.predict(input_data) assert expected_result == result diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index dfb298d47b..0f45afe51f 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -29,7 +29,6 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') SCRIPT_PATH = os.path.join(DATA_DIR, 'dummy_script.py') MODEL_DATA = 's3://mybucket/model' -REPACKED_MODEL_DATA = 's3://mybucket/repacked/model' TIMESTAMP = '2017-11-06-14:14:15.672' TIME = 1507167947 BUCKET_NAME = 'mybucket' @@ -280,7 +279,7 @@ def test_mxnet(strftime, sagemaker_session, mxnet_version, skip_if_mms_version): assert isinstance(predictor, MXNetPredictor) -@patch('sagemaker.utils.repack_model', return_value=REPACKED_MODEL_DATA) +@patch('sagemaker.utils.repack_model') @patch('time.strftime', return_value=TIMESTAMP) def test_mxnet_mms_version(strftime, repack_model, sagemaker_session, mxnet_version, skip_if_not_mms_version): mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, @@ -307,11 +306,12 @@ def test_mxnet_mms_version(strftime, repack_model, sagemaker_session, mxnet_vers expected_image_base = _get_full_image_uri(mxnet_version, IMAGE_REPO_SERVING_NAME, 'gpu') environment = { 'Environment': { - 'SAGEMAKER_SUBMIT_DIRECTORY': REPACKED_MODEL_DATA, + 'SAGEMAKER_SUBMIT_DIRECTORY': 's3://mybucket/sagemaker-mxnet-2017-11-06-14:14:15.672/model.tar.gz', 'SAGEMAKER_PROGRAM': 'dummy_script.py', 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', 'SAGEMAKER_REGION': 'us-west-2', 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20' }, - 'Image': expected_image_base.format(mxnet_version), 'ModelDataUrl': REPACKED_MODEL_DATA + 'Image': expected_image_base.format(mxnet_version), + 'ModelDataUrl': 's3://mybucket/sagemaker-mxnet-2017-11-06-14:14:15.672/model.tar.gz' } assert environment == model.prepare_container_def(GPU) @@ -366,21 +366,23 @@ def test_model(sagemaker_session): assert isinstance(predictor, MXNetPredictor) -@patch('sagemaker.utils.repack_model', return_value=REPACKED_MODEL_DATA) +@patch('sagemaker.utils.repack_model') def test_model_mms_version(repack_model, sagemaker_session): model = MXNetModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, framework_version=MXNetModel._LOWEST_MMS_VERSION, - sagemaker_session=sagemaker_session) + sagemaker_session=sagemaker_session, name='test-mxnet-model') predictor = model.deploy(1, GPU) repack_model.assert_called_once_with(inference_script=SCRIPT_PATH, source_directory=None, + dependencies=[], model_uri=MODEL_DATA, + repacked_model_uri='s3://mybucket/test-mxnet-model/model.tar.gz', sagemaker_session=sagemaker_session) assert model.model_data == MODEL_DATA - assert model.repacked_model_data == REPACKED_MODEL_DATA - assert model.uploaded_code == UploadedCode(s3_prefix=REPACKED_MODEL_DATA, + assert model.repacked_model_data == 's3://mybucket/test-mxnet-model/model.tar.gz' + assert model.uploaded_code == UploadedCode(s3_prefix='s3://mybucket/test-mxnet-model/model.tar.gz', script_name=os.path.basename(SCRIPT_PATH)) assert isinstance(predictor, MXNetPredictor) diff --git a/tests/unit/test_tfs.py b/tests/unit/test_tfs.py index 5bcdbfba8b..d2d59e0c2d 100644 --- a/tests/unit/test_tfs.py +++ b/tests/unit/test_tfs.py @@ -15,6 +15,8 @@ import io import json import logging + +import mock import pytest from mock import Mock from sagemaker.tensorflow import TensorFlow @@ -102,6 +104,60 @@ def test_tfs_model_with_custom_image(sagemaker_session, tf_version): assert cdef['Image'] == 'my-image' +@mock.patch('sagemaker.fw_utils.model_code_key_prefix', return_value='key-prefix') +@mock.patch('sagemaker.utils.repack_model') +def test_tfs_model_with_entry_point(repack_model, model_code_key_prefix, sagemaker_session, + tf_version): + model = Model("s3://some/data.tar.gz", + entry_point='train.py', + role=ROLE, framework_version=tf_version, + image='my-image', sagemaker_session=sagemaker_session) + + model.prepare_container_def(INSTANCE_TYPE) + + model_code_key_prefix.assert_called_with(model.key_prefix, model.name, model.image) + + repack_model.assert_called_with('train.py', None, [], 's3://some/data.tar.gz', + 's3://my_bucket/key-prefix/model.tar.gz', + sagemaker_session) + + +@mock.patch('sagemaker.fw_utils.model_code_key_prefix', return_value='key-prefix') +@mock.patch('sagemaker.utils.repack_model') +def test_tfs_model_with_source(repack_model, model_code_key_prefix, sagemaker_session, tf_version): + model = Model("s3://some/data.tar.gz", + entry_point='train.py', + source_dir='src', + role=ROLE, framework_version=tf_version, + image='my-image', sagemaker_session=sagemaker_session) + + model.prepare_container_def(INSTANCE_TYPE) + + model_code_key_prefix.assert_called_with(model.key_prefix, model.name, model.image) + + repack_model.assert_called_with('train.py', 'src', [], 's3://some/data.tar.gz', + 's3://my_bucket/key-prefix/model.tar.gz', + sagemaker_session) + + +@mock.patch('sagemaker.fw_utils.model_code_key_prefix', return_value='key-prefix') +@mock.patch('sagemaker.utils.repack_model') +def test_tfs_model_with_dependencies(repack_model, model_code_key_prefix, sagemaker_session, tf_version): + model = Model("s3://some/data.tar.gz", + entry_point='train.py', + dependencies=['src', 'lib'], + role=ROLE, framework_version=tf_version, + image='my-image', sagemaker_session=sagemaker_session) + + model.prepare_container_def(INSTANCE_TYPE) + + model_code_key_prefix.assert_called_with(model.key_prefix, model.name, model.image) + + repack_model.assert_called_with('train.py', None, ['src', 'lib'], 's3://some/data.tar.gz', + 's3://my_bucket/key-prefix/model.tar.gz', + sagemaker_session) + + def test_estimator_deploy(sagemaker_session): container_log_level = '"logging.INFO"' source_dir = 's3://mybucket/source' diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 94511939e3..2f2f706ff0 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -21,11 +21,13 @@ import re import time +from boto3 import exceptions import pytest from mock import call, patch, Mock, MagicMock import sagemaker +BUCKET_WITHOUT_WRITING_PERMISSION = 's3://bucket-without-writing-permission' NAME = 'base_name' BUCKET_NAME = 'some_bucket' @@ -300,207 +302,201 @@ def test_create_tar_file_with_auto_generated_path(open): assert files == [['/tmp/a', 'a'], ['/tmp/b', 'b']] -def write_file(path, content): - with open(path, 'a') as f: - f.write(content) +def create_file_tree(root, tree): + for file in tree: + try: + os.makedirs(os.path.join(root, os.path.dirname(file))) + except: # noqa: E722 Using bare except because p2/3 incompatibility issues. + pass + with open(os.path.join(root, file), 'a') as f: + f.write(file) -def test_repack_model_without_source_dir(tmpdir): +@pytest.fixture() +def tmp(tmpdir): + yield str(tmpdir) - tmp = str(tmpdir) - model_path = os.path.join(tmp, 'model') - write_file(model_path, 'model data') +def test_repack_model_without_source_dir(tmp, fake_s3): - source_dir = os.path.join(tmp, 'source-dir') - os.mkdir(source_dir) - script_path = os.path.join(source_dir, 'inference.py') - write_file(script_path, 'inference script') + create_file_tree(tmp, ['model-dir/model', + 'dependencies/a', + 'dependencies/b', + 'source-dir/inference.py', + 'source-dir/this-file-should-not-be-included.py']) - script_path = os.path.join(source_dir, 'this-file-should-not-be-included.py') - write_file(script_path, 'This file should not be included') + fake_s3.tar_and_upload('model-dir', 's3://fake/location') - contents = [model_path] + sagemaker.utils.repack_model(inference_script=os.path.join(tmp, 'source-dir/inference.py'), + source_directory=None, + dependencies=[os.path.join(tmp, 'dependencies/a'), + os.path.join(tmp, 'dependencies/b')], + model_uri='s3://fake/location', + repacked_model_uri='s3://destination-bucket/model.tar.gz', + sagemaker_session=fake_s3.sagemaker_session) - sagemaker_session = MagicMock() - mock_s3_model_tar(contents, sagemaker_session, tmp) - fake_upload_path = mock_s3_upload(sagemaker_session, tmp) - - model_uri = 's3://fake/location' - - new_model_uri = sagemaker.utils.repack_model(os.path.join(source_dir, 'inference.py'), - None, - model_uri, - sagemaker_session) - - assert list_tar_files(fake_upload_path, tmpdir) == {'/code/inference.py', '/model'} - assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) - - -def test_repack_model_with_entry_point_without_path_without_source_dir(tmpdir): + assert list_tar_files(fake_s3.fake_upload_path, tmp) == {'/model', '/code/a', + '/code/b', '/code/inference.py'} - tmp = str(tmpdir) - model_path = os.path.join(tmp, 'model') - write_file(model_path, 'model data') +def test_repack_model_with_entry_point_without_path_without_source_dir(tmp, fake_s3): - source_dir = os.path.join(tmp, 'source-dir') - os.mkdir(source_dir) - script_path = os.path.join(source_dir, 'inference.py') - write_file(script_path, 'inference script') + create_file_tree(tmp, ['model-dir/model', + 'source-dir/inference.py', + 'source-dir/this-file-should-not-be-included.py']) - script_path = os.path.join(source_dir, 'this-file-should-not-be-included.py') - write_file(script_path, 'This file should not be included') - - contents = [model_path] - - sagemaker_session = MagicMock() - mock_s3_model_tar(contents, sagemaker_session, tmp) - fake_upload_path = mock_s3_upload(sagemaker_session, tmp) - - model_uri = 's3://fake/location' + fake_s3.tar_and_upload('model-dir', 's3://fake/location') cwd = os.getcwd() try: - os.chdir(source_dir) - - new_model_uri = sagemaker.utils.repack_model('inference.py', - None, - model_uri, - sagemaker_session) + os.chdir(os.path.join(tmp, 'source-dir')) + + sagemaker.utils.repack_model('inference.py', + None, + None, + 's3://fake/location', + 's3://destination-bucket/model.tar.gz', + fake_s3.sagemaker_session) finally: os.chdir(cwd) - assert list_tar_files(fake_upload_path, tmpdir) == {'/code/inference.py', '/model'} - assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) - + assert list_tar_files(fake_s3.fake_upload_path, tmp) == {'/code/inference.py', '/model'} -def test_repack_model_from_s3_saved_model_to_s3(tmpdir): - tmp = str(tmpdir) +def test_repack_model_from_s3_to_s3(tmp, fake_s3): - model_path = os.path.join(tmp, 'model') - write_file(model_path, 'model data') + create_file_tree(tmp, ['model-dir/model', + 'source-dir/inference.py', + 'source-dir/this-file-should-be-included.py']) - source_dir = os.path.join(tmp, 'source-dir') - os.mkdir(source_dir) - script_path = os.path.join(source_dir, 'inference.py') - write_file(script_path, 'inference script') + fake_s3.tar_and_upload('model-dir', 's3://fake/location') - script_path = os.path.join(source_dir, 'this-file-should-be-included.py') - write_file(script_path, 'This file should be included') + sagemaker.utils.repack_model('inference.py', + os.path.join(tmp, 'source-dir'), + None, + 's3://fake/location', + 's3://destination-bucket/model.tar.gz', + fake_s3.sagemaker_session) - contents = [model_path] + assert list_tar_files(fake_s3.fake_upload_path, tmp) == {'/code/this-file-should-be-included.py', + '/code/inference.py', + '/model'} - sagemaker_session = MagicMock() - mock_s3_model_tar(contents, sagemaker_session, tmp) - fake_upload_path = mock_s3_upload(sagemaker_session, tmp) - - model_uri = 's3://fake/location' - new_model_uri = sagemaker.utils.repack_model('inference.py', - source_dir, - model_uri, - sagemaker_session) +def test_repack_model_from_file_to_file(tmp): + create_file_tree(tmp, ['model', + 'dependencies/a', + 'source-dir/inference.py']) - assert list_tar_files(fake_upload_path, tmpdir) == {'/code/this-file-should-be-included.py', - '/code/inference.py', - '/model'} - assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) + model_tar_path = os.path.join(tmp, 'model.tar.gz') + sagemaker.utils.create_tar_file([os.path.join(tmp, 'model')], model_tar_path) + sagemaker_session = MagicMock() -def test_repack_model_from_file_saves_model_to_file(tmpdir): + file_mode_path = 'file://%s' % model_tar_path + destination_path = 'file://%s' % os.path.join(tmp, 'repacked-model.tar.gz') - tmp = str(tmpdir) + sagemaker.utils.repack_model('inference.py', + os.path.join(tmp, 'source-dir'), + [os.path.join(tmp, 'dependencies/a')], + file_mode_path, + destination_path, + sagemaker_session) - model_path = os.path.join(tmp, 'model') - write_file(model_path, 'model data') + assert list_tar_files(destination_path, tmp) == {'/code/a', '/code/inference.py', '/model'} - source_dir = os.path.join(tmp, 'source-dir') - os.mkdir(source_dir) - script_path = os.path.join(source_dir, 'inference.py') - write_file(script_path, 'inference script') - model_tar_path = os.path.join(tmp, 'model.tar.gz') - sagemaker.utils.create_tar_file([model_path], model_tar_path) +def test_repack_model_with_inference_code_should_replace_the_code(tmp, fake_s3): + create_file_tree(tmp, ['model-dir/model', + 'source-dir/new-inference.py', + 'model-dir/code/old-inference.py']) - sagemaker_session = MagicMock() + fake_s3.tar_and_upload('model-dir', 's3://fake/location') - file_mode_path = 'file://%s' % model_tar_path - new_model_uri = sagemaker.utils.repack_model('inference.py', - source_dir, - file_mode_path, - sagemaker_session) + sagemaker.utils.repack_model('inference.py', + os.path.join(tmp, 'source-dir'), + None, + 's3://fake/location', + 's3://destination-bucket/repacked-model', + fake_s3.sagemaker_session) - assert os.path.dirname(new_model_uri) == os.path.dirname(file_mode_path) - assert list_tar_files(new_model_uri, tmpdir) == {'/code/inference.py', '/model'} + assert list_tar_files(fake_s3.fake_upload_path, tmp) == {'/code/new-inference.py', '/model'} -def test_repack_model_with_inference_code_should_replace_the_code(tmpdir): +def test_repack_model_from_file_to_folder(tmp): + create_file_tree(tmp, ['model', + 'source-dir/inference.py']) - tmp = str(tmpdir) + model_tar_path = os.path.join(tmp, 'model.tar.gz') + sagemaker.utils.create_tar_file([os.path.join(tmp, 'model')], model_tar_path) - model_path = os.path.join(tmp, 'model') - write_file(model_path, 'model data') + file_mode_path = 'file://%s' % model_tar_path - source_dir = os.path.join(tmp, 'source-dir') - os.mkdir(source_dir) - script_path = os.path.join(source_dir, 'new-inference.py') - write_file(script_path, 'inference script') + sagemaker.utils.repack_model('inference.py', + os.path.join(tmp, 'source-dir'), + [], + file_mode_path, + 'file://%s/repacked-model.tar.gz' % tmp, + MagicMock()) - old_code_path = os.path.join(tmp, 'code') - os.mkdir(old_code_path) - old_script_path = os.path.join(old_code_path, 'old-inference.py') - write_file(old_script_path, 'old inference script') - contents = [model_path, old_code_path] + assert list_tar_files('file://%s/repacked-model.tar.gz' % tmp, tmp) == {'/code/inference.py', '/model'} - sagemaker_session = MagicMock() - mock_s3_model_tar(contents, sagemaker_session, tmp) - fake_upload_path = mock_s3_upload(sagemaker_session, tmp) - model_uri = 's3://fake/location' +class FakeS3(object): - new_model_uri = sagemaker.utils.repack_model('inference.py', - source_dir, - model_uri, - sagemaker_session) + def __init__(self, tmp): + self.tmp = tmp + self.sagemaker_session = MagicMock() + self.location_map = {} + self.current_bucket = None - assert list_tar_files(fake_upload_path, tmpdir) == {'/code/new-inference.py', '/model'} - assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) + self.sagemaker_session.boto_session.resource().Bucket().download_file.side_effect = self.download_file + self.sagemaker_session.boto_session.resource().Bucket.side_effect = self.bucket + self.fake_upload_path = self.mock_s3_upload() + def bucket(self, name): + self.current_bucket = name + return self -def mock_s3_model_tar(contents, sagemaker_session, tmp): - model_tar_path = os.path.join(tmp, 'model.tar.gz') - sagemaker.utils.create_tar_file(contents, model_tar_path) - mock_s3_download(sagemaker_session, model_tar_path) + def download_file(self, path, target): + key = '%s/%s' % (self.current_bucket, path) + shutil.copy2(self.location_map[key], target) + def tar_and_upload(self, path, fake_location): + tar_location = os.path.join(self.tmp, 'model-%s.tar.gz' % time.time()) + with tarfile.open(tar_location, mode='w:gz') as t: + t.add(os.path.join(self.tmp, path), arcname=os.path.sep) -def mock_s3_download(sagemaker_session, model_tar_path): - def download_file(_, target): - shutil.copy2(model_tar_path, target) + self.location_map[fake_location.replace('s3://', '')] = tar_location + return tar_location - sagemaker_session.boto_session.resource().Bucket().download_file.side_effect = download_file + def mock_s3_upload(self): + dst = os.path.join(self.tmp, 'dst') + class MockS3Object(object): -def mock_s3_upload(sagemaker_session, tmp): - dst = os.path.join(tmp, 'dst') + def __init__(self, bucket, key): + self.bucket = bucket + self.key = key - class MockS3Object(object): + def upload_file(self, target): + if self.bucket in BUCKET_WITHOUT_WRITING_PERMISSION: + raise exceptions.S3UploadFailedError() + shutil.copy2(target, dst) - def __init__(self, bucket, key): - self.bucket = bucket - self.key = key + self.sagemaker_session.boto_session.resource().Object = MockS3Object + return dst - def upload_file(self, target): - shutil.copy2(target, dst) - sagemaker_session.boto_session.resource().Object = MockS3Object - return dst +@pytest.fixture() +def fake_s3(tmp): + return FakeS3(tmp) -def list_tar_files(tar_ball, tmpdir): +def list_tar_files(tar_ball, tmp): tar_ball = tar_ball.replace('file://', '') - startpath = str(tmpdir.ensure('tmp', dir=True)) + startpath = os.path.join(tmp, 'startpath') + os.mkdir(startpath) with tarfile.open(name=tar_ball, mode='r:gz') as t: t.extractall(path=startpath)