16
16
import json
17
17
import logging
18
18
import os .path
19
+ import urllib .request
20
+ from json import JSONDecodeError
21
+ from urllib .error import HTTPError , URLError
19
22
from enum import Enum
20
23
from typing import Optional , Union , Dict , Any
21
24
@@ -134,10 +137,10 @@ def _read_existing_serving_properties(directory: str):
134
137
135
138
def _get_model_config_properties_from_s3 (model_s3_uri : str ):
136
139
"""Placeholder docstring"""
140
+
137
141
s3_files = s3 .S3Downloader .list (model_s3_uri )
138
- valid_config_files = ["config.json" , "model_index.json" ]
139
142
model_config = None
140
- for config in valid_config_files :
143
+ for config in defaults . VALID_MODEL_CONFIG_FILES :
141
144
config_file = os .path .join (model_s3_uri , config )
142
145
if config_file in s3_files :
143
146
model_config = json .loads (s3 .S3Downloader .read_file (config_file ))
@@ -151,26 +154,53 @@ def _get_model_config_properties_from_s3(model_s3_uri: str):
151
154
return model_config
152
155
153
156
157
+ def _get_model_config_properties_from_hf (model_id : str ):
158
+ """Placeholder docstring"""
159
+
160
+ config_url_prefix = f"https://huggingface.co/{ model_id } /raw/main/"
161
+ model_config = None
162
+ for config in defaults .VALID_MODEL_CONFIG_FILES :
163
+ config_file_url = config_url_prefix + config
164
+ try :
165
+ with urllib .request .urlopen (config_file_url ) as response :
166
+ model_config = json .load (response )
167
+ break
168
+ except (HTTPError , URLError , TimeoutError , JSONDecodeError ) as e :
169
+ logger .warning (
170
+ "Exception encountered while trying to read config file %s. " "Details: %s" ,
171
+ config_file_url ,
172
+ e ,
173
+ )
174
+ if not model_config :
175
+ raise ValueError (
176
+ f"Did not find a config.json or model_index.json file in huggingface hub for "
177
+ f"{ model_id } . Please make sure a config.json exists (or model_index.json for Stable "
178
+ f"Diffusion Models) for this model in the huggingface hub"
179
+ )
180
+ return model_config
181
+
182
+
154
183
class DJLModel (FrameworkModel ):
155
184
"""A DJL SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
156
185
157
186
def __new__ (
158
187
cls ,
159
- model_s3_uri : str ,
188
+ model_id : str ,
160
189
* args ,
161
190
** kwargs ,
162
191
): # pylint: disable=W0613
163
192
"""Create a specific subclass of DJLModel for a given engine"""
164
193
165
- if not model_s3_uri .startswith ("s3://" ):
166
- raise ValueError ("DJLModel only supports loading model artifacts from s3" )
167
- if model_s3_uri .endswith ("tar.gz" ):
194
+ if model_id .endswith ("tar.gz" ):
168
195
raise ValueError (
169
196
"DJLModel does not support model artifacts in tar.gz format."
170
197
"Please store the model in uncompressed format and provide the s3 uri of the "
171
198
"containing folder"
172
199
)
173
- model_config = _get_model_config_properties_from_s3 (model_s3_uri )
200
+ if model_id .startswith ("s3://" ):
201
+ model_config = _get_model_config_properties_from_s3 (model_id )
202
+ else :
203
+ model_config = _get_model_config_properties_from_hf (model_id )
174
204
if model_config .get ("_class_name" ) == "StableDiffusionPipeline" :
175
205
model_type = defaults .STABLE_DIFFUSION_MODEL_TYPE
176
206
num_heads = 0
@@ -196,7 +226,7 @@ def __new__(
196
226
197
227
def __init__ (
198
228
self ,
199
- model_s3_uri : str ,
229
+ model_id : str ,
200
230
role : str ,
201
231
djl_version : Optional [str ] = None ,
202
232
task : Optional [str ] = None ,
@@ -216,8 +246,9 @@ def __init__(
216
246
"""Initialize a DJLModel.
217
247
218
248
Args:
219
- model_s3_uri (str): The Amazon S3 location containing the uncompressed model
220
- artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
249
+ model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
250
+ containing the uncompressed model artifacts (i.e. not a tar.gz file).
251
+ The model artifacts are expected to be in HuggingFace pre-trained model
221
252
format (i.e. model should be loadable from the huggingface transformers
222
253
from_pretrained api, and should also include tokenizer configs if applicable).
223
254
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
@@ -285,13 +316,13 @@ def __init__(
285
316
if kwargs .get ("model_data" ):
286
317
logger .warning (
287
318
"DJLModels do not use model_data parameter. model_data parameter will be ignored."
288
- "You only need to set model_S3_uri and ensure it points to uncompressed model "
289
- "artifacts."
319
+ "You only need to set model_id and ensure it points to uncompressed model "
320
+ "artifacts in s3, or a valid HuggingFace Hub model_id ."
290
321
)
291
322
super (DJLModel , self ).__init__ (
292
323
None , image_uri , role , entry_point , predictor_cls = predictor_cls , ** kwargs
293
324
)
294
- self .model_s3_uri = model_s3_uri
325
+ self .model_id = model_id
295
326
self .djl_version = djl_version
296
327
self .task = task
297
328
self .data_type = data_type
@@ -529,7 +560,10 @@ def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]
529
560
serving_properties = {}
530
561
serving_properties ["engine" ] = self .engine .value [0 ] # pylint: disable=E1101
531
562
serving_properties ["option.entryPoint" ] = self .engine .value [1 ] # pylint: disable=E1101
532
- serving_properties ["option.s3url" ] = self .model_s3_uri
563
+ if self .model_id .startswith ("s3://" ):
564
+ serving_properties ["option.s3url" ] = self .model_id
565
+ else :
566
+ serving_properties ["option.model_id" ] = self .model_id
533
567
if self .number_of_partitions :
534
568
serving_properties ["option.tensor_parallel_degree" ] = self .number_of_partitions
535
569
if self .entry_point :
@@ -593,7 +627,7 @@ class DeepSpeedModel(DJLModel):
593
627
594
628
def __init__ (
595
629
self ,
596
- model_s3_uri : str ,
630
+ model_id : str ,
597
631
role : str ,
598
632
tensor_parallel_degree : Optional [int ] = None ,
599
633
max_tokens : Optional [int ] = None ,
@@ -606,11 +640,11 @@ def __init__(
606
640
"""Initialize a DeepSpeedModel
607
641
608
642
Args:
609
- model_s3_uri (str): The Amazon S3 location containing the uncompressed model
610
- artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
643
+ model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
644
+ containing the uncompressed model artifacts (i.e. not a tar.gz file).
645
+ The model artifacts are expected to be in HuggingFace pre-trained model
611
646
format (i.e. model should be loadable from the huggingface transformers
612
- from_pretrained
613
- api, and should also include tokenizer configs if applicable).
647
+ from_pretrained api, and should also include tokenizer configs if applicable).
614
648
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
615
649
SageMaker training jobs and APIs that create Amazon SageMaker
616
650
endpoints use this role to access model artifacts. After the endpoint is created,
@@ -647,7 +681,7 @@ def __init__(
647
681
"""
648
682
649
683
super (DeepSpeedModel , self ).__init__ (
650
- model_s3_uri ,
684
+ model_id ,
651
685
role ,
652
686
** kwargs ,
653
687
)
@@ -710,7 +744,7 @@ class HuggingFaceAccelerateModel(DJLModel):
710
744
711
745
def __init__ (
712
746
self ,
713
- model_s3_uri : str ,
747
+ model_id : str ,
714
748
role : str ,
715
749
number_of_partitions : Optional [int ] = None ,
716
750
device_id : Optional [int ] = None ,
@@ -722,11 +756,11 @@ def __init__(
722
756
"""Initialize a HuggingFaceAccelerateModel.
723
757
724
758
Args:
725
- model_s3_uri (str): The Amazon S3 location containing the uncompressed model
726
- artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
759
+ model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
760
+ containing the uncompressed model artifacts (i.e. not a tar.gz file).
761
+ The model artifacts are expected to be in HuggingFace pre-trained model
727
762
format (i.e. model should be loadable from the huggingface transformers
728
- from_pretrained
729
- method).
763
+ from_pretrained api, and should also include tokenizer configs if applicable).
730
764
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
731
765
SageMaker training jobs and APIs that create Amazon SageMaker
732
766
endpoints use this role to access model artifacts. After the endpoint is created,
@@ -760,7 +794,7 @@ def __init__(
760
794
"""
761
795
762
796
super (HuggingFaceAccelerateModel , self ).__init__ (
763
- model_s3_uri ,
797
+ model_id ,
764
798
role ,
765
799
number_of_partitions = number_of_partitions ,
766
800
** kwargs ,
0 commit comments