Skip to content
This repository was archived by the owner on May 23, 2024. It is now read-only.

Commit 8860ba2

Browse files
committed
Feature: Support multiple inference.py files and universal inference.py file along with universal requirements.txt file
1 parent 1a265db commit 8860ba2

File tree

6 files changed

+104
-26
lines changed

6 files changed

+104
-26
lines changed

docker/build_artifacts/sagemaker/python_service.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import subprocess
1919
import grpc
20+
import sys
2021

2122
import falcon
2223
import requests
@@ -143,6 +144,7 @@ def _handle_load_model_post(self, res, data): # noqa: C901
143144
# validate model files are in the specified base_path
144145
if self.validate_model_dir(base_path):
145146
try:
147+
self._import_custom_modules(model_name)
146148
tfs_config = tfs_utils.create_tfs_config_individual_model(model_name, base_path)
147149
tfs_config_file = "/sagemaker/tfs-config/{}/model-config.cfg".format(model_name)
148150
log.info("tensorflow serving model config: \n%s\n", tfs_config)
@@ -221,6 +223,21 @@ def _handle_load_model_post(self, res, data): # noqa: C901
221223
}
222224
)
223225

226+
def _import_custom_modules(self, model_name):
227+
inference_script_path = "/opt/ml/models/{}/model/code/inference.py".format(model_name)
228+
python_lib_path = "/opt/ml/models/{}/model/code/lib".format(model_name)
229+
230+
if os.path.exists(python_lib_path):
231+
log.info("add Python code library path")
232+
sys.path.append(python_lib_path)
233+
234+
if os.path.exists(inference_script_path):
235+
handler, input_handler, output_handler = self._import_handlers(inference_script_path)
236+
model_handlers = self._make_handler(handler, input_handler, output_handler)
237+
self.model_handlers[model_name] = model_handlers
238+
else:
239+
self.model_handlers[model_name] = default_handler
240+
224241
def _cleanup_config_file(self, config_file):
225242
if os.path.exists(config_file):
226243
os.remove(config_file)
@@ -264,8 +281,11 @@ def _handle_invocation_post(self, req, res, model_name=None):
264281

265282
try:
266283
res.status = falcon.HTTP_200
284+
handlers = self._handlers
285+
if SAGEMAKER_MULTI_MODEL_ENABLED and self.model_handlers:
286+
handlers = self.model_handlers[model_name]
287+
res.body, res.content_type = handlers(data, context)
267288

268-
res.body, res.content_type = self._handlers(data, context)
269289
except Exception as e: # pylint: disable=broad-except
270290
log.exception("exception handling request: {}".format(e))
271291
res.status = falcon.HTTP_500
@@ -276,8 +296,7 @@ def _setup_channel(self, grpc_port):
276296
log.info("Creating grpc channel for port: %s", grpc_port)
277297
self._channels[grpc_port] = grpc.insecure_channel("localhost:{}".format(grpc_port))
278298

279-
def _import_handlers(self):
280-
inference_script = INFERENCE_SCRIPT_PATH
299+
def _import_handlers(self, inference_script=INFERENCE_SCRIPT_PATH):
281300
spec = importlib.util.spec_from_file_location("inference", inference_script)
282301
inference = importlib.util.module_from_spec(spec)
283302
spec.loader.exec_module(inference)

docker/build_artifacts/sagemaker/serve.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def __init__(self):
134134
os.environ["TFS_REST_PORTS"] = self._tfs_rest_concat_ports
135135

136136
def _need_python_service(self):
137-
if os.path.exists(INFERENCE_PATH):
137+
if (os.path.exists(INFERENCE_PATH) or os.path.exists(REQUIREMENTS_PATH)
138+
or os.path.exists(PYTHON_LIB_PATH)):
138139
self._enable_python_service = True
139140
if os.environ.get("SAGEMAKER_MULTI_MODEL_UNIVERSAL_BUCKET") and os.environ.get(
140141
"SAGEMAKER_MULTI_MODEL_UNIVERSAL_PREFIX"

test/integration/local/test_pre_post_processing_mme.py

Lines changed: 80 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424

2525
from multi_model_endpoint_test_utils import make_load_model_request, make_headers
2626

27-
2827
PING_URL = "http://localhost:8080/ping"
2928
INVOCATION_URL = "http://localhost:8080/models/{}/invoke"
30-
MODEL_NAME = "half_plus_three"
29+
MODEL_NAME_1 = "half_plus_three"
30+
MODEL_NAME_2 = "half_plus_two"
3131

3232

3333
@pytest.fixture(scope="session", autouse=True)
@@ -74,13 +74,22 @@ def container(docker_base_name, tag, runtime_config):
7474

7575

7676
@pytest.fixture
77-
def model():
77+
def model1():
7878
model_data = {
79-
"model_name": MODEL_NAME,
79+
"model_name": MODEL_NAME_1,
8080
"url": "/opt/ml/models/half_plus_three/model/half_plus_three"
8181
}
8282
make_load_model_request(json.dumps(model_data))
83-
return MODEL_NAME
83+
return MODEL_NAME_1
84+
85+
@pytest.fixture
86+
def model2():
87+
model_data = {
88+
"model_name": MODEL_NAME_2,
89+
"url": "/opt/ml/models/half_plus_two/model/half_plus_two"
90+
}
91+
make_load_model_request(json.dumps(model_data))
92+
return MODEL_NAME_2
8493

8594

8695
@pytest.mark.skip_gpu
@@ -90,20 +99,37 @@ def test_ping_service():
9099

91100

92101
@pytest.mark.skip_gpu
93-
def test_predict_json(model):
102+
def test_predict_json(model1, model2):
94103
headers = make_headers()
95104
data = "{\"instances\": [1.0, 2.0, 5.0]}"
96-
response = requests.post(INVOCATION_URL.format(model), data=data, headers=headers).json()
97-
assert response == {"predictions": [3.5, 4.0, 5.5]}
105+
response1 = requests.post(INVOCATION_URL.format(model1), data=data, headers=headers).json()
106+
print("Response 1:")
107+
print(response1)
108+
assert response1 == {"predictions": [3.5, 4.0, 5.5]}
109+
response2 = requests.post(INVOCATION_URL.format(model2), data=data, headers=headers).json()
110+
print("Response 2:")
111+
print(response2)
112+
assert response2 == {"predictions": [2.5, 3.0, 4.5]}
98113

99114

100115
@pytest.mark.skip_gpu
101116
def test_zero_content():
102117
headers = make_headers()
103118
x = ""
104-
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=x, headers=headers)
105-
assert 500 == response.status_code
106-
assert "document is empty" in response.text
119+
response1 = requests.post(INVOCATION_URL.format(MODEL_NAME_1), data=x, headers=headers)
120+
print("Response 1 status code:")
121+
print(response1.status_code)
122+
print("Response 1 text:")
123+
print(response1.text)
124+
assert 500 == response1.status_code
125+
assert "document is empty" in response1.text
126+
response2 = requests.post(INVOCATION_URL.format(MODEL_NAME_2), data=x, headers=headers)
127+
print("Response 2 status code:")
128+
print(response2.status_code)
129+
print("Response 2 text:")
130+
print(response2.text)
131+
assert 500 == response2.status_code
132+
assert "document is empty" in response2.text
107133

