43
43
logging .NOTSET : "off" ,
44
44
}
45
45
46
+
46
47
class DJLEngine (Enum ):
47
48
DEEPSPEED = ("DeepSpeed" , "djl_python.deepspeed" )
48
49
HUGGINGFACE_ACCELERATE = ("Python" , "djl_python.huggingface" )
@@ -54,6 +55,7 @@ class DJLLargeModelPredictor(Predictor):
54
55
This is able to serialize Python lists, dictionaries, and numpy arrays to
55
56
multidimensional tensors for DJL inference.
56
57
"""
58
+
57
59
def __init__ (
58
60
self ,
59
61
endpoint_name ,
@@ -140,7 +142,9 @@ def __new__(
140
142
config_file = uncompressed_model_data + "/config.json"
141
143
142
144
model_type = json .loads (s3 .S3Downloader .read_file (config_file )).get ("model_type" )
143
- cls_to_create = cls if cls is not DJLLargeModel else _determine_engine_for_model_type (model_type )
145
+ cls_to_create = (
146
+ cls if cls is not DJLLargeModel else _determine_engine_for_model_type (model_type )
147
+ )
144
148
return super (DJLLargeModel , cls ).__new__ (cls_to_create )
145
149
146
150
def __init__ (
@@ -227,8 +231,10 @@ def __init__(
227
231
model class directly.
228
232
"""
229
233
if kwargs .get ("model_data" ):
230
- logger .warning ("DJLLargeModels do not use model_data parameter. model_data parameter will be ignored."
231
- "You only need to set uncompressed_model_data" )
234
+ logger .warning (
235
+ "DJLLargeModels do not use model_data parameter. model_data parameter will be ignored."
236
+ "You only need to set uncompressed_model_data"
237
+ )
232
238
super (DJLLargeModel , self ).__init__ (
233
239
None , image_uri , role , entry_point , predictor_cls = predictor_cls , ** kwargs
234
240
)
@@ -580,12 +586,15 @@ def _get_container_env(self):
580
586
logger .warning (f"Ignoring invalid container log level: { self .container_log_level } " )
581
587
return self .env
582
588
583
- self .env ["SERVING_OPTS" ] = f'"-Dai.djl.logging.level={ _LOG_LEVEL_MAP [self .container_log_level ]} "'
589
+ self .env [
590
+ "SERVING_OPTS"
591
+ ] = f'"-Dai.djl.logging.level={ _LOG_LEVEL_MAP [self .container_log_level ]} "'
584
592
return self .env
585
593
586
594
587
595
class DeepSpeedModel (DJLLargeModel ):
588
596
"""A DeepSpeed SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``"""
597
+
589
598
_framework_name = "djl-deepspeed"
590
599
591
600
def __init__ (
@@ -679,8 +688,8 @@ def generate_serving_properties(self, serving_properties={}) -> Dict[str, str]:
679
688
680
689
681
690
class HuggingFaceAccelerateModel (DJLLargeModel ):
682
- """A Hugging Face SageMaker ``Model`` using HuggingFace Accelerate that can be deployed to a SageMaker ``Endpoint``.
683
- """
691
+ """A Hugging Face SageMaker ``Model`` using HuggingFace Accelerate that can be deployed to a SageMaker ``Endpoint``."""
692
+
684
693
_framework_name = "djl-deepspeed"
685
694
686
695
def __init__ (
0 commit comments