Skip to content

Commit 8f2c921

Browse files
committed
update model builder for new DJLModel implementation
1 parent e816567 commit 8f2c921

File tree

17 files changed

+183
-649
lines changed

17 files changed

+183
-649
lines changed

src/sagemaker/djl_inference/model.py

-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ def _infer_image_uri(self):
161161
version=self.djl_version,
162162
)
163163

164-
165164
def _configure_environment_variables(self) -> Dict[str, str]:
166165
env = self.env.copy() if self.env else {}
167166
env = _set_env_var_from_property(self.model_id, "HF_MODEL_ID", env)

src/sagemaker/serve/builder/djl_builder.py

+47-96
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import logging
1616
from typing import Type
1717
from abc import ABC, abstractmethod
18-
from pathlib import Path
1918
from datetime import datetime, timedelta
2019

2120
from sagemaker.model import Model
@@ -31,12 +30,12 @@
3130
_more_performant,
3231
_pretty_print_results,
3332
)
33+
from sagemaker.serve.utils.hf_utils import _get_model_config_properties_from_hf
3434
from sagemaker.serve.model_server.djl_serving.utils import (
35-
_auto_detect_engine,
36-
_set_serve_properties,
3735
_get_admissible_tensor_parallel_degrees,
3836
_get_admissible_dtypes,
3937
_get_default_tensor_parallel_degree,
38+
_get_default_djl_configurations,
4039
)
4140
from sagemaker.serve.utils.local_hardware import (
4241
_get_nb_instance,
@@ -45,24 +44,18 @@
4544
_get_gpu_info_fallback,
4645
)
4746
from sagemaker.serve.model_server.djl_serving.prepare import (
48-
prepare_for_djl_serving,
4947
_create_dir_structure,
5048
)
5149
from sagemaker.serve.utils.predictors import DjlLocalModePredictor
52-
from sagemaker.serve.utils.types import ModelServer, _DjlEngine
50+
from sagemaker.serve.utils.types import ModelServer
5351
from sagemaker.serve.mode.function_pointers import Mode
5452
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
55-
from sagemaker.djl_inference.model import (
56-
DeepSpeedModel,
57-
FasterTransformerModel,
58-
HuggingFaceAccelerateModel,
59-
)
53+
from sagemaker.djl_inference.model import DJLModel
6054
from sagemaker.base_predictor import PredictorBase
6155

6256
logger = logging.getLogger(__name__)
6357

