diff --git a/setup.py b/setup.py index 95f4e0b..dc72fa7 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ extras["benchmark"] = ["boto3", "locust"] extras["quality"] = [ - "black==21.4b0", + "black>=21.10", "isort>=5.5.4", "flake8>=3.8.3", ] diff --git a/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py b/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py index d949617..e2bc592 100644 --- a/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py +++ b/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py @@ -18,9 +18,9 @@ import subprocess from sagemaker_inference import environment, logging -from sagemaker_inference.environment import model_dir from sagemaker_inference.model_server import ( DEFAULT_MMS_LOG_FILE, + DEFAULT_MMS_MODEL_NAME, ENABLE_MULTI_MODEL, MMS_CONFIG_FILE, REQUIREMENTS_PATH, @@ -45,8 +45,8 @@ DEFAULT_HANDLER_SERVICE = handler_service.__name__ -DEFAULT_MMS_MODEL_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker/mms/models") -MODEL_STORE = "/" if ENABLE_MULTI_MODEL else DEFAULT_MMS_MODEL_DIRECTORY +DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker/mms/models") +DEFAULT_MODEL_STORE = "/" def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE): @@ -64,11 +64,15 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE): Defaults to ``sagemaker_huggingface_inference_toolkit.handler_service``. """ + use_hf_hub = "HF_MODEL_ID" in os.environ + model_store = DEFAULT_MODEL_STORE if ENABLE_MULTI_MODEL: if not os.getenv("SAGEMAKER_HANDLER"): os.environ["SAGEMAKER_HANDLER"] = handler_service _set_python_path() - elif "HF_MODEL_ID" in os.environ: + elif use_hf_hub: + # Use different model store directory + model_store = DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY if is_aws_neuron_available(): raise ValueError( "Hugging Face Hub deployments are currently not supported with AWS Neuron and Inferentia." @@ -76,16 +80,19 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE): ) storage_dir = _load_model_from_hub( model_id=os.environ["HF_MODEL_ID"], - model_dir=DEFAULT_MMS_MODEL_DIRECTORY, + model_dir=DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY, revision=HF_MODEL_REVISION, use_auth_token=HF_API_TOKEN, ) _adapt_to_mms_format(handler_service, storage_dir) else: - _adapt_to_mms_format(handler_service, model_dir) + _set_python_path() env = environment.Environment() - _create_model_server_config_file(env) + + # Note: multi-model default config already sets default_service_handler + handler_service_for_config = None if ENABLE_MULTI_MODEL else handler_service + _create_model_server_config_file(env, handler_service_for_config) if os.path.exists(REQUIREMENTS_PATH): _install_requirements() @@ -94,12 +101,14 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE): "multi-model-server", "--start", "--model-store", - MODEL_STORE, + model_store, "--mms-config", MMS_CONFIG_FILE, "--log-config", DEFAULT_MMS_LOG_FILE, ] + if not ENABLE_MULTI_MODEL and not use_hf_hub: + multi_model_server_cmd += ["--models", DEFAULT_MMS_MODEL_NAME + "=" + environment.model_dir] logger.info(multi_model_server_cmd) subprocess.Popen(multi_model_server_cmd) @@ -113,7 +122,7 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE): def _adapt_to_mms_format(handler_service, model_path): - os.makedirs(DEFAULT_MMS_MODEL_DIRECTORY, exist_ok=True) + os.makedirs(DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY, exist_ok=True) # gets the model from the path, default is model/ model = pathlib.PurePath(model_path) @@ -128,7 +137,7 @@ def _adapt_to_mms_format(handler_service, model_path): "--model-path", model_path, "--export-path", - DEFAULT_MMS_MODEL_DIRECTORY, + DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY, "--archive-format", "no-archive", "--f", diff --git a/tests/unit/test_mms_model_server.py b/tests/unit/test_mms_model_server.py index 4221242..07693af 100644 --- a/tests/unit/test_mms_model_server.py +++ b/tests/unit/test_mms_model_server.py @@ -47,8 +47,10 @@ def test_start_mms_default_service_handler( env.return_value.startup_timeout = 10000 mms_model_server.start_model_server() - adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, model_dir) - create_config.assert_called_once_with(env.return_value) + # In this case, we should not rearchive the model + adapt.assert_not_called() + + create_config.assert_called_once_with(env.return_value, mms_model_server.DEFAULT_HANDLER_SERVICE) exists.assert_called_once_with(mms_model_server.REQUIREMENTS_PATH) install_requirements.assert_called_once_with() @@ -56,11 +58,13 @@ def test_start_mms_default_service_handler( "multi-model-server", "--start", "--model-store", - mms_model_server.MODEL_STORE, + mms_model_server.DEFAULT_MODEL_STORE, "--mms-config", mms_model_server.MMS_CONFIG_FILE, "--log-config", mms_model_server.DEFAULT_MMS_LOG_FILE, + "--models", + "{}={}".format(mms_model_server.DEFAULT_MMS_MODEL_NAME, model_dir), ] subprocess_popen.assert_called_once_with(multi_model_server_cmd) @@ -98,8 +102,10 @@ def test_start_mms_neuron( env.return_value.startup_timeout = 10000 mms_model_server.start_model_server() - adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, model_dir) - create_config.assert_called_once_with(env.return_value) + # In this case, we should not call model archiver + adapt.assert_not_called() + + create_config.assert_called_once_with(env.return_value, mms_model_server.DEFAULT_HANDLER_SERVICE) exists.assert_called_once_with(mms_model_server.REQUIREMENTS_PATH) install_requirements.assert_called_once_with() @@ -107,11 +113,13 @@ def test_start_mms_neuron( "multi-model-server", "--start", "--model-store", - mms_model_server.MODEL_STORE, + mms_model_server.DEFAULT_MODEL_STORE, "--mms-config", mms_model_server.MMS_CONFIG_FILE, "--log-config", mms_model_server.DEFAULT_MMS_LOG_FILE, + "--models", + "{}={}".format(mms_model_server.DEFAULT_MMS_MODEL_NAME, model_dir), ] subprocess_popen.assert_called_once_with(multi_model_server_cmd) @@ -152,13 +160,15 @@ def test_start_mms_with_model_from_hub( load_model_from_hub.assert_called_once_with( model_id=os.environ["HF_MODEL_ID"], - model_dir=mms_model_server.DEFAULT_MMS_MODEL_DIRECTORY, + model_dir=mms_model_server.DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY, revision=transformers_utils.HF_MODEL_REVISION, use_auth_token=transformers_utils.HF_API_TOKEN, ) + # When loading model from hub, we do call model archiver adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, load_model_from_hub()) - create_config.assert_called_once_with(env.return_value) + + create_config.assert_called_once_with(env.return_value, mms_model_server.DEFAULT_HANDLER_SERVICE) exists.assert_called_with(mms_model_server.REQUIREMENTS_PATH) install_requirements.assert_called_once_with() @@ -166,7 +176,7 @@ def test_start_mms_with_model_from_hub( "multi-model-server", "--start", "--model-store", - mms_model_server.MODEL_STORE, + mms_model_server.DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY, "--mms-config", mms_model_server.MMS_CONFIG_FILE, "--log-config", @@ -175,7 +185,7 @@ def test_start_mms_with_model_from_hub( subprocess_popen.assert_called_once_with(multi_model_server_cmd) sigterm.assert_called_once_with(retrieve.return_value) - os.remove(mms_model_server.DEFAULT_MMS_MODEL_DIRECTORY) + os.remove(mms_model_server.DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY) @patch("sagemaker_huggingface_inference_toolkit.transformers_utils._aws_neuron_available", return_value=True)