47
47
from sagemaker .serve .model_server .djl_serving .prepare import (
48
48
_create_dir_structure ,
49
49
)
50
- from sagemaker .serve .utils .predictors import DjlLocalModePredictor
50
+ from sagemaker .serve .utils .predictors import InProcessModePredictor , DjlLocalModePredictor
51
51
from sagemaker .serve .utils .types import ModelServer
52
52
from sagemaker .serve .mode .function_pointers import Mode
53
53
from sagemaker .serve .utils .telemetry_logger import _capture_telemetry
54
54
from sagemaker .djl_inference .model import DJLModel
55
55
from sagemaker .base_predictor import PredictorBase
56
56
57
57
logger = logging .getLogger (__name__ )
58
+ LOCAL_MODES = [Mode .LOCAL_CONTAINER , Mode .IN_PROCESS ]
58
59
59
60
# Match JumpStart DJL entrypoint format
60
61
_CODE_FOLDER = "code"
@@ -77,6 +78,7 @@ def __init__(self):
77
78
self .mode = None
78
79
self .model_server = None
79
80
self .image_uri = None
81
+ self .inference_spec = None
80
82
self ._is_custom_image_uri = False
81
83
self .image_config = None
82
84
self .vpc_config = None
@@ -96,11 +98,11 @@ def __init__(self):
96
98
97
99
@abstractmethod
98
100
def _prepare_for_mode (self ):
99
- """Placeholder docstring """
101
+ """Abstract method """
100
102
101
103
@abstractmethod
102
104
def _get_client_translators (self ):
103
- """Placeholder docstring """
105
+ """Abstract method """
104
106
105
107
def _is_djl (self ):
106
108
"""Placeholder docstring"""
@@ -146,7 +148,7 @@ def _create_djl_model(self) -> Type[Model]:
146
148
147
149
@_capture_telemetry ("djl.deploy" )
148
150
def _djl_model_builder_deploy_wrapper (self , * args , ** kwargs ) -> Type [PredictorBase ]:
149
- """Placeholder docstring """
151
+ """Returns predictor depending on local mode or endpoint mode """
150
152
timeout = kwargs .get ("model_data_download_timeout" )
151
153
if timeout :
152
154
self .env_vars .update ({"MODEL_LOADING_TIMEOUT" : str (timeout )})
@@ -189,6 +191,18 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
189
191
190
192
serializer = self .schema_builder .input_serializer
191
193
deserializer = self .schema_builder ._output_deserializer
194
+
195
+ if self .mode == Mode .IN_PROCESS :
196
+
197
+ predictor = InProcessModePredictor (
198
+ self .modes [str (Mode .IN_PROCESS )], serializer , deserializer
199
+ )
200
+
201
+ self .modes [str (Mode .IN_PROCESS )].create_server (
202
+ predictor ,
203
+ )
204
+ return predictor
205
+
192
206
if self .mode == Mode .LOCAL_CONTAINER :
193
207
timeout = kwargs .get ("model_data_download_timeout" )
194
208
@@ -249,9 +263,15 @@ def _build_for_hf_djl(self):
249
263
250
264
_create_dir_structure (self .model_path )
251
265
if not hasattr (self , "pysdk_model" ):
252
- self .env_vars .update ({"HF_MODEL_ID" : self .model })
266
+ if self .inference_spec is not None :
267
+ self .env_vars .update ({"HF_MODEL_ID" : self .inference_spec .get_model ()})
268
+ else :
269
+ self .env_vars .update ({"HF_MODEL_ID" : self .model })
270
+
271
+ logger .info (self .env_vars )
272
+
253
273
self .hf_model_config = _get_model_config_properties_from_hf (
254
- self .model , self .env_vars .get ("HF_TOKEN" )
274
+ self .env_vars . get ( "HF_MODEL_ID" ) , self .env_vars .get ("HF_TOKEN" )
255
275
)
256
276
default_djl_configurations , _default_max_new_tokens = _get_default_djl_configurations (
257
277
self .model , self .hf_model_config , self .schema_builder
@@ -260,9 +280,10 @@ def _build_for_hf_djl(self):
260
280
self .schema_builder .sample_input ["parameters" ][
261
281
"max_new_tokens"
262
282
] = _default_max_new_tokens
283
+
263
284
self .pysdk_model = self ._create_djl_model ()
264
285
265
- if self .mode == Mode . LOCAL_CONTAINER :
286
+ if self .mode in LOCAL_MODES :
266
287
self ._prepare_for_mode ()
267
288
268
289
return self .pysdk_model
@@ -451,7 +472,6 @@ def _build_for_djl(self):
451
472
"""Placeholder docstring"""
452
473
self ._validate_djl_serving_sample_data ()
453
474
self .secret_key = None
454
-
455
475
self .pysdk_model = self ._build_for_hf_djl ()
456
476
self .pysdk_model .tune = self ._tune_for_hf_djl
457
477
if self .role_arn :
0 commit comments