Skip to content
This repository was archived by the owner on May 23, 2024. It is now read-only.

Commit 8897c75

Browse files
authored
Create test_multi_tfs.py (#205)
* Create test_multi_tfs.py * Update test_multi_tfs.py * modify some numbers
1 parent 34b630d commit 8897c75

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
import json
15+
import os
16+
import subprocess
17+
import sys
18+
import time
19+
20+
import pytest
21+
import requests
22+
23+
BASE_URL = "http://localhost:8080/invocations"
24+
25+
26+
@pytest.fixture(scope="session", autouse=True)
27+
def volume():
28+
try:
29+
model_dir = os.path.abspath("test/resources/models")
30+
subprocess.check_call(
31+
"docker volume create --name multi_tfs_model_volume --opt type=none "
32+
"--opt device={} --opt o=bind".format(model_dir).split())
33+
yield model_dir
34+
finally:
35+
subprocess.check_call("docker volume rm multi_tfs_model_volume".split())
36+
37+
38+
@pytest.fixture(scope="module", autouse=True, params=[True, False])
39+
def container(request, docker_base_name, tag, runtime_config):
40+
try:
41+
if request.param:
42+
batching_config = " -e SAGEMAKER_TFS_ENABLE_BATCHING=true"
43+
else:
44+
batching_config = ""
45+
command = (
46+
"docker run {}--name sagemaker-tensorflow-serving-test -p 8080:8080"
47+
" --mount type=volume,source=multi_tfs_model_volume,target=/opt/ml/model,readonly"
48+
" -e SAGEMAKER_TFS_NGINX_LOGLEVEL=info"
49+
" -e SAGEMAKER_BIND_TO_PORT=8080"
50+
" -e SAGEMAKER_SAFE_PORT_RANGE=9000-9999"
51+
" -e SAGEMAKER_TFS_INSTANCE_COUNT=2"
52+
" -e SAGEMAKER_GUNICORN_WORKERS=4"
53+
" -e SAGEMAKER_TFS_INTER_OP_PARALLELISM=1"
54+
" -e SAGEMAKER_TFS_INTRA_OP_PARALLELISM=1"
55+
" {}"
56+
" {}:{} serve"
57+
).format(runtime_config, batching_config, docker_base_name, tag)
58+
59+
proc = subprocess.Popen(command.split(), stdout=sys.stdout, stderr=subprocess.STDOUT)
60+
61+
attempts = 0
62+
63+
while attempts < 40:
64+
time.sleep(3)
65+
try:
66+
res_code = requests.get("http://localhost:8080/ping").status_code
67+
if res_code == 200:
68+
break
69+
except:
70+
attempts += 1
71+
pass
72+
73+
yield proc.pid
74+
finally:
75+
subprocess.check_call("docker rm -f sagemaker-tensorflow-serving-test".split())
76+
77+
78+
def make_request(data, content_type="application/json", method="predict", version=None):
79+
custom_attributes = "tfs-model-name=half_plus_three,tfs-method={}".format(method)
80+
if version:
81+
custom_attributes += ",tfs-model-version={}".format(version)
82+
83+
headers = {
84+
"Content-Type": content_type,
85+
"X-Amzn-SageMaker-Custom-Attributes": custom_attributes,
86+
}
87+
response = requests.post(BASE_URL, data=data, headers=headers)
88+
return json.loads(response.content.decode("utf-8"))
89+
90+
91+
def test_predict():
92+
x = {
93+
"instances": [1.0, 2.0, 5.0]
94+
}
95+
96+
y = make_request(json.dumps(x))
97+
assert y == {"predictions": [3.5, 4.0, 5.5]}

0 commit comments

Comments
 (0)