@@ -47,20 +47,24 @@ def test_start_mms_default_service_handler(
47
47
env .return_value .startup_timeout = 10000
48
48
mms_model_server .start_model_server ()
49
49
50
- adapt .assert_called_once_with (mms_model_server .DEFAULT_HANDLER_SERVICE , model_dir )
51
- create_config .assert_called_once_with (env .return_value )
50
+ # In this case, we should not rearchive the model
51
+ adapt .assert_not_called ()
52
+
53
+ create_config .assert_called_once_with (env .return_value , mms_model_server .DEFAULT_HANDLER_SERVICE )
52
54
exists .assert_called_once_with (mms_model_server .REQUIREMENTS_PATH )
53
55
install_requirements .assert_called_once_with ()
54
56
55
57
multi_model_server_cmd = [
56
58
"multi-model-server" ,
57
59
"--start" ,
58
60
"--model-store" ,
59
- mms_model_server .MODEL_STORE ,
61
+ mms_model_server .DEFAULT_MODEL_STORE ,
60
62
"--mms-config" ,
61
63
mms_model_server .MMS_CONFIG_FILE ,
62
64
"--log-config" ,
63
65
mms_model_server .DEFAULT_MMS_LOG_FILE ,
66
+ "--models" ,
67
+ "{}={}" .format (mms_model_server .DEFAULT_MMS_MODEL_NAME , model_dir ),
64
68
]
65
69
66
70
subprocess_popen .assert_called_once_with (multi_model_server_cmd )
@@ -98,20 +102,24 @@ def test_start_mms_neuron(
98
102
env .return_value .startup_timeout = 10000
99
103
mms_model_server .start_model_server ()
100
104
101
- adapt .assert_called_once_with (mms_model_server .DEFAULT_HANDLER_SERVICE , model_dir )
102
- create_config .assert_called_once_with (env .return_value )
105
+ # In this case, we should not call model archiver
106
+ adapt .assert_not_called ()
107
+
108
+ create_config .assert_called_once_with (env .return_value , mms_model_server .DEFAULT_HANDLER_SERVICE )
103
109
exists .assert_called_once_with (mms_model_server .REQUIREMENTS_PATH )
104
110
install_requirements .assert_called_once_with ()
105
111
106
112
multi_model_server_cmd = [
107
113
"multi-model-server" ,
108
114
"--start" ,
109
115
"--model-store" ,
110
- mms_model_server .MODEL_STORE ,
116
+ mms_model_server .DEFAULT_MODEL_STORE ,
111
117
"--mms-config" ,
112
118
mms_model_server .MMS_CONFIG_FILE ,
113
119
"--log-config" ,
114
120
mms_model_server .DEFAULT_MMS_LOG_FILE ,
121
+ "--models" ,
122
+ "{}={}" .format (mms_model_server .DEFAULT_MMS_MODEL_NAME , model_dir ),
115
123
]
116
124
117
125
subprocess_popen .assert_called_once_with (multi_model_server_cmd )
@@ -152,21 +160,23 @@ def test_start_mms_with_model_from_hub(
152
160
153
161
load_model_from_hub .assert_called_once_with (
154
162
model_id = os .environ ["HF_MODEL_ID" ],
155
- model_dir = mms_model_server .DEFAULT_MMS_MODEL_DIRECTORY ,
163
+ model_dir = mms_model_server .DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY ,
156
164
revision = transformers_utils .HF_MODEL_REVISION ,
157
165
use_auth_token = transformers_utils .HF_API_TOKEN ,
158
166
)
159
167
168
+ # When loading model from hub, we do call model archiver
160
169
adapt .assert_called_once_with (mms_model_server .DEFAULT_HANDLER_SERVICE , load_model_from_hub ())
161
- create_config .assert_called_once_with (env .return_value )
170
+
171
+ create_config .assert_called_once_with (env .return_value , mms_model_server .DEFAULT_HANDLER_SERVICE )
162
172
exists .assert_called_with (mms_model_server .REQUIREMENTS_PATH )
163
173
install_requirements .assert_called_once_with ()
164
174
165
175
multi_model_server_cmd = [
166
176
"multi-model-server" ,
167
177
"--start" ,
168
178
"--model-store" ,
169
- mms_model_server .MODEL_STORE ,
179
+ mms_model_server .DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY ,
170
180
"--mms-config" ,
171
181
mms_model_server .MMS_CONFIG_FILE ,
172
182
"--log-config" ,
@@ -175,7 +185,7 @@ def test_start_mms_with_model_from_hub(
175
185
176
186
subprocess_popen .assert_called_once_with (multi_model_server_cmd )
177
187
sigterm .assert_called_once_with (retrieve .return_value )
178
- os .remove (mms_model_server .DEFAULT_MMS_MODEL_DIRECTORY )
188
+ os .remove (mms_model_server .DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY )
179
189
180
190
181
191
@patch ("sagemaker_huggingface_inference_toolkit.transformers_utils._aws_neuron_available" , return_value = True )
0 commit comments