Skip to content

Commit 909dccc

Browse files
mvsusppengk19
authored andcommitted
feature: Support for TFS preprocessing (aws#797)
1 parent 6ccb807 commit 909dccc

File tree

12 files changed

+550
-132
lines changed

12 files changed

+550
-132
lines changed

src/sagemaker/local/image.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,14 @@ def retrieve_artifacts(self, compose_data, output_data_config, job_name):
235235
sagemaker.local.utils.recursive_copy(host_dir, output_artifacts)
236236

237237
# Tar Artifacts -> model.tar.gz and output.tar.gz
238-
model_files = [os.path.join(model_artifacts, name) for name in os.listdir(model_artifacts)]
239-
output_files = [os.path.join(output_artifacts, name) for name in os.listdir(output_artifacts)]
240-
sagemaker.utils.create_tar_file(model_files, os.path.join(compressed_artifacts, 'model.tar.gz'))
241-
sagemaker.utils.create_tar_file(output_files, os.path.join(compressed_artifacts, 'output.tar.gz'))
238+
model_files = [os.path.join(model_artifacts, name) for name in
239+
os.listdir(model_artifacts)]
240+
output_files = [os.path.join(output_artifacts, name) for name in
241+
os.listdir(output_artifacts)]
242+
sagemaker.utils.create_tar_file(model_files,
243+
os.path.join(compressed_artifacts, 'model.tar.gz'))
244+
sagemaker.utils.create_tar_file(output_files,
245+
os.path.join(compressed_artifacts, 'output.tar.gz'))
242246

243247
if output_data_config['S3OutputPath'] == '':
244248
output_data = 'file://%s' % compressed_artifacts

src/sagemaker/model.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class Model(object):
3737
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
3838

3939
def __init__(self, model_data, image, role=None, predictor_cls=None, env=None, name=None, vpc_config=None,
40-
sagemaker_session=None):
40+
sagemaker_session=None, enable_network_isolation=False):
4141
"""Initialize an SageMaker ``Model``.
4242
4343
Args:
@@ -58,6 +58,9 @@ def __init__(self, model_data, image, role=None, predictor_cls=None, env=None, n
5858
* 'SecurityGroupIds' (list[str]): List of security group ids.
5959
sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker
6060
interactions (default: None). If not specified, one is created using the default AWS configuration chain.
61+
enable_network_isolation (Boolean): Default False. if True, enables network isolation in the endpoint,
62+
isolating the model container. No inbound or outbound network calls can be made to or from the
63+
model container.
6164
"""
6265
self.model_data = model_data
6366
self.image = image
@@ -69,6 +72,7 @@ def __init__(self, model_data, image, role=None, predictor_cls=None, env=None, n
6972
self.sagemaker_session = sagemaker_session
7073
self._model_name = None
7174
self._is_compiled_model = False
75+
self._enable_network_isolation = enable_network_isolation
7276

7377
def prepare_container_def(self, instance_type, accelerator_type=None): # pylint: disable=unused-argument
7478
"""Return a dict created by ``sagemaker.container_def()`` for deploying this model to a specified instance type.
@@ -92,7 +96,7 @@ def enable_network_isolation(self):
9296
Returns:
9397
bool: If network isolation should be enabled or not.
9498
"""
95-
return False
99+
return self._enable_network_isolation
96100

97101
def _create_sagemaker_model(self, instance_type, accelerator_type=None, tags=None):
98102
"""Create a SageMaker Model Entity

src/sagemaker/tensorflow/serving.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import logging
16+
1617
import sagemaker
1718
from sagemaker.content_types import CONTENT_TYPE_JSON
1819
from sagemaker.fw_utils import create_image_uri
@@ -88,7 +89,7 @@ def predict(self, data, initial_args=None):
8889
return super(Predictor, self).predict(data, args)
8990

9091

91-
class Model(sagemaker.Model):
92+
class Model(sagemaker.model.FrameworkModel):
9293
FRAMEWORK_NAME = 'tensorflow-serving'
9394
LOG_LEVEL_PARAM_NAME = 'SAGEMAKER_TFS_NGINX_LOGLEVEL'
9495
LOG_LEVEL_MAP = {
@@ -99,7 +100,7 @@ class Model(sagemaker.Model):
99100
logging.CRITICAL: 'crit',
100101
}
101102

102-
def __init__(self, model_data, role, image=None, framework_version=TF_VERSION,
103+
def __init__(self, model_data, role, entry_point=None, image=None, framework_version=TF_VERSION,
103104
container_log_level=None, predictor_cls=Predictor, **kwargs):
104105
"""Initialize a Model.
105106
@@ -118,14 +119,23 @@ def __init__(self, model_data, role, image=None, framework_version=TF_VERSION,
118119
**kwargs: Keyword arguments passed to the ``Model`` initializer.
119120
"""
120121
super(Model, self).__init__(model_data=model_data, role=role, image=image,
121-
predictor_cls=predictor_cls, **kwargs)
122+
predictor_cls=predictor_cls, entry_point=entry_point, **kwargs)
122123
self._framework_version = framework_version
123124
self._container_log_level = container_log_level
124125

125126
def prepare_container_def(self, instance_type, accelerator_type=None):
126127
image = self._get_image_uri(instance_type, accelerator_type)
127128
env = self._get_container_env()
128-
return sagemaker.container_def(image, self.model_data, env)
129+
130+
if self.entry_point:
131+
model_data = sagemaker.utils.repack_model(self.entry_point,
132+
self.source_dir,
133+
self.model_data,
134+
self.sagemaker_session)
135+
else:
136+
model_data = self.model_data
137+
138+
return sagemaker.container_def(image, model_data, env)
129139

130140
def _get_container_env(self):
131141
if not self._container_log_level:

src/sagemaker/utils.py

+90-3
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,24 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import contextlib
1516
import errno
1617
import os
1718
import random
1819
import re
20+
import shutil
1921
import sys
2022
import tarfile
2123
import tempfile
2224
import time
2325

2426
from datetime import datetime
2527
from functools import wraps
28+
from six.moves.urllib import parse
2629

2730
import six
2831

32+
import sagemaker
2933

3034
ECR_URI_PATTERN = r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$'
3135

@@ -258,13 +262,10 @@ def download_folder(bucket_name, prefix, target, sagemaker_session):
258262

259263
def create_tar_file(source_files, target=None):
260264
"""Create a tar file containing all the source_files
261-
262265
Args:
263266
source_files (List[str]): List of file paths that will be contained in the tar file
264-
265267
Returns:
266268
(str): path to created tar file
267-
268269
"""
269270
if target:
270271
filename = target
@@ -278,6 +279,92 @@ def create_tar_file(source_files, target=None):
278279
return filename
279280

280281

282+
@contextlib.contextmanager
283+
def _tmpdir(suffix='', prefix='tmp'):
284+
"""Create a temporary directory with a context manager. The file is deleted when the context exits.
285+
286+
The prefix, suffix, and dir arguments are the same as for mkstemp().
287+
288+
Args:
289+
suffix (str): If suffix is specified, the file name will end with that suffix, otherwise there will be no
290+
suffix.
291+
prefix (str): If prefix is specified, the file name will begin with that prefix; otherwise,
292+
a default prefix is used.
293+
dir (str): If dir is specified, the file will be created in that directory; otherwise, a default directory is
294+
used.
295+
Returns:
296+
str: path to the directory
297+
"""
298+
tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=None)
299+
yield tmp
300+
shutil.rmtree(tmp)
301+
302+
303+
def repack_model(inference_script, source_directory, model_uri, sagemaker_session):
304+
"""Unpack model tarball and creates a new model tarball with the provided code script.
305+
306+
This function does the following:
307+
- uncompresses model tarball from S3 or local system into a temp folder
308+
- replaces the inference code from the model with the new code provided
309+
- compresses the new model tarball and saves it in S3 or local file system
310+
311+
Args:
312+
inference_script (str): path or basename of the inference script that will be packed into the model
313+
source_directory (str): path including all the files that will be packed into the model
314+
model_uri (str): S3 or file system location of the original model tar
315+
sagemaker_session (:class:`sagemaker.session.Session`): a sagemaker session to interact with S3.
316+
317+
Returns:
318+
str: path to the new packed model
319+
"""
320+
new_model_name = 'model-%s.tar.gz' % sagemaker.utils.sagemaker_short_timestamp()
321+
322+
with _tmpdir() as tmp:
323+
tmp_model_dir = os.path.join(tmp, 'model')
324+
os.mkdir(tmp_model_dir)
325+
326+
model_from_s3 = model_uri.startswith('s3://')
327+
if model_from_s3:
328+
local_model_path = os.path.join(tmp, 'tar_file')
329+
download_file_from_url(model_uri, local_model_path, sagemaker_session)
330+
331+
new_model_path = os.path.join(tmp, new_model_name)
332+
else:
333+
local_model_path = model_uri.replace('file://', '')
334+
new_model_path = os.path.join(os.path.dirname(local_model_path), new_model_name)
335+
336+
with tarfile.open(name=local_model_path, mode='r:gz') as t:
337+
t.extractall(path=tmp_model_dir)
338+
339+
code_dir = os.path.join(tmp_model_dir, 'code')
340+
if os.path.exists(code_dir):
341+
shutil.rmtree(code_dir, ignore_errors=True)
342+
343+
dirname = source_directory if source_directory else os.path.dirname(inference_script)
344+
345+
shutil.copytree(dirname, code_dir)
346+
347+
with tarfile.open(new_model_path, mode='w:gz') as t:
348+
t.add(tmp_model_dir, arcname=os.path.sep)
349+
350+
if model_from_s3:
351+
url = parse.urlparse(model_uri)
352+
bucket, key = url.netloc, url.path.lstrip('/')
353+
new_key = key.replace(os.path.basename(key), new_model_name)
354+
355+
sagemaker_session.boto_session.resource('s3').Object(bucket, new_key).upload_file(new_model_path)
356+
return 's3://%s/%s' % (bucket, new_key)
357+
else:
358+
return 'file://%s' % new_model_path
359+
360+
361+
def download_file_from_url(url, dst, sagemaker_session):
362+
url = parse.urlparse(url)
363+
bucket, key = url.netloc, url.path.lstrip('/')
364+
365+
download_file(bucket, key, dst, sagemaker_session)
366+
367+
281368
def download_file(bucket_name, path, target, sagemaker_session):
282369
"""Download a Single File from S3 into a local path
283370
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
asset-file-contents
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import json
14+
15+
16+
def input_handler(data, context):
17+
data = json.loads(data.read().decode('utf-8'))
18+
new_values = [x + 1 for x in data['instances']]
19+
dumps = json.dumps({'instances': new_values})
20+
return dumps
21+
22+
23+
def output_handler(data, context):
24+
response_content_type = context.accept_header
25+
prediction = data.content
26+
return prediction, response_content_type

tests/integ/test_tfs.py

+86-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import tarfile
16+
1517
import botocore.exceptions
18+
import os
19+
1620
import pytest
1721
import sagemaker
1822
import sagemaker.predictor
@@ -36,28 +40,87 @@ def instance_type(request):
3640
def tfs_predictor(instance_type, sagemaker_session, tf_full_version):
3741
endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving')
3842
model_data = sagemaker_session.upload_data(
39-
path='tests/data/tensorflow-serving-test-model.tar.gz',
43+
path=os.path.join(tests.integ.DATA_DIR, 'tensorflow-serving-test-model.tar.gz'),
4044
key_prefix='tensorflow-serving/models')
41-
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
45+
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name,
46+
sagemaker_session):
4247
model = Model(model_data=model_data, role='SageMakerRole',
4348
framework_version=tf_full_version,
4449
sagemaker_session=sagemaker_session)
4550
predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name)
4651
yield predictor
4752

