Skip to content
This repository was archived by the owner on May 23, 2024. It is now read-only.

Commit aba0963

Browse files
committed
Feature: Support multiple inference.py files and universal inference.py file along with universal requirements.txt file
1 parent 1a265db commit aba0963

File tree

5 files changed

+57
-25
lines changed

5 files changed

+57
-25
lines changed

docker/build_artifacts/sagemaker/serve.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,14 @@ def _enable_per_process_gpu_memory_fraction(self):
308308

309309
return False
310310

311+
def _get_number_of_gpu_on_host(self):
312+
nvidia_smi_exist = os.path.exists("/usr/bin/nvidia-smi")
313+
if nvidia_smi_exist:
314+
return len(subprocess.check_output(['nvidia-smi', '-L'])
315+
.decode('utf-8').strip().split('\n'))
316+
317+
return 0
318+
311319
def _calculate_per_process_gpu_memory_fraction(self):
312320
return round((1 - self._tfs_gpu_margin) / float(self._tfs_instance_count), 4)
313321

@@ -420,8 +428,20 @@ def _start_single_tfs(self, instance_id):
420428
tfs_gpu_memory_fraction=self._calculate_per_process_gpu_memory_fraction(),
421429
)
422430
log.info("tensorflow serving command: {}".format(cmd))
423-
p = subprocess.Popen(cmd.split())
424-
log.info("started tensorflow serving (pid: %d)", p.pid)
431+
432+
num_gpus = self._get_number_of_gpu_on_host()
433+
if num_gpus > 1:
434+
# utilizing multi-gpu
435+
worker_env = os.environ.copy()
436+
worker_env["CUDA_VISIBLE_DEVICES"] = str(instance_id % num_gpus)
437+
p = subprocess.Popen(cmd.split(), env=worker_env)
438+
log.info("started tensorflow serving (pid: {}) on GPU {}"
439+
.format(p.pid, instance_id % num_gpus))
440+
else:
441+
# cpu and single gpu
442+
p = subprocess.Popen(cmd.split())
443+
log.info("started tensorflow serving (pid: {})".format(p.pid))
444+
425445
return p
426446

427447
def _monitor(self):

test/integration/local/test_pre_post_processing_mme.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
PING_URL = "http://localhost:8080/ping"
2929
INVOCATION_URL = "http://localhost:8080/models/{}/invoke"
30-
MODEL_NAME = "half_plus_three"
30+
MODEL_NAMES = ["half_plus_three","half_plus_two"]
3131

3232

3333
@pytest.fixture(scope="session", autouse=True)
@@ -74,13 +74,14 @@ def container(docker_base_name, tag, runtime_config):
7474

7575

7676
@pytest.fixture
77-
def model():
78-
model_data = {
79-
"model_name": MODEL_NAME,
80-
"url": "/opt/ml/models/half_plus_three/model/half_plus_three"
81-
}
82-
make_load_model_request(json.dumps(model_data))
83-
return MODEL_NAME
77+
def models():
78+
for MODEL_NAME in MODEL_NAMES:
79+
model_data = {
80+
"model_name": MODEL_NAME,
81+
"url": "/opt/ml/models/{}/model/{}".format(MODEL_NAME,MODEL_NAME)
82+
}
83+
make_load_model_request(json.dumps(model_data))
84+
return MODEL_NAMES
8485

8586

8687
@pytest.mark.skip_gpu
@@ -90,20 +91,25 @@ def test_ping_service():
9091

9192

9293
@pytest.mark.skip_gpu
93-
def test_predict_json(model):
94+
def test_predict_json(models):
9495
headers = make_headers()
9596
data = "{\"instances\": [1.0, 2.0, 5.0]}"
96-
response = requests.post(INVOCATION_URL.format(model), data=data, headers=headers).json()
97-
assert response == {"predictions": [3.5, 4.0, 5.5]}
97+
responses = []
98+
for model in models:
99+
response = requests.post(INVOCATION_URL.format(model), data=data, headers=headers).json()
100+
responses.append(response)
101+
assert response[0] == {"predictions": [3.5, 4.0, 5.5]}
102+
assert response[1] == {"predictions": [2.5, 3.0, 4.5]}
98103

99104

100105
@pytest.mark.skip_gpu
101106
def test_zero_content():
102107
headers = make_headers()
103108
x = ""
104-
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=x, headers=headers)
105-
assert 500 == response.status_code
106-
assert "document is empty" in response.text
109+
for MODEL_NAME in MODEL_NAMES:
110+
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=x, headers=headers)
111+
assert 500 == response.status_code
112+
assert "document is empty" in response.text
107113

108114

109115
@pytest.mark.skip_gpu
@@ -113,21 +119,26 @@ def test_large_input():
113119
with open(data_file, "r") as file:
114120
x = file.read()
115121
headers = make_headers(content_type="text/csv")
116-
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=x, headers=headers).json()
117-
predictions = response["predictions"]
118-
assert len(predictions) == 753936
122+
for MODEL_NAME in MODEL_NAMES:
123+
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=x, headers=headers).json()
124+
predictions = response["predictions"]
125+
assert len(predictions) == 753936
119126

120127

121128
@pytest.mark.skip_gpu
122129
def test_csv_input():
123130
headers = make_headers(content_type="text/csv")
124131
data = "1.0,2.0,5.0"
125-
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers).json()
126-
assert response == {"predictions": [3.5, 4.0, 5.5]}
127-
132+
responses = []
133+
for MODEL_NAME in MODEL_NAMES:
134+
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers).json()
135+
responses.append(response)
136+
assert response[0] == {"predictions": [3.5, 4.0, 5.5]}
137+
assert response[1] == {"predictions": [2.5, 3.0, 4.5]}
128138

129139
@pytest.mark.skip_gpu
130140
def test_specific_versions():
141+
MODEL_NAME = MODEL_NAMES[0]
131142
for version in ("123", "124"):
132143
headers = make_headers(content_type="text/csv", version=version)
133144
data = "1.0,2.0,5.0"
@@ -141,6 +152,7 @@ def test_specific_versions():
141152
def test_unsupported_content_type():
142153
headers = make_headers("unsupported-type", "predict")
143154
data = "aW1hZ2UgYnl0ZXM="
144-
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers)
145-
assert 500 == response.status_code
146-
assert "unsupported content type" in response.text
155+
for MODEL_NAME in MODEL_NAMES:
156+
response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers)
157+
assert 500 == response.status_code
158+
assert "unsupported content type" in response.text

0 commit comments

Comments
 (0)