Skip to content

Commit 1ff2231

Browse files
mvsusppengk19
authored andcommitted
feature: repack_model support dependencies and code location (aws#821)
1 parent 646293c commit 1ff2231

File tree

11 files changed

+359
-215
lines changed

11 files changed

+359
-215
lines changed

src/sagemaker/model.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -448,21 +448,27 @@ def _upload_code(self, key_prefix, repack=False):
448448
local_code = utils.get_config_value('local.local_code', self.sagemaker_session.config)
449449
if self.sagemaker_session.local_mode and local_code:
450450
self.uploaded_code = None
451-
else:
452-
if not repack:
453-
bucket = self.bucket or self.sagemaker_session.default_bucket()
454-
self.uploaded_code = fw_utils.tar_and_upload_dir(session=self.sagemaker_session.boto_session,
455-
bucket=bucket,
456-
s3_key_prefix=key_prefix,
457-
script=self.entry_point,
458-
directory=self.source_dir,
459-
dependencies=self.dependencies)
451+
elif not repack:
452+
bucket = self.bucket or self.sagemaker_session.default_bucket()
453+
self.uploaded_code = fw_utils.tar_and_upload_dir(session=self.sagemaker_session.boto_session,
454+
bucket=bucket,
455+
s3_key_prefix=key_prefix,
456+
script=self.entry_point,
457+
directory=self.source_dir,
458+
dependencies=self.dependencies)
460459

461460
if repack:
462-
self.repacked_model_data = utils.repack_model(inference_script=self.entry_point,
463-
source_directory=self.source_dir,
464-
model_uri=self.model_data,
465-
sagemaker_session=self.sagemaker_session)
461+
bucket = self.bucket or self.sagemaker_session.default_bucket()
462+
repacked_model_data = 's3://' + os.path.join(bucket, key_prefix, 'model.tar.gz')
463+
464+
utils.repack_model(inference_script=self.entry_point,
465+
source_directory=self.source_dir,
466+
dependencies=self.dependencies,
467+
model_uri=self.model_data,
468+
repacked_model_uri=repacked_model_data,
469+
sagemaker_session=self.sagemaker_session)
470+
471+
self.repacked_model_data = repacked_model_data
466472
self.uploaded_code = UploadedCode(s3_prefix=self.repacked_model_data,
467473
script_name=os.path.basename(self.entry_point))
468474

src/sagemaker/mxnet/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,21 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
9292
Returns:
9393
dict[str, str]: A container definition object usable with the CreateModel API.
9494
"""
95-
mms_version = parse_version(self.framework_version) >= parse_version(self._LOWEST_MMS_VERSION)
95+
is_mms_version = parse_version(self.framework_version) >= parse_version(self._LOWEST_MMS_VERSION)
9696

9797
deploy_image = self.image
9898
if not deploy_image:
9999
region_name = self.sagemaker_session.boto_session.region_name
100100

101101
framework_name = self.__framework_name__
102-
if mms_version:
102+
if is_mms_version:
103103
framework_name += '-serving'
104104

105105
deploy_image = create_image_uri(region_name, framework_name, instance_type,
106106
self.framework_version, self.py_version, accelerator_type=accelerator_type)
107107

108108
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
109-
self._upload_code(deploy_key_prefix, mms_version)
109+
self._upload_code(deploy_key_prefix, is_mms_version)
110110
deploy_env = dict(self.env)
111111
deploy_env.update(self._framework_env_vars())
112112

src/sagemaker/tensorflow/serving.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import logging
16+
import os
1617

1718
import sagemaker
1819
from sagemaker.content_types import CONTENT_TYPE_JSON
@@ -128,10 +129,17 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
128129
env = self._get_container_env()
129130

130131
if self.entry_point:
131-
model_data = sagemaker.utils.repack_model(self.entry_point,
132-
self.source_dir,
133-
self.model_data,
134-
self.sagemaker_session)
132+
key_prefix = sagemaker.fw_utils.model_code_key_prefix(self.key_prefix, self.name, image)
133+
134+
bucket = self.bucket or self.sagemaker_session.default_bucket()
135+
model_data = 's3://' + os.path.join(bucket, key_prefix, 'model.tar.gz')
136+
137+
sagemaker.utils.repack_model(self.entry_point,
138+
self.source_dir,
139+
self.dependencies,
140+
self.model_data,
141+
model_data,
142+
self.sagemaker_session)
135143
else:
136144
model_data = self.model_data
137145

src/sagemaker/utils.py

+72-38
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929

3030
import six
3131

32-
import sagemaker
33-
3432
ECR_URI_PATTERN = r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$'
3533

3634

@@ -300,7 +298,12 @@ def _tmpdir(suffix='', prefix='tmp'):
300298
shutil.rmtree(tmp)
301299

302300

303-
def repack_model(inference_script, source_directory, model_uri, sagemaker_session):
301+
def repack_model(inference_script,
302+
source_directory,
303+
dependencies,
304+
model_uri,
305+
repacked_model_uri,
306+
sagemaker_session):
304307
"""Unpack model tarball and creates a new model tarball with the provided code script.
305308
306309
This function does the following:
@@ -311,60 +314,91 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio
311314
Args:
312315
inference_script (str): path or basename of the inference script that will be packed into the model
313316
source_directory (str): path including all the files that will be packed into the model
317+
dependencies (list[str]): A list of paths to directories (absolute or relative) with
318+
any additional libraries that will be exported to the container (default: []).
319+
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
320+
Example:
321+
322+
The following call
323+
>>> Estimator(entry_point='train.py', dependencies=['my/libs/common', 'virtual-env'])
324+
results in the following inside the container:
325+
326+
>>> $ ls
327+
328+
>>> opt/ml/code
329+
>>> |------ train.py
330+
>>> |------ common
331+
>>> |------ virtual-env
332+
333+
repacked_model_uri (str): path or file system location where the new model will be saved
314334
model_uri (str): S3 or file system location of the original model tar
315335
sagemaker_session (:class:`sagemaker.session.Session`): a sagemaker session to interact with S3.
316336
317337
Returns:
318338
str: path to the new packed model
319339
"""
320-
new_model_name = 'model-%s.tar.gz' % sagemaker.utils.sagemaker_short_timestamp()
340+
dependencies = dependencies or []
321341

322342
with _tmpdir() as tmp:
323-
tmp_model_dir = os.path.join(tmp, 'model')
324-
os.mkdir(tmp_model_dir)
343+
model_dir = _extract_model(model_uri, sagemaker_session, tmp)
325344

326-
model_from_s3 = model_uri.lower().startswith('s3://')
327-
if model_from_s3:
328-
local_model_path = os.path.join(tmp, 'tar_file')
329-
download_file_from_url(model_uri, local_model_path, sagemaker_session)
345+
_create_or_update_code_dir(model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp)
330346

331-
new_model_path = os.path.join(tmp, new_model_name)
332-
else:
333-
local_model_path = model_uri.replace('file://', '')
334-
new_model_path = os.path.join(os.path.dirname(local_model_path), new_model_name)
347+
tmp_model_path = os.path.join(tmp, 'temp-model.tar.gz')
348+
with tarfile.open(tmp_model_path, mode='w:gz') as t:
349+
t.add(model_dir, arcname=os.path.sep)
335350

336-
with tarfile.open(name=local_model_path, mode='r:gz') as t:
337-
t.extractall(path=tmp_model_dir)
351+
_save_model(repacked_model_uri, tmp_model_path, sagemaker_session)
338352

339-
code_dir = os.path.join(tmp_model_dir, 'code')
340-
if os.path.exists(code_dir):
341-
shutil.rmtree(code_dir, ignore_errors=True)
342353

343-
if source_directory and source_directory.lower().startswith('s3://'):
344-
local_code_path = os.path.join(tmp, 'local_code.tar.gz')
345-
download_file_from_url(source_directory, local_code_path, sagemaker_session)
354+
def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session):
355+
if repacked_model_uri.lower().startswith('s3://'):
356+
url = parse.urlparse(repacked_model_uri)
357+
bucket, key = url.netloc, url.path.lstrip('/')
358+
new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri))
346359

347-
with tarfile.open(name=local_code_path, mode='r:gz') as t:
348-
t.extractall(path=code_dir)
360+
sagemaker_session.boto_session.resource('s3').Object(bucket, new_key).upload_file(
361+
tmp_model_path)
362+
else:
363+
shutil.move(tmp_model_path, repacked_model_uri.replace('file://', ''))
349364

350-
elif source_directory:
351-
shutil.copytree(source_directory, code_dir)
352-
else:
353-
os.mkdir(code_dir)
354-
shutil.copy2(inference_script, code_dir)
355365

356-
with tarfile.open(new_model_path, mode='w:gz') as t:
357-
t.add(tmp_model_dir, arcname=os.path.sep)
366+
def _create_or_update_code_dir(model_dir, inference_script, source_directory,
367+
dependencies, sagemaker_session, tmp):
368+
code_dir = os.path.join(model_dir, 'code')
369+
if os.path.exists(code_dir):
370+
shutil.rmtree(code_dir, ignore_errors=True)
371+
if source_directory and source_directory.lower().startswith('s3://'):
372+
local_code_path = os.path.join(tmp, 'local_code.tar.gz')
373+
download_file_from_url(source_directory, local_code_path, sagemaker_session)
374+
375+
with tarfile.open(name=local_code_path, mode='r:gz') as t:
376+
t.extractall(path=code_dir)
358377

359-
if model_from_s3:
360-
url = parse.urlparse(model_uri)
361-
bucket, key = url.netloc, url.path.lstrip('/')
362-
new_key = key.replace(os.path.basename(key), new_model_name)
378+
elif source_directory:
379+
shutil.copytree(source_directory, code_dir)
380+
else:
381+
os.mkdir(code_dir)
382+
shutil.copy2(inference_script, code_dir)
363383

364-
sagemaker_session.boto_session.resource('s3').Object(bucket, new_key).upload_file(new_model_path)
365-
return 's3://%s/%s' % (bucket, new_key)
384+
for dependency in dependencies:
385+
if os.path.isdir(dependency):
386+
shutil.copytree(dependency, code_dir)
366387
else:
367-
return 'file://%s' % new_model_path
388+
shutil.copy2(dependency, code_dir)
389+
390+
391+
def _extract_model(model_uri, sagemaker_session, tmp):
392+
tmp_model_dir = os.path.join(tmp, 'model')
393+
os.mkdir(tmp_model_dir)
394+
if model_uri.lower().startswith('s3://'):
395+
local_model_path = os.path.join(tmp, 'tar_file')
396+
download_file_from_url(model_uri, local_model_path, sagemaker_session)
397+
else:
398+
local_model_path = model_uri.replace('file://', '')
399+
with tarfile.open(name=local_model_path, mode='r:gz') as t:
400+
t.extractall(path=tmp_model_dir)
401+
return tmp_model_dir
368402

369403

370404
def download_file_from_url(url, dst, sagemaker_session):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import json
14+
15+
import dependency
16+
17+
def input_handler(data, context):
18+
data = json.loads(data.read().decode('utf-8'))
19+
new_values = [x + 1 for x in data['instances']]
20+
dumps = json.dumps({'instances': new_values})
21+
return dumps
22+
23+
24+
def output_handler(data, context):
25+
response_content_type = context.accept_header
26+
prediction = data.content
27+
return prediction, response_content_type

tests/data/tfs/tfs-test-model-with-inference/code/inference.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -12,7 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313
import json
1414

15-
1615
def input_handler(data, context):
1716
data = json.loads(data.read().decode('utf-8'))
1817
new_values = [x + 1 for x in data['instances']]

tests/integ/test_tfs.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def tfs_predictor_with_model_and_entry_point_same_tar(instance_type,
8484

8585

8686
@pytest.fixture(scope='module')
87-
def tfs_predictor_with_model_and_entry_point_separated(instance_type,
88-
sagemaker_session, tf_full_version):
87+
def tfs_predictor_with_model_and_entry_point_and_dependencies(instance_type,
88+
sagemaker_session, tf_full_version):
8989
endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving')
9090

9191
model_data = sagemaker_session.upload_data(
@@ -96,10 +96,14 @@ def tfs_predictor_with_model_and_entry_point_separated(instance_type,
9696
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name,
9797
sagemaker_session):
9898
entry_point = os.path.join(tests.integ.DATA_DIR,
99-
'tfs/tfs-test-model-with-inference/code/inference.py')
99+
'tfs/tfs-test-entrypoint-and-dependencies/inference.py')
100+
dependencies = [os.path.join(tests.integ.DATA_DIR,
101+
'tfs/tfs-test-entrypoint-and-dependencies/dependency.py')]
102+
100103
model = Model(entry_point=entry_point,
101104
model_data=model_data,
102105
role='SageMakerRole',
106+
dependencies=dependencies,
103107
framework_version=tf_full_version,
104108
sagemaker_session=sagemaker_session)
105109
predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name)
@@ -152,12 +156,12 @@ def test_predict_with_entry_point(tfs_predictor_with_model_and_entry_point_same_
152156
assert expected_result == result
153157

154158

155-
def test_predict_with_model_and_entry_point_separated(
156-
tfs_predictor_with_model_and_entry_point_separated):
159+
def test_predict_with_model_and_entry_point_and_dependencies_separated(
160+
tfs_predictor_with_model_and_entry_point_and_dependencies):
157161
input_data = {'instances': [1.0, 2.0, 5.0]}
158162
expected_result = {'predictions': [4.0, 4.5, 6.0]}
159163

160-
result = tfs_predictor_with_model_and_entry_point_separated.predict(input_data)
164+
result = tfs_predictor_with_model_and_entry_point_and_dependencies.predict(input_data)
161165
assert expected_result == result
162166

163167

0 commit comments

Comments
 (0)