Skip to content

Commit 914a656

Browse files
committed
[change]: Support huggingface hub model_id for DJL Models
1 parent 04e3f60 commit 914a656

File tree

4 files changed

+185
-35
lines changed

4 files changed

+185
-35
lines changed

doc/frameworks/djl/using_djl.rst

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ You can either deploy your model using DeepSpeed or HuggingFace Accelerate, or l
2929
3030
# Create a DJL Model, backend is chosen automatically
3131
djl_model = DJLModel(
32-
"s3://my_bucket/my_saved_model_artifacts/",
32+
"s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id
3333
"my_sagemaker_role",
3434
data_type="fp16",
3535
task="text-generation",
@@ -46,7 +46,7 @@ If you want to use a specific backend, then you can create an instance of the co
4646
4747
# Create a model using the DeepSpeed backend
4848
deepspeed_model = DeepSpeedModel(
49-
"s3://my_bucket/my_saved_model_artifacts/",
49+
"s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id
5050
"my_sagemaker_role",
5151
data_type="bf16",
5252
task="text-generation",
@@ -56,7 +56,7 @@ If you want to use a specific backend, then you can create an instance of the co
5656
# Create a model using the HuggingFace Accelerate backend
5757
5858
hf_accelerate_model = HuggingFaceAccelerateModel(
59-
"s3://my_bucket/my_saved_model_artifacts/",
59+
"s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id
6060
"my_sagemaker_role",
6161
data_type="fp16",
6262
task="text-generation",
@@ -91,9 +91,36 @@ model server configuration.
9191
Model Artifacts
9292
---------------
9393

94+
DJL Serving supports two ways to load models for inference.
95+
1. A HuggingFace Hub model id.
96+
2. Uncompressed model artifacts stored in a S3 bucket.
97+
98+
HuggingFace Hub model id
99+
^^^^^^^^^^^^^^^^^^^^^^^^
100+
101+
Using a HuggingFace Hub model id is the easiest way to get started with deploying Large Models via DJL Serving on SageMaker.
102+
DJL Serving will use this model id to download the model at runtime via the HuggingFace Transformers ``from_pretrained`` API.
103+
This method makes it easy to deploy models quickly, but for very large models the download time can become unreasonable.
104+
105+
For example, you can deploy the EleutherAI gpt-j-6B model like this:
106+
107+
.. code::
108+
model = DJLModel(
109+
"EleutherAI/gpt-j-6B",
110+
"my_sagemaker_role",
111+
data_type="fp16",
112+
number_of_partitions=2
113+
)
114+
115+
predictor = model.deploy("ml.g5.12xlarge")
116+
117+
Uncompressed Model Artifacts stored in a S3 bucket
118+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
119+
120+
For models that are larger than 20GB (total checkpoint size), we recommend that you store the model in S3.
121+
Download times will be much faster compared to downloading from the HuggingFace Hub at runtime.
94122
DJL Serving Models expect a different model structure than most of the other frameworks in the SageMaker Python SDK.
95123
Specifically, DJLModels do not support loading models stored in tar.gz format.
96-
You must provide an Amazon S3 url pointing to uncompressed model artifacts (bucket and prefix).
97124
This is because DJL Serving is optimized for large models, and it implements a fast downloading mechanism for large models that require the artifacts be uncompressed.
98125

99126
For example, lets say you want to deploy the EleutherAI/gpt-j-6B model available on the HuggingFace Hub.
@@ -107,7 +134,17 @@ You can download the model and upload to S3 like this:
107134
# Upload to S3
108135
aws s3 sync gpt-j-6B s3://my_bucket/gpt-j-6B
109136
110-
You would then pass "s3://my_bucket/gpt-j-6B" as ``model_s3_uri`` to the ``DJLModel``.
137+
You would then pass "s3://my_bucket/gpt-j-6B" as ``model_id`` to the ``DJLModel`` like this:
138+
139+
.. code::
140+
model = DJLModel(
141+
"s3://my_bucket/gpt-j-6B",
142+
"my_sagemaker_role",
143+
data_type="fp16",
144+
number_of_partitions=2
145+
)
146+
147+
predictor = model.deploy("ml.g5.12xlarge")
111148
112149
For language models we expect that the model weights, model config, and tokenizer config are provided in S3. The model
113150
should be loadable from the HuggingFace Transformers AutoModelFor<Task>.from_pretrained API, where task

src/sagemaker/djl_inference/defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
STABLE_DIFFUSION_MODEL_TYPE = "stable-diffusion"
1717

18+
VALID_MODEL_CONFIG_FILES = ["config.json", "model_index.json"]
19+
1820
DEEPSPEED_RECOMMENDED_ARCHITECTURES = {
1921
"bloom",
2022
"opt",

src/sagemaker/djl_inference/model.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import json
1717
import logging
1818
import os.path
19+
import subprocess
20+
from tempfile import TemporaryDirectory
1921
from enum import Enum
2022
from typing import Optional, Union, Dict, Any
2123

@@ -134,10 +136,10 @@ def _read_existing_serving_properties(directory: str):
134136

135137
def _get_model_config_properties_from_s3(model_s3_uri: str):
136138
"""Placeholder docstring"""
139+
137140
s3_files = s3.S3Downloader.list(model_s3_uri)
138-
valid_config_files = ["config.json", "model_index.json"]
139141
model_config = None
140-
for config in valid_config_files:
142+
for config in defaults.VALID_MODEL_CONFIG_FILES:
141143
config_file = os.path.join(model_s3_uri, config)
142144
if config_file in s3_files:
143145
model_config = json.loads(s3.S3Downloader.read_file(config_file))
@@ -151,26 +153,52 @@ def _get_model_config_properties_from_s3(model_s3_uri: str):
151153
return model_config
152154

153155

156+
def _get_model_config_properties_from_hf(model_id: str):
157+
"""Placeholder docstring"""
158+
159+
git_lfs_model_repo = "https://huggingface.co/" + model_id
160+
my_env = os.environ.copy()
161+
my_env["GIT_LFS_SKIP_SMUDGE"] = "1"
162+
my_env["GIT_TERMINAL_PROMPT"] = "0"
163+
model_config = None
164+
with TemporaryDirectory() as tmp_dir:
165+
subprocess.check_call(["git", "clone", git_lfs_model_repo, tmp_dir], env=my_env)
166+
for config in defaults.VALID_MODEL_CONFIG_FILES:
167+
config_file = os.path.join(tmp_dir, config)
168+
if os.path.exists(config_file):
169+
with open(config_file, "r") as f:
170+
model_config = json.load(f)
171+
break
172+
if not model_config:
173+
raise ValueError(
174+
f"Did not find a config.json or model_index.json file in huggingface hub for "
175+
f"{model_id}. Please make sure a config.json exists (or model_index.json for Stable "
176+
f"Diffusion Models) for this model in the huggingface hub"
177+
)
178+
return model_config
179+
180+
154181
class DJLModel(FrameworkModel):
155182
"""A DJL SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
156183

157184
def __new__(
158185
cls,
159-
model_s3_uri: str,
186+
model_id: str,
160187
*args,
161188
**kwargs,
162189
): # pylint: disable=W0613
163190
"""Create a specific subclass of DJLModel for a given engine"""
164191

165-
if not model_s3_uri.startswith("s3://"):
166-
raise ValueError("DJLModel only supports loading model artifacts from s3")
167-
if model_s3_uri.endswith("tar.gz"):
192+
if model_id.endswith("tar.gz"):
168193
raise ValueError(
169194
"DJLModel does not support model artifacts in tar.gz format."
170195
"Please store the model in uncompressed format and provide the s3 uri of the "
171196
"containing folder"
172197
)
173-
model_config = _get_model_config_properties_from_s3(model_s3_uri)
198+
if model_id.startswith("s3://"):
199+
model_config = _get_model_config_properties_from_s3(model_id)
200+
else:
201+
model_config = _get_model_config_properties_from_hf(model_id)
174202
if model_config.get("_class_name") == "StableDiffusionPipeline":
175203
model_type = defaults.STABLE_DIFFUSION_MODEL_TYPE
176204
num_heads = 0
@@ -196,7 +224,7 @@ def __new__(
196224

197225
def __init__(
198226
self,
199-
model_s3_uri: str,
227+
model_id: str,
200228
role: str,
201229
djl_version: Optional[str] = None,
202230
task: Optional[str] = None,
@@ -216,8 +244,9 @@ def __init__(
216244
"""Initialize a DJLModel.
217245
218246
Args:
219-
model_s3_uri (str): The Amazon S3 location containing the uncompressed model
220-
artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
247+
model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
248+
containing the uncompressed model artifacts (i.e. not a tar.gz file).
249+
The model artifacts are expected to be in HuggingFace pre-trained model
221250
format (i.e. model should be loadable from the huggingface transformers
222251
from_pretrained api, and should also include tokenizer configs if applicable).
223252
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
@@ -285,13 +314,13 @@ def __init__(
285314
if kwargs.get("model_data"):
286315
logger.warning(
287316
"DJLModels do not use model_data parameter. model_data parameter will be ignored."
288-
"You only need to set model_S3_uri and ensure it points to uncompressed model "
289-
"artifacts."
317+
"You only need to set model_id and ensure it points to uncompressed model "
318+
"artifacts in s3, or a valid HuggingFace Hub model_id."
290319
)
291320
super(DJLModel, self).__init__(
292321
None, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
293322
)
294-
self.model_s3_uri = model_s3_uri
323+
self.model_id = model_id
295324
self.djl_version = djl_version
296325
self.task = task
297326
self.data_type = data_type
@@ -529,7 +558,10 @@ def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]
529558
serving_properties = {}
530559
serving_properties["engine"] = self.engine.value[0] # pylint: disable=E1101
531560
serving_properties["option.entryPoint"] = self.engine.value[1] # pylint: disable=E1101
532-
serving_properties["option.s3url"] = self.model_s3_uri
561+
if self.model_id.startswith("s3://"):
562+
serving_properties["option.s3url"] = self.model_id
563+
else:
564+
serving_properties["option.model_id"] = self.model_id
533565
if self.number_of_partitions:
534566
serving_properties["option.tensor_parallel_degree"] = self.number_of_partitions
535567
if self.entry_point:
@@ -593,7 +625,7 @@ class DeepSpeedModel(DJLModel):
593625

594626
def __init__(
595627
self,
596-
model_s3_uri: str,
628+
model_id: str,
597629
role: str,
598630
tensor_parallel_degree: Optional[int] = None,
599631
max_tokens: Optional[int] = None,
@@ -606,11 +638,11 @@ def __init__(
606638
"""Initialize a DeepSpeedModel
607639
608640
Args:
609-
model_s3_uri (str): The Amazon S3 location containing the uncompressed model
610-
artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
641+
model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
642+
containing the uncompressed model artifacts (i.e. not a tar.gz file).
643+
The model artifacts are expected to be in HuggingFace pre-trained model
611644
format (i.e. model should be loadable from the huggingface transformers
612-
from_pretrained
613-
api, and should also include tokenizer configs if applicable).
645+
from_pretrained api, and should also include tokenizer configs if applicable).
614646
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
615647
SageMaker training jobs and APIs that create Amazon SageMaker
616648
endpoints use this role to access model artifacts. After the endpoint is created,
@@ -647,7 +679,7 @@ def __init__(
647679
"""
648680

649681
super(DeepSpeedModel, self).__init__(
650-
model_s3_uri,
682+
model_id,
651683
role,
652684
**kwargs,
653685
)
@@ -710,7 +742,7 @@ class HuggingFaceAccelerateModel(DJLModel):
710742

711743
def __init__(
712744
self,
713-
model_s3_uri: str,
745+
model_id: str,
714746
role: str,
715747
number_of_partitions: Optional[int] = None,
716748
device_id: Optional[int] = None,
@@ -722,11 +754,11 @@ def __init__(
722754
"""Initialize a HuggingFaceAccelerateModel.
723755
724756
Args:
725-
model_s3_uri (str): The Amazon S3 location containing the uncompressed model
726-
artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
757+
model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
758+
containing the uncompressed model artifacts (i.e. not a tar.gz file).
759+
The model artifacts are expected to be in HuggingFace pre-trained model
727760
format (i.e. model should be loadable from the huggingface transformers
728-
from_pretrained
729-
method).
761+
from_pretrained api, and should also include tokenizer configs if applicable).
730762
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
731763
SageMaker training jobs and APIs that create Amazon SageMaker
732764
endpoints use this role to access model artifacts. After the endpoint is created,
@@ -760,7 +792,7 @@ def __init__(
760792
"""
761793

762794
super(HuggingFaceAccelerateModel, self).__init__(
763-
model_s3_uri,
795+
model_id,
764796
role,
765797
number_of_partitions=number_of_partitions,
766798
**kwargs,

0 commit comments

Comments
 (0)