Skip to content
This repository was archived by the owner on May 23, 2024. It is now read-only.

Commit 275c8d9

Browse files
committed
Feature: Support multiple inference.py files and universal inference.py file along with universal requirements.txt file
1 parent 1a265db commit 275c8d9

File tree

76 files changed

+1002
-48
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+1002
-48
lines changed

README.md

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,6 @@ For example:
164164

165165
## Pre/Post-Processing
166166

167-
**NOTE: There is currently no support for pre-/post-processing with multi-model containers.**
168-
169167
SageMaker TensorFlow Serving Container supports the following Content-Types for requests:
170168

171169
* `application/json` (default)
@@ -672,7 +670,7 @@ Only 90% of the ports will be utilized and each loaded model will be allocated w
672670
For example, if the ``SAGEMAKER_SAFE_PORT_RANGE`` is between 9000 to 9999, the maximum number of models that can be loaded to the endpoint at the same time would be 499 ((9999 - 9000) * 0.9 / 2).
673671

674672
### Using Multi-Model Endpoint with Pre/Post-Processing
675-
Multi-Model Endpoint can be used together with Pre/Post-Processing. Each model will need its own ``inference.py`` otherwise default handlers will be used. An example of the directory structure of Multi-Model Endpoint and Pre/Post-Processing would look like this:
673+
Multi-Model Endpoint can be used together with Pre/Post-Processing. Each model can either have its own ``inference.py`` or use a universal ``inference.py``. If both model-specific and universal ``inference.py`` files are provided, then the model-specific ``inference.py`` file is used. If both files are absent, then the default handlers will be used. An example of the directory structure of Multi-Model Endpoint with a model-specific ``inference.py`` file would look like this:
676674

677675
/opt/ml/models/model1/model
678676
|--[model_version_number]
@@ -687,7 +685,20 @@ Multi-Model Endpoint can be used together with Pre/Post-Processing. Each model w
687685
|--lib
688686
|--external_module
689687
|--inference.py
688+
Another example with of the directory structure of Multi-Model Endpoint with a universal ``inference.py`` file is as follows:
690689

690+
/opt/ml/models/model1/model
691+
|--[model_version_number]
692+
|--variables
693+
|--saved_model.pb
694+
/opt/ml/models/model2/model
695+
|--[model_version_number]
696+
|--assets
697+
|--variables
698+
|--saved_model.pb
699+
code
700+
|--requirements.txt
701+
|--inference.py
691702
## Contributing
692703

693704
Please read [CONTRIBUTING.md](https://github.com/aws/sagemaker-tensorflow-serving-container/blob/master/CONTRIBUTING.md)

docker/build_artifacts/sagemaker/python_service.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import subprocess
1919
import grpc
20+
import sys
2021

2122
import falcon
2223
import requests
@@ -26,7 +27,7 @@
2627
import tfs_utils
2728

2829
SAGEMAKER_MULTI_MODEL_ENABLED = os.environ.get("SAGEMAKER_MULTI_MODEL", "false").lower() == "true"
29-
MODEL_DIR = "models" if SAGEMAKER_MULTI_MODEL_ENABLED else "model"
30+
MODEL_DIR = "" if SAGEMAKER_MULTI_MODEL_ENABLED else "model"
3031
INFERENCE_SCRIPT_PATH = f"/opt/ml/{MODEL_DIR}/code/inference.py"
3132

3233
SAGEMAKER_BATCHING_ENABLED = os.environ.get("SAGEMAKER_TFS_ENABLE_BATCHING", "false").lower()
@@ -64,6 +65,7 @@ def __init__(self):
6465
self._model_tfs_grpc_port = {}
6566
self._model_tfs_pid = {}
6667
self._tfs_ports = self._parse_sagemaker_port_range_mme(SAGEMAKER_TFS_PORT_RANGE)
68+
self._default_handlers_enabled = False
6769
# If Multi-Model mode is enabled, dependencies/handlers will be imported
6870
# during the _handle_load_model_post()
6971
self.model_handlers = {}
@@ -85,6 +87,7 @@ def __init__(self):
8587
)
8688
else:
8789
self._handlers = default_handler
90+
self._default_handlers_enabled = True
8891

