diff --git a/docker/build_artifacts/sagemaker/python_service.py b/docker/build_artifacts/sagemaker/python_service.py index 35d40eb8..e294e5fc 100644 --- a/docker/build_artifacts/sagemaker/python_service.py +++ b/docker/build_artifacts/sagemaker/python_service.py @@ -31,8 +31,8 @@ SAGEMAKER_BATCHING_ENABLED = os.environ.get("SAGEMAKER_TFS_ENABLE_BATCHING", "false").lower() MODEL_CONFIG_FILE_PATH = "/sagemaker/model-config.cfg" -TFS_GRPC_PORT_RANGE = os.environ.get("TFS_GRPC_PORT_RANGE") -TFS_REST_PORT_RANGE = os.environ.get("TFS_REST_PORT_RANGE") +TFS_GRPC_PORTS = os.environ.get("TFS_GRPC_PORTS") +TFS_REST_PORTS = os.environ.get("TFS_REST_PORTS") SAGEMAKER_TFS_PORT_RANGE = os.environ.get("SAGEMAKER_SAFE_PORT_RANGE") TFS_INSTANCE_COUNT = int(os.environ.get("SAGEMAKER_TFS_INSTANCE_COUNT", "1")) @@ -69,8 +69,8 @@ def __init__(self): # during the _handle_load_model_post() self.model_handlers = {} else: - self._tfs_grpc_ports = self._parse_sagemaker_port_range(TFS_GRPC_PORT_RANGE) - self._tfs_rest_ports = self._parse_sagemaker_port_range(TFS_REST_PORT_RANGE) + self._tfs_grpc_ports = self._parse_concat_ports(TFS_GRPC_PORTS) + self._tfs_rest_ports = self._parse_concat_ports(TFS_REST_PORTS) self._channels = {} for grpc_port in self._tfs_grpc_ports: @@ -98,16 +98,11 @@ def on_post(self, req, res, model_name=None): data = json.loads(req.stream.read().decode("utf-8")) self._handle_load_model_post(res, data) - def _parse_sagemaker_port_range(self, port_range): - lower, upper = port_range.split('-') - lower = int(lower) - upper = int(upper) - if lower == upper: - return [lower] - return [lower + 2 * i for i in range(TFS_INSTANCE_COUNT)] + def _parse_concat_ports(self, concat_ports): + return concat_ports.split(",") def _pick_port(self, ports): - return str(random.choice(ports)) + return random.choice(ports) def _parse_sagemaker_port_range_mme(self, port_range): lower, upper = port_range.split('-') @@ -254,7 +249,7 @@ def _handle_invocation_post(self, req, res, model_name=None): rest_port = self._pick_port(self._tfs_rest_ports) data, context = tfs_utils.parse_request(req, rest_port, grpc_port, self._tfs_default_model_name, - channel=self._channels[int(grpc_port)]) + channel=self._channels[grpc_port]) try: res.status = falcon.HTTP_200 diff --git a/docker/build_artifacts/sagemaker/serve.py b/docker/build_artifacts/sagemaker/serve.py index d834142a..f8b87614 100644 --- a/docker/build_artifacts/sagemaker/serve.py +++ b/docker/build_artifacts/sagemaker/serve.py @@ -89,30 +89,29 @@ def __init__(self): parts = self._sagemaker_port_range.split("-") low = int(parts[0]) hi = int(parts[1]) - self._tfs_grpc_port = [] - self._tfs_rest_port = [] + self._tfs_grpc_ports = [] + self._tfs_rest_ports = [] if low + 2 * self._tfs_instance_count > hi: raise ValueError("not enough ports available in SAGEMAKER_SAFE_PORT_RANGE ({})" .format(self._sagemaker_port_range)) - self._tfs_grpc_port_range = "{}-{}".format(low, - low + 2 * self._tfs_instance_count) - self._tfs_rest_port_range = "{}-{}".format(low + 1, - low + 2 * self._tfs_instance_count + 1) + # select non-overlapping grpc and rest ports based on tfs instance count for i in range(self._tfs_instance_count): - self._tfs_grpc_port.append(str(low + 2 * i)) - self._tfs_rest_port.append(str(low + 2 * i + 1)) - # set environment variable for python service - os.environ["TFS_GRPC_PORT_RANGE"] = self._tfs_grpc_port_range - os.environ["TFS_REST_PORT_RANGE"] = self._tfs_rest_port_range + self._tfs_grpc_ports.append(str(low + 2 * i)) + self._tfs_rest_ports.append(str(low + 2 * i + 1)) + # concat selected ports respectively in order to pass them to python service + self._tfs_grpc_concat_ports = self._concat_ports(self._tfs_grpc_ports) + self._tfs_rest_concat_ports = self._concat_ports(self._tfs_rest_ports) else: # just use the standard default ports - self._tfs_grpc_port = ["9000"] - self._tfs_rest_port = ["8501"] - self._tfs_grpc_port_range = "9000-9000" - self._tfs_rest_port_range = "8501-8501" - # set environment variable for python service - os.environ["TFS_GRPC_PORT_RANGE"] = self._tfs_grpc_port_range - os.environ["TFS_REST_PORT_RANGE"] = self._tfs_rest_port_range + self._tfs_grpc_ports = ["9000"] + self._tfs_rest_ports = ["8501"] + # provide single concat port here for default case + self._tfs_grpc_concat_ports = "9000" + self._tfs_rest_concat_ports = "8501" + + # set environment variable for python service + os.environ["TFS_GRPC_PORTS"] = self._tfs_grpc_concat_ports + os.environ["TFS_REST_PORTS"] = self._tfs_rest_concat_ports def _need_python_service(self): if os.path.exists(INFERENCE_PATH): @@ -121,6 +120,11 @@ def _need_python_service(self): and os.environ.get("SAGEMAKER_MULTI_MODEL_UNIVERSAL_PREFIX"): self._enable_python_service = True + def _concat_ports(self, ports): + str_ports = [str(port) for port in ports] + concat_str_ports = ",".join(str_ports) + return concat_str_ports + def _create_tfs_config(self): models = tfs_utils.find_models() @@ -194,13 +198,13 @@ def _setup_gunicorn(self): gunicorn_command = ( "gunicorn -b unix:/tmp/gunicorn.sock -k {} --chdir /sagemaker " "--workers {} --threads {} " - "{}{} -e TFS_GRPC_PORT_RANGE={} -e TFS_REST_PORT_RANGE={} " + "{}{} -e TFS_GRPC_PORTS={} -e TFS_REST_PORTS={} " "-e SAGEMAKER_MULTI_MODEL={} -e SAGEMAKER_SAFE_PORT_RANGE={} " "-e SAGEMAKER_TFS_WAIT_TIME_SECONDS={} " "python_service:app").format(self._gunicorn_worker_class, self._gunicorn_workers, self._gunicorn_threads, python_path_option, ",".join(python_path_content), - self._tfs_grpc_port_range, self._tfs_rest_port_range, + self._tfs_grpc_concat_ports, self._tfs_rest_concat_ports, self._tfs_enable_multi_model_endpoint, self._sagemaker_port_range, self._tfs_wait_time_seconds) @@ -230,7 +234,7 @@ def _download_scripts(self, bucket, prefix): def _create_nginx_tfs_upstream(self): indentation = " " tfs_upstream = "" - for port in self._tfs_rest_port: + for port in self._tfs_rest_ports: tfs_upstream += "{}server localhost:{};\n".format(indentation, port) tfs_upstream = tfs_upstream[len(indentation):-2] @@ -334,7 +338,7 @@ def _wait_for_gunicorn(self): def _wait_for_tfs(self): for i in range(self._tfs_instance_count): - tfs_utils.wait_for_model(self._tfs_rest_port[i], + tfs_utils.wait_for_model(self._tfs_rest_ports[i], self._tfs_default_model_name, self._tfs_wait_time_seconds) @contextmanager @@ -370,8 +374,8 @@ def _restart_single_tfs(self, pid): def _start_single_tfs(self, instance_id): cmd = tfs_utils.tfs_command( - self._tfs_grpc_port[instance_id], - self._tfs_rest_port[instance_id], + self._tfs_grpc_ports[instance_id], + self._tfs_rest_ports[instance_id], self._tfs_config_path, self._tfs_enable_batching, self._tfs_batching_config_path,