Skip to content

Commit 63fe993

Browse files
Merge pull request #71 from davidthomas426/serve-from-original-model-artifact-dir
feature: Serve model directly from original model artifact directory …
2 parents e1221c2 + 666ebe2 commit 63fe993

File tree

3 files changed

+40
-21
lines changed

3 files changed

+40
-21
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
extras["benchmark"] = ["boto3", "locust"]
8181

8282
extras["quality"] = [
83-
"black==21.4b0",
83+
"black>=21.10",
8484
"isort>=5.5.4",
8585
"flake8>=3.8.3",
8686
]

src/sagemaker_huggingface_inference_toolkit/mms_model_server.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
import subprocess
1919

2020
from sagemaker_inference import environment, logging
21-
from sagemaker_inference.environment import model_dir
2221
from sagemaker_inference.model_server import (
2322
DEFAULT_MMS_LOG_FILE,
23+
DEFAULT_MMS_MODEL_NAME,
2424
ENABLE_MULTI_MODEL,
2525
MMS_CONFIG_FILE,
2626
REQUIREMENTS_PATH,
@@ -45,8 +45,8 @@
4545

4646
DEFAULT_HANDLER_SERVICE = handler_service.__name__
4747

48-
DEFAULT_MMS_MODEL_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker/mms/models")
49-
MODEL_STORE = "/" if ENABLE_MULTI_MODEL else DEFAULT_MMS_MODEL_DIRECTORY
48+
DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker/mms/models")
49+
DEFAULT_MODEL_STORE = "/"
5050

5151

5252
def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
@@ -64,28 +64,35 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
6464
Defaults to ``sagemaker_huggingface_inference_toolkit.handler_service``.
6565
6666
"""
67+
use_hf_hub = "HF_MODEL_ID" in os.environ
68+
model_store = DEFAULT_MODEL_STORE
6769
if ENABLE_MULTI_MODEL:
6870
if not os.getenv("SAGEMAKER_HANDLER"):
6971
os.environ["SAGEMAKER_HANDLER"] = handler_service
7072
_set_python_path()
71-
elif "HF_MODEL_ID" in os.environ:
73+
elif use_hf_hub:
74+
# Use different model store directory
75+
model_store = DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY
7276
if is_aws_neuron_available():
7377
raise ValueError(
7478
"Hugging Face Hub deployments are currently not supported with AWS Neuron and Inferentia."
7579
"You need to create a `inference.py` script to run your model using AWS Neuron"
7680
)
7781
storage_dir = _load_model_from_hub(
7882
model_id=os.environ["HF_MODEL_ID"],
79-
model_dir=DEFAULT_MMS_MODEL_DIRECTORY,
83+
model_dir=DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY,
8084
revision=HF_MODEL_REVISION,
8185
use_auth_token=HF_API_TOKEN,
8286
)
8387
_adapt_to_mms_format(handler_service, storage_dir)
8488
else:
85-
_adapt_to_mms_format(handler_service, model_dir)
89+
_set_python_path()
8690

8791
env = environment.Environment()
88-
_create_model_server_config_file(env)
92+
93+
# Note: multi-model default config already sets default_service_handler
94+
handler_service_for_config = None if ENABLE_MULTI_MODEL else handler_service
95+
_create_model_server_config_file(env, handler_service_for_config)
8996

9097
if os.path.exists(REQUIREMENTS_PATH):
9198
_install_requirements()
@@ -94,12 +101,14 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
94101
"multi-model-server",
95102
"--start",
96103
"--model-store",
97-
MODEL_STORE,
104+
model_store,
98105
"--mms-config",
99106
MMS_CONFIG_FILE,
100107
"--log-config",
101108
DEFAULT_MMS_LOG_FILE,
102109
]
110+
if not ENABLE_MULTI_MODEL and not use_hf_hub:
111+
multi_model_server_cmd += ["--models", DEFAULT_MMS_MODEL_NAME + "=" + environment.model_dir]
103112

104113
logger.info(multi_model_server_cmd)
105114
subprocess.Popen(multi_model_server_cmd)
@@ -113,7 +122,7 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
113122

114123

115124
def _adapt_to_mms_format(handler_service, model_path):
116-
os.makedirs(DEFAULT_MMS_MODEL_DIRECTORY, exist_ok=True)
125+
os.makedirs(DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY, exist_ok=True)
117126

118127
# gets the model from the path, default is model/
119128
model = pathlib.PurePath(model_path)
@@ -128,7 +137,7 @@ def _adapt_to_mms_format(handler_service, model_path):
128137
"--model-path",
129138
model_path,
130139
"--export-path",
131-
DEFAULT_MMS_MODEL_DIRECTORY,
140+
DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY,
132141
"--archive-format",
133142
"no-archive",
134143
"--f",

tests/unit/test_mms_model_server.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,24 @@ def test_start_mms_default_service_handler(
4747
env.return_value.startup_timeout = 10000
4848
mms_model_server.start_model_server()
4949

50-
adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, model_dir)
51-
create_config.assert_called_once_with(env.return_value)
50+
# In this case, we should not rearchive the model
51+
adapt.assert_not_called()
52+
53+
create_config.assert_called_once_with(env.return_value, mms_model_server.DEFAULT_HANDLER_SERVICE)
5254
exists.assert_called_once_with(mms_model_server.REQUIREMENTS_PATH)
5355
install_requirements.assert_called_once_with()
5456

5557
multi_model_server_cmd = [
5658
"multi-model-server",
5759
"--start",
5860
"--model-store",
59-
mms_model_server.MODEL_STORE,
61+
mms_model_server.DEFAULT_MODEL_STORE,
6062
"--mms-config",
6163
mms_model_server.MMS_CONFIG_FILE,
6264
"--log-config",
6365
mms_model_server.DEFAULT_MMS_LOG_FILE,
66+
"--models",
67+
"{}={}".format(mms_model_server.DEFAULT_MMS_MODEL_NAME, model_dir),
6468
]
6569

6670
subprocess_popen.assert_called_once_with(multi_model_server_cmd)
@@ -98,20 +102,24 @@ def test_start_mms_neuron(
98102
env.return_value.startup_timeout = 10000
99103
mms_model_server.start_model_server()
100104

101-
adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, model_dir)
102-
create_config.assert_called_once_with(env.return_value)
105+
# In this case, we should not call model archiver
106+
adapt.assert_not_called()
107+
108+
create_config.assert_called_once_with(env.return_value, mms_model_server.DEFAULT_HANDLER_SERVICE)
103109
exists.assert_called_once_with(mms_model_server.REQUIREMENTS_PATH)
104110
install_requirements.assert_called_once_with()
105111

106112
multi_model_server_cmd = [
107113
"multi-model-server",
108114
"--start",
109115
"--model-store",
110-
mms_model_server.MODEL_STORE,
116+
mms_model_server.DEFAULT_MODEL_STORE,
111117
"--mms-config",
112118
mms_model_server.MMS_CONFIG_FILE,
113119
"--log-config",
114120
mms_model_server.DEFAULT_MMS_LOG_FILE,
121+
"--models",
122+
"{}={}".format(mms_model_server.DEFAULT_MMS_MODEL_NAME, model_dir),
115123
]
116124

117125
subprocess_popen.assert_called_once_with(multi_model_server_cmd)
@@ -152,21 +160,23 @@ def test_start_mms_with_model_from_hub(
152160

153161
load_model_from_hub.assert_called_once_with(
154162
model_id=os.environ["HF_MODEL_ID"],
155-
model_dir=mms_model_server.DEFAULT_MMS_MODEL_DIRECTORY,
163+
model_dir=mms_model_server.DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY,
156164
revision=transformers_utils.HF_MODEL_REVISION,
157165
use_auth_token=transformers_utils.HF_API_TOKEN,
158166
)
159167

168+
# When loading model from hub, we do call model archiver
160169
adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, load_model_from_hub())
161-
create_config.assert_called_once_with(env.return_value)
170+
171+
create_config.assert_called_once_with(env.return_value, mms_model_server.DEFAULT_HANDLER_SERVICE)
162172
exists.assert_called_with(mms_model_server.REQUIREMENTS_PATH)
163173
install_requirements.assert_called_once_with()
164174

165175
multi_model_server_cmd = [
166176
"multi-model-server",
167177
"--start",
168178
"--model-store",
169-
mms_model_server.MODEL_STORE,
179+
mms_model_server.DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY,
170180
"--mms-config",
171181
mms_model_server.MMS_CONFIG_FILE,
172182
"--log-config",
@@ -175,7 +185,7 @@ def test_start_mms_with_model_from_hub(
175185

176186
subprocess_popen.assert_called_once_with(multi_model_server_cmd)
177187
sigterm.assert_called_once_with(retrieve.return_value)
178-
os.remove(mms_model_server.DEFAULT_MMS_MODEL_DIRECTORY)
188+
os.remove(mms_model_server.DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY)
179189

180190

181191
@patch("sagemaker_huggingface_inference_toolkit.transformers_utils._aws_neuron_available", return_value=True)

0 commit comments

Comments
 (0)