Skip to content
This repository was archived by the owner on May 23, 2024. It is now read-only.

Commit a980f35

Browse files
committed
feature: add model_version_policy to model config
1 parent 5a9690d commit a980f35

File tree

7 files changed

+33
-5
lines changed

7 files changed

+33
-5
lines changed

docker/build_artifacts/sagemaker/serve.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,17 @@ def _create_tfs_config(self):
9999
config = "model_config_list: {\n"
100100
for m in models:
101101
config += " config: {\n"
102-
config += " name: '{}',\n".format(os.path.basename(m))
103-
config += " base_path: '{}',\n".format(m)
102+
config += " name: '{}'\n".format(os.path.basename(m))
103+
config += " base_path: '{}'\n".format(m)
104104
config += " model_platform: 'tensorflow'\n"
105+
106+
config += " model_version_policy: {\n"
107+
config += " specific: {\n"
108+
for version in tfs_utils.find_model_versions(m):
109+
config += " versions: {}\n".format(version)
110+
config += " }\n"
111+
config += " }\n"
112+
105113
config += " }\n"
106114
config += "}\n"
107115

docker/build_artifacts/sagemaker/tfs_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ def find_models():
142142
return models
143143

144144

145+
def find_model_versions(model_path):
146+
return [version.lstrip("0") for version in os.listdir(model_path)]
147+
148+
145149
def _find_saved_model_files(path):
146150
for e in os.scandir(path):
147151
if e.is_dir():

test/integration/local/test_container.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,14 @@ def container(request, docker_base_name, tag, runtime_config):
7171
subprocess.check_call("docker rm -f sagemaker-tensorflow-serving-test".split())
7272

7373

74-
def make_request(data, content_type="application/json", method="predict"):
74+
def make_request(data, content_type="application/json", method="predict", version=None):
75+
custom_attributes = "tfs-model-name=half_plus_three,tfs-method={}".format(method)
76+
if version:
77+
custom_attributes += ",tfs-model-version={}".format(version)
78+
7579
headers = {
7680
"Content-Type": content_type,
77-
"X-Amzn-SageMaker-Custom-Attributes":
78-
"tfs-model-name=half_plus_three,tfs-method=%s" % method
81+
"X-Amzn-SageMaker-Custom-Attributes": custom_attributes,
7982
}
8083
response = requests.post(BASE_URL, data=data, headers=headers)
8184
return json.loads(response.content.decode("utf-8"))
@@ -101,6 +104,18 @@ def test_predict_twice():
101104
assert z == {"predictions": [3.5, 4.0, 5.5]}
102105

103106

107+
def test_predict_specific_versions():
108+
x = {
109+
"instances": [1.0, 2.0, 5.0]
110+
}
111+
112+
y = make_request(json.dumps(x), version=123)
113+
assert y == {"predictions": [3.5, 4.0, 5.5]}
114+
115+
y = make_request(json.dumps(x), version=124)
116+
assert y == {"predictions": [3.5, 4.0, 5.5]}
117+
118+
104119
def test_predict_two_instances():
105120
x = {
106121
"instances": [[1.0, 2.0, 5.0], [1.0, 2.0, 5.0]]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
asset-file-contents
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)