Skip to content

Commit 310ac7f

Browse files
committed
[change]: Support huggingface hub model_id for DJL Models
1 parent 04e3f60 commit 310ac7f

File tree

4 files changed

+164
-34
lines changed

4 files changed

+164
-34
lines changed

doc/frameworks/djl/using_djl.rst

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ You can either deploy your model using DeepSpeed or HuggingFace Accelerate, or l
2929
3030
# Create a DJL Model, backend is chosen automatically
3131
djl_model = DJLModel(
32-
"s3://my_bucket/my_saved_model_artifacts/",
32+
"s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id
3333
"my_sagemaker_role",
3434
data_type="fp16",
3535
task="text-generation",
@@ -46,7 +46,7 @@ If you want to use a specific backend, then you can create an instance of the co
4646
4747
# Create a model using the DeepSpeed backend
4848
deepspeed_model = DeepSpeedModel(
49-
"s3://my_bucket/my_saved_model_artifacts/",
49+
"s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id
5050
"my_sagemaker_role",
5151
data_type="bf16",
5252
task="text-generation",
@@ -56,7 +56,7 @@ If you want to use a specific backend, then you can create an instance of the co
5656
# Create a model using the HuggingFace Accelerate backend
5757
5858
hf_accelerate_model = HuggingFaceAccelerateModel(
59-
"s3://my_bucket/my_saved_model_artifacts/",
59+
"s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id
6060
"my_sagemaker_role",
6161
data_type="fp16",
6262
task="text-generation",
@@ -91,9 +91,37 @@ model server configuration.
9191
Model Artifacts
9292
---------------
9393

94+
DJL Serving supports two ways to load models for inference.
95+
1. A HuggingFace Hub model id.
96+
2. Uncompressed model artifacts stored in a S3 bucket.
97+
98+
HuggingFace Hub model id
99+
^^^^^^^^^^^^^^^^^^^^^^^^
100+
101+
Using a HuggingFace Hub model id is the easiest way to get started with deploying Large Models via DJL Serving on SageMaker.
102+
DJL Serving will use this model id to download the model at runtime via the HuggingFace Transformers ``from_pretrained`` API.
103+
This method makes it easy to deploy models quickly, but for very large models the download time can become unreasonable.
104+
105+
For example, you can deploy the EleutherAI gpt-j-6B model like this:
106+
107+
.. code::
108+
109+
model = DJLModel(
110+
"EleutherAI/gpt-j-6B",
111+
"my_sagemaker_role",
112+
data_type="fp16",
113+
number_of_partitions=2
114+
)
115+
116+
predictor = model.deploy("ml.g5.12xlarge")
117+
118+
Uncompressed Model Artifacts stored in a S3 bucket
119+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
120+
121+
For models that are larger than 20GB (total checkpoint size), we recommend that you store the model in S3.
122+
Download times will be much faster compared to downloading from the HuggingFace Hub at runtime.
94123
DJL Serving Models expect a different model structure than most of the other frameworks in the SageMaker Python SDK.
95124
Specifically, DJLModels do not support loading models stored in tar.gz format.
96-
You must provide an Amazon S3 url pointing to uncompressed model artifacts (bucket and prefix).
97125
This is because DJL Serving is optimized for large models, and it implements a fast downloading mechanism for large models that require the artifacts be uncompressed.
98126

99127
For example, lets say you want to deploy the EleutherAI/gpt-j-6B model available on the HuggingFace Hub.
@@ -107,7 +135,18 @@ You can download the model and upload to S3 like this:
107135
# Upload to S3
108136
aws s3 sync gpt-j-6B s3://my_bucket/gpt-j-6B
109137
110-
You would then pass "s3://my_bucket/gpt-j-6B" as ``model_s3_uri`` to the ``DJLModel``.
138+
You would then pass "s3://my_bucket/gpt-j-6B" as ``model_id`` to the ``DJLModel`` like this:
139+
140+
.. code::
141+
142+
model = DJLModel(
143+
"s3://my_bucket/gpt-j-6B",
144+
"my_sagemaker_role",
145+
data_type="fp16",
146+
number_of_partitions=2
147+
)
148+
149+
predictor = model.deploy("ml.g5.12xlarge")
111150
112151
For language models we expect that the model weights, model config, and tokenizer config are provided in S3. The model
113152
should be loadable from the HuggingFace Transformers AutoModelFor<Task>.from_pretrained API, where task

src/sagemaker/djl_inference/defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
STABLE_DIFFUSION_MODEL_TYPE = "stable-diffusion"
1717

18+
VALID_MODEL_CONFIG_FILES = ["config.json", "model_index.json"]
19+
1820
DEEPSPEED_RECOMMENDED_ARCHITECTURES = {
1921
"bloom",
2022
"opt",

src/sagemaker/djl_inference/model.py

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import logging
1818
import os.path
19+
import subprocess
1920
from enum import Enum
2021
from typing import Optional, Union, Dict, Any
2122

@@ -134,10 +135,10 @@ def _read_existing_serving_properties(directory: str):
134135

135136
def _get_model_config_properties_from_s3(model_s3_uri: str):
136137
"""Placeholder docstring"""
138+
137139
s3_files = s3.S3Downloader.list(model_s3_uri)
138-
valid_config_files = ["config.json", "model_index.json"]
139140
model_config = None
140-
for config in valid_config_files:
141+
for config in defaults.VALID_MODEL_CONFIG_FILES:
141142
config_file = os.path.join(model_s3_uri, config)
142143
if config_file in s3_files:
143144
model_config = json.loads(s3.S3Downloader.read_file(config_file))
@@ -151,26 +152,65 @@ def _get_model_config_properties_from_s3(model_s3_uri: str):
151152
return model_config
152153

153154

155+
def _get_model_config_properties_from_hf(model_id: str):
156+
"""Placeholder docstring"""
157+
158+
config_url_prefix = f"https://huggingface.co/{model_id}/raw/main/"
159+
model_config = None
160+
for config in defaults.VALID_MODEL_CONFIG_FILES:
161+
config_file_url = config_url_prefix + config
162+
try:
163+
output = subprocess.run(["curl", config_file_url], capture_output=True, check=True)
164+
model_config = json.loads(output.stdout.decode("utf-8"))
165+
break
166+
except FileNotFoundError as e:
167+
logger.error(
168+
"Unable to download config file for %s from huggingface hub. "
169+
"Make sure you have curl installed.",
170+
model_id,
171+
)
172+
raise e
173+
except subprocess.CalledProcessError as e:
174+
logger.error("curl %s process failed.", config_file_url)
175+
raise e
176+
except (ValueError, TypeError, IOError) as e:
177+
logger.debug(
178+
"Did not find config file %s for model %s. Error details: %s, %s",
179+
config_file_url,
180+
model_id,
181+
type(e),
182+
e,
183+
)
184+
if not model_config:
185+
raise ValueError(
186+
f"Did not find a config.json or model_index.json file in huggingface hub for "
187+
f"{model_id}. Please make sure a config.json exists (or model_index.json for Stable "
188+
f"Diffusion Models) for this model in the huggingface hub"
189+
)
190+
return model_config
191+
192+
154193
class DJLModel(FrameworkModel):
155194
"""A DJL SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
156195

157196
def __new__(
158197
cls,
159-
model_s3_uri: str,
198+
model_id: str,
160199
*args,
161200
**kwargs,
162201
): # pylint: disable=W0613
163202
"""Create a specific subclass of DJLModel for a given engine"""
164203

165-
if not model_s3_uri.startswith("s3://"):
166-
raise ValueError("DJLModel only supports loading model artifacts from s3")
167-
if model_s3_uri.endswith("tar.gz"):
204+
if model_id.endswith("tar.gz"):
168205
raise ValueError(
169206
"DJLModel does not support model artifacts in tar.gz format."
170207
"Please store the model in uncompressed format and provide the s3 uri of the "
171208
"containing folder"
172209
)
173-
model_config = _get_model_config_properties_from_s3(model_s3_uri)
210+
if model_id.startswith("s3://"):
211+
model_config = _get_model_config_properties_from_s3(model_id)
212+
else:
213+
model_config = _get_model_config_properties_from_hf(model_id)
174214
if model_config.get("_class_name") == "StableDiffusionPipeline":
175215
model_type = defaults.STABLE_DIFFUSION_MODEL_TYPE
176216
num_heads = 0
@@ -196,7 +236,7 @@ def __new__(
196236

197237
def __init__(
198238
self,
199-
model_s3_uri: str,
239+
model_id: str,
200240
role: str,
201241
djl_version: Optional[str] = None,
202242
task: Optional[str] = None,
@@ -216,8 +256,9 @@ def __init__(
216256
"""Initialize a DJLModel.
217257
218258
Args:
219-
model_s3_uri (str): The Amazon S3 location containing the uncompressed model
220-
artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
259+
model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
260+
containing the uncompressed model artifacts (i.e. not a tar.gz file).
261+
The model artifacts are expected to be in HuggingFace pre-trained model
221262
format (i.e. model should be loadable from the huggingface transformers
222263
from_pretrained api, and should also include tokenizer configs if applicable).
223264
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
@@ -285,13 +326,13 @@ def __init__(
285326
if kwargs.get("model_data"):
286327
logger.warning(
287328
"DJLModels do not use model_data parameter. model_data parameter will be ignored."
288-
"You only need to set model_S3_uri and ensure it points to uncompressed model "
289-
"artifacts."
329+
"You only need to set model_id and ensure it points to uncompressed model "
330+
"artifacts in s3, or a valid HuggingFace Hub model_id."
290331
)
291332
super(DJLModel, self).__init__(
292333
None, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
293334
)
294-
self.model_s3_uri = model_s3_uri
335+
self.model_id = model_id
295336
self.djl_version = djl_version
296337
self.task = task
297338
self.data_type = data_type
@@ -529,7 +570,10 @@ def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]
529570
serving_properties = {}
530571
serving_properties["engine"] = self.engine.value[0] # pylint: disable=E1101
531572
serving_properties["option.entryPoint"] = self.engine.value[1] # pylint: disable=E1101
532-
serving_properties["option.s3url"] = self.model_s3_uri
573+
if self.model_id.startswith("s3://"):
574+
serving_properties["option.s3url"] = self.model_id
575+
else:
576+
serving_properties["option.model_id"] = self.model_id
533577
if self.number_of_partitions:
534578
serving_properties["option.tensor_parallel_degree"] = self.number_of_partitions
535579
if self.entry_point:
@@ -593,7 +637,7 @@ class DeepSpeedModel(DJLModel):
593637

594638
def __init__(
595639
self,
596-
model_s3_uri: str,
640+
model_id: str,
597641
role: str,
598642
tensor_parallel_degree: Optional[int] = None,
599643
max_tokens: Optional[int] = None,
@@ -606,11 +650,11 @@ def __init__(
606650
"""Initialize a DeepSpeedModel
607651
608652
Args:
609-
model_s3_uri (str): The Amazon S3 location containing the uncompressed model
610-
artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
653+
model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
654+
containing the uncompressed model artifacts (i.e. not a tar.gz file).
655+
The model artifacts are expected to be in HuggingFace pre-trained model
611656
format (i.e. model should be loadable from the huggingface transformers
612-
from_pretrained
613-
api, and should also include tokenizer configs if applicable).
657+
from_pretrained api, and should also include tokenizer configs if applicable).
614658
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
615659
SageMaker training jobs and APIs that create Amazon SageMaker
616660
endpoints use this role to access model artifacts. After the endpoint is created,
@@ -647,7 +691,7 @@ def __init__(
647691
"""
648692

649693
super(DeepSpeedModel, self).__init__(
650-
model_s3_uri,
694+
model_id,
651695
role,
652696
**kwargs,
653697
)
@@ -710,7 +754,7 @@ class HuggingFaceAccelerateModel(DJLModel):
710754

711755
def __init__(
712756
self,
713-
model_s3_uri: str,
757+
model_id: str,
714758
role: str,
715759
number_of_partitions: Optional[int] = None,
716760
device_id: Optional[int] = None,
@@ -722,11 +766,11 @@ def __init__(
722766
"""Initialize a HuggingFaceAccelerateModel.
723767
724768
Args:
725-
model_s3_uri (str): The Amazon S3 location containing the uncompressed model
726-
artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
769+
model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
770+
containing the uncompressed model artifacts (i.e. not a tar.gz file).
771+
The model artifacts are expected to be in HuggingFace pre-trained model
727772
format (i.e. model should be loadable from the huggingface transformers
728-
from_pretrained
729-
method).
773+
from_pretrained api, and should also include tokenizer configs if applicable).
730774
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
731775
SageMaker training jobs and APIs that create Amazon SageMaker
732776
endpoints use this role to access model artifacts. After the endpoint is created,
@@ -760,7 +804,7 @@ def __init__(
760804
"""
761805

762806
super(HuggingFaceAccelerateModel, self).__init__(
763-
model_s3_uri,
807+
model_id,
764808
role,
765809
number_of_partitions=number_of_partitions,
766810
**kwargs,

tests/unit/test_djl_inference.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import logging
1616

1717
import json
18+
from json import JSONDecodeError
19+
1820
import pytest
1921
from mock import Mock
2022
from mock import patch, mock_open
@@ -31,6 +33,7 @@
3133

3234
VALID_UNCOMPRESSED_MODEL_DATA = "s3://mybucket/model"
3335
INVALID_UNCOMPRESSED_MODEL_DATA = "s3://mybucket/model.tar.gz"
36+
HF_MODEL_ID = "hf_hub_model_id"
3437
ENTRY_POINT = "entrypoint.py"
3538
SOURCE_DIR = "source_dir/"
3639
ENV = {"ENV_VAR": "env_value"}
@@ -70,12 +73,54 @@ def test_create_model_invalid_s3_uri():
7073
"DJLModel does not support model artifacts in tar.gz"
7174
)
7275

73-
with pytest.raises(ValueError) as invalid_s3_data:
76+
77+
@patch("json.loads")
78+
@patch("subprocess.run")
79+
def test_create_model_valid_hf_hub_model_id(
80+
subprocess_run,
81+
json_loads,
82+
sagemaker_session,
83+
):
84+
model_config = {
85+
"model_type": "opt",
86+
"num_attention_heads": 4,
87+
}
88+
json_loads.return_value = model_config
89+
model = DJLModel(
90+
HF_MODEL_ID,
91+
ROLE,
92+
sagemaker_session=sagemaker_session,
93+
number_of_partitions=4,
94+
)
95+
assert model.engine == DJLServingEngineEntryPointDefaults.DEEPSPEED
96+
expected_git_url = f"https://huggingface.co/{HF_MODEL_ID}/raw/main/config.json"
97+
subprocess_run.assert_any_call(["curl", expected_git_url], capture_output=True, check=True)
98+
99+
serving_properties = model.generate_serving_properties()
100+
assert serving_properties["option.model_id"] == HF_MODEL_ID
101+
assert "option.s3url" not in serving_properties
102+
103+
104+
@patch("json.loads")
105+
@patch("subprocess.run")
106+
def test_create_model_invalid_hf_hub_model_id(
107+
subprocess_run,
108+
json_loads,
109+
sagemaker_session,
110+
):
111+
expected_git_url = f"https://huggingface.co/{HF_MODEL_ID}/raw/main/config.json"
112+
with pytest.raises(ValueError) as invalid_model_id:
113+
json_loads.side_effect = JSONDecodeError("", "", 0)
74114
_ = DJLModel(
75-
SOURCE_DIR,
115+
HF_MODEL_ID,
76116
ROLE,
117+
sagemaker_session=sagemaker_session,
118+
number_of_partitions=4,
77119
)
78-
assert str(invalid_s3_data.value).startswith("DJLModel only supports loading model artifacts")
120+
subprocess_run.assert_any_call(["curl", expected_git_url], capture_output=True, check=True)
121+
assert str(invalid_model_id.value).startswith(
122+
"Did not find a config.json or model_index.json file in huggingface hub"
123+
)
79124

80125

81126
@patch("sagemaker.s3.S3Downloader.read_file")

0 commit comments

Comments
 (0)