Skip to content
This repository was archived by the owner on May 23, 2024. It is now read-only.

Commit f3433a6

Browse files
author
Chuyang Deng
committed
update tfs pre-post-processing file path and test
1 parent 8ee6c5f commit f3433a6

File tree

3 files changed

+60
-23
lines changed

3 files changed

+60
-23
lines changed

docker/build_artifacts/sagemaker/python_service.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import subprocess
1919
import time
20+
import sys
2021

2122
import falcon
2223
import requests
@@ -27,10 +28,8 @@
2728
import tfs_utils
2829

2930
SAGEMAKER_MULTI_MODEL_ENABLED = os.environ.get('SAGEMAKER_MULTI_MODEL', 'false').lower() == 'true'
30-
INFERENCE_SCRIPT_PATH = '/opt/ml/{}/code/inference.py'.format('models'
31-
if SAGEMAKER_MULTI_MODEL_ENABLED
32-
else 'model')
33-
PYTHON_PROCESSING_ENABLED = os.path.exists(INFERENCE_SCRIPT_PATH)
31+
INFERENCE_SCRIPT_PATH = '/opt/ml/model/code/inference.py'
32+
3433
SAGEMAKER_BATCHING_ENABLED = os.environ.get('SAGEMAKER_TFS_ENABLE_BATCHING', 'false').lower()
3534
MODEL_CONFIG_FILE_PATH = '/sagemaker/model-config.cfg'
3635
TFS_GRPC_PORT = os.environ.get('TFS_GRPC_PORT')
@@ -64,21 +63,24 @@ def __init__(self):
6463
self._model_tfs_grpc_port = {}
6564
self._model_tfs_pid = {}
6665
self._tfs_ports = self._parse_sagemaker_port_range(SAGEMAKER_TFS_PORT_RANGE)
66+
# If Multi-Model mode is enabled, dependencies/handlers will be imported
67+
# during the _handle_load_model_post()
68+
self.model_handlers = {}
6769
else:
6870
self._tfs_grpc_port = TFS_GRPC_PORT
6971
self._tfs_rest_port = TFS_REST_PORT
7072

73+
if os.path.exists(INFERENCE_SCRIPT_PATH):
74+
self._handler, self._input_handler, self._output_handler = self._import_handlers()
75+
self._handlers = self._make_handler(self._handler,
76+
self._input_handler,
77+
self._output_handler)
78+
else:
79+
self._handlers = default_handler
80+
7181
self._tfs_enable_batching = SAGEMAKER_BATCHING_ENABLED == 'true'
7282
self._tfs_default_model_name = os.environ.get('TFS_DEFAULT_MODEL_NAME', "None")
7383

74-
if PYTHON_PROCESSING_ENABLED:
75-
self._handler, self._input_handler, self._output_handler = self._import_handlers()
76-
self._handlers = self._make_handler(self._handler,
77-
self._input_handler,
78-
self._output_handler)
79-
else:
80-
self._handlers = default_handler
81-
8284
def on_post(self, req, res, model_name=None):
8385
log.info(req.uri)
8486
if model_name or "invocations" in req.uri:
@@ -129,6 +131,9 @@ def _handle_load_model_post(self, res, data): # noqa: C901
129131
# validate model files are in the specified base_path
130132
if self.validate_model_dir(base_path):
131133
try:
134+
# install custom dependencies, import handlers
135+
self._import_custom_modules(model_name)
136+
132137
tfs_config = tfs_utils.create_tfs_config_individual_model(model_name, base_path)
133138
tfs_config_file = '/sagemaker/tfs-config/{}/model-config.cfg'.format(model_name)
134139
log.info('tensorflow serving model config: \n%s\n', tfs_config)
@@ -197,6 +202,33 @@ def _handle_load_model_post(self, res, data): # noqa: C901
197202
model_name)
198203
})
199204

205+
def _import_custom_modules(self, model_name):
206+
inference_script_path = "/opt/ml/models/{}/model/code/inference.py".format(model_name)
207+
requirements_file_path = "/opt/ml/models/{}/model/code/requirements.txt".format(model_name)
208+
python_lib_path = "/opt/ml/models/{}/model/code/lib".format(model_name)
209+
210+
if os.path.exists(requirements_file_path):
211+
log.info("pip install dependencies from requirements.txt")
212+
pip_install_cmd = "pip3 install -r {}".format(requirements_file_path)
213+
try:
214+
subprocess.check_call(pip_install_cmd.split())
215+
except subprocess.CalledProcessError:
216+
log.error('failed to install required packages, exiting.')
217+
raise ChildProcessError('failed to install required packages.')
218+
219+
if os.path.exists(python_lib_path):
220+
log.info("add Python code library path")
221+
sys.path.append(python_lib_path)
222+
223+
if os.path.exists(inference_script_path):
224+
handler, input_handler, output_handler = self._import_handlers(model_name)
225+
model_handlers = self._make_handler(handler,
226+
input_handler,
227+
output_handler)
228+
self.model_handlers[model_name] = model_handlers
229+
else:
230+
self.model_handlers[model_name] = default_handler
231+
200232
def _cleanup_config_file(self, config_file):
201233
if os.path.exists(config_file):
202234
os.remove(config_file)
@@ -249,16 +281,24 @@ def _handle_invocation_post(self, req, res, model_name=None):
249281