6458
# Match JumpStart DJL entrypoint format
65-
_DJL_MODEL_BUILDER_ENTRY_POINT = "inference.py"
6659
_CODE_FOLDER = "code"
6760
_INVALID_SAMPLE_DATA_EX = (
6861
'For djl-serving, sample input must be of {"inputs": str, "parameters": dict}, '
@@ -88,14 +81,11 @@ def __init__(self):
8881
self.vpc_config = None
8982
self._original_deploy = None
9083
self.secret_key = None
91-
self.engine = None
9284
self.hf_model_config = None
9385
self._default_tensor_parallel_degree = None
9486
self._default_data_type = None
9587
self._default_max_tokens = None
96-
self._default_max_new_tokens = None
9788
self.pysdk_model = None
98-
self.overwrite_props_from_file = None
9989
self.schema_builder = None
10090
self.env_vars = None
10191
self.nb_instance_type = None
@@ -117,6 +107,7 @@ def _validate_djl_serving_sample_data(self):
117107
"""Placeholder docstring"""
118108
sample_input = self.schema_builder.sample_input
119109
sample_output = self.schema_builder.sample_output
110+
logger.info(f"sample input is {sample_input}, sample output is {sample_output}")
120111

121112
if ( # pylint: disable=R0916
122113
not isinstance(sample_input, dict)
@@ -130,37 +121,15 @@ def _validate_djl_serving_sample_data(self):
130121

131122
def _create_djl_model(self) -> Type[Model]:
132123
"""Placeholder docstring"""
133-
code_dir = str(Path(self.model_path).joinpath(_CODE_FOLDER))
134-
135-
kwargs = {
136-
"model_id": self.model,
137-
"role": self.serve_settings.role_arn,
138-
"entry_point": _DJL_MODEL_BUILDER_ENTRY_POINT,
139-
"dtype": self._default_data_type,
140-
"sagemaker_session": self.sagemaker_session,
141-
"source_dir": code_dir,
142-
"env": self.env_vars,
143-
"hf_hub_token": self.env_vars.get("HUGGING_FACE_HUB_TOKEN"),
144-
"image_config": self.image_config,
145-
"vpc_config": self.vpc_config,
146-
}
147-
148-
if self.engine == _DjlEngine.DEEPSPEED:
149-
pysdk_model = DeepSpeedModel(
150-
tensor_parallel_degree=self._default_tensor_parallel_degree,
151-
max_tokens=self._default_max_tokens,
152-
**kwargs,
153-
)
154-
elif self.engine == _DjlEngine.FASTER_TRANSFORMER:
155-
pysdk_model = FasterTransformerModel(
156-
tensor_parallel_degree=self._default_tensor_parallel_degree,
157-
**kwargs,
158-
)
159-
else:
160-
pysdk_model = HuggingFaceAccelerateModel(
161-
number_of_partitions=self._default_tensor_parallel_degree,
162-
**kwargs,
163-
)
124+
pysdk_model = DJLModel(
125+
model_id=self.model,
126+
role=self.serve_settings.role_arn,
127+
sagemaker_session=self.sagemaker_session,
128+
env=self.env_vars,
129+
huggingface_hub_token=self.env_vars.get("HF_TOKEN"),
130+
image_config=self.image_config,
131+
vpc_config=self.vpc_config,
132+
)
164133

165134
if not self.image_uri:
166135
self.image_uri = pysdk_model.serving_image_uri(self.sagemaker_session.boto_region_name)
@@ -196,7 +165,6 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
196165
else:
197166
raise ValueError("Mode %s is not supported!" % overwrite_mode)
198167

199-
manual_set_props = None
200168
if self.mode == Mode.SAGEMAKER_ENDPOINT:
201169
if self.nb_instance_type and "instance_type" not in kwargs:
202170
kwargs.update({"instance_type": self.nb_instance_type})
@@ -212,17 +180,9 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
212180
default_tensor_parallel_degree = _get_default_tensor_parallel_degree(
213181
self.hf_model_config, tot_gpus
214182
)
215-
manual_set_props = {
216-
"option.tensor_parallel_degree": str(default_tensor_parallel_degree) + "\n"
217-
}
218-
219-
prepare_for_djl_serving(
220-
model_path=self.model_path,
221-
model=self.pysdk_model,
222-
dependencies=self.dependencies,
223-
overwrite_props_from_file=self.overwrite_props_from_file,
224-
manual_set_props=manual_set_props,
225-
)
183+
self.pysdk_model.env.update(
184+
{"TENSOR_PARALLEL_DEGREE": str(default_tensor_parallel_degree)}
185+
)
226186

227187
serializer = self.schema_builder.input_serializer
228188
deserializer = self.schema_builder._output_deserializer
@@ -239,7 +199,7 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
239199
timeout if timeout else 1800,
240200
self.secret_key,
241201
predictor,
242-
self.env_vars,
202+
self.pysdk_model.env,
243203
)
244204
ram_usage_after = _get_ram_usage_mb()
245205

@@ -281,25 +241,22 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
281241

282242
def _build_for_hf_djl(self):
283243
"""Placeholder docstring"""
284-
self.overwrite_props_from_file = True
285244
self.nb_instance_type = _get_nb_instance()
286245

287246
_create_dir_structure(self.model_path)
288-
self.engine, self.hf_model_config = _auto_detect_engine(
289-
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
290-
)
291-
292247
if not hasattr(self, "pysdk_model"):
293-
(
294-
self._default_tensor_parallel_degree,
295-
self._default_data_type,
296-
_,
297-
self._default_max_tokens,
298-
self._default_max_new_tokens,
299-
) = _set_serve_properties(self.hf_model_config, self.schema_builder)
248+
self.env_vars.update({"HF_MODEL_ID": self.model})
249+
self.hf_model_config = _get_model_config_properties_from_hf(
250+
self.model, self.env_vars.get("HF_TOKEN")
251+
)
252+
default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations(
253+
self.model, self.hf_model_config, self.schema_builder
254+
)
255+
self.env_vars.update(default_djl_configurations)
300256
self.schema_builder.sample_input["parameters"][
301257
"max_new_tokens"
302-
] = self._default_max_new_tokens
258+
] = _default_max_new_tokens
259+
logger.info(f"env vars are {self.env_vars}")
303260
self.pysdk_model = self._create_djl_model()
304261

305262
if self.mode == Mode.LOCAL_CONTAINER:
@@ -316,8 +273,6 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800):
316273
)
317274
return self.pysdk_model
318275

319-
self.overwrite_props_from_file = False
320-
321276
admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees(
322277
self.hf_model_config
323278
)
@@ -337,8 +292,9 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800):
337292
"Trying tensor parallel degree: %s, dtype: %s...", tensor_parallel_degree, dtype
338293
)
339294

340-
self._default_tensor_parallel_degree = tensor_parallel_degree
341-
self._default_data_type = dtype
295+
self.env_vars.update(
296+
{"TENSOR_PARALLEL_DEGREE": str(tensor_parallel_degree), "OPTION_DTYPE": dtype}
297+
)
342298
self.pysdk_model = self._create_djl_model()
343299

344300
try:
@@ -353,15 +309,15 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800):
353309
predictor, self.schema_builder.sample_input
354310
)
355311

