diff --git a/doc/frameworks/djl/using_djl.rst b/doc/frameworks/djl/using_djl.rst index 9359e25f29..dc77cae405 100644 --- a/doc/frameworks/djl/using_djl.rst +++ b/doc/frameworks/djl/using_djl.rst @@ -29,7 +29,7 @@ You can either deploy your model using DeepSpeed or HuggingFace Accelerate, or l # Create a DJL Model, backend is chosen automatically djl_model = DJLModel( - "s3://my_bucket/my_saved_model_artifacts/", + "s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id "my_sagemaker_role", data_type="fp16", task="text-generation", @@ -46,7 +46,7 @@ If you want to use a specific backend, then you can create an instance of the co # Create a model using the DeepSpeed backend deepspeed_model = DeepSpeedModel( - "s3://my_bucket/my_saved_model_artifacts/", + "s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id "my_sagemaker_role", data_type="bf16", task="text-generation", @@ -56,7 +56,7 @@ If you want to use a specific backend, then you can create an instance of the co # Create a model using the HuggingFace Accelerate backend hf_accelerate_model = HuggingFaceAccelerateModel( - "s3://my_bucket/my_saved_model_artifacts/", + "s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id "my_sagemaker_role", data_type="fp16", task="text-generation", @@ -91,9 +91,37 @@ model server configuration. Model Artifacts --------------- +DJL Serving supports two ways to load models for inference. +1. A HuggingFace Hub model id. +2. Uncompressed model artifacts stored in a S3 bucket. + +HuggingFace Hub model id +^^^^^^^^^^^^^^^^^^^^^^^^ + +Using a HuggingFace Hub model id is the easiest way to get started with deploying Large Models via DJL Serving on SageMaker. +DJL Serving will use this model id to download the model at runtime via the HuggingFace Transformers ``from_pretrained`` API. +This method makes it easy to deploy models quickly, but for very large models the download time can become unreasonable. + +For example, you can deploy the EleutherAI gpt-j-6B model like this: + +.. code:: + + model = DJLModel( + "EleutherAI/gpt-j-6B", + "my_sagemaker_role", + data_type="fp16", + number_of_partitions=2 + ) + + predictor = model.deploy("ml.g5.12xlarge") + +Uncompressed Model Artifacts stored in a S3 bucket +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For models that are larger than 20GB (total checkpoint size), we recommend that you store the model in S3. +Download times will be much faster compared to downloading from the HuggingFace Hub at runtime. DJL Serving Models expect a different model structure than most of the other frameworks in the SageMaker Python SDK. Specifically, DJLModels do not support loading models stored in tar.gz format. -You must provide an Amazon S3 url pointing to uncompressed model artifacts (bucket and prefix). This is because DJL Serving is optimized for large models, and it implements a fast downloading mechanism for large models that require the artifacts be uncompressed. For example, lets say you want to deploy the EleutherAI/gpt-j-6B model available on the HuggingFace Hub. @@ -107,7 +135,18 @@ You can download the model and upload to S3 like this: # Upload to S3 aws s3 sync gpt-j-6B s3://my_bucket/gpt-j-6B -You would then pass "s3://my_bucket/gpt-j-6B" as ``model_s3_uri`` to the ``DJLModel``. +You would then pass "s3://my_bucket/gpt-j-6B" as ``model_id`` to the ``DJLModel`` like this: + +.. code:: + + model = DJLModel( + "s3://my_bucket/gpt-j-6B", + "my_sagemaker_role", + data_type="fp16", + number_of_partitions=2 + ) + + predictor = model.deploy("ml.g5.12xlarge") For language models we expect that the model weights, model config, and tokenizer config are provided in S3. The model should be loadable from the HuggingFace Transformers AutoModelFor.from_pretrained API, where task diff --git a/src/sagemaker/djl_inference/defaults.py b/src/sagemaker/djl_inference/defaults.py index 1917fb8655..04cdb5844e 100644 --- a/src/sagemaker/djl_inference/defaults.py +++ b/src/sagemaker/djl_inference/defaults.py @@ -15,6 +15,8 @@ STABLE_DIFFUSION_MODEL_TYPE = "stable-diffusion" +VALID_MODEL_CONFIG_FILES = ["config.json", "model_index.json"] + DEEPSPEED_RECOMMENDED_ARCHITECTURES = { "bloom", "opt", diff --git a/src/sagemaker/djl_inference/model.py b/src/sagemaker/djl_inference/model.py index 6a0402832c..d57420a31d 100644 --- a/src/sagemaker/djl_inference/model.py +++ b/src/sagemaker/djl_inference/model.py @@ -16,6 +16,9 @@ import json import logging import os.path +import urllib.request +from json import JSONDecodeError +from urllib.error import HTTPError, URLError from enum import Enum from typing import Optional, Union, Dict, Any @@ -134,10 +137,10 @@ def _read_existing_serving_properties(directory: str): def _get_model_config_properties_from_s3(model_s3_uri: str): """Placeholder docstring""" + s3_files = s3.S3Downloader.list(model_s3_uri) - valid_config_files = ["config.json", "model_index.json"] model_config = None - for config in valid_config_files: + for config in defaults.VALID_MODEL_CONFIG_FILES: config_file = os.path.join(model_s3_uri, config) if config_file in s3_files: model_config = json.loads(s3.S3Downloader.read_file(config_file)) @@ -151,26 +154,53 @@ def _get_model_config_properties_from_s3(model_s3_uri: str): return model_config +def _get_model_config_properties_from_hf(model_id: str): + """Placeholder docstring""" + + config_url_prefix = f"https://huggingface.co/{model_id}/raw/main/" + model_config = None + for config in defaults.VALID_MODEL_CONFIG_FILES: + config_file_url = config_url_prefix + config + try: + with urllib.request.urlopen(config_file_url) as response: + model_config = json.load(response) + break + except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e: + logger.warning( + "Exception encountered while trying to read config file %s. " "Details: %s", + config_file_url, + e, + ) + if not model_config: + raise ValueError( + f"Did not find a config.json or model_index.json file in huggingface hub for " + f"{model_id}. Please make sure a config.json exists (or model_index.json for Stable " + f"Diffusion Models) for this model in the huggingface hub" + ) + return model_config + + class DJLModel(FrameworkModel): """A DJL SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``.""" def __new__( cls, - model_s3_uri: str, + model_id: str, *args, **kwargs, ): # pylint: disable=W0613 """Create a specific subclass of DJLModel for a given engine""" - if not model_s3_uri.startswith("s3://"): - raise ValueError("DJLModel only supports loading model artifacts from s3") - if model_s3_uri.endswith("tar.gz"): + if model_id.endswith("tar.gz"): raise ValueError( "DJLModel does not support model artifacts in tar.gz format." "Please store the model in uncompressed format and provide the s3 uri of the " "containing folder" ) - model_config = _get_model_config_properties_from_s3(model_s3_uri) + if model_id.startswith("s3://"): + model_config = _get_model_config_properties_from_s3(model_id) + else: + model_config = _get_model_config_properties_from_hf(model_id) if model_config.get("_class_name") == "StableDiffusionPipeline": model_type = defaults.STABLE_DIFFUSION_MODEL_TYPE num_heads = 0 @@ -196,7 +226,7 @@ def __new__( def __init__( self, - model_s3_uri: str, + model_id: str, role: str, djl_version: Optional[str] = None, task: Optional[str] = None, @@ -216,8 +246,9 @@ def __init__( """Initialize a DJLModel. Args: - model_s3_uri (str): The Amazon S3 location containing the uncompressed model - artifacts. The model artifacts are expected to be in HuggingFace pre-trained model + model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location + containing the uncompressed model artifacts (i.e. not a tar.gz file). + The model artifacts are expected to be in HuggingFace pre-trained model format (i.e. model should be loadable from the huggingface transformers from_pretrained api, and should also include tokenizer configs if applicable). role (str): An AWS IAM role specified with either the name or full ARN. The Amazon @@ -285,13 +316,13 @@ def __init__( if kwargs.get("model_data"): logger.warning( "DJLModels do not use model_data parameter. model_data parameter will be ignored." - "You only need to set model_S3_uri and ensure it points to uncompressed model " - "artifacts." + "You only need to set model_id and ensure it points to uncompressed model " + "artifacts in s3, or a valid HuggingFace Hub model_id." ) super(DJLModel, self).__init__( None, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs ) - self.model_s3_uri = model_s3_uri + self.model_id = model_id self.djl_version = djl_version self.task = task self.data_type = data_type @@ -529,7 +560,10 @@ def generate_serving_properties(self, serving_properties=None) -> Dict[str, str] serving_properties = {} serving_properties["engine"] = self.engine.value[0] # pylint: disable=E1101 serving_properties["option.entryPoint"] = self.engine.value[1] # pylint: disable=E1101 - serving_properties["option.s3url"] = self.model_s3_uri + if self.model_id.startswith("s3://"): + serving_properties["option.s3url"] = self.model_id + else: + serving_properties["option.model_id"] = self.model_id if self.number_of_partitions: serving_properties["option.tensor_parallel_degree"] = self.number_of_partitions if self.entry_point: @@ -593,7 +627,7 @@ class DeepSpeedModel(DJLModel): def __init__( self, - model_s3_uri: str, + model_id: str, role: str, tensor_parallel_degree: Optional[int] = None, max_tokens: Optional[int] = None, @@ -606,11 +640,11 @@ def __init__( """Initialize a DeepSpeedModel Args: - model_s3_uri (str): The Amazon S3 location containing the uncompressed model - artifacts. The model artifacts are expected to be in HuggingFace pre-trained model + model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location + containing the uncompressed model artifacts (i.e. not a tar.gz file). + The model artifacts are expected to be in HuggingFace pre-trained model format (i.e. model should be loadable from the huggingface transformers - from_pretrained - api, and should also include tokenizer configs if applicable). + from_pretrained api, and should also include tokenizer configs if applicable). role (str): An AWS IAM role specified with either the name or full ARN. The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access model artifacts. After the endpoint is created, @@ -647,7 +681,7 @@ def __init__( """ super(DeepSpeedModel, self).__init__( - model_s3_uri, + model_id, role, **kwargs, ) @@ -710,7 +744,7 @@ class HuggingFaceAccelerateModel(DJLModel): def __init__( self, - model_s3_uri: str, + model_id: str, role: str, number_of_partitions: Optional[int] = None, device_id: Optional[int] = None, @@ -722,11 +756,11 @@ def __init__( """Initialize a HuggingFaceAccelerateModel. Args: - model_s3_uri (str): The Amazon S3 location containing the uncompressed model - artifacts. The model artifacts are expected to be in HuggingFace pre-trained model + model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location + containing the uncompressed model artifacts (i.e. not a tar.gz file). + The model artifacts are expected to be in HuggingFace pre-trained model format (i.e. model should be loadable from the huggingface transformers - from_pretrained - method). + from_pretrained api, and should also include tokenizer configs if applicable). role (str): An AWS IAM role specified with either the name or full ARN. The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access model artifacts. After the endpoint is created, @@ -760,7 +794,7 @@ def __init__( """ super(HuggingFaceAccelerateModel, self).__init__( - model_s3_uri, + model_id, role, number_of_partitions=number_of_partitions, **kwargs, diff --git a/tests/unit/test_djl_inference.py b/tests/unit/test_djl_inference.py index 65d031921c..c01bcabeea 100644 --- a/tests/unit/test_djl_inference.py +++ b/tests/unit/test_djl_inference.py @@ -15,8 +15,10 @@ import logging import json +from json import JSONDecodeError + import pytest -from mock import Mock +from mock import Mock, MagicMock from mock import patch, mock_open from sagemaker.djl_inference import ( @@ -31,6 +33,7 @@ VALID_UNCOMPRESSED_MODEL_DATA = "s3://mybucket/model" INVALID_UNCOMPRESSED_MODEL_DATA = "s3://mybucket/model.tar.gz" +HF_MODEL_ID = "hf_hub_model_id" ENTRY_POINT = "entrypoint.py" SOURCE_DIR = "source_dir/" ENV = {"ENV_VAR": "env_value"} @@ -70,12 +73,60 @@ def test_create_model_invalid_s3_uri(): "DJLModel does not support model artifacts in tar.gz" ) - with pytest.raises(ValueError) as invalid_s3_data: + +@patch("urllib.request.urlopen") +def test_create_model_valid_hf_hub_model_id( + mock_urlopen, + sagemaker_session, +): + model_config = { + "model_type": "opt", + "num_attention_heads": 4, + } + + cm = MagicMock() + cm.getcode.return_value = 200 + cm.read.return_value = json.dumps(model_config).encode("utf-8") + cm.__enter__.return_value = cm + mock_urlopen.return_value = cm + model = DJLModel( + HF_MODEL_ID, + ROLE, + sagemaker_session=sagemaker_session, + number_of_partitions=4, + ) + assert model.engine == DJLServingEngineEntryPointDefaults.DEEPSPEED + expected_url = f"https://huggingface.co/{HF_MODEL_ID}/raw/main/config.json" + mock_urlopen.assert_any_call(expected_url) + + serving_properties = model.generate_serving_properties() + assert serving_properties["option.model_id"] == HF_MODEL_ID + assert "option.s3url" not in serving_properties + + +@patch("json.load") +@patch("urllib.request.urlopen") +def test_create_model_invalid_hf_hub_model_id( + mock_urlopen, + json_load, + sagemaker_session, +): + expected_url = f"https://huggingface.co/{HF_MODEL_ID}/raw/main/config.json" + with pytest.raises(ValueError) as invalid_model_id: + cm = MagicMock() + cm.__enter__.return_value = cm + mock_urlopen.return_value = cm + json_load.side_effect = JSONDecodeError("", "", 0) _ = DJLModel( - SOURCE_DIR, + HF_MODEL_ID, ROLE, + sagemaker_session=sagemaker_session, + number_of_partitions=4, ) - assert str(invalid_s3_data.value).startswith("DJLModel only supports loading model artifacts") + mock_urlopen.assert_any_call(expected_url) + assert str(invalid_model_id.value).startswith( + "Did not find a config.json or model_index.json file in huggingface hub" + ) @patch("sagemaker.s3.S3Downloader.read_file")