Skip to content

Commit 26c8696

Browse files
makungaj1Jonathan Makunga
and
Jonathan Makunga
authored
Optimize support for hf models (aws#1499)
* HF support * refactoring * Refactoring * Refactoing * HF Refactoring * Refactoring * UT * Fix UT * Resolving PR comments * HF Token * Resolving PR comments * Fix UT * Fix JS ModelServer deploy wrapper override * Fix tests * fix UT * Resolve PR comments * fix doc --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 15e26c4 commit 26c8696

File tree

23 files changed

+804
-225
lines changed

23 files changed

+804
-225
lines changed

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@ accelerate>=0.24.1,<=0.27.0
3838
schema==0.7.5
3939
tensorflow>=2.1,<=2.16
4040
mlflow>=2.12.2,<2.13
41+
huggingface_hub>=0.23.4

src/sagemaker/huggingface/llm_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
"""Functions for generating ECR image URIs for pre-built SageMaker Docker images."""
1414
from __future__ import absolute_import
1515

16+
import os
1617
from typing import Optional
18+
import importlib.util
1719

1820
import urllib.request
1921
from urllib.error import HTTPError, URLError
@@ -123,3 +125,26 @@ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] =
123125
"Did not find model metadata for the following HuggingFace Model ID %s" % model_id
124126
)
125127
return hf_model_metadata_json
128+
129+
130+
def download_huggingface_model_metadata(
131+
model_id: str, model_local_path: str, hf_hub_token: Optional[str] = None
132+
) -> None:
133+
"""Downloads the HuggingFace Model snapshot via HuggingFace API.
134+
135+
Args:
136+
model_id (str): The HuggingFace Model ID
137+
model_local_path (str): The local path to save the HuggingFace Model snapshot.
138+
hf_hub_token (str): The HuggingFace Hub Token
139+
140+
Raises:
141+
ImportError: If huggingface_hub is not installed.
142+
"""
143+
if not importlib.util.find_spec("huggingface_hub"):
144+
raise ImportError("Unable to import huggingface_hub, check if huggingface_hub is installed")
145+
146+
from huggingface_hub import snapshot_download
147+
148+
os.makedirs(model_local_path, exist_ok=True)
149+
logger.info("Downloading model %s from Hugging Face Hub to %s", model_id, model_local_path)
150+
snapshot_download(repo_id=model_id, local_dir=model_local_path, token=hf_hub_token)

src/sagemaker/serve/builder/djl_builder.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
LocalModelOutOfMemoryException,
2525
LocalModelInvocationException,
2626
)
27+
from sagemaker.serve.utils.optimize_utils import _is_optimized
2728
from sagemaker.serve.utils.tuning import (
2829
_serial_benchmark,
2930
_concurrent_benchmark,
@@ -214,9 +215,10 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
214215
del kwargs["role"]
215216

216217
# set model_data to uncompressed s3 dict
217-
self.pysdk_model.model_data, env_vars = self._prepare_for_mode()
218-
self.env_vars.update(env_vars)
219-
self.pysdk_model.env.update(self.env_vars)
218+
if not _is_optimized(self.pysdk_model):
219+
self.pysdk_model.model_data, env_vars = self._prepare_for_mode()
220+
self.env_vars.update(env_vars)
221+
self.pysdk_model.env.update(self.env_vars)
220222

221223
# if the weights have been cached via local container mode -> set to offline
222224
if str(Mode.LOCAL_CONTAINER) in self.modes:

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 33 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@
4242
_update_environment_variables,
4343
_extract_speculative_draft_model_provider,
4444
_is_image_compatible_with_optimization_job,
45-
_extracts_and_validates_speculative_model_source,
4645
_generate_channel_name,
47-
_generate_additional_model_data_sources,
48-
_is_s3_uri,
46+
_extract_optimization_config_and_env,
47+
_is_optimized,
48+
_custom_speculative_decoding,
49+
SPECULATIVE_DRAFT_MODEL,
4950
)
5051
from sagemaker.serve.utils.predictors import (
5152
DjlLocalModePredictor,
@@ -121,7 +122,7 @@ def __init__(self):
121122
self.speculative_decoding_draft_model_source = None
122123

123124
@abstractmethod
124-
def _prepare_for_mode(self):
125+
def _prepare_for_mode(self, **kwargs):
125126
"""Placeholder docstring"""
126127

127128
@abstractmethod
@@ -130,6 +131,9 @@ def _get_client_translators(self):
130131

131132
def _is_jumpstart_model_id(self) -> bool:
132133
"""Placeholder docstring"""
134+
if self.model is None:
135+
return False
136+
133137
try:
134138
model_uris.retrieve(model_id=self.model, model_version="*", model_scope=_JS_SCOPE)
135139
except KeyError:
@@ -141,8 +145,9 @@ def _is_jumpstart_model_id(self) -> bool:
141145

142146
def _create_pre_trained_js_model(self) -> Type[Model]:
143147
"""Placeholder docstring"""
144-
pysdk_model = JumpStartModel(self.model, vpc_config=self.vpc_config)
145-
pysdk_model.sagemaker_session = self.sagemaker_session
148+
pysdk_model = JumpStartModel(
149+
self.model, vpc_config=self.vpc_config, sagemaker_session=self.sagemaker_session
150+
)
146151

147152
self._original_deploy = pysdk_model.deploy
148153
pysdk_model.deploy = self._js_builder_deploy_wrapper
@@ -151,6 +156,7 @@ def _create_pre_trained_js_model(self) -> Type[Model]:
151156
@_capture_telemetry("jumpstart.deploy")
152157
def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
153158
"""Placeholder docstring"""
159+
env = {}
154160
if "mode" in kwargs and kwargs.get("mode") != self.mode:
155161
overwrite_mode = kwargs.get("mode")
156162
# mode overwritten by customer during model.deploy()
@@ -167,7 +173,8 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
167173
or not hasattr(self, "prepared_for_tgi")
168174
or not hasattr(self, "prepared_for_mms")
169175
):
170-
self.pysdk_model.model_data, env = self._prepare_for_mode()
176+
if not _is_optimized(self.pysdk_model):
177+
self.pysdk_model.model_data, env = self._prepare_for_mode()
171178
elif overwrite_mode == Mode.LOCAL_CONTAINER:
172179
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
173180

