Skip to content

Commit f36644e

Browse files
gwang111knikure
authored andcommitted
fix: Address SA feedback regarding deployment straight to Endpoint Mode - Galactus (#1405)
1 parent c5c8f3f commit f36644e

File tree

8 files changed

+198
-33
lines changed

8 files changed

+198
-33
lines changed

src/sagemaker/serve/builder/djl_builder.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,14 @@
3636
_set_serve_properties,
3737
_get_admissible_tensor_parallel_degrees,
3838
_get_admissible_dtypes,
39+
_get_default_tensor_parallel_degree,
40+
)
41+
from sagemaker.serve.utils.local_hardware import (
42+
_get_nb_instance,
43+
_get_ram_usage_mb,
44+
_get_gpu_info,
45+
_get_gpu_info_fallback,
3946
)
40-
from sagemaker.serve.utils.local_hardware import _get_nb_instance, _get_ram_usage_mb
4147
from sagemaker.serve.model_server.djl_serving.prepare import (
4248
prepare_for_djl_serving,
4349
_create_dir_structure,
@@ -164,13 +170,6 @@ def _create_djl_model(self) -> Type[Model]:
164170
@_capture_telemetry("djl.deploy")
165171
def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
166172
"""Placeholder docstring"""
167-
prepare_for_djl_serving(
168-
model_path=self.model_path,
169-
model=self.pysdk_model,
170-
dependencies=self.dependencies,
171-
overwrite_props_from_file=self.overwrite_props_from_file,
172-
)
173-
174173
timeout = kwargs.get("model_data_download_timeout")
175174
if timeout:
176175
self.env_vars.update({"MODEL_LOADING_TIMEOUT": str(timeout)})
@@ -192,6 +191,34 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
192191
else:
193192
raise ValueError("Mode %s is not supported!" % overwrite_mode)
194193

194+
manual_set_props = None
195+
if self.mode == Mode.SAGEMAKER_ENDPOINT:
196+
if self.nb_instance_type and "instance_type" not in kwargs:
197+
kwargs.update({"instance_type": self.nb_instance_type})
198+
elif not self.nb_instance_type and "instance_type" not in kwargs:
199+
raise ValueError(
200+
"Instance type must be provided when deploying " "to SageMaker Endpoint mode."
201+
)
202+
else:
203+
try:
204+
tot_gpus = _get_gpu_info(kwargs.get("instance_type"), self.sagemaker_session)
205+
except Exception: # pylint: disable=W0703
206+
tot_gpus = _get_gpu_info_fallback(kwargs.get("instance_type"))
207+
default_tensor_parallel_degree = _get_default_tensor_parallel_degree(
208+
self.hf_model_config, tot_gpus
209+
)
210+
manual_set_props = {
211+
"option.tensor_parallel_degree": str(default_tensor_parallel_degree) + "\n"
212+
}
213+
214+
prepare_for_djl_serving(
215+
model_path=self.model_path,
216+
model=self.pysdk_model,
217+
dependencies=self.dependencies,
218+
overwrite_props_from_file=self.overwrite_props_from_file,
219+
manual_set_props=manual_set_props,
220+
)
221+
195222
serializer = self.schema_builder.input_serializer
196223
deserializer = self.schema_builder._output_deserializer
197224
if self.mode == Mode.LOCAL_CONTAINER:
@@ -237,8 +264,6 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
237264

238265
if "endpoint_logging" not in kwargs:
239266
kwargs["endpoint_logging"] = True
240-
if self.nb_instance_type and "instance_type" not in kwargs:
241-
kwargs.update({"instance_type": self.nb_instance_type})
242267

243268
predictor = self._original_deploy(*args, **kwargs)
244269

@@ -252,6 +277,7 @@ def _build_for_hf_djl(self):
252277
"""Placeholder docstring"""
253278
self.overwrite_props_from_file = True
254279
self.nb_instance_type = _get_nb_instance()
280+
255281
_create_dir_structure(self.model_path)
256282
self.engine, self.hf_model_config = _auto_detect_engine(
257283
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")

src/sagemaker/serve/builder/tgi_builder.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,21 @@
3232
_pretty_print_results_tgi,
3333
)
3434
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf
35-
from sagemaker.serve.model_server.djl_serving.utils import _get_admissible_tensor_parallel_degrees
35+
from sagemaker.serve.model_server.djl_serving.utils import (
36+
_get_admissible_tensor_parallel_degrees,
37+
_get_default_tensor_parallel_degree,
38+
)
3639
from sagemaker.serve.model_server.tgi.utils import (
3740
_get_default_tgi_configurations,
3841
_get_admissible_dtypes,
3942
)
4043
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
41-
from sagemaker.serve.utils.local_hardware import _get_nb_instance, _get_ram_usage_mb
44+
from sagemaker.serve.utils.local_hardware import (
45+
_get_nb_instance,
46+
_get_ram_usage_mb,
47+
_get_gpu_info,
48+
_get_gpu_info_fallback,
49+
)
4250
from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure
4351
from sagemaker.serve.utils.predictors import TgiLocalModePredictor
4452
from sagemaker.serve.utils.types import ModelServer
@@ -202,8 +210,26 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
202210

203211
if "endpoint_logging" not in kwargs:
204212
kwargs["endpoint_logging"] = True
213+
205214
if self.nb_instance_type and "instance_type" not in kwargs:
206215
kwargs.update({"instance_type": self.nb_instance_type})
216+
elif not self.nb_instance_type and "instance_type" not in kwargs:
217+
raise ValueError(
218+
"Instance type must be provided when deploying " "to SageMaker Endpoint mode."
219+
)
220+
else:
221+
try:
222+
tot_gpus = _get_gpu_info(kwargs.get("instance_type"), self.sagemaker_session)
223+
except Exception: # pylint: disable=W0703
224+
tot_gpus = _get_gpu_info_fallback(kwargs.get("instance_type"))
225+
default_num_shard = _get_default_tensor_parallel_degree(self.hf_model_config, tot_gpus)
226+
self.pysdk_model.env.update(
227+
{
228+
"NUM_SHARD": str(default_num_shard),
229+
"SHARDED": "true" if default_num_shard > 1 else "false",
230+
}
231+
)
232+
207233
if "initial_instance_count" not in kwargs:
208234
kwargs.update({"initial_instance_count": 1})
209235

@@ -218,6 +244,7 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
218244
def _build_for_hf_tgi(self):
219245
"""Placeholder docstring"""
220246
self.nb_instance_type = _get_nb_instance()
247+
221248
_create_dir_structure(self.model_path)
222249
if not hasattr(self, "pysdk_model"):
223250
self.env_vars.update({"HF_MODEL_ID": self.model})

src/sagemaker/serve/mode/local_container_mode.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -190,5 +190,5 @@ def _pull_image(self, image: str):
190190
try:
191191
logger.info("Pulling image %s from repository...", image)
192192
self.client.images.pull(image)
193-
except docker.errors.NotFound:
194-
logger.warning("Could not find remote image to pull")
193+
except docker.errors.NotFound as e:
194+
raise ValueError("Could not find remote image to pull") from e

src/sagemaker/serve/model_server/djl_serving/prepare.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
113
"""Prepare DjlModel for Deployment"""
214

315
from __future__ import absolute_import
@@ -55,7 +67,9 @@ def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path):
5567
return (existing_properties, hf_model_config, True)
5668