108134

109135
@pytest.mark.skip_gpu
@@ -113,34 +139,66 @@ def test_large_input():
113139
with open(data_file, "r") as file:
114140
x = file.read()
115141
headers = make_headers(content_type="text/csv")
116-
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=x, headers=headers).json()
117-
predictions = response["predictions"]
118-
assert len(predictions) == 753936
142+
response1 = requests.post(INVOCATION_URL.format(MODEL_NAME_1), data=x, headers=headers).json()
143+
predictions1 = response1["predictions"]
144+
print("Response 1:")
145+
print(response1)
146+
assert len(predictions1) == 753936
147+
response2 = requests.post(INVOCATION_URL.format(MODEL_NAME_2), data=x, headers=headers).json()
148+
print("Response 2:")
149+
print(response2)
150+
predictions2 = response2["predictions"]
151+
assert len(predictions2) == 753936
119152

120153

121154
@pytest.mark.skip_gpu
122155
def test_csv_input():
123156
headers = make_headers(content_type="text/csv")
124157
data = "1.0,2.0,5.0"
125-
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers).json()
126-
assert response == {"predictions": [3.5, 4.0, 5.5]}
158+
response1 = requests.post(INVOCATION_URL.format(MODEL_NAME_1), data=data, headers=headers).json()
159+
print("Response 1:")
160+
print(response1)
161+
assert response1 == {"predictions": [3.5, 4.0, 5.5]}
162+
response2 = requests.post(INVOCATION_URL.format(MODEL_NAME_2), data=data, headers=headers).json()
163+
print("Response 2:")
164+
print(response2)
165+
assert response2 == {"predictions": [2.5, 3.0, 4.5]}
127166

128167

129168
@pytest.mark.skip_gpu
130169
def test_specific_versions():
131170
for version in ("123", "124"):
132171
headers = make_headers(content_type="text/csv", version=version)
133172
data = "1.0,2.0,5.0"
134-
response = requests.post(
135-
INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers
173+
response1 = requests.post(
174+
INVOCATION_URL.format(MODEL_NAME_1), data=data, headers=headers
175+
).json()
176+
print("Response 1")
177+
print(response1)
178+
assert response1 == {"predictions": [3.5, 4.0, 5.5]}
179+
response2 = requests.post(
180+
INVOCATION_URL.format(MODEL_NAME_2), data=data, headers=headers
136181
).json()
137-
assert response == {"predictions": [3.5, 4.0, 5.5]}
182+
print("Response 2:")
183+
print(response2)
184+
assert response2 == {"predictions": [2.5, 3.0, 4.5]}
138185

139186

140187
@pytest.mark.skip_gpu
141188
def test_unsupported_content_type():
142189
headers = make_headers("unsupported-type", "predict")
143190
data = "aW1hZ2UgYnl0ZXM="
144-
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers)
145-
assert 500 == response.status_code
146-
assert "unsupported content type" in response.text
191+
response1 = requests.post(INVOCATION_URL.format(MODEL_NAME_1), data=data, headers=headers)
192+
print("Response 1 status code:")
193+
print(response1.status_code)
194+
print("Response 1 text:")
195+
print(response1.text)
196+
assert 500 == response1.status_code
197+
assert "unsupported content type" in response1.text
198+
response2 = requests.post(INVOCATION_URL.format(MODEL_NAME_2), data=data, headers=headers)
199+
print("Response 2 status code:")
200+
print(response2.status_code)
201+
print("Response 2 text:")
202+
print(response2.text)
203+
assert 500 == response2.status_code
204+
assert "unsupported content type" in response2.text
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)