@@ -198,7 +205,6 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
198205
)
199206

200207
self._prepare_for_mode()
201-
env = {}
202208
else:
203209
raise ValueError("Mode %s is not supported!" % overwrite_mode)
204210

@@ -726,25 +732,17 @@ def _optimize_for_jumpstart(
726732
)
727733

728734
model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula)
729-
730-
optimization_config = {}
731-
if quantization_config:
732-
optimization_config["ModelQuantizationConfig"] = quantization_config
733-
pysdk_model_env_vars = _update_environment_variables(
734-
pysdk_model_env_vars, quantization_config["OverrideEnvironment"]
735-
)
736-
if compilation_config:
737-
optimization_config["ModelCompilationConfig"] = compilation_config
738-
pysdk_model_env_vars = _update_environment_variables(
739-
pysdk_model_env_vars, compilation_config["OverrideEnvironment"]
740-
)
735+
optimization_config, env = _extract_optimization_config_and_env(
736+
quantization_config, compilation_config
737+
)
738+
pysdk_model_env_vars = _update_environment_variables(pysdk_model_env_vars, env)
741739

742740
output_config = {"S3OutputLocation": output_path}
743741
if kms_key:
744742
output_config["KmsKeyId"] = kms_key
745743
if not instance_type:
746-
instance_type = self.pysdk_model.deployment_config.get("DeploymentArgs").get(
747-
"InstanceType"
744+
instance_type = self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get(
745+
"InstanceType", _get_nb_instance()
748746
)
749747

750748
create_optimization_job_args = {
@@ -771,6 +769,10 @@ def _optimize_for_jumpstart(
771769
self.pysdk_model.env.update(pysdk_model_env_vars)
772770
if accept_eula:
773771
self.pysdk_model.accept_eula = accept_eula
772+
if isinstance(self.pysdk_model.model_data, dict):
773+
self.pysdk_model.model_data["S3DataSource"]["ModelAccessConfig"] = {
774+
"AcceptEula": True
775+
}
774776

775777
if quantization_config or compilation_config:
776778
return create_optimization_job_args
@@ -806,7 +808,6 @@ def _set_additional_model_source(
806808
if speculative_decoding_config:
807809
model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config)
808810
channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources)
809-
speculative_draft_model = f"/opt/ml/additional-model-data-sources/{channel_name}"
810811

811812
if model_provider == "sagemaker":
812813
additional_model_data_sources = self.pysdk_model.deployment_config.get(
@@ -825,32 +826,18 @@ def _set_additional_model_source(
825826
raise ValueError(
826827
"Cannot find deployment config compatible for optimization job."
827828
)
829+
830+
self.pysdk_model.env.update(
831+
{"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}"}
832+
)
833+
self.pysdk_model.add_tags(
834+
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"},
835+
)
828836
else:
829-
model_source = _extracts_and_validates_speculative_model_source(
830-
speculative_decoding_config
837+
self.pysdk_model = _custom_speculative_decoding(
838+
self.pysdk_model, speculative_decoding_config, accept_eula
831839
)
832840

833-
if _is_s3_uri(model_source):
834-
self.pysdk_model.additional_model_data_sources = (
835-
_generate_additional_model_data_sources(
836-
model_source, channel_name, accept_eula
837-
)
838-
)
839-
else:
840-
speculative_draft_model = model_source
841-
842-
self.pysdk_model.env = _update_environment_variables(
843-
self.pysdk_model.env,
844-
{"OPTION_SPECULATIVE_DRAFT_MODEL": speculative_draft_model},
845-
)
846-
self.pysdk_model.add_tags(
847-
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": model_provider},
848-
)
849-
if accept_eula and isinstance(self.pysdk_model.model_data, dict):
850-
self.pysdk_model.model_data["S3DataSource"]["ModelAccessConfig"] = {
851-
"AcceptEula": True
852-
}
853-
854841
def _find_compatible_deployment_config(
855842
self, speculative_decoding_config: Optional[Dict] = None
856843
) -> Optional[Dict[str, Any]]:

0 commit comments

Comments
 (0)