Skip to content

Commit f55e3c9

Browse files
makungaj1Jonathan Makunga
and
Jonathan Makunga
authored
Feat: Add optimize to ModelBuilder JS (aws#1474)
* QS JS vanilla model * Use Alt config for Optimization * JS Optimize * Resolve config * inject additional tags * Inject tags * Refactoring * Refactoring * Filter Deployment config * Refactoring * Refactoring * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 3c7b966 commit f55e3c9

File tree

9 files changed

+568
-68
lines changed

9 files changed

+568
-68
lines changed

src/sagemaker/enums.py

+6
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,9 @@ class RoutingStrategy(Enum):
4040
"""The endpoint routes requests to the specific instances that have
4141
more capacity to process them.
4242
"""
43+
44+
45+
class Tag(str, Enum):
46+
"""Enum class for tag keys to apply to models."""
47+
48+
OPTIMIZATION_JOB_NAME = "sagemaker-sdk:optimization-job-name"

src/sagemaker/jumpstart/utils.py

+17
Original file line numberDiff line numberDiff line change
@@ -1336,3 +1336,20 @@ def wrapped_f(*args, **kwargs):
13361336
if _func is None:
13371337
return wrapper_cache
13381338
return wrapper_cache(_func)
1339+
1340+
1341+
def _extract_image_tag_and_version(image_uri: str) -> Tuple[Optional[str], Optional[str]]:
1342+
"""Extract Image tag and version from image URI.
1343+
1344+
Args:
1345+
image_uri (str): Image URI.
1346+
1347+
Returns:
1348+
Tuple[Optional[str], Optional[str]]: The tag and version of the image.
1349+
"""
1350+
if image_uri is None:
1351+
return None, None
1352+
1353+
tag = image_uri.split(":")[-1]
1354+
1355+
return tag, tag.split("-")[0]

src/sagemaker/model.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,18 @@ def __init__(
404404
self.content_types = None
405405
self.response_types = None
406406
self.accept_eula = None
407+
self._tags: Optional[Tags] = None
408+
409+
def add_tags(self, tags: Tags) -> None:
410+
"""Add tags to this ``Model``
411+
412+
Args:
413+
tags (Tags): Tags to add.
414+
"""
415+
if self._tags and tags:
416+
self._tags.update(tags)
417+
else:
418+
self._tags = tags
407419

408420
@runnable_by_pipeline
409421
def register(
@@ -1457,7 +1469,8 @@ def deploy(
14571469
sagemaker_session=self.sagemaker_session,
14581470
)
14591471

1460-
tags = format_tags(tags)
1472+
self.add_tags(tags)
1473+
tags = format_tags(self._tags)
14611474

14621475
if (
14631476
getattr(self.sagemaker_session, "settings", None) is not None

src/sagemaker/serve/builder/jumpstart_builder.py

+150
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from typing import Type, Any, List, Dict, Optional
2020
import logging
2121

22+
from sagemaker.jumpstart import enums
23+
from sagemaker.jumpstart.utils import verify_model_region_and_return_specs, get_eula_message
2224
from sagemaker.model import Model
2325
from sagemaker import model_uris
2426
from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources
@@ -33,6 +35,11 @@
3335
LocalModelLoadException,
3436
SkipTuningComboException,
3537
)
38+
from sagemaker.serve.utils.optimize_utils import (
39+
_extract_supported_deployment_config,
40+
_is_speculation_enabled,
41+
_is_compatible_with_optimization_job,
42+
)
3643
from sagemaker.serve.utils.predictors import (
3744
DjlLocalModePredictor,
3845
TgiLocalModePredictor,
@@ -53,6 +60,7 @@
5360
from sagemaker.serve.utils.types import ModelServer
5461
from sagemaker.base_predictor import PredictorBase
5562
from sagemaker.jumpstart.model import JumpStartModel
63+
from sagemaker.utils import Tags
5664

5765
_DJL_MODEL_BUILDER_ENTRY_POINT = "inference.py"
5866
_NO_JS_MODEL_EX = "HuggingFace JumpStart Model ID not detected. Building for HuggingFace Model ID."
@@ -564,6 +572,148 @@ def _build_for_jumpstart(self):
564572

565573
return self.pysdk_model
566574

575+
def _optimize_for_jumpstart(
576+
self,
577+
output_path: str,
578+
instance_type: Optional[str] = None,
579+
role: Optional[str] = None,
580+
tags: Optional[Tags] = None,
581+
job_name: Optional[str] = None,
582+
accept_eula: Optional[bool] = None,
583+
quantization_config: Optional[Dict] = None,
584+
compilation_config: Optional[Dict] = None,
585+
speculative_decoding_config: Optional[Dict] = None,
586+
env_vars: Optional[Dict] = None,
587+
vpc_config: Optional[Dict] = None,
588+
kms_key: Optional[str] = None,
589+
max_runtime_in_sec: Optional[int] = None,
590+
) -> None:
591+
"""Runs a model optimization job.
592+
593+
Args:
594+
output_path (str): Specifies where to store the compiled/quantized model.
595+
instance_type (Optional[str]): Target deployment instance type that
596+
the model is optimized for.
597+
role (Optional[str]): Execution role. Defaults to ``None``.
598+
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
599+
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
600+
accept_eula (bool): For models that require a Model Access Config, specify True or
601+
False to indicate whether model terms of use have been accepted.
602+
The `accept_eula` value must be explicitly defined as `True` in order to
603+
accept the end-user license agreement (EULA) that some
604+
models require. (Default: None).
605+
quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``.
606+
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
607+
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
608+
Defaults to ``None``
609+
env_vars (Optional[Dict]): Additional environment variables to run the optimization
610+
container. Defaults to ``None``.
611+
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
612+
kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading
613+
to S3. Defaults to ``None``.
614+
max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to
615+
``None``.
616+
"""
617+
model_specs = verify_model_region_and_return_specs(
618+
region=self.sagemaker_session.boto_region_name,
619+
model_id=self.pysdk_model.model_id,
620+
version=self.pysdk_model.model_version,
621+
sagemaker_session=self.sagemaker_session,
622+
scope=enums.JumpStartScriptScope.INFERENCE,
623+
model_type=self.pysdk_model.model_type,
624+
)
625+
626+
if model_specs.is_gated_model() and accept_eula is not True:
627+
raise ValueError(get_eula_message(model_specs, self.sagemaker_session.boto_region_name))
628+
629+
if not (self.pysdk_model.model_data and self.pysdk_model.model_data.get("S3DataSource")):
630+
raise ValueError("Model Optimization Job only supports model backed by S3.")
631+
632+
has_alternative_config = self.pysdk_model.deployment_config is not None
633+
merged_env_vars = None
634+
# TODO: Match Optimization Input Schema
635+
model_source = {
636+
"S3": {"S3Uri": self.pysdk_model.model_data.get("S3DataSource").get("S3Uri")},
637+
"SageMakerModel": {"ModelName": self.model},
638+
}
639+
640+
if has_alternative_config:
641+
image_uri = self.pysdk_model.deployment_config.get("DeploymentArgs").get("ImageUri")
642+
instance_type = self.pysdk_model.deployment_config.get("InstanceType")
643+
else:
644+
image_uri = self.pysdk_model.image_uri
645+
646+
if not _is_compatible_with_optimization_job(instance_type, image_uri) or (
647+
speculative_decoding_config
648+
and not _is_speculation_enabled(self.pysdk_model.deployment_config)
649+
):
650+
deployment_config = _extract_supported_deployment_config(
651+
self.pysdk_model.list_deployment_configs(), speculative_decoding_config is None
652+
)
653+
654+
if deployment_config:
655+
self.pysdk_model.set_deployment_config(
656+
config_name=deployment_config.get("DeploymentConfigName"),
657+
instance_type=deployment_config.get("InstanceType"),
658+
)
659+
merged_env_vars = self.pysdk_model.deployment_config.get("Environment")
660+
661+
if speculative_decoding_config:
662+
# TODO: Match Optimization Input Schema
663+
s3 = {
664+
"S3Uri": self.pysdk_model.additional_model_data_sources[
665+
"SpeculativeDecoding"
666+
][0]["S3DataSource"]["S3Uri"]
667+
}
668+
model_source["S3"].update(s3)
669+
elif speculative_decoding_config:
670+
raise ValueError("Can't find deployment config for model optimization job.")
671+
672+
optimization_config = {}
673+
if env_vars:
674+
if merged_env_vars:
675+
merged_env_vars.update(env_vars)
676+
else:
677+
merged_env_vars = env_vars
678+
if quantization_config:
679+
optimization_config["ModelQuantizationConfig"] = quantization_config
680+
if compilation_config:
681+
optimization_config["ModelCompilationConfig"] = compilation_config
682+
683+
if accept_eula:
684+
self.pysdk_model.accept_eula = accept_eula
685+
self.pysdk_model.model_data["S3DataSource"].update(
686+
{"ModelAccessConfig": {"AcceptEula": accept_eula}}
687+
)
688+
model_source["S3"].update({"ModelAccessConfig": {"AcceptEula": accept_eula}})
689+
690+
output_config = {"S3OutputLocation": output_path}
691+
if kms_key:
692+
output_config["KmsKeyId"] = kms_key
693+
694+
create_optimization_job_args = {
695+
"OptimizationJobName": job_name,
696+
"ModelSource": model_source,
697+
"DeploymentInstanceType": instance_type,
698+
"Environment": merged_env_vars,
699+
"OptimizationConfigs": [optimization_config],
700+
"OutputConfig": output_config,
701+
"RoleArn": role,
702+
}
703+
704+
if max_runtime_in_sec:
705+
create_optimization_job_args["StoppingCondition"] = {
706+
"MaxRuntimeInSeconds": max_runtime_in_sec
707+
}
708+
if tags:
709+
create_optimization_job_args["Tags"] = tags
710+
if vpc_config:
711+
create_optimization_job_args["VpcConfig"] = vpc_config
712+
713+
self.sagemaker_session.sagemaker_client.create_optimization_job(
714+
**create_optimization_job_args
715+
)
716+
567717
def _is_gated_model(self, model) -> bool:
568718
"""Determine if ``this`` Model is Gated
569719

src/sagemaker/serve/builder/model_builder.py

+42-54
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,7 @@
6262
from sagemaker.serve.utils import task
6363
from sagemaker.serve.utils.exceptions import TaskNotFoundException
6464
from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model
65-
from sagemaker.serve.utils.optimize_utils import (
66-
_is_compatible_with_compilation,
67-
_poll_optimization_job,
68-
)
65+
from sagemaker.serve.utils.optimize_utils import _poll_optimization_job, _generate_optimized_model
6966
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
7067
from sagemaker.serve.utils.hardware_detector import (
7168
_get_gpu_info,
@@ -961,13 +958,15 @@ def optimize(self, *args, **kwargs) -> Type[Model]:
961958
@_capture_telemetry("optimize")
962959
def _model_builder_optimize_wrapper(
963960
self,
964-
instance_type: str,
965961
output_path: str,
962+
instance_type: Optional[str] = None,
966963
role: Optional[str] = None,
967964
tags: Optional[Tags] = None,
968965
job_name: Optional[str] = None,
966+
accept_eula: Optional[bool] = None,
969967
quantization_config: Optional[Dict] = None,
970968
compilation_config: Optional[Dict] = None,
969+
speculative_decoding_config: Optional[Dict] = None,
971970
env_vars: Optional[Dict] = None,
972971
vpc_config: Optional[Dict] = None,
973972
kms_key: Optional[str] = None,
@@ -977,13 +976,20 @@ def _model_builder_optimize_wrapper(
977976
"""Runs a model optimization job.
978977
979978
Args:
980-
instance_type (str): Target deployment instance type that the model is optimized for.
981979
output_path (str): Specifies where to store the compiled/quantized model.
980+
instance_type (str): Target deployment instance type that the model is optimized for.
982981
role (Optional[str]): Execution role. Defaults to ``None``.
983982
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
984983
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
984+
accept_eula (bool): For models that require a Model Access Config, specify True or
985+
False to indicate whether model terms of use have been accepted.
986+
The `accept_eula` value must be explicitly defined as `True` in order to
987+
accept the end-user license agreement (EULA) that some
988+
models require. (Default: None).
985989
quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``.
986990
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
991+
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
992+
Defaults to ``None``
987993
env_vars (Optional[Dict]): Additional environment variables to run the optimization
988994
container. Defaults to ``None``.
989995
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -999,57 +1005,39 @@ def _model_builder_optimize_wrapper(
9991005
Type[Model]: A deployable ``Model`` object.
10001006
"""
10011007
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
1008+
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
1009+
job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}"
10021010

1003-
# TODO: inject actual model source location based on different scenarios
1004-
model_source = {"S3": {"S3Uri": self.model_path, "ModelAccessConfig": {"AcceptEula": True}}}
1005-
1006-
optimization_configs = []
1007-
if quantization_config:
1008-
optimization_configs.append({"ModelQuantizationConfig": quantization_config})
1009-
if compilation_config:
1010-
if _is_compatible_with_compilation(instance_type):
1011-
optimization_configs.append({"ModelCompilationConfig": compilation_config})
1012-
else:
1013-
logger.warning(
1014-
"Model compilation is currently only supported for Inferentia and Trainium"
1015-
"instances, ignoring `compilation_config'."
1016-
)
1011+
if self._is_jumpstart_model_id():
1012+
self._optimize_for_jumpstart(
1013+
output_path=output_path,
1014+
instance_type=instance_type,
1015+
role=role if role else self.role_arn,
1016+
tags=tags,
1017+
job_name=job_name,
1018+
accept_eula=accept_eula,
1019+
quantization_config=quantization_config,
1020+
compilation_config=compilation_config,
1021+
speculative_decoding_config=speculative_decoding_config,
1022+
env_vars=env_vars,
1023+
vpc_config=vpc_config,
1024+
kms_key=kms_key,
1025+
max_runtime_in_sec=max_runtime_in_sec,
1026+
)
10171027

1018-
output_config = {"S3OutputLocation": output_path}
1019-
if kms_key:
1020-
output_config["KmsKeyId"] = kms_key
1028+
# TODO: use the wait for job pattern similar to
1029+
# https://quip-amazon.com/TKaPAhJck5sD/PySDK-Model-Optimization#temp:C:YcX3f2b103dabb4431090568bca2
1030+
if not _poll_optimization_job(job_name, self.sagemaker_session):
1031+
raise Exception("Optimization job timed out.")
10211032

1022-
job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}"
1023-
create_optimization_job_args = {
1024-
"OptimizationJobName": job_name,
1025-
"ModelSource": model_source,
1026-
"DeploymentInstanceType": instance_type,
1027-
"OptimizationConfigs": optimization_configs,
1028-
"OutputConfig": output_config,
1029-
"RoleArn": role or self.role_arn,
1030-
}
1031-
1032-
if env_vars:
1033-
create_optimization_job_args["OptimizationEnvironment"] = env_vars
1034-
1035-
if max_runtime_in_sec:
1036-
create_optimization_job_args["StoppingCondition"] = {
1037-
"MaxRuntimeInSeconds": max_runtime_in_sec
1038-
}
1039-
1040-
# TODO: tag injection if it is a JumpStart model
1041-
if tags:
1042-
create_optimization_job_args["Tags"] = tags
1043-
1044-
if vpc_config:
1045-
create_optimization_job_args["VpcConfig"] = vpc_config
1046-
1047-
response = self.sagemaker_session.sagemaker_client.create_optimization_job(
1048-
**create_optimization_job_args
1033+
describe_optimization_job_res = (
1034+
self.sagemaker_session.sagemaker_client.describe_optimization_job(
1035+
OptimizationJobName=job_name
1036+
)
10491037
)
10501038

1051-
if not _poll_optimization_job(job_name, self.sagemaker_session):
1052-
raise Exception("Optimization job timed out.")
1039+
self.pysdk_model = _generate_optimized_model(
1040+
self.pysdk_model, describe_optimization_job_res
1041+
)
10531042

1054-
# TODO: return model created by optimization job
1055-
return response
1043+
return self.pysdk_model

0 commit comments

Comments
 (0)