|
19 | 19 | from typing import Type, Any, List, Dict, Optional
|
20 | 20 | import logging
|
21 | 21 |
|
| 22 | +from sagemaker.jumpstart import enums |
| 23 | +from sagemaker.jumpstart.utils import verify_model_region_and_return_specs, get_eula_message |
22 | 24 | from sagemaker.model import Model
|
23 | 25 | from sagemaker import model_uris
|
24 | 26 | from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources
|
|
33 | 35 | LocalModelLoadException,
|
34 | 36 | SkipTuningComboException,
|
35 | 37 | )
|
| 38 | +from sagemaker.serve.utils.optimize_utils import ( |
| 39 | + _extract_supported_deployment_config, |
| 40 | + _is_speculation_enabled, |
| 41 | + _is_compatible_with_optimization_job, |
| 42 | +) |
36 | 43 | from sagemaker.serve.utils.predictors import (
|
37 | 44 | DjlLocalModePredictor,
|
38 | 45 | TgiLocalModePredictor,
|
|
53 | 60 | from sagemaker.serve.utils.types import ModelServer
|
54 | 61 | from sagemaker.base_predictor import PredictorBase
|
55 | 62 | from sagemaker.jumpstart.model import JumpStartModel
|
| 63 | +from sagemaker.utils import Tags |
56 | 64 |
|
57 | 65 | _DJL_MODEL_BUILDER_ENTRY_POINT = "inference.py"
|
58 | 66 | _NO_JS_MODEL_EX = "HuggingFace JumpStart Model ID not detected. Building for HuggingFace Model ID."
|
@@ -564,6 +572,148 @@ def _build_for_jumpstart(self):
|
564 | 572 |
|
565 | 573 | return self.pysdk_model
|
566 | 574 |
|
| 575 | + def _optimize_for_jumpstart( |
| 576 | + self, |
| 577 | + output_path: str, |
| 578 | + instance_type: Optional[str] = None, |
| 579 | + role: Optional[str] = None, |
| 580 | + tags: Optional[Tags] = None, |
| 581 | + job_name: Optional[str] = None, |
| 582 | + accept_eula: Optional[bool] = None, |
| 583 | + quantization_config: Optional[Dict] = None, |
| 584 | + compilation_config: Optional[Dict] = None, |
| 585 | + speculative_decoding_config: Optional[Dict] = None, |
| 586 | + env_vars: Optional[Dict] = None, |
| 587 | + vpc_config: Optional[Dict] = None, |
| 588 | + kms_key: Optional[str] = None, |
| 589 | + max_runtime_in_sec: Optional[int] = None, |
| 590 | + ) -> None: |
| 591 | + """Runs a model optimization job. |
| 592 | +
|
| 593 | + Args: |
| 594 | + output_path (str): Specifies where to store the compiled/quantized model. |
| 595 | + instance_type (Optional[str]): Target deployment instance type that |
| 596 | + the model is optimized for. |
| 597 | + role (Optional[str]): Execution role. Defaults to ``None``. |
| 598 | + tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. |
| 599 | + job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. |
| 600 | + accept_eula (bool): For models that require a Model Access Config, specify True or |
| 601 | + False to indicate whether model terms of use have been accepted. |
| 602 | + The `accept_eula` value must be explicitly defined as `True` in order to |
| 603 | + accept the end-user license agreement (EULA) that some |
| 604 | + models require. (Default: None). |
| 605 | + quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. |
| 606 | + compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. |
| 607 | + speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. |
| 608 | + Defaults to ``None`` |
| 609 | + env_vars (Optional[Dict]): Additional environment variables to run the optimization |
| 610 | + container. Defaults to ``None``. |
| 611 | + vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. |
| 612 | + kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading |
| 613 | + to S3. Defaults to ``None``. |
| 614 | + max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to |
| 615 | + ``None``. |
| 616 | + """ |
| 617 | + model_specs = verify_model_region_and_return_specs( |
| 618 | + region=self.sagemaker_session.boto_region_name, |
| 619 | + model_id=self.pysdk_model.model_id, |
| 620 | + version=self.pysdk_model.model_version, |
| 621 | + sagemaker_session=self.sagemaker_session, |
| 622 | + scope=enums.JumpStartScriptScope.INFERENCE, |
| 623 | + model_type=self.pysdk_model.model_type, |
| 624 | + ) |
| 625 | + |
| 626 | + if model_specs.is_gated_model() and accept_eula is not True: |
| 627 | + raise ValueError(get_eula_message(model_specs, self.sagemaker_session.boto_region_name)) |
| 628 | + |
| 629 | + if not (self.pysdk_model.model_data and self.pysdk_model.model_data.get("S3DataSource")): |
| 630 | + raise ValueError("Model Optimization Job only supports model backed by S3.") |
| 631 | + |
| 632 | + has_alternative_config = self.pysdk_model.deployment_config is not None |
| 633 | + merged_env_vars = None |
| 634 | + # TODO: Match Optimization Input Schema |
| 635 | + model_source = { |
| 636 | + "S3": {"S3Uri": self.pysdk_model.model_data.get("S3DataSource").get("S3Uri")}, |
| 637 | + "SageMakerModel": {"ModelName": self.model}, |
| 638 | + } |
| 639 | + |
| 640 | + if has_alternative_config: |
| 641 | + image_uri = self.pysdk_model.deployment_config.get("DeploymentArgs").get("ImageUri") |
| 642 | + instance_type = self.pysdk_model.deployment_config.get("InstanceType") |
| 643 | + else: |
| 644 | + image_uri = self.pysdk_model.image_uri |
| 645 | + |
| 646 | + if not _is_compatible_with_optimization_job(instance_type, image_uri) or ( |
| 647 | + speculative_decoding_config |
| 648 | + and not _is_speculation_enabled(self.pysdk_model.deployment_config) |
| 649 | + ): |
| 650 | + deployment_config = _extract_supported_deployment_config( |
| 651 | + self.pysdk_model.list_deployment_configs(), speculative_decoding_config is None |
| 652 | + ) |
| 653 | + |
| 654 | + if deployment_config: |
| 655 | + self.pysdk_model.set_deployment_config( |
| 656 | + config_name=deployment_config.get("DeploymentConfigName"), |
| 657 | + instance_type=deployment_config.get("InstanceType"), |
| 658 | + ) |
| 659 | + merged_env_vars = self.pysdk_model.deployment_config.get("Environment") |
| 660 | + |
| 661 | + if speculative_decoding_config: |
| 662 | + # TODO: Match Optimization Input Schema |
| 663 | + s3 = { |
| 664 | + "S3Uri": self.pysdk_model.additional_model_data_sources[ |
| 665 | + "SpeculativeDecoding" |
| 666 | + ][0]["S3DataSource"]["S3Uri"] |
| 667 | + } |
| 668 | + model_source["S3"].update(s3) |
| 669 | + elif speculative_decoding_config: |
| 670 | + raise ValueError("Can't find deployment config for model optimization job.") |
| 671 | + |
| 672 | + optimization_config = {} |
| 673 | + if env_vars: |
| 674 | + if merged_env_vars: |
| 675 | + merged_env_vars.update(env_vars) |
| 676 | + else: |
| 677 | + merged_env_vars = env_vars |
| 678 | + if quantization_config: |
| 679 | + optimization_config["ModelQuantizationConfig"] = quantization_config |
| 680 | + if compilation_config: |
| 681 | + optimization_config["ModelCompilationConfig"] = compilation_config |
| 682 | + |
| 683 | + if accept_eula: |
| 684 | + self.pysdk_model.accept_eula = accept_eula |
| 685 | + self.pysdk_model.model_data["S3DataSource"].update( |
| 686 | + {"ModelAccessConfig": {"AcceptEula": accept_eula}} |
| 687 | + ) |
| 688 | + model_source["S3"].update({"ModelAccessConfig": {"AcceptEula": accept_eula}}) |
| 689 | + |
| 690 | + output_config = {"S3OutputLocation": output_path} |
| 691 | + if kms_key: |
| 692 | + output_config["KmsKeyId"] = kms_key |
| 693 | + |
| 694 | + create_optimization_job_args = { |
| 695 | + "OptimizationJobName": job_name, |
| 696 | + "ModelSource": model_source, |
| 697 | + "DeploymentInstanceType": instance_type, |
| 698 | + "Environment": merged_env_vars, |
| 699 | + "OptimizationConfigs": [optimization_config], |
| 700 | + "OutputConfig": output_config, |
| 701 | + "RoleArn": role, |
| 702 | + } |
| 703 | + |
| 704 | + if max_runtime_in_sec: |
| 705 | + create_optimization_job_args["StoppingCondition"] = { |
| 706 | + "MaxRuntimeInSeconds": max_runtime_in_sec |
| 707 | + } |
| 708 | + if tags: |
| 709 | + create_optimization_job_args["Tags"] = tags |
| 710 | + if vpc_config: |
| 711 | + create_optimization_job_args["VpcConfig"] = vpc_config |
| 712 | + |
| 713 | + self.sagemaker_session.sagemaker_client.create_optimization_job( |
| 714 | + **create_optimization_job_args |
| 715 | + ) |
| 716 | + |
567 | 717 | def _is_gated_model(self, model) -> bool:
|
568 | 718 | """Determine if ``this`` Model is Gated
|
569 | 719 |
|
|
0 commit comments