Skip to content

Commit 3e8b430

Browse files
committed
Set logging env variable for djl-serving
1 parent 4cdd1fe commit 3e8b430

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

src/sagemaker/djl_inference/defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
}
3939

4040
ALLOWED_INSTANCE_FAMILIES = {
41-
"ml.g4",
41+
"ml.g4dn",
4242
"ml.g5",
4343
"ml.p3",
4444
"ml.p4",

src/sagemaker/djl_inference/model.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import json
17+
import logging
1718
import os.path
1819
from enum import Enum
1920
from typing import Optional, Union, Dict
@@ -29,6 +30,19 @@
2930
from sagemaker.workflow.entities import PipelineVariable
3031

3132

33+
logger = logging.getLogger("sagemaker")
34+
35+
36+
LOG_LEVEL_MAP = {
37+
logging.INFO: "info",
38+
logging.DEBUG: "debug",
39+
logging.WARNING: "warn",
40+
logging.ERROR: "error",
41+
logging.FATAL: "fatal",
42+
logging.CRITICAL: "fatal",
43+
logging.NOTSET: "off",
44+
}
45+
3246
class DJLEngine(Enum):
3347
DEEPSPEED = ("DeepSpeed", "djl_python.deepspeed")
3448
FASTER_TRANSFORMERS = ("FasterTransformers", "djl_python.faster_transformers")
@@ -220,6 +234,8 @@ def prepare_container_def(
220234
region_name = self.sagemaker_session.boto_session.region_name
221235
self.image_uri = self.serving_image_uri(region_name)
222236

237+
environment = self._get_container_env()
238+
223239
local_download_dir = (
224240
None
225241
if self.sagemaker_session.settings is None
@@ -258,7 +274,7 @@ def prepare_container_def(
258274
kms_key=self.model_kms_key,
259275
)
260276
return sagemaker.container_def(
261-
self.image_uri, model_data_url=uploaded_code.s3_prefix, env=self.env
277+
self.image_uri, model_data_url=uploaded_code.s3_prefix, env=environment
262278
)
263279

264280
def generate_serving_properties(self, serving_properties={}) -> Dict[str, str]:
@@ -297,6 +313,18 @@ def serving_image_uri(self, region_name):
297313
version=self.djl_version,
298314
)
299315

316+
def _get_container_env(self):
317+
if not self.container_log_level:
318+
return self.env
319+
320+
if self.container_log_level not in LOG_LEVEL_MAP:
321+
logger.warning(f"Ignoring invalid container log level: {self.container_log_level}")
322+
return self.env
323+
324+
self.env["SERVING_OPTS"] = f'"-Dai.djl.logging.level={LOG_LEVEL_MAP[self.container_log_level]}"'
325+
return self.env
326+
327+
300328

301329
def _determine_engine_for_model_type(model_type: str):
302330
if model_type in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES:

0 commit comments

Comments
 (0)