Skip to content

Commit 51b989f

Browse files
committed
format python code
1 parent 802ac5b commit 51b989f

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

src/sagemaker/djl_inference/model.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
logging.NOTSET: "off",
4444
}
4545

46+
4647
class DJLEngine(Enum):
4748
DEEPSPEED = ("DeepSpeed", "djl_python.deepspeed")
4849
HUGGINGFACE_ACCELERATE = ("Python", "djl_python.huggingface")
@@ -54,6 +55,7 @@ class DJLLargeModelPredictor(Predictor):
5455
This is able to serialize Python lists, dictionaries, and numpy arrays to
5556
multidimensional tensors for DJL inference.
5657
"""
58+
5759
def __init__(
5860
self,
5961
endpoint_name,
@@ -140,7 +142,9 @@ def __new__(
140142
config_file = uncompressed_model_data + "/config.json"
141143

142144
model_type = json.loads(s3.S3Downloader.read_file(config_file)).get("model_type")
143-
cls_to_create = cls if cls is not DJLLargeModel else _determine_engine_for_model_type(model_type)
145+
cls_to_create = (
146+
cls if cls is not DJLLargeModel else _determine_engine_for_model_type(model_type)
147+
)
144148
return super(DJLLargeModel, cls).__new__(cls_to_create)
145149

146150
def __init__(
@@ -227,8 +231,10 @@ def __init__(
227231
model class directly.
228232
"""
229233
if kwargs.get("model_data"):
230-
logger.warning("DJLLargeModels do not use model_data parameter. model_data parameter will be ignored."
231-
"You only need to set uncompressed_model_data")
234+
logger.warning(
235+
"DJLLargeModels do not use model_data parameter. model_data parameter will be ignored."
236+
"You only need to set uncompressed_model_data"
237+
)
232238
super(DJLLargeModel, self).__init__(
233239
None, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
234240
)
@@ -580,12 +586,15 @@ def _get_container_env(self):
580586
logger.warning(f"Ignoring invalid container log level: {self.container_log_level}")
581587
return self.env
582588

583-
self.env["SERVING_OPTS"] = f'"-Dai.djl.logging.level={_LOG_LEVEL_MAP[self.container_log_level]}"'
589+
self.env[
590+
"SERVING_OPTS"
591+
] = f'"-Dai.djl.logging.level={_LOG_LEVEL_MAP[self.container_log_level]}"'
584592
return self.env
585593

586594

587595
class DeepSpeedModel(DJLLargeModel):
588596
"""A DeepSpeed SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``"""
597+
589598
_framework_name = "djl-deepspeed"
590599

591600
def __init__(
@@ -679,8 +688,8 @@ def generate_serving_properties(self, serving_properties={}) -> Dict[str, str]:
679688

680689

681690
class HuggingFaceAccelerateModel(DJLLargeModel):
682-
"""A Hugging Face SageMaker ``Model`` using HuggingFace Accelerate that can be deployed to a SageMaker ``Endpoint``.
683-
"""
691+
"""A Hugging Face SageMaker ``Model`` using HuggingFace Accelerate that can be deployed to a SageMaker ``Endpoint``."""
692+
684693
_framework_name = "djl-deepspeed"
685694

686695
def __init__(

0 commit comments

Comments
 (0)