Skip to content
This repository was archived by the owner on May 23, 2024. It is now read-only.

fix: modify the way port number passing #210

Merged
merged 8 commits into from
Jun 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions docker/build_artifacts/sagemaker/python_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@

SAGEMAKER_BATCHING_ENABLED = os.environ.get("SAGEMAKER_TFS_ENABLE_BATCHING", "false").lower()
MODEL_CONFIG_FILE_PATH = "/sagemaker/model-config.cfg"
TFS_GRPC_PORT_RANGE = os.environ.get("TFS_GRPC_PORT_RANGE")
TFS_REST_PORT_RANGE = os.environ.get("TFS_REST_PORT_RANGE")
TFS_GRPC_PORTS = os.environ.get("TFS_GRPC_PORTS")
TFS_REST_PORTS = os.environ.get("TFS_REST_PORTS")
SAGEMAKER_TFS_PORT_RANGE = os.environ.get("SAGEMAKER_SAFE_PORT_RANGE")
TFS_INSTANCE_COUNT = int(os.environ.get("SAGEMAKER_TFS_INSTANCE_COUNT", "1"))

Expand Down Expand Up @@ -69,8 +69,8 @@ def __init__(self):
# during the _handle_load_model_post()
self.model_handlers = {}
else:
self._tfs_grpc_ports = self._parse_sagemaker_port_range(TFS_GRPC_PORT_RANGE)
self._tfs_rest_ports = self._parse_sagemaker_port_range(TFS_REST_PORT_RANGE)
self._tfs_grpc_ports = self._parse_concat_ports(TFS_GRPC_PORTS)
self._tfs_rest_ports = self._parse_concat_ports(TFS_REST_PORTS)

self._channels = {}
for grpc_port in self._tfs_grpc_ports:
Expand Down Expand Up @@ -98,16 +98,11 @@ def on_post(self, req, res, model_name=None):
data = json.loads(req.stream.read().decode("utf-8"))
self._handle_load_model_post(res, data)

def _parse_sagemaker_port_range(self, port_range):
lower, upper = port_range.split('-')
lower = int(lower)
upper = int(upper)
if lower == upper:
return [lower]
return [lower + 2 * i for i in range(TFS_INSTANCE_COUNT)]
def _parse_concat_ports(self, concat_ports):
return concat_ports.split(",")

def _pick_port(self, ports):
return str(random.choice(ports))
return random.choice(ports)

def _parse_sagemaker_port_range_mme(self, port_range):
lower, upper = port_range.split('-')
Expand Down Expand Up @@ -254,7 +249,7 @@ def _handle_invocation_post(self, req, res, model_name=None):
rest_port = self._pick_port(self._tfs_rest_ports)
data, context = tfs_utils.parse_request(req, rest_port, grpc_port,
self._tfs_default_model_name,
channel=self._channels[int(grpc_port)])
channel=self._channels[grpc_port])

try:
res.status = falcon.HTTP_200
Expand Down
52 changes: 28 additions & 24 deletions docker/build_artifacts/sagemaker/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,30 +89,29 @@ def __init__(self):
parts = self._sagemaker_port_range.split("-")
low = int(parts[0])
hi = int(parts[1])
self._tfs_grpc_port = []
self._tfs_rest_port = []
self._tfs_grpc_ports = []
self._tfs_rest_ports = []
if low + 2 * self._tfs_instance_count > hi:
raise ValueError("not enough ports available in SAGEMAKER_SAFE_PORT_RANGE ({})"
.format(self._sagemaker_port_range))
self._tfs_grpc_port_range = "{}-{}".format(low,
low + 2 * self._tfs_instance_count)
self._tfs_rest_port_range = "{}-{}".format(low + 1,
low + 2 * self._tfs_instance_count + 1)
# select non-overlapping grpc and rest ports based on tfs instance count
for i in range(self._tfs_instance_count):
self._tfs_grpc_port.append(str(low + 2 * i))
self._tfs_rest_port.append(str(low + 2 * i + 1))
# set environment variable for python service
os.environ["TFS_GRPC_PORT_RANGE"] = self._tfs_grpc_port_range
os.environ["TFS_REST_PORT_RANGE"] = self._tfs_rest_port_range
self._tfs_grpc_ports.append(str(low + 2 * i))
self._tfs_rest_ports.append(str(low + 2 * i + 1))
# concat selected ports respectively in order to pass them to python service
self._tfs_grpc_concat_ports = self._concat_ports(self._tfs_grpc_ports)
self._tfs_rest_concat_ports = self._concat_ports(self._tfs_rest_ports)
else:
# just use the standard default ports
self._tfs_grpc_port = ["9000"]
self._tfs_rest_port = ["8501"]
self._tfs_grpc_port_range = "9000-9000"
self._tfs_rest_port_range = "8501-8501"
# set environment variable for python service
os.environ["TFS_GRPC_PORT_RANGE"] = self._tfs_grpc_port_range
os.environ["TFS_REST_PORT_RANGE"] = self._tfs_rest_port_range
self._tfs_grpc_ports = ["9000"]
self._tfs_rest_ports = ["8501"]
# provide single concat port here for default case
self._tfs_grpc_concat_ports = "9000"
self._tfs_rest_concat_ports = "8501"

