16
16
import json
17
17
import logging
18
18
import os .path
19
+ import subprocess
19
20
from enum import Enum
20
21
from typing import Optional , Union , Dict , Any
21
22
@@ -134,10 +135,10 @@ def _read_existing_serving_properties(directory: str):
134
135
135
136
def _get_model_config_properties_from_s3 (model_s3_uri : str ):
136
137
"""Placeholder docstring"""
138
+
137
139
s3_files = s3 .S3Downloader .list (model_s3_uri )
138
- valid_config_files = ["config.json" , "model_index.json" ]
139
140
model_config = None
140
- for config in valid_config_files :
141
+ for config in defaults . VALID_MODEL_CONFIG_FILES :
141
142
config_file = os .path .join (model_s3_uri , config )
142
143
if config_file in s3_files :
143
144
model_config = json .loads (s3 .S3Downloader .read_file (config_file ))
@@ -151,26 +152,65 @@ def _get_model_config_properties_from_s3(model_s3_uri: str):
151
152
return model_config
152
153
153
154
155
+ def _get_model_config_properties_from_hf (model_id : str ):
156
+ """Placeholder docstring"""
157
+
158
+ config_url_prefix = f"https://huggingface.co/{ model_id } /raw/main/"
159
+ model_config = None
160
+ for config in defaults .VALID_MODEL_CONFIG_FILES :
161
+ config_file_url = config_url_prefix + config
162
+ try :
163
+ output = subprocess .run (["curl" , config_file_url ], capture_output = True , check = True )
164
+ model_config = json .loads (output .stdout .decode ("utf-8" ))
165
+ break
166
+ except FileNotFoundError as e :
167
+ logger .error (
168
+ "Unable to download config file for %s from huggingface hub. "
169
+ "Make sure you have curl installed." ,
170
+ model_id ,
171
+ )
172
+ raise e
173
+ except subprocess .CalledProcessError as e :
174
+ logger .error ("curl %s process failed." , config_file_url )
175
+ raise e
176
+ except (ValueError , TypeError , IOError ) as e :
177
+ logger .debug (
178
+ "Did not find config file %s for model %s. Error details: %s, %s" ,
179
+ config_file_url ,
180
+ model_id ,
181
+ type (e ),
182
+ e ,
183
+ )
184
+ if not model_config :
185
+ raise ValueError (
186
+ f"Did not find a config.json or model_index.json file in huggingface hub for "
187
+ f"{ model_id } . Please make sure a config.json exists (or model_index.json for Stable "
188
+ f"Diffusion Models) for this model in the huggingface hub"
189
+ )
190
+ return model_config
191
+
192
+
154
193
class DJLModel (FrameworkModel ):
155
194
"""A DJL SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
156
195
157
196
def __new__ (
158
197
cls ,
159
- model_s3_uri : str ,
198
+ model_id : str ,
160
199
* args ,
161
200
** kwargs ,
162
201
): # pylint: disable=W0613
163
202
"""Create a specific subclass of DJLModel for a given engine"""
164
203
165
- if not model_s3_uri .startswith ("s3://" ):
166
- raise ValueError ("DJLModel only supports loading model artifacts from s3" )
167
- if model_s3_uri .endswith ("tar.gz" ):
204
+ if model_id .endswith ("tar.gz" ):
168
205
raise ValueError (
169
206
"DJLModel does not support model artifacts in tar.gz format."
170
207
"Please store the model in uncompressed format and provide the s3 uri of the "
171
208
"containing folder"
172
209
)
173
- model_config = _get_model_config_properties_from_s3 (model_s3_uri )
210
+ if model_id .startswith ("s3://" ):
211
+ model_config = _get_model_config_properties_from_s3 (model_id )
212
+ else :
213
+ model_config = _get_model_config_properties_from_hf (model_id )
174
214
if model_config .get ("_class_name" ) == "StableDiffusionPipeline" :
175
215
model_type = defaults .STABLE_DIFFUSION_MODEL_TYPE
176
216
num_heads = 0
@@ -196,7 +236,7 @@ def __new__(
196
236
197
237
def __init__ (
198
238
self ,
199
- model_s3_uri : str ,
239
+ model_id : str ,
200
240
role : str ,
201
241
djl_version : Optional [str ] = None ,
202
242
task : Optional [str ] = None ,
@@ -216,8 +256,9 @@ def __init__(
216
256
"""Initialize a DJLModel.
217
257
218
258
Args:
219
- model_s3_uri (str): The Amazon S3 location containing the uncompressed model
220
- artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
259
+ model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
260
+ containing the uncompressed model artifacts (i.e. not a tar.gz file).
261
+ The model artifacts are expected to be in HuggingFace pre-trained model
221
262
format (i.e. model should be loadable from the huggingface transformers
222
263
from_pretrained api, and should also include tokenizer configs if applicable).
223
264
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
@@ -285,13 +326,13 @@ def __init__(
285
326
if kwargs .get ("model_data" ):
286
327
logger .warning (
287
328
"DJLModels do not use model_data parameter. model_data parameter will be ignored."
288
- "You only need to set model_S3_uri and ensure it points to uncompressed model "
289
- "artifacts."
329
+ "You only need to set model_id and ensure it points to uncompressed model "
330
+ "artifacts in s3, or a valid HuggingFace Hub model_id ."
290
331
)
291
332
super (DJLModel , self ).__init__ (
292
333
None , image_uri , role , entry_point , predictor_cls = predictor_cls , ** kwargs
293
334
)
294
- self .model_s3_uri = model_s3_uri
335
+ self .model_id = model_id
295
336
self .djl_version = djl_version
296
337
self .task = task
297
338
self .data_type = data_type
@@ -529,7 +570,10 @@ def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]
529
570
serving_properties = {}
530
571
serving_properties ["engine" ] = self .engine .value [0 ] # pylint: disable=E1101
531
572
serving_properties ["option.entryPoint" ] = self .engine .value [1 ] # pylint: disable=E1101
532
- serving_properties ["option.s3url" ] = self .model_s3_uri
573
+ if self .model_id .startswith ("s3://" ):
574
+ serving_properties ["option.s3url" ] = self .model_id
575
+ else :
576
+ serving_properties ["option.model_id" ] = self .model_id
533
577
if self .number_of_partitions :
534
578
serving_properties ["option.tensor_parallel_degree" ] = self .number_of_partitions
535
579
if self .entry_point :
@@ -593,7 +637,7 @@ class DeepSpeedModel(DJLModel):
593
637
594
638
def __init__ (
595
639
self ,
596
- model_s3_uri : str ,
640
+ model_id : str ,
597
641
role : str ,
598
642
tensor_parallel_degree : Optional [int ] = None ,
599
643
max_tokens : Optional [int ] = None ,
@@ -606,11 +650,11 @@ def __init__(
606
650
"""Initialize a DeepSpeedModel
607
651
608
652
Args:
609
- model_s3_uri (str): The Amazon S3 location containing the uncompressed model
610
- artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
653
+ model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
654
+ containing the uncompressed model artifacts (i.e. not a tar.gz file).
655
+ The model artifacts are expected to be in HuggingFace pre-trained model
611
656
format (i.e. model should be loadable from the huggingface transformers
612
- from_pretrained
613
- api, and should also include tokenizer configs if applicable).
657
+ from_pretrained api, and should also include tokenizer configs if applicable).
614
658
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
615
659
SageMaker training jobs and APIs that create Amazon SageMaker
616
660
endpoints use this role to access model artifacts. After the endpoint is created,
@@ -647,7 +691,7 @@ def __init__(
647
691
"""
648
692
649
693
super (DeepSpeedModel , self ).__init__ (
650
- model_s3_uri ,
694
+ model_id ,
651
695
role ,
652
696
** kwargs ,
653
697
)
@@ -710,7 +754,7 @@ class HuggingFaceAccelerateModel(DJLModel):
710
754
711
755
def __init__ (
712
756
self ,
713
- model_s3_uri : str ,
757
+ model_id : str ,
714
758
role : str ,
715
759
number_of_partitions : Optional [int ] = None ,
716
760
device_id : Optional [int ] = None ,
@@ -722,11 +766,11 @@ def __init__(
722
766
"""Initialize a HuggingFaceAccelerateModel.
723
767
724
768
Args:
725
- model_s3_uri (str): The Amazon S3 location containing the uncompressed model
726
- artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
769
+ model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
770
+ containing the uncompressed model artifacts (i.e. not a tar.gz file).
771
+ The model artifacts are expected to be in HuggingFace pre-trained model
727
772
format (i.e. model should be loadable from the huggingface transformers
728
- from_pretrained
729
- method).
773
+ from_pretrained api, and should also include tokenizer configs if applicable).
730
774
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
731
775
SageMaker training jobs and APIs that create Amazon SageMaker
732
776
endpoints use this role to access model artifacts. After the endpoint is created,
@@ -760,7 +804,7 @@ def __init__(
760
804
"""
761
805
762
806
super (HuggingFaceAccelerateModel , self ).__init__ (
763
- model_s3_uri ,
807
+ model_id ,
764
808
role ,
765
809
number_of_partitions = number_of_partitions ,
766
810
** kwargs ,
0 commit comments