356-
serving_properties = self.pysdk_model.generate_serving_properties()
312+
tested_env = self.pysdk_model.env.copy()
357313
logger.info(
358314
"Average latency: %s, throughput/s: %s for configuration: %s",
359315
avg_latency,
360316
throughput_per_second,
361-
serving_properties,
317+
tested_env,
362318
)
363319
benchmark_results[avg_latency] = [
364-
serving_properties,
320+
tested_env,
365321
p90,
366322
avg_tokens_per_second,
367323
throughput_per_second,
@@ -449,48 +405,43 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800):
449405
if best_tuned_combination:
450406
self._default_tensor_parallel_degree = best_tuned_combination[1]
451407
self._default_data_type = best_tuned_combination[2]
408+
self.env_vars.update(
409+
{
410+
"TENSOR_PARALLEL_DEGREE": str(self._default_tensor_parallel_degree),
411+
"OPTION_DTYPE": self._default_data_type,
412+
}
413+
)
452414
self.pysdk_model = self._create_djl_model()
453415

454416
_pretty_print_results(benchmark_results)
455417
logger.info(
456418
"Model Configuration: %s was most performant with avg latency: %s, "
457419
"p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
458420
"standard deviation of request %s",
459-
self.pysdk_model.generate_serving_properties(),
421+
self.pysdk_model.env,
460422
best_tuned_combination[0],
461423
best_tuned_combination[3],
462424
best_tuned_combination[4],
463425
best_tuned_combination[5],
464426
best_tuned_combination[6],
465427
)
466428
else:
467-
(
468-
self._default_tensor_parallel_degree,
469-
self._default_data_type,
470-
_,
471-
self._default_max_tokens,
472-
self._default_max_new_tokens,
473-
) = _set_serve_properties(self.hf_model_config, self.schema_builder)
429+
default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations(
430+
self.model, self.hf_model_config, self.schema_builder
431+
)
432+
self.env_vars.update(default_djl_configurations)
474433
self.schema_builder.sample_input["parameters"][
475434
"max_new_tokens"
476-
] = self._default_max_new_tokens
435+
] = _default_max_new_tokens
477436
self.pysdk_model = self._create_djl_model()
478437

479438
logger.debug(
480439
"Failed to gather any tuning results. "
481440
"Please inspect the stack trace emitted from live logging for more details. "
482441
"Falling back to default serving.properties: %s",
483-
self.pysdk_model.generate_serving_properties(),
442+
self.pysdk_model.env,
484443
)
485444

486-
prepare_for_djl_serving(
487-
model_path=self.model_path,
488-
model=self.pysdk_model,
489-
dependencies=self.dependencies,
490-
overwrite_props_from_file=self.overwrite_props_from_file,
491-
)
492-
self.overwrite_props_from_file = True
493-
494445
return self.pysdk_model
495446

496447
def _build_for_djl(self):

src/sagemaker/serve/builder/model_builder.py

-6
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from sagemaker import Session
2929
from sagemaker.model import Model
3030
from sagemaker.base_predictor import PredictorBase
31-
from sagemaker.djl_inference import defaults
3231
from sagemaker.serializers import NumpySerializer, TorchTensorSerializer
3332
from sagemaker.deserializers import JSONDeserializer, TorchTensorDeserializer
3433
from sagemaker.serve.builder.schema_builder import SchemaBuilder
@@ -846,11 +845,6 @@ def build( # pylint: disable=R0911
846845
return self._build_for_tei()
847846
elif self._can_fit_on_single_gpu():
848847
return self._build_for_transformers()
849-
elif (
850-
self.model in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES
851-
or self.model in defaults.FASTER_TRANSFORMER_RECOMMENDED_ARCHITECTURES
852-
):
853-
return self._build_for_djl()
854848
else:
855849
return self._build_for_transformers()
856850

src/sagemaker/serve/builder/tei_builder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from sagemaker import image_uris
2020
from sagemaker.model import Model
21-
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf
21+
from sagemaker.serve.utils.hf_utils import _get_model_config_properties_from_hf
2222

2323
from sagemaker.huggingface import HuggingFaceModel
2424
from sagemaker.serve.utils.local_hardware import (

src/sagemaker/serve/builder/tgi_builder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
_more_performant,
3232
_pretty_print_results_tgi,
3333
)
34-
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf
34+
from sagemaker.serve.utils.hf_utils import _get_model_config_properties_from_hf
3535
from sagemaker.serve.model_server.djl_serving.utils import (
3636
_get_admissible_tensor_parallel_degrees,
3737
_get_default_tensor_parallel_degree,

src/sagemaker/serve/builder/transformers_builder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sagemaker.serve.utils.local_hardware import (
2323
_get_nb_instance,
2424
)
25-
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf
25+
from sagemaker.serve.utils.hf_utils import _get_model_config_properties_from_hf
2626
from sagemaker.huggingface import HuggingFaceModel
2727
from sagemaker.serve.model_server.multi_model_server.prepare import (
2828
_create_dir_structure,

0 commit comments

Comments
 (0)