# set environment variable for python service
os.environ["TFS_GRPC_PORTS"] = self._tfs_grpc_concat_ports
os.environ["TFS_REST_PORTS"] = self._tfs_rest_concat_ports

def _need_python_service(self):
if os.path.exists(INFERENCE_PATH):
Expand All @@ -121,6 +120,11 @@ def _need_python_service(self):
and os.environ.get("SAGEMAKER_MULTI_MODEL_UNIVERSAL_PREFIX"):
self._enable_python_service = True

def _concat_ports(self, ports):
str_ports = [str(port) for port in ports]
concat_str_ports = ",".join(str_ports)
return concat_str_ports

def _create_tfs_config(self):
models = tfs_utils.find_models()

Expand Down Expand Up @@ -194,13 +198,13 @@ def _setup_gunicorn(self):
gunicorn_command = (
"gunicorn -b unix:/tmp/gunicorn.sock -k {} --chdir /sagemaker "
"--workers {} --threads {} "
"{}{} -e TFS_GRPC_PORT_RANGE={} -e TFS_REST_PORT_RANGE={} "
"{}{} -e TFS_GRPC_PORTS={} -e TFS_REST_PORTS={} "
"-e SAGEMAKER_MULTI_MODEL={} -e SAGEMAKER_SAFE_PORT_RANGE={} "
"-e SAGEMAKER_TFS_WAIT_TIME_SECONDS={} "
"python_service:app").format(self._gunicorn_worker_class,
self._gunicorn_workers, self._gunicorn_threads,
python_path_option, ",".join(python_path_content),
self._tfs_grpc_port_range, self._tfs_rest_port_range,
self._tfs_grpc_concat_ports, self._tfs_rest_concat_ports,
self._tfs_enable_multi_model_endpoint,
self._sagemaker_port_range,
self._tfs_wait_time_seconds)
Expand Down Expand Up @@ -230,7 +234,7 @@ def _download_scripts(self, bucket, prefix):
def _create_nginx_tfs_upstream(self):
indentation = " "
tfs_upstream = ""
for port in self._tfs_rest_port:
for port in self._tfs_rest_ports:
tfs_upstream += "{}server localhost:{};\n".format(indentation, port)
tfs_upstream = tfs_upstream[len(indentation):-2]

Expand Down Expand Up @@ -334,7 +338,7 @@ def _wait_for_gunicorn(self):

def _wait_for_tfs(self):
for i in range(self._tfs_instance_count):
tfs_utils.wait_for_model(self._tfs_rest_port[i],
tfs_utils.wait_for_model(self._tfs_rest_ports[i],
self._tfs_default_model_name, self._tfs_wait_time_seconds)

@contextmanager
Expand Down Expand Up @@ -370,8 +374,8 @@ def _restart_single_tfs(self, pid):

def _start_single_tfs(self, instance_id):
cmd = tfs_utils.tfs_command(
self._tfs_grpc_port[instance_id],
self._tfs_rest_port[instance_id],
self._tfs_grpc_ports[instance_id],
self._tfs_rest_ports[instance_id],
self._tfs_config_path,
self._tfs_enable_batching,
self._tfs_batching_config_path,
Expand Down