|
20 | 20 |
|
21 | 21 | from pathlib import Path
|
22 | 22 |
|
| 23 | +from sagemaker.s3 import S3Downloader |
| 24 | + |
23 | 25 | from sagemaker import Session
|
24 | 26 | from sagemaker.model import Model
|
25 | 27 | from sagemaker.base_predictor import PredictorBase
|
|
37 | 39 | from sagemaker.serve.builder.jumpstart_builder import JumpStart
|
38 | 40 | from sagemaker.serve.builder.transformers_builder import Transformers
|
39 | 41 | from sagemaker.predictor import Predictor
|
| 42 | +from sagemaker.serve.model_format.mlflow.constants import ( |
| 43 | + MLFLOW_MODEL_PATH, |
| 44 | + MLFLOW_METADATA_FILE, |
| 45 | + MLFLOW_PIP_DEPENDENCY_FILE, |
| 46 | +) |
| 47 | +from sagemaker.serve.model_format.mlflow.utils import ( |
| 48 | + _get_default_model_server_for_mlflow, |
| 49 | + _mlflow_input_is_local_path, |
| 50 | + _download_s3_artifacts, |
| 51 | + _select_container_for_mlflow_model, |
| 52 | + _generate_mlflow_artifact_path, |
| 53 | + _get_all_flavor_metadata, |
| 54 | + _get_deployment_flavor, |
| 55 | + _validate_input_for_mlflow, |
| 56 | + _copy_directory_contents, |
| 57 | +) |
40 | 58 | from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
|
41 | 59 | from sagemaker.serve.spec.inference_spec import InferenceSpec
|
42 | 60 | from sagemaker.serve.utils import task
|
@@ -145,8 +163,11 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
|
145 | 163 | to the model server). Possible values for this argument are
|
146 | 164 | ``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
|
147 | 165 | ``TRITON``, and``TGI``.
|
148 |
| - model_metadata (Optional[Dict[str, Any]): Dictionary used to override the HuggingFace |
149 |
| - model metadata. Currently ``HF_TASK`` is overridable. |
| 166 | + model_metadata (Optional[Dict[str, Any]): Dictionary used to override model metadata. |
| 167 | + Currently, ``HF_TASK`` is overridable for HuggingFace model. ``MLFLOW_MODEL_PATH`` |
| 168 | + is available for providing local path or s3 path to MLflow artifacts. However, |
| 169 | + ``MLFLOW_MODEL_PATH`` is experimental and is not intended for production use |
| 170 | + at this moment. |
150 | 171 | """
|
151 | 172 |
|
152 | 173 | model_path: Optional[str] = field(
|
@@ -245,7 +266,10 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
|
245 | 266 | )
|
246 | 267 | model_metadata: Optional[Dict[str, Any]] = field(
|
247 | 268 | default=None,
|
248 |
| - metadata={"help": "Define the model metadata to override, currently supports `HF_TASK`"}, |
| 269 | + metadata={ |
| 270 | + "help": "Define the model metadata to override, currently supports `HF_TASK`, " |
| 271 | + "`MLFLOW_MODEL_PATH`" |
| 272 | + }, |
249 | 273 | )
|
250 | 274 |
|
251 | 275 | def _build_validations(self):
|
@@ -297,6 +321,8 @@ def _save_model_inference_spec(self):
|
297 | 321 | save_pkl(code_path, (self._framework, self.schema_builder))
|
298 | 322 | else:
|
299 | 323 | save_pkl(code_path, (self.model, self.schema_builder))
|
| 324 | + elif self._is_mlflow_model: |
| 325 | + save_pkl(code_path, self.schema_builder) |
300 | 326 | else:
|
301 | 327 | raise ValueError("Cannot detect required model or inference spec")
|
302 | 328 |
|
@@ -577,6 +603,76 @@ def wrapper(*args, **kwargs):
|
577 | 603 |
|
578 | 604 | return wrapper
|
579 | 605 |
|
| 606 | + def _check_if_input_is_mlflow_model(self) -> bool: |
| 607 | + """Checks whether an MLmodel file exists in the given directory. |
| 608 | +
|
| 609 | + Returns: |
| 610 | + bool: True if the MLmodel file exists, False otherwise. |
| 611 | + """ |
| 612 | + if self.inference_spec or self.model: |
| 613 | + logger.info( |
| 614 | + "Either inference spec or model is provided. " |
| 615 | + "ModelBuilder is not handling MLflow model input" |
| 616 | + ) |
| 617 | + return False |
| 618 | + |
| 619 | + if not self.model_metadata: |
| 620 | + logger.info( |
| 621 | + "No ModelMetadata provided. ModelBuilder is not handling MLflow model input" |
| 622 | + ) |
| 623 | + return False |
| 624 | + |
| 625 | + path = self.model_metadata.get(MLFLOW_MODEL_PATH) |
| 626 | + if not path: |
| 627 | + logger.info( |
| 628 | + "%s is not provided in ModelMetadata. ModelBuilder is not handling MLflow model " |
| 629 | + "input", |
| 630 | + MLFLOW_MODEL_PATH, |
| 631 | + ) |
| 632 | + return False |
| 633 | + |
| 634 | + # Check for S3 path |
| 635 | + if path.startswith("s3://"): |
| 636 | + s3_downloader = S3Downloader() |
| 637 | + if not path.endswith("/"): |
| 638 | + path += "/" |
| 639 | + s3_uri_to_mlmodel_file = f"{path}{MLFLOW_METADATA_FILE}" |
| 640 | + response = s3_downloader.list(s3_uri_to_mlmodel_file, self.sagemaker_session) |
| 641 | + return len(response) > 0 |
| 642 | + |
| 643 | + file_path = os.path.join(path, MLFLOW_METADATA_FILE) |
| 644 | + return os.path.isfile(file_path) |
| 645 | + |
| 646 | + def _initialize_for_mlflow(self) -> None: |
| 647 | + """Initialize mlflow model artifacts, image uri and model server.""" |
| 648 | + mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH) |
| 649 | + if not _mlflow_input_is_local_path(mlflow_path): |
| 650 | + # TODO: extend to package arn, run id and etc. |
| 651 | + _download_s3_artifacts(mlflow_path, self.model_path, self.sagemaker_session) |
| 652 | + else: |
| 653 | + _copy_directory_contents(mlflow_path, self.model_path) |
| 654 | + mlflow_model_metadata_path = _generate_mlflow_artifact_path( |
| 655 | + self.model_path, MLFLOW_METADATA_FILE |
| 656 | + ) |
| 657 | + # TODO: add validation on MLmodel file |
| 658 | + mlflow_model_dependency_path = _generate_mlflow_artifact_path( |
| 659 | + self.model_path, MLFLOW_PIP_DEPENDENCY_FILE |
| 660 | + ) |
| 661 | + flavor_metadata = _get_all_flavor_metadata(mlflow_model_metadata_path) |
| 662 | + deployment_flavor = _get_deployment_flavor(flavor_metadata) |
| 663 | + |
| 664 | + self.model_server = self.model_server or _get_default_model_server_for_mlflow( |
| 665 | + deployment_flavor |
| 666 | + ) |
| 667 | + self.image_uri = self.image_uri or _select_container_for_mlflow_model( |
| 668 | + mlflow_model_src_path=self.model_path, |
| 669 | + deployment_flavor=deployment_flavor, |
| 670 | + region=self.sagemaker_session.boto_region_name, |
| 671 | + instance_type=self.instance_type, |
| 672 | + ) |
| 673 | + self.env_vars.update({"MLFLOW_MODEL_FLAVOR": f"{deployment_flavor}"}) |
| 674 | + self.dependencies.update({"requirements": mlflow_model_dependency_path}) |
| 675 | + |
580 | 676 | # Model Builder is a class to build the model for deployment.
|
581 | 677 | # It supports two modes of deployment
|
582 | 678 | # 1/ SageMaker Endpoint
|
@@ -620,6 +716,14 @@ def build( # pylint: disable=R0911
|
620 | 716 | self.serve_settings = self._get_serve_setting()
|
621 | 717 |
|
622 | 718 | self._is_custom_image_uri = self.image_uri is not None
|
| 719 | + self._is_mlflow_model = self._check_if_input_is_mlflow_model() |
| 720 | + if self._is_mlflow_model: |
| 721 | + logger.warning( |
| 722 | + "Support of MLflow format models is experimental and is not intended" |
| 723 | + " for production at this moment." |
| 724 | + ) |
| 725 | + self._initialize_for_mlflow() |
| 726 | + _validate_input_for_mlflow(self.model_server) |
623 | 727 |
|
624 | 728 | if isinstance(self.model, str):
|
625 | 729 | model_task = None
|
|
0 commit comments