10
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
-
13
+ """Placeholder docstring"""
14
14
from __future__ import absolute_import
15
15
16
16
import json
19
19
from enum import Enum
20
20
from typing import Optional , Union , Dict
21
21
22
- from sagemaker .djl_inference import defaults
23
22
import sagemaker
23
+ from sagemaker .djl_inference import defaults
24
24
from sagemaker import s3 , Predictor , image_uris , fw_utils
25
25
from sagemaker .deserializers import JSONDeserializer
26
26
from sagemaker .model import FrameworkModel
32
32
33
33
logger = logging .getLogger ("sagemaker" )
34
34
35
- # DJL Serving uses log4j, so we convert python logging name to log4j equivalent
36
- LOG_LEVEL_MAP = {
35
+ # DJL Serving uses log4j, so we convert python logging level to log4j equivalent
36
+ _LOG_LEVEL_MAP = {
37
37
logging .INFO : "info" ,
38
38
logging .DEBUG : "debug" ,
39
39
logging .WARNING : "warn" ,
@@ -107,12 +107,9 @@ def __init__(
107
107
predictor_cls : callable = DJLLargeModelPredictor ,
108
108
** kwargs ,
109
109
):
110
- if kwargs .get ("model_data" ) is not None :
111
- raise ValueError (
112
- "DJLLargeModels do not support the model_data parameter. Please use"
113
- "uncompressed_model_data and ensure the s3 uri points to a folder containing"
114
- "all model artifacts, not a tar.gz file"
115
- )
110
+ if kwargs .get ("model_data" ):
111
+ logger .warning ("DJLLargeModels do not use model_data parameter. You only need to specify the"
112
+ "uncompressed_model_data parameter." )
116
113
super (DJLLargeModel , self ).__init__ (
117
114
None , image_uri , role , entry_point , predictor_cls = predictor_cls , ** kwargs
118
115
)
@@ -189,7 +186,7 @@ def deploy(
189
186
if instance_type is None and not self .inference_recommender_job_results :
190
187
raise ValueError (
191
188
f"instance_type must be specified, or inference recommendation from right_size()"
192
- "must be run to deploy the model. Supported instance type families are :"
189
+ "must be run to deploy the model. Supported instance type families are:"
193
190
f"{ defaults .ALLOWED_INSTANCE_FAMILIES } "
194
191
)
195
192
if instance_type :
@@ -225,11 +222,6 @@ def prepare_container_def(
225
222
accelerator_type = None ,
226
223
serverless_inference_config = None ,
227
224
):
228
- if serverless_inference_config is not None :
229
- raise ValueError ("DJLLargeModel does not support serverless deployment" )
230
- if accelerator_type is not None :
231
- raise ValueError ("DJLLargeModel does not support Elastic Inference accelerator" )
232
-
233
225
if not self .image_uri :
234
226
region_name = self .sagemaker_session .boto_session .region_name
235
227
self .image_uri = self .serving_image_uri (region_name )
@@ -243,8 +235,8 @@ def prepare_container_def(
243
235
else self .sagemaker_session .settings .local_download_dir
244
236
)
245
237
with _tmpdir (directory = local_download_dir ) as tmp :
246
- logger .info (f"Using tmp dir { tmp } " )
247
238
if self .source_dir or self .entry_point :
239
+ # Below method downloads from s3, or moves local files to tmp/code
248
240
_create_or_update_code_dir (
249
241
tmp ,
250
242
self .entry_point ,
@@ -253,11 +245,14 @@ def prepare_container_def(
253
245
self .sagemaker_session ,
254
246
tmp ,
255
247
)
256
- existing_serving_properties = _read_existing_serving_properties (tmp )
248
+ tmp_code_dir = os .path .join (tmp , "code" )
249
+ existing_serving_properties = _read_existing_serving_properties (tmp_code_dir )
257
250
kwargs_serving_properties = self .generate_serving_properties ()
258
251
existing_serving_properties .update (kwargs_serving_properties )
259
252
260
- with open (os .path .join (tmp , "serving.properties" ), "w+" ) as f :
253
+ if not os .path .exists (tmp_code_dir ):
254
+ os .mkdir (tmp_code_dir )
255
+ with open (os .path .join (tmp_code_dir , "serving.properties" ), "w+" ) as f :
261
256
for key , val in existing_serving_properties .items ():
262
257
f .write (f"{ key } ={ val } \n " )
263
258
@@ -270,7 +265,7 @@ def prepare_container_def(
270
265
bucket ,
271
266
deploy_key_prefix ,
272
267
self .entry_point ,
273
- directory = tmp ,
268
+ directory = tmp_code_dir ,
274
269
dependencies = self .dependencies ,
275
270
kms_key = self .model_kms_key ,
276
271
)
@@ -318,11 +313,11 @@ def _get_container_env(self):
318
313
if not self .container_log_level :
319
314
return self .env
320
315
321
- if self .container_log_level not in LOG_LEVEL_MAP :
316
+ if self .container_log_level not in _LOG_LEVEL_MAP :
322
317
logger .warning (f"Ignoring invalid container log level: { self .container_log_level } " )
323
318
return self .env
324
319
325
- self .env ["SERVING_OPTS" ] = f'"-Dai.djl.logging.level={ LOG_LEVEL_MAP [self .container_log_level ]} "'
320
+ self .env ["SERVING_OPTS" ] = f'"-Dai.djl.logging.level={ _LOG_LEVEL_MAP [self .container_log_level ]} "'
326
321
return self .env
327
322
328
323
0 commit comments