16
16
import json
17
17
import logging
18
18
import os .path
19
+ import subprocess
20
+ from tempfile import TemporaryDirectory
19
21
from enum import Enum
20
22
from typing import Optional , Union , Dict , Any
21
23
@@ -134,10 +136,10 @@ def _read_existing_serving_properties(directory: str):
134
136
135
137
def _get_model_config_properties_from_s3 (model_s3_uri : str ):
136
138
"""Placeholder docstring"""
139
+
137
140
s3_files = s3 .S3Downloader .list (model_s3_uri )
138
- valid_config_files = ["config.json" , "model_index.json" ]
139
141
model_config = None
140
- for config in valid_config_files :
142
+ for config in defaults . VALID_MODEL_CONFIG_FILES :
141
143
config_file = os .path .join (model_s3_uri , config )
142
144
if config_file in s3_files :
143
145
model_config = json .loads (s3 .S3Downloader .read_file (config_file ))
@@ -151,26 +153,52 @@ def _get_model_config_properties_from_s3(model_s3_uri: str):
151
153
return model_config
152
154
153
155
156
+ def _get_model_config_properties_from_hf (model_id : str ):
157
+ """Placeholder docstring"""
158
+
159
+ git_lfs_model_repo = "https://huggingface.co/" + model_id
160
+ my_env = os .environ .copy ()
161
+ my_env ["GIT_LFS_SKIP_SMUDGE" ] = "1"
162
+ my_env ["GIT_TERMINAL_PROMPT" ] = "0"
163
+ model_config = None
164
+ with TemporaryDirectory () as tmp_dir :
165
+ subprocess .check_call (["git" , "clone" , git_lfs_model_repo , tmp_dir ], env = my_env )
166
+ for config in defaults .VALID_MODEL_CONFIG_FILES :
167
+ config_file = os .path .join (tmp_dir , config )
168
+ if os .path .exists (config_file ):
169
+ with open (config_file , "r" ) as f :
170
+ model_config = json .load (f )
171
+ break
172
+ if not model_config :
173
+ raise ValueError (
174
+ f"Did not find a config.json or model_index.json file in huggingface hub for "
175
+ f"{ model_id } . Please make sure a config.json exists (or model_index.json for Stable "
176
+ f"Diffusion Models) for this model in the huggingface hub"
177
+ )
178
+ return model_config
179
+
180
+
154
181
class DJLModel (FrameworkModel ):
155
182
"""A DJL SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
156
183
157
184
def __new__ (
158
185
cls ,
159
- model_s3_uri : str ,
186
+ model_id : str ,
160
187
* args ,
161
188
** kwargs ,
162
189
): # pylint: disable=W0613
163
190
"""Create a specific subclass of DJLModel for a given engine"""
164
191
165
- if not model_s3_uri .startswith ("s3://" ):
166
- raise ValueError ("DJLModel only supports loading model artifacts from s3" )
167
- if model_s3_uri .endswith ("tar.gz" ):
192
+ if model_id .endswith ("tar.gz" ):
168
193
raise ValueError (
169
194
"DJLModel does not support model artifacts in tar.gz format."
170
195
"Please store the model in uncompressed format and provide the s3 uri of the "
171
196
"containing folder"
172
197
)
173
- model_config = _get_model_config_properties_from_s3 (model_s3_uri )
198
+ if model_id .startswith ("s3://" ):
199
+ model_config = _get_model_config_properties_from_s3 (model_id )
200
+ else :
201
+ model_config = _get_model_config_properties_from_hf (model_id )
174
202
if model_config .get ("_class_name" ) == "StableDiffusionPipeline" :
175
203
model_type = defaults .STABLE_DIFFUSION_MODEL_TYPE
176
204
num_heads = 0
@@ -196,7 +224,7 @@ def __new__(
196
224
197
225
def __init__ (
198
226
self ,
199
- model_s3_uri : str ,
227
+ model_id : str ,
200
228
role : str ,
201
229
djl_version : Optional [str ] = None ,
202
230
task : Optional [str ] = None ,
@@ -216,8 +244,9 @@ def __init__(
216
244
"""Initialize a DJLModel.
217
245
218
246
Args:
219
- model_s3_uri (str): The Amazon S3 location containing the uncompressed model
220
- artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
247
+ model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
248
+ containing the uncompressed model artifacts (i.e. not a tar.gz file).
249
+ The model artifacts are expected to be in HuggingFace pre-trained model
221
250
format (i.e. model should be loadable from the huggingface transformers
222
251
from_pretrained api, and should also include tokenizer configs if applicable).
223
252
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
@@ -285,13 +314,13 @@ def __init__(
285
314
if kwargs .get ("model_data" ):
286
315
logger .warning (
287
316
"DJLModels do not use model_data parameter. model_data parameter will be ignored."
288
- "You only need to set model_S3_uri and ensure it points to uncompressed model "
289
- "artifacts."
317
+ "You only need to set model_id and ensure it points to uncompressed model "
318
+ "artifacts in s3, or a valid HuggingFace Hub model_id ."
290
319
)
291
320
super (DJLModel , self ).__init__ (
292
321
None , image_uri , role , entry_point , predictor_cls = predictor_cls , ** kwargs
293
322
)
294
- self .model_s3_uri = model_s3_uri
323
+ self .model_id = model_id
295
324
self .djl_version = djl_version
296
325
self .task = task
297
326
self .data_type = data_type
@@ -529,7 +558,10 @@ def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]
529
558
serving_properties = {}
530
559
serving_properties ["engine" ] = self .engine .value [0 ] # pylint: disable=E1101
531
560
serving_properties ["option.entryPoint" ] = self .engine .value [1 ] # pylint: disable=E1101
532
- serving_properties ["option.s3url" ] = self .model_s3_uri
561
+ if self .model_id .startswith ("s3://" ):
562
+ serving_properties ["option.s3url" ] = self .model_id
563
+ else :
564
+ serving_properties ["option.model_id" ] = self .model_id
533
565
if self .number_of_partitions :
534
566
serving_properties ["option.tensor_parallel_degree" ] = self .number_of_partitions
535
567
if self .entry_point :
@@ -593,7 +625,7 @@ class DeepSpeedModel(DJLModel):
593
625
594
626
def __init__ (
595
627
self ,
596
- model_s3_uri : str ,
628
+ model_id : str ,
597
629
role : str ,
598
630
tensor_parallel_degree : Optional [int ] = None ,
599
631
max_tokens : Optional [int ] = None ,
@@ -606,11 +638,11 @@ def __init__(
606
638
"""Initialize a DeepSpeedModel
607
639
608
640
Args:
609
- model_s3_uri (str): The Amazon S3 location containing the uncompressed model
610
- artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
641
+ model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
642
+ containing the uncompressed model artifacts (i.e. not a tar.gz file).
643
+ The model artifacts are expected to be in HuggingFace pre-trained model
611
644
format (i.e. model should be loadable from the huggingface transformers
612
- from_pretrained
613
- api, and should also include tokenizer configs if applicable).
645
+ from_pretrained api, and should also include tokenizer configs if applicable).
614
646
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
615
647
SageMaker training jobs and APIs that create Amazon SageMaker
616
648
endpoints use this role to access model artifacts. After the endpoint is created,
@@ -647,7 +679,7 @@ def __init__(
647
679
"""
648
680
649
681
super (DeepSpeedModel , self ).__init__ (
650
- model_s3_uri ,
682
+ model_id ,
651
683
role ,
652
684
** kwargs ,
653
685
)
@@ -710,7 +742,7 @@ class HuggingFaceAccelerateModel(DJLModel):
710
742
711
743
def __init__ (
712
744
self ,
713
- model_s3_uri : str ,
745
+ model_id : str ,
714
746
role : str ,
715
747
number_of_partitions : Optional [int ] = None ,
716
748
device_id : Optional [int ] = None ,
@@ -722,11 +754,11 @@ def __init__(
722
754
"""Initialize a HuggingFaceAccelerateModel.
723
755
724
756
Args:
725
- model_s3_uri (str): The Amazon S3 location containing the uncompressed model
726
- artifacts. The model artifacts are expected to be in HuggingFace pre-trained model
757
+ model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location
758
+ containing the uncompressed model artifacts (i.e. not a tar.gz file).
759
+ The model artifacts are expected to be in HuggingFace pre-trained model
727
760
format (i.e. model should be loadable from the huggingface transformers
728
- from_pretrained
729
- method).
761
+ from_pretrained api, and should also include tokenizer configs if applicable).
730
762
role (str): An AWS IAM role specified with either the name or full ARN. The Amazon
731
763
SageMaker training jobs and APIs that create Amazon SageMaker
732
764
endpoints use this role to access model artifacts. After the endpoint is created,
@@ -760,7 +792,7 @@ def __init__(
760
792
"""
761
793
762
794
super (HuggingFaceAccelerateModel , self ).__init__ (
763
- model_s3_uri ,
795
+ model_id ,
764
796
role ,
765
797
number_of_partitions = number_of_partitions ,
766
798
** kwargs ,
0 commit comments