Skip to content
This repository was archived by the owner on May 23, 2024. It is now read-only.

Commit b97ce67

Browse files
authored
feature: add model_version_policy to model config (#155)
1 parent 5a9690d commit b97ce67

File tree

17 files changed

+85
-62
lines changed

17 files changed

+85
-62
lines changed

docker/build_artifacts/sagemaker/serve.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,23 @@ def _create_tfs_config(self):
9999
config = "model_config_list: {\n"
100100
for m in models:
101101
config += " config: {\n"
102-
config += " name: '{}',\n".format(os.path.basename(m))
103-
config += " base_path: '{}',\n".format(m)
102+
config += " name: '{}'\n".format(os.path.basename(m))
103+
config += " base_path: '{}'\n".format(m)
104104
config += " model_platform: 'tensorflow'\n"
105+
106+
config += " model_version_policy: {\n"
107+
config += " specific: {\n"
108+
for version in tfs_utils.find_model_versions(m):
109+
config += " versions: {}\n".format(version)
110+
config += " }\n"
111+
config += " }\n"
112+
105113
config += " }\n"
106114
config += "}\n"
107115

108116
log.info("tensorflow serving model config: \n%s\n", config)
109117

110-
with open("/sagemaker/model-config.cfg", "w") as f:
118+
with open(self._tfs_config_path, "w") as f:
111119
f.write(config)
112120

113121
def _setup_gunicorn(self):
@@ -259,10 +267,6 @@ def start(self):
259267
if self._tfs_enable_multi_model_endpoint:
260268
log.info("multi-model endpoint is enabled, TFS model servers will be started later")
261269
else:
262-
tfs_utils.create_tfs_config(
263-
self._tfs_default_model_name,
264-
self._tfs_config_path
265-
)
266270
self._create_tfs_config()
267271
self._start_tfs()
268272

docker/build_artifacts/sagemaker/tfs_utils.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -77,42 +77,20 @@ def parse_tfs_custom_attributes(req):
7777
def create_tfs_config_individual_model(model_name, base_path):
7878
config = "model_config_list: {\n"
7979
config += " config: {\n"
80-
config += " name: '{}',\n".format(model_name)
81-
config += " base_path: '{}',\n".format(base_path)
80+
config += " name: '{}'\n".format(model_name)
81+
config += " base_path: '{}'\n".format(base_path)
8282
config += " model_platform: 'tensorflow'\n"
83-
config += " }\n"
84-
config += "}\n"
85-
return config
86-
8783

88-
def create_tfs_config(
89-
tfs_default_model_name,
90-
tfs_config_path,
91-
):
92-
models = find_models()
93-
if not models:
94-
raise ValueError("no SavedModel bundles found!")
84+
config += " model_version_policy: {\n"
85+
config += " specific: {\n"
86+
for version in find_model_versions(base_path):
87+
config += " versions: {}\n".format(version)
88+
config += " }\n"
89+
config += " }\n"
9590

96-
if tfs_default_model_name == "None":
97-
default_model = os.path.basename(models[0])
98-
if default_model:
99-
tfs_default_model_name = default_model
100-
log.info("using default model name: {}".format(tfs_default_model_name))
101-
else:
102-
log.info("no default model detected")
103-
104-
# config (may) include duplicate 'config' keys, so we can't just dump a dict
105-
config = "model_config_list: {\n"
106-
for m in models:
107-
config += " config: {\n"
108-
config += " name: '{}',\n".format(os.path.basename(m))
109-
config += " base_path: '{}',\n".format(m)
110-
config += " model_platform: 'tensorflow'\n"
111-
config += " }\n"
91+
config += " }\n"
11292
config += "}\n"
113-
114-
with open(tfs_config_path, 'w') as f:
115-
f.write(config)
93+
return config
11694

11795

11896
def tfs_command(tfs_grpc_port,
@@ -142,6 +120,10 @@ def find_models():
142120
return models
143121

144122

123+
def find_model_versions(model_path):
124+
return [version.lstrip("0") for version in os.listdir(model_path)]
125+
126+
145127
def _find_saved_model_files(path):
146128
for e in os.scandir(path):
147129
if e.is_dir():

test/integration/local/multi_model_endpoint_test_utils.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,26 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313

14-
import json
1514
import requests
1615

1716
INVOCATION_URL = "http://localhost:8080/models/{}/invoke"
1817
MODELS_URL = "http://localhost:8080/models"
1918
DELETE_MODEL_URL = "http://localhost:8080/models/{}"
2019

2120

22-
def make_headers(content_type="application/json", method="predict"):
23-
headers = {
21+
def make_headers(content_type="application/json", method="predict", version=None):
22+
custom_attributes = "tfs-method={}".format(method)
23+
if version:
24+
custom_attributes += ",tfs-model-version={}".format(version)
25+
26+
return {
2427
"Content-Type": content_type,
25-
"X-Amzn-SageMaker-Custom-Attributes": "tfs-method=%s" % method
28+
"X-Amzn-SageMaker-Custom-Attributes": custom_attributes,
2629
}
27-
return headers
2830

2931

30-
def make_invocation_request(data, model_name, content_type="application/json"):
31-
headers = {
32-
"Content-Type": content_type,
33-
"X-Amzn-SageMaker-Custom-Attributes": "tfs-method=predict"
34-
}
32+
def make_invocation_request(data, model_name, content_type="application/json", version=None):
33+
headers = make_headers(content_type=content_type, method="predict", version=version)
3534
response = requests.post(INVOCATION_URL.format(model_name), data=data, headers=headers)
3635
return response.status_code, response.content.decode("utf-8")
3736

test/integration/local/test_container.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,14 @@ def container(request, docker_base_name, tag, runtime_config):
7171
subprocess.check_call("docker rm -f sagemaker-tensorflow-serving-test".split())
7272

7373

74-
def make_request(data, content_type="application/json", method="predict"):
74+
def make_request(data, content_type="application/json", method="predict", version=None):
75+
custom_attributes = "tfs-model-name=half_plus_three,tfs-method={}".format(method)
76+
if version:
77+
custom_attributes += ",tfs-model-version={}".format(version)
78+
7579
headers = {
7680
"Content-Type": content_type,
77-
"X-Amzn-SageMaker-Custom-Attributes":
78-
"tfs-model-name=half_plus_three,tfs-method=%s" % method
81+
"X-Amzn-SageMaker-Custom-Attributes": custom_attributes,
7982
}
8083
response = requests.post(BASE_URL, data=data, headers=headers)
8184
return json.loads(response.content.decode("utf-8"))
@@ -101,6 +104,18 @@ def test_predict_twice():
101104
assert z == {"predictions": [3.5, 4.0, 5.5]}
102105

103106

107+
def test_predict_specific_versions():
108+
x = {
109+
"instances": [1.0, 2.0, 5.0]
110+
}
111+
112+
y = make_request(json.dumps(x), version=123)
113+
assert y == {"predictions": [3.5, 4.0, 5.5]}
114+
115+
y = make_request(json.dumps(x), version=124)
116+
assert y == {"predictions": [3.5, 4.0, 5.5]}
117+
118+
104119
def test_predict_two_instances():
105120
x = {
106121
"instances": [[1.0, 2.0, 5.0], [1.0, 2.0, 5.0]]

test/integration/local/test_multi_model_endpoint.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,11 @@ def test_load_two_models():
165165
assert y1 == {"predictions": [2.5, 3.0, 4.5]}
166166

167167
# make invocation request to the second model
168-
code_invoke2, y2 = make_invocation_request(json.dumps(x), "half_plus_three")
169-
y2 = json.loads(y2)
170-
assert code_invoke2 == 200
171-
assert y2 == {"predictions": [3.5, 4.0, 5.5]}
168+
for ver in ("123", "124"):
169+
code_invoke2, y2 = make_invocation_request(json.dumps(x), "half_plus_three", version=ver)
170+
y2 = json.loads(y2)
171+
assert code_invoke2 == 200
172+
assert y2 == {"predictions": [3.5, 4.0, 5.5]}
172173

173174
code_list, res3 = make_list_model_request()
174175
res3 = json.loads(res3)

test/integration/local/test_pre_post_processing.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,15 @@ def container(volume, docker_base_name, tag, runtime_config):
7777
subprocess.check_call("docker rm -f sagemaker-tensorflow-serving-test".split())
7878

7979

80-
def make_headers(content_type, method):
81-
headers = {
80+
def make_headers(content_type, method, version=None):
81+
custom_attributes = "tfs-model-name=half_plus_three,tfs-method={}".format(method)
82+
if version:
83+
custom_attributes += ",tfs-model-version={}".format(version)
84+
85+
return {
8286
"Content-Type": content_type,
83-
"X-Amzn-SageMaker-Custom-Attributes": "tfs-model-name=half_plus_three,tfs-method=%s" % method
87+
"X-Amzn-SageMaker-Custom-Attributes": custom_attributes,
8488
}
85-
return headers
8689

8790

8891
def test_predict_json():
@@ -118,6 +121,14 @@ def test_csv_input():
118121
assert response == {"predictions": [3.5, 4.0, 5.5]}
119122

120123

124+
def test_predict_specific_versions():
125+
for version in ("123", "124"):
126+
headers = make_headers("application/json", "predict", version=version)
127+
data = "{\"instances\": [1.0, 2.0, 5.0]}"
128+
response = requests.post(INVOCATIONS_URL, data=data, headers=headers).json()
129+
assert response == {"predictions": [3.5, 4.0, 5.5]}
130+
131+
121132
def test_unsupported_content_type():
122133
headers = make_headers("unsupported-type", "predict")
123134
data = "aW1hZ2UgYnl0ZXM="

test/integration/local/test_pre_post_processing_mme.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,17 @@ def test_csv_input():
135135
assert response == {"predictions": [3.5, 4.0, 5.5]}
136136

137137

138+
@pytest.mark.skip_gpu
139+
def test_specific_versions():
140+
for version in ("123", "124"):
141+
headers = make_headers(content_type="text/csv", version=version)
142+
data = "1.0,2.0,5.0"
143+
response = requests.post(
144+
INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers
145+
).json()
146+
assert response == {"predictions": [3.5, 4.0, 5.5]}
147+
148+
138149
@pytest.mark.skip_gpu
139150
def test_unsupported_content_type():
140151
headers = make_headers("unsupported-type", "predict")

test/integration/local/test_tfs_batching.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313

1414
import os
1515
import subprocess
16-
import sys
17-
import time
1816

1917
import pytest
2018

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
asset-file-contents
Binary file not shown.
Binary file not shown.

test/resources/mme/half_plus_three/abcde/dummy.txt

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
asset-file-contents
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)