Skip to content
This repository was archived by the owner on May 23, 2024. It is now read-only.

Commit 7387dc5

Browse files
committed
fix: Use default model name when model name is None
1 parent b28564b commit 7387dc5

File tree

6 files changed

+210
-1
lines changed

6 files changed

+210
-1
lines changed

docker/build_artifacts/sagemaker/tfs_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def parse_request(req, rest_port, grpc_port, default_model_name, model_name=None
4141
tfs_uri = make_tfs_uri(rest_port, tfs_attributes, default_model_name, model_name)
4242

4343
if not model_name:
44-
model_name = tfs_attributes.get("tfs-model-name")
44+
model_name = tfs_attributes.get("tfs-model-name") or default_model_name
4545

4646
context = Context(model_name,
4747
tfs_attributes.get("tfs-model-version"),

test/integration/sagemaker/test_tfs.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,29 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
import json
1314
import os
15+
import tarfile
1416

17+
import boto3
1518
import pytest
19+
import sagemaker
20+
import urllib.request
1621

1722
import util
1823

24+
from packaging.version import Version
25+
26+
from sagemaker.tensorflow import TensorFlowModel
27+
from sagemaker.utils import name_from_base
28+
29+
from timeout import timeout_and_delete_endpoint
30+
1931
NON_P3_REGIONS = ["ap-southeast-1", "ap-southeast-2", "ap-south-1",
2032
"ca-central-1", "eu-central-1", "eu-west-2", "us-west-1"]
2133

34+
RESOURCES_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "resources"))
35+
2236

2337
@pytest.fixture(params=os.environ["TEST_VERSIONS"].split(","))
2438
def version(request):
@@ -75,6 +89,28 @@ def python_model_with_lib(region, boto_session):
7589
"test/data/python-with-lib.tar.gz")
7690

7791

