Skip to content

Commit 701b788

Browse files
makungaj1Jonathan Makunga
and
Jonathan Makunga
authored
update: Add optimize to ModelBuilder JS (aws#1480)
* Testing with Notebook * Refactoring * _poll_optimization_job refactoring * Resolve PR Comments * Refactoring * Refactoring * refactoring * Fix conflicts * Notebook testing * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 997e2ce commit 701b788

File tree

6 files changed

+252
-204
lines changed

6 files changed

+252
-204
lines changed

src/sagemaker/enums.py

+1
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,6 @@ class Tag(str, Enum):
4646
"""Enum class for tag keys to apply to models."""
4747

4848
OPTIMIZATION_JOB_NAME = "sagemaker-sdk:optimization-job-name"
49+
SPECULATIVE_DRAFT_MODL_PROVIDER = "sagemaker-sdk:speculative-draft-model-provider"
4950
FINE_TUNING_MODEL_PATH = "sagemaker-sdk:fine-tuning-model-path"
5051
FINE_TUNING_JOB_NAME = "sagemaker-sdk:fine-tuning-job-name"

src/sagemaker/serve/builder/jumpstart_builder.py

+135-71
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
from botocore.exceptions import ClientError
2424

2525
from sagemaker.enums import Tag
26-
from sagemaker.jumpstart import enums
27-
from sagemaker.jumpstart.utils import verify_model_region_and_return_specs, get_eula_message
2826
from sagemaker.model import Model
2927
from sagemaker import model_uris
3028
from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources
@@ -40,9 +38,9 @@
4038
SkipTuningComboException,
4139
)
4240
from sagemaker.serve.utils.optimize_utils import (
43-
_extract_supported_deployment_config,
44-
_is_speculation_enabled,
4541
_is_compatible_with_optimization_job,
42+
_extract_model_source,
43+
_update_environment_variables,
4644
)
4745
from sagemaker.serve.utils.predictors import (
4846
DjlLocalModePredictor,
@@ -643,7 +641,7 @@ def _optimize_for_jumpstart(
643641
vpc_config: Optional[Dict] = None,
644642
kms_key: Optional[str] = None,
645643
max_runtime_in_sec: Optional[int] = None,
646-
) -> None:
644+
) -> Dict[str, Any]:
647645
"""Runs a model optimization job.
648646
649647
Args:
@@ -669,79 +667,60 @@ def _optimize_for_jumpstart(
669667
to S3. Defaults to ``None``.
670668
max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to
671669
``None``.
672-
"""
673-
model_specs = verify_model_region_and_return_specs(
674-
region=self.sagemaker_session.boto_region_name,
675-
model_id=self.pysdk_model.model_id,
676-
version=self.pysdk_model.model_version,
677-
sagemaker_session=self.sagemaker_session,
678-
scope=enums.JumpStartScriptScope.INFERENCE,
679-
model_type=self.pysdk_model.model_type,
680-
)
681670
682-
if model_specs.is_gated_model() and accept_eula is not True:
683-
raise ValueError(get_eula_message(model_specs, self.sagemaker_session.boto_region_name))
684-
685-
if not (self.pysdk_model.model_data and self.pysdk_model.model_data.get("S3DataSource")):
686-
raise ValueError("Model Optimization Job only supports model backed by S3.")
671+
Returns:
672+
Dict[str, Any]: Model optimization job input arguments.
673+
"""
674+
if self._is_gated_model() and accept_eula is not True:
675+
raise ValueError(
676+
f"ValueError: Model '{self.model}' "
677+
f"requires accepting end-user license agreement (EULA)."
678+
)
687679

688-
has_alternative_config = self.pysdk_model.deployment_config is not None
689-
merged_env_vars = None
690-
# TODO: Match Optimization Input Schema
691-
model_source = {
692-
"S3": {"S3Uri": self.pysdk_model.model_data.get("S3DataSource").get("S3Uri")},
693-
"SageMakerModel": {"ModelName": self.model},
694-
}
680+
optimization_env_vars = None
681+
pysdk_model_env_vars = None
682+
model_source = _extract_model_source(self.pysdk_model.model_data, accept_eula)
695683

696-
if has_alternative_config:
697-
image_uri = self.pysdk_model.deployment_config.get("DeploymentArgs").get("ImageUri")
698-
instance_type = self.pysdk_model.deployment_config.get("InstanceType")
684+
if speculative_decoding_config:
685+
self._set_additional_model_source(speculative_decoding_config)
686+
optimization_env_vars = self.pysdk_model.deployment_config.get("DeploymentArgs").get(
687+
"Environment"
688+
)
699689
else:
700-
image_uri = self.pysdk_model.image_uri
701-
702-
if not _is_compatible_with_optimization_job(instance_type, image_uri) or (
703-
speculative_decoding_config
704-
and not _is_speculation_enabled(self.pysdk_model.deployment_config)
705-
):
706-
deployment_config = _extract_supported_deployment_config(
707-
self.pysdk_model.list_deployment_configs(), speculative_decoding_config is None
690+
image_uri = None
691+
if quantization_config and quantization_config.get("Image"):
692+
image_uri = quantization_config.get("Image")
693+
elif compilation_config and compilation_config.get("Image"):
694+
image_uri = compilation_config.get("Image")
695+
instance_type = (
696+
instance_type
697+
or self.pysdk_model.deployment_config.get("DeploymentArgs").get("InstanceType")
698+
or _get_nb_instance()
708699
)
700+
if not _is_compatible_with_optimization_job(instance_type, image_uri):
701+
deployment_config = self._find_compatible_deployment_config(None)
702+
if deployment_config:
703+
optimization_env_vars = deployment_config.get("DeploymentArgs").get(
704+
"Environment"
705+
)
706+
self.pysdk_model.set_deployment_config(
707+
config_name=deployment_config.get("DeploymentConfigName"),
708+
instance_type=deployment_config.get("InstanceType"),
709+
)
709710

710-
if deployment_config:
711-
self.pysdk_model.set_deployment_config(
712-
config_name=deployment_config.get("DeploymentConfigName"),
713-
instance_type=deployment_config.get("InstanceType"),
714-
)
715-
merged_env_vars = self.pysdk_model.deployment_config.get("Environment")
716-
717-
if speculative_decoding_config:
718-
# TODO: Match Optimization Input Schema
719-
s3 = {
720-
"S3Uri": self.pysdk_model.additional_model_data_sources[
721-
"SpeculativeDecoding"
722-
][0]["S3DataSource"]["S3Uri"]
723-
}
724-
model_source["S3"].update(s3)
725-
elif speculative_decoding_config:
726-
raise ValueError("Can't find deployment config for model optimization job.")
711+
optimization_env_vars = _update_environment_variables(optimization_env_vars, env_vars)
727712

728713
optimization_config = {}
729-
if env_vars:
730-
if merged_env_vars:
731-
merged_env_vars.update(env_vars)
732-
else:
733-
merged_env_vars = env_vars
734714
if quantization_config:
735715
optimization_config["ModelQuantizationConfig"] = quantization_config
716+
pysdk_model_env_vars = _update_environment_variables(
717+
pysdk_model_env_vars, quantization_config["OverrideEnvironment"]
718+
)
736719
if compilation_config:
737720
optimization_config["ModelCompilationConfig"] = compilation_config
738-
739-
if accept_eula:
740-
self.pysdk_model.accept_eula = accept_eula
741-
self.pysdk_model.model_data["S3DataSource"].update(
742-
{"ModelAccessConfig": {"AcceptEula": accept_eula}}
721+
pysdk_model_env_vars = _update_environment_variables(
722+
pysdk_model_env_vars, compilation_config["OverrideEnvironment"]
743723
)
744-
model_source["S3"].update({"ModelAccessConfig": {"AcceptEula": accept_eula}})
745724

746725
output_config = {"S3OutputLocation": output_path}
747726
if kms_key:
@@ -751,12 +730,13 @@ def _optimize_for_jumpstart(
751730
"OptimizationJobName": job_name,
752731
"ModelSource": model_source,
753732
"DeploymentInstanceType": instance_type,
754-
"Environment": merged_env_vars,
755733
"OptimizationConfigs": [optimization_config],
756734
"OutputConfig": output_config,
757735
"RoleArn": role,
758736
}
759737

738+
if optimization_env_vars:
739+
create_optimization_job_args["Environment"] = optimization_env_vars
760740
if max_runtime_in_sec:
761741
create_optimization_job_args["StoppingCondition"] = {
762742
"MaxRuntimeInSeconds": max_runtime_in_sec
@@ -766,22 +746,106 @@ def _optimize_for_jumpstart(
766746
if vpc_config:
767747
create_optimization_job_args["VpcConfig"] = vpc_config
768748

769-
self.sagemaker_session.sagemaker_client.create_optimization_job(
770-
**create_optimization_job_args
771-
)
749+
self.pysdk_model.env.update(pysdk_model_env_vars)
750+
return create_optimization_job_args
772751

773-
def _is_gated_model(self, model: Model) -> bool:
752+
def _is_gated_model(self, model=None) -> bool:
774753
"""Determine if ``this`` Model is Gated
775754
776755
Args:
777756
model (Model): Jumpstart Model
778757
Returns:
779758
bool: ``True`` if ``this`` Model is Gated
780759
"""
781-
s3_uri = model.model_data
760+
s3_uri = model.model_data if model else self.pysdk_model.model_data
782761
if isinstance(s3_uri, dict):
783762
s3_uri = s3_uri.get("S3DataSource").get("S3Uri")
784763

785764
if s3_uri is None:
786765
return False
787766
return "private" in s3_uri
767+
768+
def _set_additional_model_source(
769+
self, speculative_decoding_config: Optional[Dict[str, Any]] = None
770+
) -> None:
771+
"""Set Additional Model Source to ``this`` model.
772+
773+
Args:
774+
speculative_decoding_config (Optional[Dict[str, Any]]): Speculative decoding config.
775+
"""
776+
if speculative_decoding_config:
777+
model_provider: str = speculative_decoding_config["ModelProvider"]
778+
779+
if model_provider.lower() == "sagemaker":
780+
if not self._is_speculation_enabled(self.pysdk_model.deployment_config):
781+
deployment_config = self._find_compatible_deployment_config(
782+
speculative_decoding_config
783+
)
784+
if deployment_config:
785+
self.pysdk_model.set_deployment_config(
786+
config_name=deployment_config.get("DeploymentConfigName"),
787+
instance_type=deployment_config.get("InstanceType"),
788+
)
789+
self.pysdk_model.add_tags(
790+
{"key": Tag.SPECULATIVE_DRAFT_MODL_PROVIDER, "value": "sagemaker"},
791+
)
792+
else:
793+
raise ValueError(
794+
"Cannot find deployment config compatible for optimization job."
795+
)
796+
else:
797+
s3_uri = speculative_decoding_config.get("ModelSource")
798+
if not s3_uri:
799+
raise ValueError("Custom S3 Uri cannot be none.")
800+
801+
self.pysdk_model.additional_model_data_sources["speculative_decoding"][0][
802+
"s3_data_source"
803+
]["s3_uri"] = s3_uri
804+
self.pysdk_model.add_tags(
805+
{"key": Tag.SPECULATIVE_DRAFT_MODL_PROVIDER, "value": "customer"},
806+
)
807+
808+
def _find_compatible_deployment_config(
809+
self, speculative_decoding_config: Optional[Dict] = None
810+
) -> Optional[Dict[str, Any]]:
811+
"""Finds compatible model deployment config for optimization job.
812+
813+
Args:
814+
speculative_decoding_config (Optional[Dict]): Speculative decoding config.
815+
816+
Returns:
817+
Optional[Dict[str, Any]]: A compatible model deployment config for optimization job.
818+
"""
819+
for deployment_config in self.pysdk_model.list_deployment_configs():
820+
instance_type = deployment_config.get("deployment_config").get("InstanceType")
821+
image_uri = deployment_config.get("deployment_config").get("ImageUri")
822+
823+
if _is_compatible_with_optimization_job(instance_type, image_uri):
824+
if not speculative_decoding_config:
825+
return deployment_config
826+
827+
if self._is_speculation_enabled(deployment_config):
828+
return deployment_config
829+
830+
return None
831+
832+
def _is_speculation_enabled(self, deployment_config: Optional[Dict[str, Any]]) -> bool:
833+
"""Checks whether speculative is enabled for the given deployment config.
834+
835+
Args:
836+
deployment_config (Dict[str, Any]): A deployment config.
837+
838+
Returns:
839+
bool: Whether speculative is enabled for this deployment config.
840+
"""
841+
if deployment_config is None:
842+
return False
843+
844+
acceleration_configs = deployment_config.get("AccelerationConfigs")
845+
if acceleration_configs:
846+
for acceleration_config in acceleration_configs:
847+
if acceleration_config.get(
848+
"type", "default"
849+
).lower() == "speculative" and acceleration_config.get("enabled"):
850+
return True
851+
return False

src/sagemaker/serve/builder/model_builder.py

+8-18
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
from sagemaker.serve.utils import task
6464
from sagemaker.serve.utils.exceptions import TaskNotFoundException
6565
from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model
66-
from sagemaker.serve.utils.optimize_utils import _poll_optimization_job, _generate_optimized_model
66+
from sagemaker.serve.utils.optimize_utils import _generate_optimized_model
6767
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
6868
from sagemaker.serve.utils.hardware_detector import (
6969
_get_gpu_info,
@@ -972,7 +972,7 @@ def _model_builder_optimize_wrapper(
972972
env_vars: Optional[Dict] = None,
973973
vpc_config: Optional[Dict] = None,
974974
kms_key: Optional[str] = None,
975-
max_runtime_in_sec: Optional[int] = None,
975+
max_runtime_in_sec: Optional[int] = 36000,
976976
sagemaker_session: Optional[Session] = None,
977977
) -> Model:
978978
"""Runs a model optimization job.
@@ -998,7 +998,7 @@ def _model_builder_optimize_wrapper(
998998
kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading
999999
to S3. Defaults to ``None``.
10001000
max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to
1001-
``None``.
1001+
36000 seconds.
10021002
sagemaker_session (Optional[Session]): Session object which manages interactions
10031003
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
10041004
function creates one using the default AWS configuration chain.
@@ -1010,8 +1010,9 @@ def _model_builder_optimize_wrapper(
10101010
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
10111011
job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}"
10121012

1013+
input_args = {}
10131014
if self._is_jumpstart_model_id():
1014-
self._optimize_for_jumpstart(
1015+
input_args = self._optimize_for_jumpstart(
10151016
output_path=output_path,
10161017
instance_type=instance_type,
10171018
role=role if role else self.role_arn,
@@ -1027,19 +1028,8 @@ def _model_builder_optimize_wrapper(
10271028
max_runtime_in_sec=max_runtime_in_sec,
10281029
)
10291030

1030-
# TODO: use the wait for job pattern similar to
1031-
# https://quip-amazon.com/TKaPAhJck5sD/PySDK-Model-Optimization#temp:C:YcX3f2b103dabb4431090568bca2
1032-
if not _poll_optimization_job(job_name, self.sagemaker_session):
1033-
raise Exception("Optimization job timed out.")
1034-
1035-
describe_optimization_job_res = (
1036-
self.sagemaker_session.sagemaker_client.describe_optimization_job(
1037-
OptimizationJobName=job_name
1038-
)
1039-
)
1040-
1041-
self.pysdk_model = _generate_optimized_model(
1042-
self.pysdk_model, describe_optimization_job_res
1043-
)
1031+
self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args)
1032+
job_status = self.sagemaker_session.wait_for_optimization_job(job_name)
1033+
self.pysdk_model = _generate_optimized_model(self.pysdk_model, job_status)
10441034

10451035
return self.pysdk_model

0 commit comments

Comments
 (0)