4853

54+
def tar_dir(directory, tmpdir):
55+
target = os.path.join(str(tmpdir), 'model.tar.gz')
56+
57+
with tarfile.open(target, mode='w:gz') as t:
58+
t.add(directory, arcname=os.path.sep)
59+
return target
60+
61+
62+
@pytest.fixture
63+
def tfs_predictor_with_model_and_entry_point_same_tar(instance_type,
64+
sagemaker_session,
65+
tf_full_version,
66+
tmpdir):
67+
endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving')
68+
69+
model_tar = tar_dir(os.path.join(tests.integ.DATA_DIR, 'tfs/tfs-test-model-with-inference'),
70+
tmpdir)
71+
72+
model_data = sagemaker_session.upload_data(
73+
path=model_tar,
74+
key_prefix='tensorflow-serving/models')
75+
76+
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name,
77+
sagemaker_session):
78+
model = Model(model_data=model_data,
79+
role='SageMakerRole',
80+
framework_version=tf_full_version,
81+
sagemaker_session=sagemaker_session)
82+
predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name)
83+
yield predictor
84+
85+
86+
@pytest.fixture(scope='module')
87+
def tfs_predictor_with_model_and_entry_point_separated(instance_type,
88+
sagemaker_session, tf_full_version):
89+
endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving')
90+
91+
model_data = sagemaker_session.upload_data(
92+
path=os.path.join(tests.integ.DATA_DIR,
93+
'tensorflow-serving-test-model.tar.gz'),
94+
key_prefix='tensorflow-serving/models')
95+
96+
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name,
97+
sagemaker_session):
98+
entry_point = os.path.join(tests.integ.DATA_DIR,
99+
'tfs/tfs-test-model-with-inference/code/inference.py')
100+
model = Model(entry_point=entry_point,
101+
model_data=model_data,
102+
role='SageMakerRole',
103+
framework_version=tf_full_version,
104+
sagemaker_session=sagemaker_session)
105+
predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name)
106+
yield predictor
107+
108+
49109
@pytest.fixture(scope='module')
50110
def tfs_predictor_with_accelerator(sagemaker_session, tf_full_version):
51111
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
52112
instance_type = 'ml.c4.large'
53113
accelerator_type = 'ml.eia1.medium'
54-
model_data = sagemaker_session.upload_data(path='tests/data/tensorflow-serving-test-model.tar.gz',
55-
key_prefix='tensorflow-serving/models')
56-
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
114+
model_data = sagemaker_session.upload_data(
115+
path=os.path.join(tests.integ.DATA_DIR, 'tensorflow-serving-test-model.tar.gz'),
116+
key_prefix='tensorflow-serving/models')
117+
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name,
118+
sagemaker_session):
57119
model = Model(model_data=model_data, role='SageMakerRole',
58120
framework_version=tf_full_version,
59121
sagemaker_session=sagemaker_session)
60-
predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name, accelerator_type=accelerator_type)
122+
predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name,
123+
accelerator_type=accelerator_type)
61124
yield predictor
62125

63126

@@ -81,6 +144,23 @@ def test_predict_with_accelerator(tfs_predictor_with_accelerator):
81144
assert expected_result == result
82145

83146

147+
def test_predict_with_entry_point(tfs_predictor_with_model_and_entry_point_same_tar):
148+
input_data = {'instances': [1.0, 2.0, 5.0]}
149+
expected_result = {'predictions': [4.0, 4.5, 6.0]}
150+
151+
result = tfs_predictor_with_model_and_entry_point_same_tar.predict(input_data)
152+
assert expected_result == result
153+
154+
155+
def test_predict_with_model_and_entry_point_separated(
156+
tfs_predictor_with_model_and_entry_point_separated):
157+
input_data = {'instances': [1.0, 2.0, 5.0]}
158+
expected_result = {'predictions': [4.0, 4.5, 6.0]}
159+
160+
result = tfs_predictor_with_model_and_entry_point_separated.predict(input_data)
161+
assert expected_result == result
162+
163+
84164
def test_predict_generic_json(tfs_predictor):
85165
input_data = [[1.0, 2.0, 5.0], [1.0, 2.0, 5.0]]
86166
expected_result = {'predictions': [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]}

0 commit comments

Comments
 (0)