8992
self._tfs_enable_batching = SAGEMAKER_BATCHING_ENABLED == "true"
9093
self._tfs_default_model_name = os.environ.get("TFS_DEFAULT_MODEL_NAME", "None")
@@ -143,6 +146,7 @@ def _handle_load_model_post(self, res, data): # noqa: C901
143146
# validate model files are in the specified base_path
144147
if self.validate_model_dir(base_path):
145148
try:
149+
self._import_custom_modules(model_name)
146150
tfs_config = tfs_utils.create_tfs_config_individual_model(model_name, base_path)
147151
tfs_config_file = "/sagemaker/tfs-config/{}/model-config.cfg".format(model_name)
148152
log.info("tensorflow serving model config: \n%s\n", tfs_config)
@@ -221,6 +225,17 @@ def _handle_load_model_post(self, res, data): # noqa: C901
221225
}
222226
)
223227

228+
def _import_custom_modules(self, model_name):
229+
inference_script_path = "/opt/ml/models/{}/model/code/inference.py".format(model_name)
230+
python_lib_path = "/opt/ml/models/{}/model/code/lib".format(model_name)
231+
if os.path.exists(python_lib_path):
232+
log.info("add Python code library path")
233+
sys.path.append(python_lib_path)
234+
if os.path.exists(inference_script_path):
235+
handler, input_handler, output_handler = self._import_handlers(inference_script_path)
236+
model_handlers = self._make_handler(handler, input_handler, output_handler)
237+
self.model_handlers[model_name] = model_handlers
238+
224239
def _cleanup_config_file(self, config_file):
225240
if os.path.exists(config_file):
226241
os.remove(config_file)
@@ -264,8 +279,20 @@ def _handle_invocation_post(self, req, res, model_name=None):
264279

265280
try:
266281
res.status = falcon.HTTP_200
267-
268-
res.body, res.content_type = self._handlers(data, context)
282+
handlers = self._handlers
283+
if SAGEMAKER_MULTI_MODEL_ENABLED and model_name in self.model_handlers:
284+
inference_script_path = "/opt/ml/models/{}/model/code/" \
285+
"inference.py".format(model_name)
286+
log.info("Inference script found at path {}.".format(inference_script_path))
287+
log.info("Inference script exists, importing handlers.")
288+
handlers = self.model_handlers[model_name]
289+
elif not self._default_handlers_enabled:
290+
log.info("Universal inference script found at path "
291+
"{}.".format(INFERENCE_SCRIPT_PATH))
292+
log.info("Universal inference script exists, importing handlers.")
293+
else:
294+
log.info("Inference script does not exist, using default handlers.")
295+
res.body, res.content_type = handlers(data, context)
269296
except Exception as e: # pylint: disable=broad-except
270297
log.exception("exception handling request: {}".format(e))
271298
res.status = falcon.HTTP_500
@@ -276,8 +303,7 @@ def _setup_channel(self, grpc_port):
276303
log.info("Creating grpc channel for port: %s", grpc_port)
277304
self._channels[grpc_port] = grpc.insecure_channel("localhost:{}".format(grpc_port))
278305

279-
def _import_handlers(self):
280-
inference_script = INFERENCE_SCRIPT_PATH
306+
def _import_handlers(self, inference_script=INFERENCE_SCRIPT_PATH):
281307
spec = importlib.util.spec_from_file_location("inference", inference_script)
282308
inference = importlib.util.module_from_spec(spec)
283309
spec.loader.exec_module(inference)

docker/build_artifacts/sagemaker/serve.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@
2828
JS_INVOCATIONS = "js_content tensorflowServing.invocations"
2929
GUNICORN_PING = "proxy_pass http://gunicorn_upstream/ping"
3030
GUNICORN_INVOCATIONS = "proxy_pass http://gunicorn_upstream/invocations"
31-
MULTI_MODEL = "s" if os.environ.get("SAGEMAKER_MULTI_MODEL", "False").lower() == "true" else ""
32-
MODEL_DIR = f"model{MULTI_MODEL}"
33-
CODE_DIR = "/opt/ml/{}/code".format(MODEL_DIR)
31+
MODEL_DIR = "" if os.environ.get("SAGEMAKER_MULTI_MODEL", "False").lower() == "true" else "model"
32+
CODE_DIR = f"/opt/ml/{MODEL_DIR}/code"
3433
PYTHON_LIB_PATH = os.path.join(CODE_DIR, "lib")
3534
REQUIREMENTS_PATH = os.path.join(CODE_DIR, "requirements.txt")
3635
INFERENCE_PATH = os.path.join(CODE_DIR, "inference.py")
@@ -134,7 +133,8 @@ def __init__(self):
134133
os.environ["TFS_REST_PORTS"] = self._tfs_rest_concat_ports
135134

