17
17
import os
18
18
import subprocess
19
19
import time
20
+ import sys
20
21
21
22
import falcon
22
23
import requests
27
28
import tfs_utils
28
29
29
30
SAGEMAKER_MULTI_MODEL_ENABLED = os .environ .get ('SAGEMAKER_MULTI_MODEL' , 'false' ).lower () == 'true'
30
- INFERENCE_SCRIPT_PATH = '/opt/ml/{}/code/inference.py' .format ('models'
31
- if SAGEMAKER_MULTI_MODEL_ENABLED
32
- else 'model' )
33
- PYTHON_PROCESSING_ENABLED = os .path .exists (INFERENCE_SCRIPT_PATH )
31
+ INFERENCE_SCRIPT_PATH = '/opt/ml/model/code/inference.py'
32
+
34
33
SAGEMAKER_BATCHING_ENABLED = os .environ .get ('SAGEMAKER_TFS_ENABLE_BATCHING' , 'false' ).lower ()
35
34
MODEL_CONFIG_FILE_PATH = '/sagemaker/model-config.cfg'
36
35
TFS_GRPC_PORT = os .environ .get ('TFS_GRPC_PORT' )
@@ -64,21 +63,24 @@ def __init__(self):
64
63
self ._model_tfs_grpc_port = {}
65
64
self ._model_tfs_pid = {}
66
65
self ._tfs_ports = self ._parse_sagemaker_port_range (SAGEMAKER_TFS_PORT_RANGE )
66
+ # If Multi-Model mode is enabled, dependencies/handlers will be imported
67
+ # during the _handle_load_model_post()
68
+ self .model_handlers = {}
67
69
else :
68
70
self ._tfs_grpc_port = TFS_GRPC_PORT
69
71
self ._tfs_rest_port = TFS_REST_PORT
70
72
73
+ if os .path .exists (INFERENCE_SCRIPT_PATH ):
74
+ self ._handler , self ._input_handler , self ._output_handler = self ._import_handlers ()
75
+ self ._handlers = self ._make_handler (self ._handler ,
76
+ self ._input_handler ,
77
+ self ._output_handler )
78
+ else :
79
+ self ._handlers = default_handler
80
+
71
81
self ._tfs_enable_batching = SAGEMAKER_BATCHING_ENABLED == 'true'
72
82
self ._tfs_default_model_name = os .environ .get ('TFS_DEFAULT_MODEL_NAME' , "None" )
73
83
74
- if PYTHON_PROCESSING_ENABLED :
75
- self ._handler , self ._input_handler , self ._output_handler = self ._import_handlers ()
76
- self ._handlers = self ._make_handler (self ._handler ,
77
- self ._input_handler ,
78
- self ._output_handler )
79
- else :
80
- self ._handlers = default_handler
81
-
82
84
def on_post (self , req , res , model_name = None ):
83
85
log .info (req .uri )
84
86
if model_name or "invocations" in req .uri :
@@ -129,6 +131,9 @@ def _handle_load_model_post(self, res, data): # noqa: C901
129
131
# validate model files are in the specified base_path
130
132
if self .validate_model_dir (base_path ):
131
133
try :
134
+ # install custom dependencies, import handlers
135
+ self ._import_custom_modules (model_name )
136
+
132
137
tfs_config = tfs_utils .create_tfs_config_individual_model (model_name , base_path )
133
138
tfs_config_file = '/sagemaker/tfs-config/{}/model-config.cfg' .format (model_name )
134
139
log .info ('tensorflow serving model config: \n %s\n ' , tfs_config )
@@ -197,6 +202,33 @@ def _handle_load_model_post(self, res, data): # noqa: C901
197
202
model_name )
198
203
})
199
204
205
+ def _import_custom_modules (self , model_name ):
206
+ inference_script_path = "/opt/ml/models/{}/model/code/inference.py" .format (model_name )
207
+ requirements_file_path = "/opt/ml/models/{}/model/code/requirements.txt" .format (model_name )
208
+ python_lib_path = "/opt/ml/models/{}/model/code/lib" .format (model_name )
209
+
210
+ if os .path .exists (requirements_file_path ):
211
+ log .info ("pip install dependencies from requirements.txt" )
212
+ pip_install_cmd = "pip3 install -r {}" .format (requirements_file_path )
213
+ try :
214
+ subprocess .check_call (pip_install_cmd .split ())
215
+ except subprocess .CalledProcessError :
216
+ log .error ('failed to install required packages, exiting.' )
217
+ raise ChildProcessError ('failed to install required packages.' )
218
+
219
+ if os .path .exists (python_lib_path ):
220
+ log .info ("add Python code library path" )
221
+ sys .path .append (python_lib_path )
222
+
223
+ if os .path .exists (inference_script_path ):
224
+ handler , input_handler , output_handler = self ._import_handlers (model_name )
225
+ model_handlers = self ._make_handler (handler ,
226
+ input_handler ,
227
+ output_handler )
228
+ self .model_handlers [model_name ] = model_handlers
229
+ else :
230
+ self .model_handlers [model_name ] = default_handler
231
+
200
232
def _cleanup_config_file (self , config_file ):
201
233
if os .path .exists (config_file ):
202
234
os .remove (config_file )
@@ -249,16 +281,24 @@ def _handle_invocation_post(self, req, res, model_name=None):
249
281
250
282
try :
251
283
res .status = falcon .HTTP_200
252
- res .body , res .content_type = self ._handlers (data , context )
284
+ if SAGEMAKER_MULTI_MODEL_ENABLED :
285
+ with lock ():
286
+ handlers = self .model_handlers [model_name ]
287
+ res .body , res .content_type = handlers (data , context )
288
+ else :
289
+ res .body , res .content_type = self ._handlers (data , context )
253
290
except Exception as e : # pylint: disable=broad-except
254
291
log .exception ('exception handling request: {}' .format (e ))
255
292
res .status = falcon .HTTP_500
256
293
res .body = json .dumps ({
257
294
'error' : str (e )
258
295
}).encode ('utf-8' ) # pylint: disable=E1101
259
296
260
- def _import_handlers (self ):
261
- spec = importlib .util .spec_from_file_location ('inference' , INFERENCE_SCRIPT_PATH )
297
+ def _import_handlers (self , model_name = None ):
298
+ inference_script = INFERENCE_SCRIPT_PATH
299
+ if model_name :
300
+ inference_script = "/opt/ml/models/{}/model/code/inference.py" .format (model_name )
301
+ spec = importlib .util .spec_from_file_location ('inference' , inference_script )
262
302
inference = importlib .util .module_from_spec (spec )
263
303
spec .loader .exec_module (inference )
264
304
@@ -358,7 +398,6 @@ def validate_model_dir(self, model_path):
358
398
versions = []
359
399
for _ , dirs , _ in os .walk (model_path ):
360
400
for dirname in dirs :
361
- log .info ("dirname: {}" .format (dirname ))
362
401
if dirname .isdigit ():
363
402
versions .append (dirname )
364
403
return self .validate_model_versions (versions )
@@ -383,7 +422,6 @@ def on_get(self, req, res): # pylint: disable=W0613
383
422
384
423
class ServiceResources :
385
424
def __init__ (self ):
386
- self ._enable_python_processing = PYTHON_PROCESSING_ENABLED
387
425
self ._enable_model_manager = SAGEMAKER_MULTI_MODEL_ENABLED
388
426
self ._python_service_resource = PythonServiceResource ()
389
427
self ._ping_resource = PingResource ()
0 commit comments