250282
try:
251283
res.status = falcon.HTTP_200
252-
res.body, res.content_type = self._handlers(data, context)
284+
if SAGEMAKER_MULTI_MODEL_ENABLED:
285+
with lock():
286+
handlers = self.model_handlers[model_name]
287+
res.body, res.content_type = handlers(data, context)
288+
else:
289+
res.body, res.content_type = self._handlers(data, context)
253290
except Exception as e: # pylint: disable=broad-except
254291
log.exception('exception handling request: {}'.format(e))
255292
res.status = falcon.HTTP_500
256293
res.body = json.dumps({
257294
'error': str(e)
258295
}).encode('utf-8') # pylint: disable=E1101
259296

260-
def _import_handlers(self):
261-
spec = importlib.util.spec_from_file_location('inference', INFERENCE_SCRIPT_PATH)
297+
def _import_handlers(self, model_name=None):
298+
inference_script = INFERENCE_SCRIPT_PATH
299+
if model_name:
300+
inference_script = "/opt/ml/models/{}/model/code/inference.py".format(model_name)
301+
spec = importlib.util.spec_from_file_location('inference', inference_script)
262302
inference = importlib.util.module_from_spec(spec)
263303
spec.loader.exec_module(inference)
264304

@@ -358,7 +398,6 @@ def validate_model_dir(self, model_path):
358398
versions = []
359399
for _, dirs, _ in os.walk(model_path):
360400
for dirname in dirs:
361-
log.info("dirname: {}".format(dirname))
362401
if dirname.isdigit():
363402
versions.append(dirname)
364403
return self.validate_model_versions(versions)
@@ -383,7 +422,6 @@ def on_get(self, req, res): # pylint: disable=W0613
383422

384423
class ServiceResources:
385424
def __init__(self):
386-
self._enable_python_processing = PYTHON_PROCESSING_ENABLED
387425
self._enable_model_manager = SAGEMAKER_MULTI_MODEL_ENABLED
388426
self._python_service_resource = PythonServiceResource()
389427
self._ping_resource = PingResource()

test/integration/local/test_multi_model_endpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import pytest
2121
import requests
2222

23-
from multi_model_endpoint_test_utils import make_invocation_request, make_list_model_request, \
24-
make_get_model_request, make_load_model_request, make_unload_model_request
23+
from multi_model_endpoint_test_utils import make_invocation_request, make_list_model_request,\
24+
make_load_model_request, make_unload_model_request
2525

2626
PING_URL = 'http://localhost:8080/ping'
2727

test/integration/local/test_pre_post_processing_mme.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222

2323
import requests
2424

25-
from multi_model_endpoint_test_utils import make_invocation_request, make_list_model_request, \
26-
make_get_model_request, make_load_model_request, make_unload_model_request, make_headers
25+
from multi_model_endpoint_test_utils import make_load_model_request, make_headers
2726

2827

2928
PING_URL = 'http://localhost:8080/ping'
@@ -57,7 +56,7 @@ def container(volume, docker_base_name, tag, runtime_config):
5756
try:
5857
command = (
5958
'docker run {}--name sagemaker-tensorflow-serving-test -p 8080:8080'
60-
' --mount type=volume,source={},target=/opt/ml/models,readonly'
59+
' --mount type=volume,source={},target=/opt/ml/models/half_plus_three/model,readonly'
6160
' -e SAGEMAKER_TFS_NGINX_LOGLEVEL=info'
6261
' -e SAGEMAKER_BIND_TO_PORT=8080'
6362
' -e SAGEMAKER_SAFE_PORT_RANGE=9000-9999'
@@ -87,7 +86,7 @@ def container(volume, docker_base_name, tag, runtime_config):
8786
def model():
8887
model_data = {
8988
'model_name': MODEL_NAME,
90-
'url': '/opt/ml/models/half_plus_three'
89+
'url': '/opt/ml/models/half_plus_three/model/half_plus_three'
9190
}
9291
make_load_model_request(json.dumps(model_data))
9392
return MODEL_NAME

0 commit comments

Comments
 (0)