136135
def _need_python_service(self):
137-
if os.path.exists(INFERENCE_PATH):
136+
if (os.path.exists(INFERENCE_PATH) or os.path.exists(REQUIREMENTS_PATH)
137+
or os.path.exists(PYTHON_LIB_PATH)):
138138
self._enable_python_service = True
139139
if os.environ.get("SAGEMAKER_MULTI_MODEL_UNIVERSAL_BUCKET") and os.environ.get(
140140
"SAGEMAKER_MULTI_MODEL_UNIVERSAL_PREFIX"
@@ -256,7 +256,7 @@ def _download_scripts(self, bucket, prefix):
256256
paginator = client.get_paginator("list_objects")
257257
for result in paginator.paginate(Bucket=bucket, Delimiter="/", Prefix=prefix):
258258
for file in result.get("Contents", []):
259-
destination = os.path.join(CODE_DIR, file.get("Key"))
259+
destination = os.path.join(CODE_DIR, file.get("Key").split("/")[-1])
260260
if not os.path.exists(os.path.dirname(destination)):
261261
os.makedirs(os.path.dirname(destination))
262262
resource.meta.client.download_file(bucket, file.get("Key"), destination)

test/integration/local/test_pre_post_processing_mme.py renamed to test/integration/local/test_pre_post_processing_mme1.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313

14+
# In this test, only a universal inference.py file is provided. It's expected the handlers from the universal
15+
# inference.py file should be used by both models.
16+
1417
import json
1518
import os
16-
import shutil
1719
import subprocess
1820
import sys
1921
import time
@@ -27,27 +29,27 @@
2729

2830
PING_URL = "http://localhost:8080/ping"
2931
INVOCATION_URL = "http://localhost:8080/models/{}/invoke"
30-
MODEL_NAME = "half_plus_three"
32+
MODEL_NAMES = ["half_plus_three","half_plus_two"]
3133

3234

3335
@pytest.fixture(scope="session", autouse=True)
3436
def volume():
3537
try:
36-
model_dir = os.path.abspath("test/resources/mme_universal_script")
38+
model_dir = os.path.abspath("test/resources/mme1")
3739
subprocess.check_call(
38-
"docker volume create --name model_volume_mme --opt type=none "
40+
"docker volume create --name model_volume_mme1 --opt type=none "
3941
"--opt device={} --opt o=bind".format(model_dir).split())
4042
yield model_dir
4143
finally:
42-
subprocess.check_call("docker volume rm model_volume_mme".split())
44+
subprocess.check_call("docker volume rm model_volume_mme1".split())
4345

4446

4547
@pytest.fixture(scope="module", autouse=True)
4648
def container(docker_base_name, tag, runtime_config):
4749
try:
4850
command = (
4951
"docker run {}--name sagemaker-tensorflow-serving-test -p 8080:8080"
50-
" --mount type=volume,source=model_volume_mme,target=/opt/ml/models,readonly"
52+
" --mount type=volume,source=model_volume_mme1,target=/opt/ml/models,readonly"
5153
" -e SAGEMAKER_TFS_NGINX_LOGLEVEL=info"
5254
" -e SAGEMAKER_BIND_TO_PORT=8080"
5355
" -e SAGEMAKER_SAFE_PORT_RANGE=9000-9999"
@@ -74,13 +76,14 @@ def container(docker_base_name, tag, runtime_config):
7476

7577

7678
@pytest.fixture
77-
def model():
78-
model_data = {
79-
"model_name": MODEL_NAME,
80-
"url": "/opt/ml/models/half_plus_three/model/half_plus_three"
81-
}
82-
make_load_model_request(json.dumps(model_data))
83-
return MODEL_NAME
79+
def models():
80+
for MODEL_NAME in MODEL_NAMES:
81+
model_data = {
82+
"model_name": MODEL_NAME,
83+
"url": "/opt/ml/models/{}/model/{}".format(MODEL_NAME,MODEL_NAME)
84+
}
85+
make_load_model_request(json.dumps(model_data))
86+
return MODEL_NAMES
8487

8588

8689
@pytest.mark.skip_gpu
@@ -90,20 +93,25 @@ def test_ping_service():
9093

9194

9295
@pytest.mark.skip_gpu
93-
def test_predict_json(model):
96+
def test_predict_json(models):
9497
headers = make_headers()
9598
data = "{\"instances\": [1.0, 2.0, 5.0]}"
96-
response = requests.post(INVOCATION_URL.format(model), data=data, headers=headers).json()
97-
assert response == {"predictions": [3.5, 4.0, 5.5]}
99+
responses = []
100+
for model in models:
101+
response = requests.post(INVOCATION_URL.format(model), data=data, headers=headers).json()
102+
responses.append(response)
103+
assert responses[0] == {"predictions": [3.5, 4.0, 5.5]}
104+
assert responses[1] == {"predictions": [2.5, 3.0, 4.5]}
98105

99106

100107
@pytest.mark.skip_gpu
101108
def test_zero_content():
102109
headers = make_headers()
103110
x = ""
104-
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=x, headers=headers)
105-
assert 500 == response.status_code
106-
assert "document is empty" in response.text
111+
for MODEL_NAME in MODEL_NAMES:
112+
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=x, headers=headers)
113+
assert 500 == response.status_code
114+
assert "document is empty" in response.text
107115