5769

58-
def _generate_properties_file(model: DJLModel, code_dir: Path, overwrite_props_from_file: bool):
70+
def _generate_properties_file(
71+
model: DJLModel, code_dir: Path, overwrite_props_from_file: bool, manual_set_props: dict
72+
):
5973
"""Placeholder Docstring"""
6074
if _has_serving_properties_file(code_dir):
6175
existing_properties = _read_existing_serving_properties(code_dir)
@@ -67,6 +81,13 @@ def _generate_properties_file(model: DJLModel, code_dir: Path, overwrite_props_f
6781

6882
with open(serving_properties_file, mode="w+") as file:
6983
covered_keys = set()
84+
85+
if manual_set_props:
86+
for key, value in manual_set_props.items():
87+
logger.info(_SETTING_PROPERTY_STMT, key, value.strip())
88+
covered_keys.add(key)
89+
file.write(f"{key}={value}")
90+
7091
for key, value in serving_properties_dict.items():
7192
if not overwrite_props_from_file:
7293
logger.info(_SETTING_PROPERTY_STMT, key, value)
@@ -129,6 +150,7 @@ def prepare_for_djl_serving(
129150
shared_libs: List[str] = None,
130151
dependencies: str = None,
131152
overwrite_props_from_file: bool = True,
153+
manual_set_props: dict = None,
132154
):
133155
"""Prepare serving when a HF model id is given
134156
@@ -149,7 +171,7 @@ def prepare_for_djl_serving(
149171

150172
_copy_inference_script(code_dir)
151173

152-
_generate_properties_file(model, code_dir, overwrite_props_from_file)
174+
_generate_properties_file(model, code_dir, overwrite_props_from_file, manual_set_props)
153175

154176

155177
def prepare_djl_js_resources(

src/sagemaker/serve/model_server/djl_serving/utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@ def _auto_detect_engine(model_id: str, hf_hub_token: str) -> tuple:
6060
return (engine, hf_model_config)
6161

6262

63-
def _get_default_tensor_parallel_degree(hf_model_config: dict) -> int:
63+
def _get_default_tensor_parallel_degree(hf_model_config: dict, gpu_count: int = None) -> int:
6464
"""Placeholder docstring"""
6565
available_gpus = _get_available_gpus()
66+
if not available_gpus and not gpu_count:
67+
return None
6668

6769
attention_heads = None
6870
for variant in ATTENTION_HEAD_NAME_VARIENTS:
@@ -73,7 +75,8 @@ def _get_default_tensor_parallel_degree(hf_model_config: dict) -> int:
7375
if not attention_heads:
7476
return 1
7577

76-
for i in (n + 1 for n in reversed(range(len(available_gpus)))):
78+
tot_gpus = len(available_gpus) if available_gpus else gpu_count
79+
for i in (n + 1 for n in reversed(range(tot_gpus))):
7780
if attention_heads % i == 0:
7881
logger.info(
7982
"Max GPU parallelism of %s is allowed. Total attention heads %s", i, attention_heads

src/sagemaker/serve/model_server/tgi/utils.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,19 @@ def _get_default_tgi_configurations(
2323
schema_builder.sample_input, schema_builder.sample_output
2424
)
2525

26+
if default_num_shard:
27+
return (
28+
{
29+
"SHARDED": "true" if default_num_shard > 1 else "false",
30+
"NUM_SHARD": str(default_num_shard),
31+
"DTYPE": _get_default_dtype(),
32+
},
33+
default_max_new_tokens,
34+
)
2635
return (
2736
{
28-
"SHARDED": "true" if default_num_shard > 1 else "false",
29-
"NUM_SHARD": str(default_num_shard),
37+
"SHARDED": None,
38+
"NUM_SHARD": None,
3039
"DTYPE": _get_default_dtype(),
3140
},
3241
default_max_new_tokens,

src/sagemaker/serve/utils/local_hardware.py

+79-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
113
"""Utilites for identifying and analyzing local gpu hardware"""
214
from __future__ import absolute_import
315

@@ -10,6 +22,8 @@
1022
from pathlib import Path
1123
import psutil
1224

25+
from sagemaker import Session
26+
1327
logger = logging.getLogger(__name__)
1428

1529
# key = vCPUs
@@ -54,6 +68,34 @@
5468
}
5569

5670

71+
fallback_gpu_resource_mapping = {
72+
"ml.p5.48xlarge": 8,
73+
"ml.p4d.24xlarge": 8,
74+
"ml.p4de.24xlarge": 8,
75+
"ml.p3.2xlarge": 1,
76+
"ml.p3.8xlarge": 4,
77+
"ml.p3.16xlarge": 8,
78+
"ml.p3dn.24xlarge": 8,
79+
"ml.p2.xlarge": 1,
80+
"ml.p2.8xlarge": 8,
81+
"ml.p2.16xlarge": 16,
82+
"ml.g4dn.xlarge": 1,
83+
"ml.g4dn.2xlarge": 1,
84+
"ml.g4dn.4xlarge": 1,
85+
"ml.g4dn.8xlarge": 1,
86+
"ml.g4dn.16xlarge": 1,
87+
"ml.g4dn.12xlarge": 4,
88+
"ml.g5n.xlarge": 1,
89+
"ml.g5.2xlarge": 1,
90+
"ml.g5.4xlarge": 1,
91+
"ml.g5.8xlarge": 1,
92+
"ml.g5.16xlarge": 1,
93+
"ml.g5.12xlarge": 4,
94+
"ml.g5.24xlarge": 4,
95+
"ml.g5.48xlarge": 8,
96+
}
97+
98+
5799
def _get_available_gpus(log=True):
58100
"""Detect the GPUs available on the device and their available resources"""
59101
try:
@@ -63,16 +105,24 @@ def _get_available_gpus(log=True):
63105

64106
if log:
65107
logger.info("CUDA enabled hardware on the device: %s", gpu_info)
66-
67108
return gpu_info
68-
except Exception as e:
109+
except Exception as e: # pylint: disable=W0703
69110
# for nvidia-smi to run, a cuda driver must be present
70-
raise ValueError("CUDA is not enabled on your device. %s" % str(e))
111+
logger.warning(
112+
"CUDA is not enabled on your device. %s. "
113+
"Please run ModelBuilder on CUDA enabled hardware "
114+
"to deploy locally.",
115+
str(e),
116+
)
117+
return None
71118

72119

73120
def _get_nb_instance():
74121
"""Placeholder docstring"""
75122
gpu_info = _get_available_gpus(False)
123+
if not gpu_info:
124+
return None
125+
76126
gpu_name, gpu_mem = gpu_info[0].split(", ")
77127
cpu_count = multiprocessing.cpu_count()
78128

@@ -156,3 +206,29 @@ def _check_docker_disk_usage():
156206
docker_path,
157207
str(e),
158208
)
209+
210+
211+
def _get_gpu_info(instance_type: str, session: Session) -> int:
212+
"""Get GPU info for the provided instance"""
213+
ec2_client = session.boto_session.client("ec2")
214+
215+
split_instance = instance_type.split(".")
216+
split_instance.pop(0)
217+
218+
ec2_instance = ".".join(split_instance)
219+
220+
instance_info = ec2_client.describe_instance_types(InstanceTypes=[ec2_instance])
221+
222+
gpus_info = instance_info.get("InstanceTypes")[0].get("GpuInfo")
223+
224+
if gpus_info:
225+
return gpus_info.get("Gpus")[0].get("Count")
226+
raise ValueError("Provided instance_type is not GPU enabled.")
227+
228+
229+
def _get_gpu_info_fallback(instance_type: str) -> int:
230+
"""Get GPU info for the provided instance fallback"""
231+
available_gpus = fallback_gpu_resource_mapping.get(instance_type)
232+
if not available_gpus:
233+
raise ValueError("Provided instance_type is not GPU enabled.")
234+
return available_gpus

src/sagemaker/serve/utils/predictors.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -137,16 +137,18 @@ def __init__(
137137

138138
def predict(self, data):
139139
"""Placeholder docstring"""
140-
return self.deserializer.deserialize(
141-
io.BytesIO(
142-
self._mode_obj._invoke_tgi_serving(
143-
self.serializer.serialize(data),
144-
self.content_type,
145-
self.deserializer.ACCEPT[0],
146-
)
147-
),
148-
self.content_type,
149-
)
140+
return [
141+
self.deserializer.deserialize(
142+
io.BytesIO(
143+
self._mode_obj._invoke_tgi_serving(
144+
self.serializer.serialize(data),
145+
self.content_type,
146+
self.deserializer.ACCEPT[0],
147+
)
148+
),
149+
self.content_type,
150+
)
151+
]
150152

151153
@property
152154
def content_type(self):

0 commit comments

Comments
 (0)