Skip to content

Commit f900670

Browse files
committed
Add local MME integration test
1 parent edc0cc3 commit f900670

File tree

1 file changed

+172
-0
lines changed

1 file changed

+172
-0
lines changed
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import json
16+
import os
17+
import subprocess
18+
import sys
19+
import time
20+
21+
import pytest
22+
import requests
23+
24+
from integration import mme_path
25+
26+
BASE_URL = "http://0.0.0.0:8080/"
27+
PING_URL = BASE_URL + "ping"
28+
INVOCATION_URL = BASE_URL + "models/{}/invoke"
29+
MODELS_URL = BASE_URL + "models"
30+
DELETE_MODEL_URL = BASE_URL + "models/{}"
31+
32+
33+
@pytest.fixture(scope="module", autouse=True)
34+
def container(image_uri, use_gpu):
35+
try:
36+
gpu_option = "--gpus device=0" if use_gpu else ""
37+
resnet18_path = os.path.join(mme_path, 'resnet18')
38+
traced_resnet18_path = os.path.join(mme_path, 'traced_resnet18')
39+
40+
command = (
41+
"docker run -it --rm {} "
42+
"--name sagemaker-pytorch-inference-toolkit-mme-test "
43+
"-p 8080:8080 "
44+
"-v {}:/resnet18 "
45+
"-v {}:/traced_resnet18 "
46+
"-e SAGEMAKER_MULTI_MODEL=true {} serve"
47+
).format(gpu_option, resnet18_path, traced_resnet18_path, image_uri)
48+
49+
proc = subprocess.Popen(command.split(), stdout=sys.stdout, stderr=subprocess.STDOUT)
50+
51+
attempts = 0
52+
while attempts < 10:
53+
time.sleep(3)
54+
try:
55+
requests.get(PING_URL)
56+
break
57+
except Exception:
58+
attempts += 1
59+
pass
60+
yield proc.pid
61+
62+
finally:
63+
subprocess.check_call("docker rm -f sagemaker-pytorch-inference-toolkit-mme-test".split())
64+
65+
66+
def make_list_model_request():
67+
response = requests.get(MODELS_URL)
68+
return response.status_code, json.loads(response.content.decode("utf-8"))
69+
70+
71+
def make_load_model_request(data, content_type="application/json"):
72+
headers = {"Content-Type": content_type}
73+
response = requests.post(MODELS_URL, data=data, headers=headers)
74+
return response.status_code, json.loads(response.content.decode("utf-8"))
75+
76+
77+
def make_unload_model_request(model_name):
78+
response = requests.delete(DELETE_MODEL_URL.format(model_name))
79+
return response.status_code, json.loads(response.content.decode("utf-8"))
80+
81+
82+
def make_invocation_request(model_name, data, content_type="application/octet-stream"):
83+
headers = {"Content-Type": content_type}
84+
response = requests.post(INVOCATION_URL.format(model_name), data=data, headers=headers)
85+
return response.status_code, json.loads(response.content.decode("utf-8"))
86+
87+
88+
def test_ping():
89+
res = requests.get(PING_URL)
90+
assert res.status_code == 200
91+
92+
93+
def test_list_models_empty():
94+
code, models = make_list_model_request()
95+
assert code == 200
96+
assert models["models"] == []
97+
98+
99+
def test_load_models():
100+
data1 = {"model_name": "resnet18", "url": "/resnet18"}
101+
code1, content1 = make_load_model_request(data=json.dumps(data1))
102+
assert code1 == 200
103+
assert content1["status"] == 'Model "resnet18" Version: 1.0 registered with 1 initial workers'
104+
105+
code2, content2 = make_list_model_request()
106+
assert code2 == 200
107+
assert content2["models"] == [{"modelName": "resnet18", "modelUrl": "/resnet18"}]
108+
109+
data2 = {"model_name": "traced_resnet18", "url": "/traced_resnet18"}
110+
code3, content3 = make_load_model_request(data=json.dumps(data2))
111+
assert code3 == 200
112+
assert content3["status"] == 'Model "traced_resnet18" Version: 1.0 registered with 1 initial workers'
113+
114+
code4, content4 = make_list_model_request()
115+
assert code4 == 200
116+
assert content4["models"] == [
117+
{"modelName": "resnet18", "modelUrl": "/resnet18"},
118+
{"modelName": "traced_resnet18", "modelUrl": "/traced_resnet18"},
119+
]
120+
121+
122+
def test_unload_models():
123+
code1, content1 = make_unload_model_request("resnet18")
124+
assert code1 == 200
125+
assert content1["status"] == 'Model "resnet18" unregistered'
126+
127+
code2, content2 = make_list_model_request()
128+
assert code2 == 200
129+
assert content2["models"] == [{"modelName": "traced_resnet18", "modelUrl": "/traced_resnet18"}]
130+
131+
132+
def test_load_non_existing_model():
133+
data = {"model_name": "banana", "url": "/banana"}
134+
code, content = make_load_model_request(data=json.dumps(data))
135+
assert code == 404
136+
137+
138+
def test_unload_non_existing_model():
139+
# resnet18 is already unloaded
140+
code, content = make_unload_model_request("resnet18")
141+
assert code == 404
142+
143+
144+
def test_load_model_multiple_times():
145+
# traced_resnet18 is already loaded
146+
data = {"model_name": "traced_resnet18", "url": "traced_resnet18"}
147+
code, content = make_load_model_request(data=json.dumps(data))
148+
assert code == 409
149+
150+
151+
def test_invocation():
152+
data = {"model_name": "resnet18", "url": "/resnet18"}
153+
code, content = make_load_model_request(data=json.dumps(data))
154+
155+
image_url = (
156+
"https://raw.githubusercontent.com/aws/amazon-sagemaker-examples/master/"
157+
"sagemaker_neo_compilation_jobs/pytorch_torchvision/cat.jpg"
158+
)
159+
img_data = requests.get(image_url).content
160+
with open("cat.jpg", "wb") as file_obj:
161+
file_obj.write(img_data)
162+
with open("cat.jpg", "rb") as f:
163+
payload = f.read()
164+
payload = bytearray(payload)
165+
166+
code1, predictions1 = make_invocation_request("resnet18", payload)
167+
assert code1 == 200
168+
assert len(predictions1) == 1000
169+
170+
code2, predictions2 = make_invocation_request("traced_resnet18", payload)
171+
assert code2 == 200
172+
assert len(predictions2) == 1000

0 commit comments

Comments
 (0)