108116

109117
@pytest.mark.skip_gpu
@@ -113,34 +121,43 @@ def test_large_input():
113121
with open(data_file, "r") as file:
114122
x = file.read()
115123
headers = make_headers(content_type="text/csv")
116-
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=x, headers=headers).json()
117-
predictions = response["predictions"]
118-
assert len(predictions) == 753936
124+
for MODEL_NAME in MODEL_NAMES:
125+
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=x, headers=headers).json()
126+
predictions = response["predictions"]
127+
assert len(predictions) == 753936
119128

120129

121130
@pytest.mark.skip_gpu
122131
def test_csv_input():
123132
headers = make_headers(content_type="text/csv")
124133
data = "1.0,2.0,5.0"
125-
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers).json()
126-
assert response == {"predictions": [3.5, 4.0, 5.5]}
127-
134+
responses = []
135+
for MODEL_NAME in MODEL_NAMES:
136+
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers).json()
137+
responses.append(response)
138+
assert responses[0] == {"predictions": [3.5, 4.0, 5.5]}
139+
assert responses[1] == {"predictions": [2.5, 3.0, 4.5]}
128140

129141
@pytest.mark.skip_gpu
130142
def test_specific_versions():
131-
for version in ("123", "124"):
132-
headers = make_headers(content_type="text/csv", version=version)
133-
data = "1.0,2.0,5.0"
134-
response = requests.post(
135-
INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers
136-
).json()
137-
assert response == {"predictions": [3.5, 4.0, 5.5]}
143+
for MODEL_NAME in MODEL_NAMES:
144+
for version in ("123", "124"):
145+
headers = make_headers(content_type="text/csv", version=version)
146+
data = "1.0,2.0,5.0"
147+
response = requests.post(
148+
INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers
149+
).json()
150+
if MODEL_NAME == "half_plus_three":
151+
assert response == {"predictions": [3.5, 4.0, 5.5]}
152+
else:
153+
assert response == {"predictions": [2.5, 3.0, 4.5]}
138154

139155

140156
@pytest.mark.skip_gpu
141157
def test_unsupported_content_type():
142158
headers = make_headers("unsupported-type", "predict")
143159
data = "aW1hZ2UgYnl0ZXM="
144-
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers)
145-
assert 500 == response.status_code
146-
assert "unsupported content type" in response.text
160+
for MODEL_NAME in MODEL_NAMES:
161+
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers)
162+
assert 500 == response.status_code
163+
assert "unsupported content type" in response.text

0 commit comments

Comments
 (0)