Skip to content

Commit c4c4e83

Browse files
authored
feature: Support huggingface hub model_id for DJL Models (#3753)
1 parent bd51517 commit c4c4e83

File tree

4 files changed

+161
-35
lines changed

4 files changed

+161
-35
lines changed

doc/frameworks/djl/using_djl.rst

+44-5
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ You can either deploy your model using DeepSpeed or HuggingFace Accelerate, or l
2929
3030
# Create a DJL Model, backend is chosen automatically
3131
djl_model = DJLModel(
32-
"s3://my_bucket/my_saved_model_artifacts/",
32+
"s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id
3333
"my_sagemaker_role",
3434
data_type="fp16",
3535
task="text-generation",
@@ -46,7 +46,7 @@ If you want to use a specific backend, then you can create an instance of the co
4646
4747
# Create a model using the DeepSpeed backend
4848
deepspeed_model = DeepSpeedModel(
49-
"s3://my_bucket/my_saved_model_artifacts/",
49+
"s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id
5050
"my_sagemaker_role",
5151
data_type="bf16",
5252
task="text-generation",
@@ -56,7 +56,7 @@ If you want to use a specific backend, then you can create an instance of the co
5656
# Create a model using the HuggingFace Accelerate backend
5757
5858
hf_accelerate_model = HuggingFaceAccelerateModel(
59-
"s3://my_bucket/my_saved_model_artifacts/",
59+
"s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id
6060
"my_sagemaker_role",
6161
data_type="fp16",
6262
task="text-generation",
@@ -91,9 +91,37 @@ model server configuration.
9191
Model Artifacts
9292
---------------
9393

94+
DJL Serving supports two ways to load models for inference.
95+
1. A HuggingFace Hub model id.
96+
2. Uncompressed model artifacts stored in a S3 bucket.
97+
98+
HuggingFace Hub model id
99+
^^^^^^^^^^^^^^^^^^^^^^^^
100+
101+
Using a HuggingFace Hub model id is the easiest way to get started with deploying Large Models via DJL Serving on SageMaker.
102+
DJL Serving will use this model id to download the model at runtime via the HuggingFace Transformers ``from_pretrained`` API.
103+
This method makes it easy to deploy models quickly, but for very large models the download time can become unreasonable.
104+
105+
For example, you can deploy the EleutherAI gpt-j-6B model like this:
106+
107+
.. code::
108+
109+
model = DJLModel(
110+
"EleutherAI/gpt-j-6B",
111+
"my_sagemaker_role",
112+
data_type="fp16",
113+
number_of_partitions=2
114+
)
115+
116+
predictor = model.deploy("ml.g5.12xlarge")
117+
118+
Uncompressed Model Artifacts stored in a S3 bucket
119+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
120+
121+
For models that are larger than 20GB (total checkpoint size), we recommend that you store the model in S3.
122+
Download times will be much faster compared to downloading from the HuggingFace Hub at runtime.
94123
DJL Serving Models expect a different model structure than most of the other frameworks in the SageMaker Python SDK.
95124
Specifically, DJLModels do not support loading models stored in tar.gz format.
96-
You must provide an Amazon S3 url pointing to uncompressed model artifacts (bucket and prefix).
97125
This is because DJL Serving is optimized for large models, and it implements a fast downloading mechanism for large models that require the artifacts be uncompressed.
98126

99127
For example, lets say you want to deploy the EleutherAI/gpt-j-6B model available on the HuggingFace Hub.
@@ -107,7 +135,18 @@ You can download the model and upload to S3 like this:
107135
# Upload to S3
108136
aws s3 sync gpt-j-6B s3://my_bucket/gpt-j-6B
109137
110-
You would then pass "s3://my_bucket/gpt-j-6B" as ``model_s3_uri`` to the ``DJLModel``.
138+
You would then pass "s3://my_bucket/gpt-j-6B" as ``model_id`` to the ``DJLModel`` like this:
139+
140+
.. code::
141+
142+
model = DJLModel(
143+
"s3://my_bucket/gpt-j-6B",
144+
"my_sagemaker_role",
145+
data_type="fp16",
146+
number_of_partitions=2
147+
)
148+
149+
predictor = model.deploy("ml.g5.12xlarge")
111150
112151
For language models we expect that the model weights, model config, and tokenizer config are provided in S3. The model
113152
should be loadable from the HuggingFace Transformers AutoModelFor<Task>.from_pretrained API, where task

src/sagemaker/djl_inference/defaults.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
STABLE_DIFFUSION_MODEL_TYPE = "stable-diffusion"
1717

18+
VALID_MODEL_CONFIG_FILES = ["config.json", "model_index.json"]
19+
1820
DEEPSPEED_RECOMMENDED_ARCHITECTURES = {
1921
"bloom",
2022
"opt",

src/sagemaker/djl_inference/model.py

+60-26
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
import json
1717
import logging
1818
import os.path
19+
import urllib.request
20+
from json import JSONDecodeError
21+
from urllib.error import HTTPError, URLError
1922
from enum import Enum
2023
from typing import Optional, Union, Dict, Any
2124

@@ -134,10 +137,10 @@ def _read_existing_serving_properties(directory: str):
134137

135138
def _get_model_config_properties_from_s3(model_s3_uri: str):
136139
"""Placeholder docstring"""
140+
137141
s3_files = s3.S3Downloader.list(model_s3_uri)
138-
valid_config_files = ["config.json", "model_index.json"]
139142
model_config = None
140-
for config in valid_config_files:
143+
for config in defaults.VALID_MODEL_CONFIG_FILES:
141144
config_file = os.path.join(model_s3_uri, config)
142145
if config_file in s3_files:
143146
model_config = json.loads(s3.S3Downloader.read_file(config_file))
@@ -151,26 +154,53 @@ def _get_model_config_properties_from_s3(model_s3_uri: str):
151154
return model_config
152155

153156

157+
def _get_model_config_properties_from_hf(model_id: str):
158+
"""Placeholder docstring"""
159+
160+
config_url_prefix = f"https://huggingface.co/{model_id}/raw/main/"
161+
model_config = None
162+
for config in defaults.VALID_MODEL_CONFIG_FILES:
163+
config_file_url = config_url_prefix + config
164+
try:
165+
with urllib.request.urlopen(config_file_url) as response:
166+
model_config = json.load(response)
167+
break
168+
except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e:
169+
logger.warning(
170+
"Exception encountered while trying to read config file %s. " "Details: %s",
171+
config_file_url,
172+
e,
173+
)
174+
if not model_config:
175+
raise ValueError(
176+
f"Did not find a config.json or model_index.json file in huggingface hub for "
177+
f"{model_id}. Please make sure a config.json exists (or model_index.json for Stable "
178+
f"Diffusion Models) for this model in the huggingface hub"
179+
)
180+
return model_config
181+
182+
154183
class DJLModel(FrameworkModel):
155184
"""A DJL SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
156185

157186
def __new__(
158187
cls,
159-
model_s3_uri: str,
188+
model_id: str,
160189
*args,
161190
**kwargs,
162191
): # pylint: disable=W0613
163192
"""Create a specific subclass of DJLModel for a given engine"""
164193

165-
if not model_s3_uri.startswith("s3://"):
166-
raise ValueError("DJLModel only supports loading model artifacts from s3")
167-
if model_s3_uri.endswith("tar.gz"):
194+
if model_id.endswith("tar.gz"):
168195
raise ValueError(
169196
"DJLModel does not support model artifacts in tar.gz format."
170197
"Please store the model in uncompressed format and provide the s3 uri of the "
171198
"containing folder"
172199
)
173-
model_config = _get_model_config_properties_from_s3(model_s3_uri)
200+
if model_id.startswith("s3://"):
201+
model_config = _get_model_config_properties_from_s3(model_id)
202+
else:
203+
model_config = _get_model_config_properties_from_hf(model_id)
174204
if model_config.get("_class_name") == "StableDiffusionPipeline":
175205
model_type = defaults.STABLE_DIFFUSION_MODEL_TYPE
176206
num_heads = 0
@@ -196,7 +226,7 @@ def __new__(
196226

197227
def __init__(
198228
self,
199-
model_s3_uri: str,
229+
model_id: str,
200230
role: str,
201231
djl_version: Optional[str] = None,
202232
task: Optional[str] = None,
@@ -216,8 +246,9 @@ def __init__(
216246
"""Initialize a DJLModel.
217247
218248
Args:
219-
model_s3_uri (str): The Amazon S3 location containing the uncompressed model
220-
artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
249+
model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
250+
containing the uncompressed model artifacts (i.e. not a tar.gz file).
251+
The model artifacts are expected to be in HuggingFace pre-trained model
221252
format (i.e. model should be loadable from the huggingface transformers
222253
from_pretrained api, and should also include tokenizer configs if applicable).
223254
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
@@ -285,13 +316,13 @@ def __init__(
285316
if kwargs.get("model_data"):
286317
logger.warning(
287318
"DJLModels do not use model_data parameter. model_data parameter will be ignored."
288-
"You only need to set model_S3_uri and ensure it points to uncompressed model "
289-
"artifacts."
319+
"You only need to set model_id and ensure it points to uncompressed model "
320+
"artifacts in s3, or a valid HuggingFace Hub model_id."
290321
)
291322
super(DJLModel, self).__init__(
292323
None, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
293324
)
294-
self.model_s3_uri = model_s3_uri
325+
self.model_id = model_id
295326
self.djl_version = djl_version
296327
self.task = task
297328
self.data_type = data_type
@@ -529,7 +560,10 @@ def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]
529560
serving_properties = {}
530561
serving_properties["engine"] = self.engine.value[0] # pylint: disable=E1101
531562
serving_properties["option.entryPoint"] = self.engine.value[1] # pylint: disable=E1101
532-
serving_properties["option.s3url"] = self.model_s3_uri
563+
if self.model_id.startswith("s3://"):
564+
serving_properties["option.s3url"] = self.model_id
565+
else:
566+
serving_properties["option.model_id"] = self.model_id
533567
if self.number_of_partitions:
534568
serving_properties["option.tensor_parallel_degree"] = self.number_of_partitions
535569
if self.entry_point:
@@ -593,7 +627,7 @@ class DeepSpeedModel(DJLModel):
593627

594628
def __init__(
595629
self,
596-
model_s3_uri: str,
630+
model_id: str,
597631
role: str,
598632
tensor_parallel_degree: Optional[int] = None,
599633
max_tokens: Optional[int] = None,
@@ -606,11 +640,11 @@ def __init__(
606640
"""Initialize a DeepSpeedModel
607641
608642
Args:
609-
model_s3_uri (str): The Amazon S3 location containing the uncompressed model
610-
artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
643+
model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
644+
containing the uncompressed model artifacts (i.e. not a tar.gz file).
645+
The model artifacts are expected to be in HuggingFace pre-trained model
611646
format (i.e. model should be loadable from the huggingface transformers
612-
from_pretrained
613-
api, and should also include tokenizer configs if applicable).
647+
from_pretrained api, and should also include tokenizer configs if applicable).
614648
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
615649
SageMaker training jobs and APIs that create Amazon SageMaker
616650
endpoints use this role to access model artifacts. After the endpoint is created,
@@ -647,7 +681,7 @@ def __init__(
647681
"""
648682

649683
super(DeepSpeedModel, self).__init__(
650-
model_s3_uri,
684+
model_id,
651685
role,
652686
**kwargs,
653687
)
@@ -710,7 +744,7 @@ class HuggingFaceAccelerateModel(DJLModel):
710744

711745
def __init__(
712746
self,
713-
model_s3_uri: str,
747+
model_id: str,
714748
role: str,
715749
number_of_partitions: Optional[int] = None,
716750
device_id: Optional[int] = None,
@@ -722,11 +756,11 @@ def __init__(
722756
"""Initialize a HuggingFaceAccelerateModel.
723757
724758
Args:
725-
model_s3_uri (str): The Amazon S3 location containing the uncompressed model
726-
artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
759+
model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
760+
containing the uncompressed model artifacts (i.e. not a tar.gz file).
761+
The model artifacts are expected to be in HuggingFace pre-trained model
727762
format (i.e. model should be loadable from the huggingface transformers
728-
from_pretrained
729-
method).
763+
from_pretrained api, and should also include tokenizer configs if applicable).
730764
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
731765
SageMaker training jobs and APIs that create Amazon SageMaker
732766
endpoints use this role to access model artifacts. After the endpoint is created,
@@ -760,7 +794,7 @@ def __init__(
760794
"""
761795

762796
super(HuggingFaceAccelerateModel, self).__init__(
763-
model_s3_uri,
797+
model_id,
764798
role,
765799
number_of_partitions=number_of_partitions,
766800
**kwargs,

tests/unit/test_djl_inference.py

+55-4
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
import logging
1616

1717
import json
18+
from json import JSONDecodeError
19+
1820
import pytest
19-
from mock import Mock
21+
from mock import Mock, MagicMock
2022
from mock import patch, mock_open
2123

2224
from sagemaker.djl_inference import (
@@ -31,6 +33,7 @@
3133

3234
VALID_UNCOMPRESSED_MODEL_DATA = "s3://mybucket/model"
3335
INVALID_UNCOMPRESSED_MODEL_DATA = "s3://mybucket/model.tar.gz"
36+
HF_MODEL_ID = "hf_hub_model_id"
3437
ENTRY_POINT = "entrypoint.py"
3538
SOURCE_DIR = "source_dir/"
3639
ENV = {"ENV_VAR": "env_value"}
@@ -73,12 +76,60 @@ def test_create_model_invalid_s3_uri():
7376
"DJLModel does not support model artifacts in tar.gz"
7477
)
7578

76-
with pytest.raises(ValueError) as invalid_s3_data:
79+
80+
@patch("urllib.request.urlopen")
81+
def test_create_model_valid_hf_hub_model_id(
82+
mock_urlopen,
83+
sagemaker_session,
84+
):
85+
model_config = {
86+
"model_type": "opt",
87+
"num_attention_heads": 4,
88+
}
89+
90+
cm = MagicMock()
91+
cm.getcode.return_value = 200
92+
cm.read.return_value = json.dumps(model_config).encode("utf-8")
93+
cm.__enter__.return_value = cm
94+
mock_urlopen.return_value = cm
95+
model = DJLModel(
96+
HF_MODEL_ID,
97+
ROLE,
98+
sagemaker_session=sagemaker_session,
99+
number_of_partitions=4,
100+
)
101+
assert model.engine == DJLServingEngineEntryPointDefaults.DEEPSPEED
102+
expected_url = f"https://huggingface.co/{HF_MODEL_ID}/raw/main/config.json"
103+
mock_urlopen.assert_any_call(expected_url)
104+
105+
serving_properties = model.generate_serving_properties()
106+
assert serving_properties["option.model_id"] == HF_MODEL_ID
107+
assert "option.s3url" not in serving_properties
108+
109+
110+
@patch("json.load")
111+
@patch("urllib.request.urlopen")
112+
def test_create_model_invalid_hf_hub_model_id(
113+
mock_urlopen,
114+
json_load,
115+
sagemaker_session,
116+
):
117+
expected_url = f"https://huggingface.co/{HF_MODEL_ID}/raw/main/config.json"
118+
with pytest.raises(ValueError) as invalid_model_id:
119+
cm = MagicMock()
120+
cm.__enter__.return_value = cm
121+
mock_urlopen.return_value = cm
122+
json_load.side_effect = JSONDecodeError("", "", 0)
77123
_ = DJLModel(
78-
SOURCE_DIR,
124+
HF_MODEL_ID,
79125
ROLE,
126+
sagemaker_session=sagemaker_session,
127+
number_of_partitions=4,
80128
)
81-
assert str(invalid_s3_data.value).startswith("DJLModel only supports loading model artifacts")
129+
mock_urlopen.assert_any_call(expected_url)
130+
assert str(invalid_model_id.value).startswith(
131+
"Did not find a config.json or model_index.json file in huggingface hub"
132+
)
82133

83134

84135
@patch("sagemaker.s3.S3Downloader.read_file")

0 commit comments

Comments
 (0)