|
14 | 14 | from __future__ import absolute_import
|
15 | 15 |
|
16 | 16 | import json
|
| 17 | +import logging |
17 | 18 | import os.path
|
18 | 19 | from enum import Enum
|
19 | 20 | from typing import Optional, Union, Dict
|
|
29 | 30 | from sagemaker.workflow.entities import PipelineVariable
|
30 | 31 |
|
31 | 32 |
|
| 33 | +logger = logging.getLogger("sagemaker") |
| 34 | + |
| 35 | + |
| 36 | +LOG_LEVEL_MAP = { |
| 37 | + logging.INFO: "info", |
| 38 | + logging.DEBUG: "debug", |
| 39 | + logging.WARNING: "warn", |
| 40 | + logging.ERROR: "error", |
| 41 | + logging.FATAL: "fatal", |
| 42 | + logging.CRITICAL: "fatal", |
| 43 | + logging.NOTSET: "off", |
| 44 | +} |
| 45 | + |
32 | 46 | class DJLEngine(Enum):
|
33 | 47 | DEEPSPEED = ("DeepSpeed", "djl_python.deepspeed")
|
34 | 48 | FASTER_TRANSFORMERS = ("FasterTransformers", "djl_python.faster_transformers")
|
@@ -220,6 +234,8 @@ def prepare_container_def(
|
220 | 234 | region_name = self.sagemaker_session.boto_session.region_name
|
221 | 235 | self.image_uri = self.serving_image_uri(region_name)
|
222 | 236 |
|
| 237 | + environment = self._get_container_env() |
| 238 | + |
223 | 239 | local_download_dir = (
|
224 | 240 | None
|
225 | 241 | if self.sagemaker_session.settings is None
|
@@ -258,7 +274,7 @@ def prepare_container_def(
|
258 | 274 | kms_key=self.model_kms_key,
|
259 | 275 | )
|
260 | 276 | return sagemaker.container_def(
|
261 |
| - self.image_uri, model_data_url=uploaded_code.s3_prefix, env=self.env |
| 277 | + self.image_uri, model_data_url=uploaded_code.s3_prefix, env=environment |
262 | 278 | )
|
263 | 279 |
|
264 | 280 | def generate_serving_properties(self, serving_properties={}) -> Dict[str, str]:
|
@@ -297,6 +313,18 @@ def serving_image_uri(self, region_name):
|
297 | 313 | version=self.djl_version,
|
298 | 314 | )
|
299 | 315 |
|
| 316 | + def _get_container_env(self): |
| 317 | + if not self.container_log_level: |
| 318 | + return self.env |
| 319 | + |
| 320 | + if self.container_log_level not in LOG_LEVEL_MAP: |
| 321 | + logger.warning(f"Ignoring invalid container log level: {self.container_log_level}") |
| 322 | + return self.env |
| 323 | + |
| 324 | + self.env["SERVING_OPTS"] = f'"-Dai.djl.logging.level={LOG_LEVEL_MAP[self.container_log_level]}"' |
| 325 | + return self.env |
| 326 | + |
| 327 | + |
300 | 328 |
|
301 | 329 | def _determine_engine_for_model_type(model_type: str):
|
302 | 330 | if model_type in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES:
|
|
0 commit comments