@@ -89,30 +89,29 @@ def __init__(self):
89
89
parts = self ._sagemaker_port_range .split ("-" )
90
90
low = int (parts [0 ])
91
91
hi = int (parts [1 ])
92
- self ._tfs_grpc_port = []
93
- self ._tfs_rest_port = []
92
+ self ._tfs_grpc_ports = []
93
+ self ._tfs_rest_ports = []
94
94
if low + 2 * self ._tfs_instance_count > hi :
95
95
raise ValueError ("not enough ports available in SAGEMAKER_SAFE_PORT_RANGE ({})"
96
96
.format (self ._sagemaker_port_range ))
97
- self ._tfs_grpc_port_range = "{}-{}" .format (low ,
98
- low + 2 * self ._tfs_instance_count )
99
- self ._tfs_rest_port_range = "{}-{}" .format (low + 1 ,
100
- low + 2 * self ._tfs_instance_count + 1 )
97
+ # select non-overlapping grpc and rest ports based on tfs instance count
101
98
for i in range (self ._tfs_instance_count ):
102
- self ._tfs_grpc_port .append (str (low + 2 * i ))
103
- self ._tfs_rest_port .append (str (low + 2 * i + 1 ))
104
- # set environment variable for python service
105
- os . environ [ "TFS_GRPC_PORT_RANGE" ] = self ._tfs_grpc_port_range
106
- os . environ [ "TFS_REST_PORT_RANGE" ] = self ._tfs_rest_port_range
99
+ self ._tfs_grpc_ports .append (str (low + 2 * i ))
100
+ self ._tfs_rest_ports .append (str (low + 2 * i + 1 ))
101
+ # concat selected ports respectively in order to pass them to python service
102
+ self . _tfs_grpc_concat_ports = self ._concat_ports ( self . _tfs_grpc_ports )
103
+ self . _tfs_rest_concat_ports = self ._concat_ports ( self . _tfs_rest_ports )
107
104
else :
108
105
# just use the standard default ports
109
- self ._tfs_grpc_port = ["9000" ]
110
- self ._tfs_rest_port = ["8501" ]
111
- self ._tfs_grpc_port_range = "9000-9000"
112
- self ._tfs_rest_port_range = "8501-8501"
113
- # set environment variable for python service
114
- os .environ ["TFS_GRPC_PORT_RANGE" ] = self ._tfs_grpc_port_range
115
- os .environ ["TFS_REST_PORT_RANGE" ] = self ._tfs_rest_port_range
106
+ self ._tfs_grpc_ports = ["9000" ]
107
+ self ._tfs_rest_ports = ["8501" ]
108
+ # provide single concat port here for default case
109
+ self ._tfs_grpc_concat_ports = "9000"
110
+ self ._tfs_rest_concat_ports = "8501"
111
+
112
+ # set environment variable for python service
113
+ os .environ ["TFS_GRPC_PORTS" ] = self ._tfs_grpc_concat_ports
114
+ os .environ ["TFS_REST_PORTS" ] = self ._tfs_rest_concat_ports
116
115
117
116
def _need_python_service (self ):
118
117
if os .path .exists (INFERENCE_PATH ):
@@ -121,6 +120,11 @@ def _need_python_service(self):
121
120
and os .environ .get ("SAGEMAKER_MULTI_MODEL_UNIVERSAL_PREFIX" ):
122
121
self ._enable_python_service = True
123
122
123
+ def _concat_ports (self , ports ):
124
+ str_ports = [str (port ) for port in ports ]
125
+ concat_str_ports = "," .join (str_ports )
126
+ return concat_str_ports
127
+
124
128
def _create_tfs_config (self ):
125
129
models = tfs_utils .find_models ()
126
130
@@ -194,13 +198,13 @@ def _setup_gunicorn(self):
194
198
gunicorn_command = (
195
199
"gunicorn -b unix:/tmp/gunicorn.sock -k {} --chdir /sagemaker "
196
200
"--workers {} --threads {} "
197
- "{}{} -e TFS_GRPC_PORT_RANGE ={} -e TFS_REST_PORT_RANGE ={} "
201
+ "{}{} -e TFS_GRPC_PORTS ={} -e TFS_REST_PORTS ={} "
198
202
"-e SAGEMAKER_MULTI_MODEL={} -e SAGEMAKER_SAFE_PORT_RANGE={} "
199
203
"-e SAGEMAKER_TFS_WAIT_TIME_SECONDS={} "
200
204
"python_service:app" ).format (self ._gunicorn_worker_class ,
201
205
self ._gunicorn_workers , self ._gunicorn_threads ,
202
206
python_path_option , "," .join (python_path_content ),
203
- self ._tfs_grpc_port_range , self ._tfs_rest_port_range ,
207
+ self ._tfs_grpc_concat_ports , self ._tfs_rest_concat_ports ,
204
208
self ._tfs_enable_multi_model_endpoint ,
205
209
self ._sagemaker_port_range ,
206
210
self ._tfs_wait_time_seconds )
@@ -230,7 +234,7 @@ def _download_scripts(self, bucket, prefix):
230
234
def _create_nginx_tfs_upstream (self ):
231
235
indentation = " "
232
236
tfs_upstream = ""
233
- for port in self ._tfs_rest_port :
237
+ for port in self ._tfs_rest_ports :
234
238
tfs_upstream += "{}server localhost:{};\n " .format (indentation , port )
235
239
tfs_upstream = tfs_upstream [len (indentation ):- 2 ]
236
240
@@ -334,7 +338,7 @@ def _wait_for_gunicorn(self):
334
338
335
339
def _wait_for_tfs (self ):
336
340
for i in range (self ._tfs_instance_count ):
337
- tfs_utils .wait_for_model (self ._tfs_rest_port [i ],
341
+ tfs_utils .wait_for_model (self ._tfs_rest_ports [i ],
338
342
self ._tfs_default_model_name , self ._tfs_wait_time_seconds )
339
343
340
344
@contextmanager
@@ -370,8 +374,8 @@ def _restart_single_tfs(self, pid):
370
374
371
375
def _start_single_tfs (self , instance_id ):
372
376
cmd = tfs_utils .tfs_command (
373
- self ._tfs_grpc_port [instance_id ],
374
- self ._tfs_rest_port [instance_id ],
377
+ self ._tfs_grpc_ports [instance_id ],
378
+ self ._tfs_rest_ports [instance_id ],
375
379
self ._tfs_config_path ,
376
380
self ._tfs_enable_batching ,
377
381
self ._tfs_batching_config_path ,
0 commit comments