Skip to content

Commit 7ba88ce

Browse files
jiapinwJacky Lee
authored andcommitted
[Feat] Support MLflow Model Format Through ModelBuilder (aws#4564)
* Initial commit for mlflow support in ModelBuilder Fix some formating issues Fix style issues and typo Fix s3 downloading for mlflow Fix logging in inference script for torchserve Add validation for mlflow fix unused param in s3 downloader * Add mlflow tensorflow flavor load in inference script [Fix] enhance error handling for mlflow model support * Add unit tests for utils fix black format * mlflow integration integ tests initial commit Fix schema builder and constant import Fix typo Add integ tests for mlflow xgboost flavor fix style issues for integ tests adding ut fix import fix import add missed mock add mock for path.isfile add mock for open try another way of patching test patching fix patching fix patching fix patching fix patching fix patching fix patching fix patching fix patching add ut for s3 input add negative test case fix indent increase test coverage fix patching module fix session patch fix test input debug remov debug messages fix ut prepare for initial pr fix doc8 fix ut name collision fix integ tests naming fix ut run in PR Add ignore path to doc8 * Move mlflow inputs under model_metadata instead fix ut failure fix assertion * Fix local path input for mlflow fix ut fix ut normalize path befor copying over fix for ut fix black format * mark local container integ test as flaky as they all bind to the same host * marg mlflow xgboost integ local container test as flaky * skip copying files if src and dst are the same * resolve pr comments * keep resolving pr comments * update docstrings so that mlflow support is mark as beta * fix pylint failures --------- Co-authored-by: Jacky Lee <[email protected]>
1 parent e03daf7 commit 7ba88ce

File tree

24 files changed

+1897
-7
lines changed

24 files changed

+1897
-7
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
from pathlib import Path
2222

23+
from sagemaker.s3 import S3Downloader
24+
2325
from sagemaker import Session
2426
from sagemaker.model import Model
2527
from sagemaker.base_predictor import PredictorBase
@@ -37,6 +39,22 @@
3739
from sagemaker.serve.builder.jumpstart_builder import JumpStart
3840
from sagemaker.serve.builder.transformers_builder import Transformers
3941
from sagemaker.predictor import Predictor
42+
from sagemaker.serve.model_format.mlflow.constants import (
43+
MLFLOW_MODEL_PATH,
44+
MLFLOW_METADATA_FILE,
45+
MLFLOW_PIP_DEPENDENCY_FILE,
46+
)
47+
from sagemaker.serve.model_format.mlflow.utils import (
48+
_get_default_model_server_for_mlflow,
49+
_mlflow_input_is_local_path,
50+
_download_s3_artifacts,
51+
_select_container_for_mlflow_model,
52+
_generate_mlflow_artifact_path,
53+
_get_all_flavor_metadata,
54+
_get_deployment_flavor,
55+
_validate_input_for_mlflow,
56+
_copy_directory_contents,
57+
)
4058
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
4159
from sagemaker.serve.spec.inference_spec import InferenceSpec
4260
from sagemaker.serve.utils import task
@@ -145,8 +163,11 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
145163
to the model server). Possible values for this argument are
146164
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
147165
``TRITON``, and``TGI``.
148-
model_metadata (Optional[Dict[str, Any]): Dictionary used to override the HuggingFace
149-
model metadata. Currently ``HF_TASK`` is overridable.
166+
model_metadata (Optional[Dict[str, Any]): Dictionary used to override model metadata.
167+
Currently, ``HF_TASK`` is overridable for HuggingFace model. ``MLFLOW_MODEL_PATH``
168+
is available for providing local path or s3 path to MLflow artifacts. However,
169+
``MLFLOW_MODEL_PATH`` is experimental and is not intended for production use
170+
at this moment.
150171
"""
151172

152173
model_path: Optional[str] = field(
@@ -245,7 +266,10 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
245266
)
246267
model_metadata: Optional[Dict[str, Any]] = field(
247268
default=None,
248-
metadata={"help": "Define the model metadata to override, currently supports `HF_TASK`"},
269+
metadata={
270+
"help": "Define the model metadata to override, currently supports `HF_TASK`, "
271+
"`MLFLOW_MODEL_PATH`"
272+
},
249273
)
250274

251275
def _build_validations(self):
@@ -297,6 +321,8 @@ def _save_model_inference_spec(self):
297321
save_pkl(code_path, (self._framework, self.schema_builder))
298322
else:
299323
save_pkl(code_path, (self.model, self.schema_builder))
324+
elif self._is_mlflow_model:
325+
save_pkl(code_path, self.schema_builder)
300326
else:
301327
raise ValueError("Cannot detect required model or inference spec")
302328

@@ -577,6 +603,76 @@ def wrapper(*args, **kwargs):
577603

578604
return wrapper
579605

606+
def _check_if_input_is_mlflow_model(self) -> bool:
607+
"""Checks whether an MLmodel file exists in the given directory.
608+
609+
Returns:
610+
bool: True if the MLmodel file exists, False otherwise.
611+
"""
612+
if self.inference_spec or self.model:
613+
logger.info(
614+
"Either inference spec or model is provided. "
615+
"ModelBuilder is not handling MLflow model input"
616+
)
617+
return False
618+
619+
if not self.model_metadata:
620+
logger.info(
621+
"No ModelMetadata provided. ModelBuilder is not handling MLflow model input"
622+
)
623+
return False
624+
625+
path = self.model_metadata.get(MLFLOW_MODEL_PATH)
626+
if not path:
627+
logger.info(
628+
"%s is not provided in ModelMetadata. ModelBuilder is not handling MLflow model "
629+
"input",
630+
MLFLOW_MODEL_PATH,
631+
)
632+
return False
633+
634+
# Check for S3 path
635+
if path.startswith("s3://"):
636+
s3_downloader = S3Downloader()
637+
if not path.endswith("/"):
638+
path += "/"
639+
s3_uri_to_mlmodel_file = f"{path}{MLFLOW_METADATA_FILE}"
640+
response = s3_downloader.list(s3_uri_to_mlmodel_file, self.sagemaker_session)
641+
return len(response) > 0
642+
643+
file_path = os.path.join(path, MLFLOW_METADATA_FILE)
644+
return os.path.isfile(file_path)
645+
646+
def _initialize_for_mlflow(self) -> None:
647+
"""Initialize mlflow model artifacts, image uri and model server."""
648+
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
649+
if not _mlflow_input_is_local_path(mlflow_path):
650+
# TODO: extend to package arn, run id and etc.
651+
_download_s3_artifacts(mlflow_path, self.model_path, self.sagemaker_session)
652+
else:
653+
_copy_directory_contents(mlflow_path, self.model_path)
654+
mlflow_model_metadata_path = _generate_mlflow_artifact_path(
655+
self.model_path, MLFLOW_METADATA_FILE
656+
)
657+
# TODO: add validation on MLmodel file
658+
mlflow_model_dependency_path = _generate_mlflow_artifact_path(
659+
self.model_path, MLFLOW_PIP_DEPENDENCY_FILE
660+
)
661+
flavor_metadata = _get_all_flavor_metadata(mlflow_model_metadata_path)
662+
deployment_flavor = _get_deployment_flavor(flavor_metadata)
663+
664+
self.model_server = self.model_server or _get_default_model_server_for_mlflow(
665+
deployment_flavor
666+
)
667+
self.image_uri = self.image_uri or _select_container_for_mlflow_model(
668+
mlflow_model_src_path=self.model_path,
669+
deployment_flavor=deployment_flavor,
670+
region=self.sagemaker_session.boto_region_name,
671+
instance_type=self.instance_type,
672+
)
673+
self.env_vars.update({"MLFLOW_MODEL_FLAVOR": f"{deployment_flavor}"})
674+
self.dependencies.update({"requirements": mlflow_model_dependency_path})
675+
580676
# Model Builder is a class to build the model for deployment.
581677
# It supports two modes of deployment
582678
# 1/ SageMaker Endpoint
@@ -620,6 +716,14 @@ def build( # pylint: disable=R0911
620716
self.serve_settings = self._get_serve_setting()
621717

622718
self._is_custom_image_uri = self.image_uri is not None
719+
self._is_mlflow_model = self._check_if_input_is_mlflow_model()
720+
if self._is_mlflow_model:
721+
logger.warning(
722+
"Support of MLflow format models is experimental and is not intended"
723+
" for production at this moment."
724+
)
725+
self._initialize_for_mlflow()
726+
_validate_input_for_mlflow(self.model_server)
623727

624728
if isinstance(self.model, str):
625729
model_task = None
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Holds constants used for interpreting MLflow models."""
14+
from __future__ import absolute_import
15+
16+
DEFAULT_FW_USED_FOR_DEFAULT_IMAGE = "pytorch"
17+
DEFAULT_PYTORCH_VERSION = {
18+
"py38": "1.12.1",
19+
"py39": "1.13.1",
20+
"py310": "2.2.0",
21+
}
22+
MLFLOW_MODEL_PATH = "MLFLOW_MODEL_PATH"
23+
MLFLOW_METADATA_FILE = "MLmodel"
24+
MLFLOW_PIP_DEPENDENCY_FILE = "requirements.txt"
25+
MLFLOW_PYFUNC = "python_function"
26+
MLFLOW_FLAVOR_TO_PYTHON_PACKAGE_MAP = {
27+
"sklearn": "scikit-learn",
28+
"pytorch": "torch",
29+
"tensorflow": "tensorflow",
30+
"keras": "tensorflow",
31+
"xgboost": "xgboost",
32+
"lightgbm": "lightgbm",
33+
"h2o": "h2o",
34+
"spark": "pyspark",
35+
"onnx": "onnxruntime",
36+
}
37+
FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT = [ # will extend to keras and tf
38+
"sklearn",
39+
"pytorch",
40+
"xgboost",
41+
]

0 commit comments

Comments
 (0)