Skip to content

Commit 886bb82

Browse files
committed
Fix issue with bad path mount in container when user supplies code
1 parent c4f8090 commit 886bb82

File tree

1 file changed

+17
-22
lines changed

1 file changed

+17
-22
lines changed

src/sagemaker/djl_inference/model.py

+17-22
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
13+
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

1616
import json
@@ -19,8 +19,8 @@
1919
from enum import Enum
2020
from typing import Optional, Union, Dict
2121

22-
from sagemaker.djl_inference import defaults
2322
import sagemaker
23+
from sagemaker.djl_inference import defaults
2424
from sagemaker import s3, Predictor, image_uris, fw_utils
2525
from sagemaker.deserializers import JSONDeserializer
2626
from sagemaker.model import FrameworkModel
@@ -32,8 +32,8 @@
3232

3333
logger = logging.getLogger("sagemaker")
3434

35-
# DJL Serving uses log4j, so we convert python logging name to log4j equivalent
36-
LOG_LEVEL_MAP = {
35+
# DJL Serving uses log4j, so we convert python logging level to log4j equivalent
36+
_LOG_LEVEL_MAP = {
3737
logging.INFO: "info",
3838
logging.DEBUG: "debug",
3939
logging.WARNING: "warn",
@@ -107,12 +107,9 @@ def __init__(
107107
predictor_cls: callable = DJLLargeModelPredictor,
108108
**kwargs,
109109
):
110-
if kwargs.get("model_data") is not None:
111-
raise ValueError(
112-
"DJLLargeModels do not support the model_data parameter. Please use"
113-
"uncompressed_model_data and ensure the s3 uri points to a folder containing"
114-
"all model artifacts, not a tar.gz file"
115-
)
110+
if kwargs.get("model_data"):
111+
logger.warning("DJLLargeModels do not use model_data parameter. You only need to specify the"
112+
"uncompressed_model_data parameter.")
116113
super(DJLLargeModel, self).__init__(
117114
None, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
118115
)
@@ -189,7 +186,7 @@ def deploy(
189186
if instance_type is None and not self.inference_recommender_job_results:
190187
raise ValueError(
191188
f"instance_type must be specified, or inference recommendation from right_size()"
192-
"must be run to deploy the model. Supported instance type families are :"
189+
"must be run to deploy the model. Supported instance type families are:"
193190
f"{defaults.ALLOWED_INSTANCE_FAMILIES}"
194191
)
195192
if instance_type:
@@ -225,11 +222,6 @@ def prepare_container_def(
225222
accelerator_type=None,
226223
serverless_inference_config=None,
227224
):
228-
if serverless_inference_config is not None:
229-
raise ValueError("DJLLargeModel does not support serverless deployment")
230-
if accelerator_type is not None:
231-
raise ValueError("DJLLargeModel does not support Elastic Inference accelerator")
232-
233225
if not self.image_uri:
234226
region_name = self.sagemaker_session.boto_session.region_name
235227
self.image_uri = self.serving_image_uri(region_name)
@@ -243,8 +235,8 @@ def prepare_container_def(
243235
else self.sagemaker_session.settings.local_download_dir
244236
)
245237
with _tmpdir(directory=local_download_dir) as tmp:
246-
logger.info(f"Using tmp dir {tmp}")
247238
if self.source_dir or self.entry_point:
239+
# Below method downloads from s3, or moves local files to tmp/code
248240
_create_or_update_code_dir(
249241
tmp,
250242
self.entry_point,
@@ -253,11 +245,14 @@ def prepare_container_def(
253245
self.sagemaker_session,
254246
tmp,
255247
)
256-
existing_serving_properties = _read_existing_serving_properties(tmp)
248+
tmp_code_dir = os.path.join(tmp, "code")
249+
existing_serving_properties = _read_existing_serving_properties(tmp_code_dir)
257250
kwargs_serving_properties = self.generate_serving_properties()
258251
existing_serving_properties.update(kwargs_serving_properties)
259252

260-
with open(os.path.join(tmp, "serving.properties"), "w+") as f:
253+
if not os.path.exists(tmp_code_dir):
254+
os.mkdir(tmp_code_dir)
255+
with open(os.path.join(tmp_code_dir, "serving.properties"), "w+") as f:
261256
for key, val in existing_serving_properties.items():
262257
f.write(f"{key}={val}\n")
263258

@@ -270,7 +265,7 @@ def prepare_container_def(
270265
bucket,
271266
deploy_key_prefix,
272267
self.entry_point,
273-
directory=tmp,
268+
directory=tmp_code_dir,
274269
dependencies=self.dependencies,
275270
kms_key=self.model_kms_key,
276271
)
@@ -318,11 +313,11 @@ def _get_container_env(self):
318313
if not self.container_log_level:
319314
return self.env
320315

321-
if self.container_log_level not in LOG_LEVEL_MAP:
316+
if self.container_log_level not in _LOG_LEVEL_MAP:
322317
logger.warning(f"Ignoring invalid container log level: {self.container_log_level}")
323318
return self.env
324319

325-
self.env["SERVING_OPTS"] = f'"-Dai.djl.logging.level={LOG_LEVEL_MAP[self.container_log_level]}"'
320+
self.env["SERVING_OPTS"] = f'"-Dai.djl.logging.level={_LOG_LEVEL_MAP[self.container_log_level]}"'
326321
return self.env
327322

328323

0 commit comments

Comments
 (0)