92+
@pytest.fixture(scope="session")
93+
def resnet_model_tar_path():
94+
model_path = os.path.join(RESOURCES_PATH, "models", "resnet50_v1")
95+
model_tar_path = os.path.join(model_path, "model.tar.gz")
96+
if os.path.exists(model_tar_path):
97+
os.remove(model_tar_path)
98+
s3_resource = boto3.resource("s3")
99+
models_bucket = s3_resource.Bucket("aws-dlc-sample-models")
100+
model_s3_location = "tensorflow/resnet50_v1/model"
101+
for obj in models_bucket.objects.filter(Prefix=model_s3_location):
102+
local_file = os.path.join(model_path, "model", os.path.relpath(obj.key, model_s3_location))
103+
if not os.path.isdir(os.path.dirname(local_file)):
104+
os.makedirs(os.path.dirname(local_file))
105+
if obj.key.endswith("/"):
106+
continue
107+
models_bucket.download_file(obj.key, local_file)
108+
with tarfile.open(model_tar_path, "w:gz") as model_tar:
109+
model_tar.add(os.path.join(model_path, "code"), arcname="code")
110+
model_tar.add(os.path.join(model_path, "model"), arcname="model")
111+
return model_tar_path
112+
113+
78114
def test_tfs_model(boto_session, sagemaker_client,
79115
sagemaker_runtime_client, model_name, tfs_model,
80116
image_uri, instance_type, accelerator_type):
@@ -135,3 +171,43 @@ def test_python_model_with_lib(boto_session, sagemaker_client,
135171
# python service adds this to tfs response
136172
assert output_data["python"] is True
137173
assert output_data["dummy_module"] == "0.1"
174+
175+
176+
def test_resnet_with_inference_handler(
177+
boto_session, image_uri, instance_type, resnet_model_tar_path, framework_version
178+
):
179+
if Version(framework_version) >= Version("2.6"):
180+
pytest.skip(
181+
"The inference script based test currently uses v1 compat features, making it incompatible with TF>=2.6"
182+
)
183+
sagemaker_session = sagemaker.Session(boto_session=boto_session)
184+
model_data = sagemaker_session.upload_data(
185+
path=resnet_model_tar_path, key_prefix=os.path.join("tensorflow-inference", "resnet")
186+
)
187+
endpoint_name = name_from_base("tensorflow-inference")
188+
189+
tensorflow_model = TensorFlowModel(
190+
model_data=model_data,
191+
role="SageMakerRole",
192+
entry_point="inference.py",
193+
image_uri=image_uri,
194+
sagemaker_session=sagemaker_session,
195+
)
196+
197+
with timeout_and_delete_endpoint(endpoint_name, sagemaker_session, minutes=30):
198+
tensorflow_predictor = tensorflow_model.deploy(
199+
initial_instance_count=1, instance_type=instance_type, endpoint_name=endpoint_name
200+
)
201+
kitten_url = "https://s3.amazonaws.com/model-server/inputs/kitten.jpg"
202+
kitten_local_path = "kitten.jpg"
203+
urllib.request.urlretrieve(kitten_url, kitten_local_path)
204+
with open(kitten_local_path, "rb") as f:
205+
kitten_image = f.read()
206+
207+
runtime_client = sagemaker_session.sagemaker_runtime_client
208+
response = runtime_client.invoke_endpoint(
209+
EndpointName=endpoint_name, ContentType='application/x-image', Body=kitten_image
210+
)
211+
result = json.loads(response['Body'].read().decode('ascii'))
212+
213+
assert len(result["outputs"]["probabilities"]["floatVal"]) == 3

test/integration/sagemaker/timeout.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
# TODO: this is used in all containers and sdk. We should move it to container support or sdk test utils.
14+
from __future__ import absolute_import
15+
import signal
16+
from contextlib import contextmanager
17+
import logging
18+
19+
from botocore.exceptions import ClientError
20+
21+
LOGGER = logging.getLogger('timeout')
22+
23+
24+
class TimeoutError(Exception):
25+
pass
26+
27+
28+
@contextmanager
29+
def timeout(seconds=0, minutes=0, hours=0):
30+
"""Add a signal-based timeout to any block of code.
31+
If multiple time units are specified, they will be added together to determine time limit.
32+
Usage:
33+
with timeout(seconds=5):
34+
my_slow_function(...)
35+
Args:
36+
- seconds: The time limit, in seconds.
37+
- minutes: The time limit, in minutes.
38+
- hours: The time limit, in hours.
39+
"""
40+
41+
limit = seconds + 60 * minutes + 3600 * hours
42+
43+
def handler(signum, frame):
44+
raise TimeoutError('timed out after {} seconds'.format(limit))
45+
46+
try:
47+
signal.signal(signal.SIGALRM, handler)
48+
signal.alarm(limit)
49+
50+
yield
51+
finally:
52+
signal.alarm(0)
53+
54+
55+
@contextmanager
56+
def timeout_and_delete_endpoint(endpoint_name, sagemaker_session,
57+
seconds=0, minutes=0, hours=0):
58+
with timeout(seconds=seconds, minutes=minutes, hours=hours) as t:
59+
try:
60+
yield [t]
61+
finally:
62+
try:
63+
sagemaker_session.delete_endpoint(endpoint_name)
64+
LOGGER.info("deleted endpoint {}".format(endpoint_name))
65+
except ClientError as ce:
66+
if ce.response['Error']['Code'] == 'ValidationException':
67+
# avoids the inner exception to be overwritten
68+
pass
69+
70+
71+
@contextmanager
72+
def timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, seconds=0, minutes=0, hours=0):
73+
with timeout(seconds=seconds, minutes=minutes, hours=hours) as t:
74+
try:
75+
yield [t]
76+
finally:
77+
try:
78+
sagemaker_session.delete_endpoint(endpoint_name)
79+
LOGGER.info('deleted endpoint {}'.format(endpoint_name))
80+
except ClientError as ce:
81+
if ce.response['Error']['Code'] == 'ValidationException':
82+
# avoids the inner exception to be overwritten
83+
pass
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import io
2+
3+
import grpc
4+
import gzip
5+
import numpy as np
6+
import tensorflow as tf
7+
8+
from google.protobuf.json_format import MessageToJson
9+
from PIL import Image
10+
from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc
11+
12+
prediction_services = {}
13+
compression_algo = gzip
14+
15+
16+
def handler(data, context):
17+
f = data.read()
18+
f = io.BytesIO(f)
19+
image = Image.open(f).convert('RGB')
20+
batch_size = 1
21+
image = np.asarray(image.resize((224, 224)))
22+
image = np.concatenate([image[np.newaxis, :, :]] * batch_size)
23+
24+
request = predict_pb2.PredictRequest()
25+
request.model_spec.name = context.model_name
26+
request.model_spec.signature_name = 'serving_default'
27+
request.inputs['images'].CopyFrom(
28+
tf.compat.v1.make_tensor_proto(image, shape=image.shape, dtype=tf.float32))
29+
30+
# Call Predict gRPC service
31+
result = get_prediction_service(context).Predict(request, 60.0)
32+
print("Returning the response for grpc port: {}".format(context.grpc_port))
33+
34+
# Return response
35+
json_obj = MessageToJson(result)
36+
return json_obj, "application/json"
37+
38+
39+
def get_prediction_service(context):
40+
# global prediction_service
41+
if context.grpc_port not in prediction_services:
42+
channel = grpc.insecure_channel("localhost:{}".format(context.grpc_port))
43+
prediction_services[context.grpc_port] = prediction_service_pb2_grpc.PredictionServiceStub(channel)
44+
return prediction_services[context.grpc_port]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
numpy
2+
Pillow
3+
tensorflow==1.15.5

tox.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ deps =
6060
pytest
6161
pytest-xdist
6262
boto3
63+
packaging
6364
requests
65+
sagemaker
66+
urllib3
6467

6568
[testenv:flake8]
6669
deps =

0 commit comments

Comments
 (0)