-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: Support for TFS preprocessing #797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
ee61563
83a1ff0
012f16a
c354878
fde4d9f
3bbc3bb
eeec58d
abd0d77
1a96e0e
82d5366
ada8b2a
5392854
09bd185
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,20 +12,24 @@ | |
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import contextlib | ||
import errno | ||
import os | ||
import random | ||
import re | ||
import shutil | ||
import sys | ||
import tarfile | ||
import tempfile | ||
import time | ||
|
||
from datetime import datetime | ||
from functools import wraps | ||
from six.moves.urllib import parse | ||
|
||
import six | ||
|
||
import sagemaker | ||
|
||
ECR_URI_PATTERN = r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$' | ||
|
||
|
@@ -278,6 +282,107 @@ def create_tar_file(source_files, target=None): | |
return filename | ||
|
||
|
||
@contextlib.contextmanager | ||
def _tmpdir(suffix='', prefix='tmp'): | ||
"""Create a temporary directory with a context manager. The file is deleted when the context exits. | ||
|
||
The prefix, suffix, and dir arguments are the same as for mkstemp(). | ||
|
||
Args: | ||
suffix (str): If suffix is specified, the file name will end with that suffix, otherwise there will be no | ||
suffix. | ||
prefix (str): If prefix is specified, the file name will begin with that prefix; otherwise, | ||
a default prefix is used. | ||
dir (str): If dir is specified, the file will be created in that directory; otherwise, a default directory is | ||
used. | ||
Returns: | ||
str: path to the directory | ||
""" | ||
tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=None) | ||
yield tmp | ||
shutil.rmtree(tmp) | ||
|
||
|
||
def repack_model(inference_script, source_directory, model_uri, sagemaker_session): | ||
"""Unpack model tarball and creates a new model tarball with the provided code script. | ||
|
||
This function does the following: | ||
- uncompresses model tarball from S3 or local system into a temp folder | ||
- replaces the inference code from the model with the new code provided | ||
- compresses the new model tarball and saves it in S3 or local file system | ||
|
||
Args: | ||
inference_script (str): path or basename of the inference script that will be packed into the model | ||
source_directory (str): path including all the files that will be packed into the model | ||
model_uri (str): S3 or file system location of the original model tar | ||
sagemaker_session (:class:`sagemaker.session.Session`): a sagemaker session to interact with S3. | ||
|
||
Returns: | ||
str: path to the new packed model | ||
""" | ||
new_model_name = 'model-%s.tar.gz' % sagemaker.utils.sagemaker_short_timestamp() | ||
|
||
with _tmpdir() as tmp: | ||
|
||
tmp_model_dir = os.path.join(tmp, 'model') | ||
os.mkdir(tmp_model_dir) | ||
|
||
model_from_s3 = model_uri.startswith('s3://') | ||
if model_from_s3: | ||
|
||
local_model_uri = os.path.join(tmp, 'tar_file') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not a uri There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. renamed to |
||
download_file_from_url(model_uri, local_model_uri, sagemaker_session) | ||
|
||
new_model_path = os.path.join(tmp, new_model_name) | ||
else: | ||
local_model_uri = model_uri.replace('file://', '') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change var name, since it's not a uri anymore There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. renamed to |
||
new_model_path = os.path.join(os.path.dirname(local_model_uri), new_model_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we are writing into user dir as sibling of original model. we should stage things in a temp dir so we don't trip over other user files or leave junk around There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are correct that we need to do that. I am currently staging things in an temp folder: new_model_path is the location where the new model will be saved. |
||
|
||
with tarfile.open(name=local_model_uri, mode='r:gz') as t: | ||
t.extractall(path=tmp_model_dir) | ||
|
||
code_dir = os.path.join(tmp_model_dir, 'code') | ||
if os.path.exists(code_dir): | ||
shutil.rmtree(code_dir, ignore_errors=True) | ||
|
||
os.mkdir(code_dir) | ||
|
||
source_files = _list_files(inference_script, source_directory) | ||
|
||
for source_file in source_files: | ||
shutil.copy(source_file, code_dir) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why list files and then copy one by one instead of just copying the source_dir to code? this would also eliminate L348 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good call. I am going to do that. |
||
|
||
files_to_compress = [os.path.join(tmp_model_dir, file) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is also weird. why not just tar the new model directory? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to tar the contents of the folder There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you try with arcname=os.path.sep? |
||
for file in os.listdir(tmp_model_dir)] | ||
|
||
tar_file = sagemaker.utils.create_tar_file(files_to_compress, new_model_path) | ||
|
||
if model_from_s3: | ||
url = parse.urlparse(model_uri) | ||
bucket, key = url.netloc, url.path.lstrip('/') | ||
new_key = key.replace(os.path.basename(key), new_model_name) | ||
|
||
sagemaker_session.boto_session.resource('s3').Object(bucket, new_key).upload_file(tar_file) | ||
return 's3://%s/%s' % (bucket, new_key) | ||
else: | ||
return 'file://%s' % new_model_path | ||
|
||
|
||
def _list_files(script, directory): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't be necessary There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mentioned above why we need to list the files instead of passing the entire folder. Another thing that this function does is dealing with the use case which directory is None and only the entry point should be copied to the tarball. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this still needed? it doesn't seem to be called anywhere anymore. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call! |
||
if directory is None: | ||
return [script] | ||
|
||
basedir = directory if directory else os.path.dirname(script) | ||
return [os.path.join(basedir, name) for name in os.listdir(basedir)] | ||
|
||
|
||
def download_file_from_url(url, dst, sagemaker_session): | ||
url = parse.urlparse(url) | ||
bucket, key = url.netloc, url.path.lstrip('/') | ||
|
||
download_file(bucket, key, dst, sagemaker_session) | ||
|
||
|
||
def download_file(bucket_name, path, target, sagemaker_session): | ||
"""Download a Single File from S3 into a local path | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
asset-file-contents |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
import json | ||
|
||
|
||
def input_handler(data, context): | ||
data = json.loads(data.read().decode('utf-8')) | ||
new_values = [x + 1 for x in data['instances']] | ||
dumps = json.dumps({'instances': new_values}) | ||
return dumps | ||
|
||
|
||
def output_handler(data, context): | ||
response_content_type = context.accept_header | ||
prediction = data.content | ||
return prediction, response_content_type |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,9 @@ | |
from __future__ import absolute_import | ||
|
||
import botocore.exceptions | ||
import os | ||
import tempfile | ||
|
||
import pytest | ||
import sagemaker | ||
import sagemaker.predictor | ||
|
@@ -36,28 +39,85 @@ def instance_type(request): | |
def tfs_predictor(instance_type, sagemaker_session, tf_full_version): | ||
endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') | ||
model_data = sagemaker_session.upload_data( | ||
path='tests/data/tensorflow-serving-test-model.tar.gz', | ||
path=os.path.join(tests.integ.DATA_DIR, 'tensorflow-serving-test-model.tar.gz'), | ||
key_prefix='tensorflow-serving/models') | ||
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): | ||
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, | ||
sagemaker_session): | ||
model = Model(model_data=model_data, role='SageMakerRole', | ||
framework_version=tf_full_version, | ||
sagemaker_session=sagemaker_session) | ||
predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name) | ||
yield predictor | ||
|
||
|
||
def tar_dir(directory): | ||
|
||
tmp = tempfile.mkdtemp() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should use a pytest tmpdir instead, so cleanup is automatic |
||
|
||
source_files = [os.path.join(directory, name) for name in os.listdir(directory)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't need to do this. method in utils should just take a dir path and do the right thing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think the create_tar_file should be changed... still not convinced we need to be generating these file lists this way (see arcname comment) |
||
return sagemaker.utils.create_tar_file(source_files, os.path.join(tmp, 'model.tar.gz')) | ||
|
||
|
||
@pytest.fixture(scope='module') | ||
def tfs_predictor_with_model_and_entry_point_same_tar(instance_type, | ||
sagemaker_session, | ||
tf_full_version): | ||
endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') | ||
|
||
model_tar = tar_dir(os.path.join(tests.integ.DATA_DIR, 'tfs/tfs-test-model-with-inference')) | ||
|
||
model_data = sagemaker_session.upload_data( | ||
path=model_tar, | ||
key_prefix='tensorflow-serving/models') | ||
|
||
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, | ||
sagemaker_session): | ||
model = Model(model_data=model_data, | ||
role='SageMakerRole', | ||
framework_version=tf_full_version, | ||
sagemaker_session=sagemaker_session) | ||
predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name) | ||
yield predictor | ||
|
||
|
||
@pytest.fixture(scope='module') | ||
def tfs_predictor_with_model_and_entry_point_separated(instance_type, | ||
sagemaker_session, tf_full_version): | ||
endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') | ||
|
||
model_data = sagemaker_session.upload_data( | ||
path=os.path.join(tests.integ.DATA_DIR, | ||
'tensorflow-serving-test-model.tar.gz'), | ||
key_prefix='tensorflow-serving/models') | ||
|
||
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, | ||
sagemaker_session): | ||
entry_point = os.path.join(tests.integ.DATA_DIR, | ||
'tfs/tfs-test-model-with-inference/code/inference.py') | ||
model = Model(entry_point=entry_point, | ||
model_data=model_data, | ||
role='SageMakerRole', | ||
framework_version=tf_full_version, | ||
sagemaker_session=sagemaker_session) | ||
predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name) | ||
yield predictor | ||
|
||
|
||
@pytest.fixture(scope='module') | ||
def tfs_predictor_with_accelerator(sagemaker_session, tf_full_version): | ||
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving") | ||
instance_type = 'ml.c4.large' | ||
accelerator_type = 'ml.eia1.medium' | ||
model_data = sagemaker_session.upload_data(path='tests/data/tensorflow-serving-test-model.tar.gz', | ||
key_prefix='tensorflow-serving/models') | ||
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): | ||
model_data = sagemaker_session.upload_data( | ||
path=os.path.join(tests.integ.DATA_DIR, 'tensorflow-serving-test-model.tar.gz'), | ||
key_prefix='tensorflow-serving/models') | ||
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, | ||
sagemaker_session): | ||
model = Model(model_data=model_data, role='SageMakerRole', | ||
framework_version=tf_full_version, | ||
sagemaker_session=sagemaker_session) | ||
predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name, accelerator_type=accelerator_type) | ||
predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name, | ||
accelerator_type=accelerator_type) | ||
yield predictor | ||
|
||
|
||
|
@@ -81,6 +141,23 @@ def test_predict_with_accelerator(tfs_predictor_with_accelerator): | |
assert expected_result == result | ||
|
||
|
||
def test_predict_with_entry_point(tfs_predictor_with_model_and_entry_point_same_tar): | ||
input_data = {'instances': [1.0, 2.0, 5.0]} | ||
expected_result = {'predictions': [4.0, 4.5, 6.0]} | ||
|
||
result = tfs_predictor_with_model_and_entry_point_same_tar.predict(input_data) | ||
assert expected_result == result | ||
|
||
|
||
def test_predict_with_model_and_entry_point_separated( | ||
tfs_predictor_with_model_and_entry_point_separated): | ||
input_data = {'instances': [1.0, 2.0, 5.0]} | ||
expected_result = {'predictions': [4.0, 4.5, 6.0]} | ||
|
||
result = tfs_predictor_with_model_and_entry_point_separated.predict(input_data) | ||
assert expected_result == result | ||
|
||
|
||
def test_predict_generic_json(tfs_predictor): | ||
input_data = [[1.0, 2.0, 5.0], [1.0, 2.0, 5.0]] | ||
expected_result = {'predictions': [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like a lot of extra blank lines in this method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally like to use white lines as a logical way to group correlated lines. It is all under PEP8 and linting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure. they just seem kind of random here. like blank at 332 but not at 338.