Skip to content

Commit e5c22c1

Browse files
ananth102grenmesterJacky Leejiapinwbenieric
committed
feat(sagemaker-mlflow): New features for SageMaker MLflow (aws#4744)
* feat: add support for mlflow inputs (aws#1441) * feat: add support for mlflow inputs * fix: typo * fix: doc * fix: S3 regex * fix: refactor * fix: refactor typo * fix: pylint * fix: pylint * fix: black and pylint --------- Co-authored-by: Jacky Lee <[email protected]> * fix: lineage tracking bug (aws#1447) * fix: lineage bug * fix: lineage * fix: add validation for tracking ARN input with MLflow input type * fix: bug * fix: unit tests * fix: mock * fix: args --------- Co-authored-by: Jacky Lee <[email protected]> * [Fix] regex for RunId to handle empty artifact path and change mlflow plugin name (aws#1455) * [Fix] run id regex pattern such that empty artifact path is handled * Change mlflow plugin name as per legal team requirement * Update describe_mlflow_tracking_server call to align with api changes (aws#1466) * feat: (sagemaker-mlflow) Adding Presigned Url function to SDK (aws#1462) (aws#1477) * mlflow presigned url changes * addressing design feedback * test changes * change: mlflow plugin name (aws#1489) Co-authored-by: Jacky Lee <[email protected]> --------- Co-authored-by: Jacky Lee <[email protected]> Co-authored-by: Jacky Lee <[email protected]> Co-authored-by: jiapinw <[email protected]> Co-authored-by: Erick Benitez-Ramos <[email protected]>
1 parent 116118f commit e5c22c1

File tree

15 files changed

+598
-92
lines changed

15 files changed

+598
-92
lines changed

requirements/extras/test_requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ nbformat>=5.9,<6
3737
accelerate>=0.24.1,<=0.27.0
3838
schema==0.7.5
3939
tensorflow>=2.1,<=2.16
40+
mlflow>=2.12.2,<2.13

src/sagemaker/mlflow/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
15+
"""This module contains code related to the Mlflow Tracking Server."""
16+
17+
from __future__ import absolute_import
18+
from typing import Optional, TYPE_CHECKING
19+
from sagemaker.apiutils import _utils
20+
21+
if TYPE_CHECKING:
22+
from sagemaker import Session
23+
24+
25+
def generate_mlflow_presigned_url(
26+
name: str,
27+
expires_in_seconds: Optional[int] = None,
28+
session_expiration_duration_in_seconds: Optional[int] = None,
29+
sagemaker_session: Optional["Session"] = None,
30+
) -> str:
31+
"""Generate a presigned url to acess the Mlflow UI.
32+
33+
Args:
34+
name (str): Name of the Mlflow Tracking Server
35+
expires_in_seconds (int): Expiration time of the first usage
36+
of the presigned url in seconds.
37+
session_expiration_duration_in_seconds (int): Session duration of the presigned url in
38+
seconds after the first use.
39+
sagemaker_session (sagemaker.session.Session): Session object which
40+
manages interactions with Amazon SageMaker APIs and any other
41+
AWS services needed. If not specified, one is created using the
42+
default AWS configuration chain.
43+
Returns:
44+
(str): Authorized Url to acess the Mlflow UI.
45+
"""
46+
session = sagemaker_session or _utils.default_session()
47+
api_response = session.create_presigned_mlflow_tracking_server_url(
48+
name, expires_in_seconds, session_expiration_duration_in_seconds
49+
)
50+
return api_response["AuthorizedUrl"]

src/sagemaker/serve/builder/model_builder.py

+131-27
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
# language governing permissions and limitations under the License.
1313
"""Holds the ModelBuilder class and the ModelServer enum."""
1414
from __future__ import absolute_import
15+
16+
import importlib.util
1517
import uuid
1618
from typing import Any, Type, List, Dict, Optional, Union
1719
from dataclasses import dataclass, field
1820
import logging
1921
import os
22+
import re
2023

2124
from pathlib import Path
2225

@@ -44,12 +47,15 @@
4447
from sagemaker.predictor import Predictor
4548
from sagemaker.serve.model_format.mlflow.constants import (
4649
MLFLOW_MODEL_PATH,
50+
MLFLOW_TRACKING_ARN,
51+
MLFLOW_RUN_ID_REGEX,
52+
MLFLOW_REGISTRY_PATH_REGEX,
53+
MODEL_PACKAGE_ARN_REGEX,
4754
MLFLOW_METADATA_FILE,
4855
MLFLOW_PIP_DEPENDENCY_FILE,
4956
)
5057
from sagemaker.serve.model_format.mlflow.utils import (
5158
_get_default_model_server_for_mlflow,
52-
_mlflow_input_is_local_path,
5359
_download_s3_artifacts,
5460
_select_container_for_mlflow_model,
5561
_generate_mlflow_artifact_path,
@@ -278,8 +284,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
278284
default=None,
279285
metadata={
280286
"help": "Define the model metadata to override, currently supports `HF_TASK`, "
281-
"`MLFLOW_MODEL_PATH`. HF_TASK should be set for new models without task metadata in "
282-
"the Hub, Adding unsupported task types will throw an exception"
287+
"`MLFLOW_MODEL_PATH`, and `MLFLOW_TRACKING_ARN`. HF_TASK should be set for new "
288+
"models without task metadata in the Hub, Adding unsupported task types will "
289+
"throw an exception"
283290
},
284291
)
285292

@@ -504,6 +511,7 @@ def _model_builder_register_wrapper(self, *args, **kwargs):
504511
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
505512
s3_upload_path=self.s3_upload_path,
506513
sagemaker_session=self.sagemaker_session,
514+
tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN),
507515
)
508516
return new_model_package
509517

@@ -574,6 +582,7 @@ def _model_builder_deploy_wrapper(
574582
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
575583
s3_upload_path=self.s3_upload_path,
576584
sagemaker_session=self.sagemaker_session,
585+
tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN),
577586
)
578587
return predictor
579588

@@ -627,11 +636,30 @@ def wrapper(*args, **kwargs):
627636

628637
return wrapper
629638

630-
def _check_if_input_is_mlflow_model(self) -> bool:
631-
"""Checks whether an MLmodel file exists in the given directory.
639+
def _handle_mlflow_input(self):
640+
"""Check whether an MLflow model is present and handle accordingly"""
641+
self._is_mlflow_model = self._has_mlflow_arguments()
642+
if not self._is_mlflow_model:
643+
return
644+
645+
mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
646+
artifact_path = self._get_artifact_path(mlflow_model_path)
647+
if not self._mlflow_metadata_exists(artifact_path):
648+
logger.info(
649+
"MLflow model metadata not detected in %s. ModelBuilder is not "
650+
"handling MLflow model input",
651+
mlflow_model_path,
652+
)
653+
return
654+
655+
self._initialize_for_mlflow(artifact_path)
656+
_validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR"))
657+
658+
def _has_mlflow_arguments(self) -> bool:
659+
"""Check whether MLflow model arguments are present
632660
633661
Returns:
634-
bool: True if the MLmodel file exists, False otherwise.
662+
bool: True if MLflow arguments are present, False otherwise.
635663
"""
636664
if self.inference_spec or self.model:
637665
logger.info(
@@ -646,16 +674,82 @@ def _check_if_input_is_mlflow_model(self) -> bool:
646674
)
647675
return False
648676

649-
path = self.model_metadata.get(MLFLOW_MODEL_PATH)
650-
if not path:
677+
mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
678+
if not mlflow_model_path:
651679
logger.info(
652680
"%s is not provided in ModelMetadata. ModelBuilder is not handling MLflow model "
653681
"input",
654682
MLFLOW_MODEL_PATH,
655683
)
656684
return False
657685

658-
# Check for S3 path
686+
return True
687+
688+
def _get_artifact_path(self, mlflow_model_path: str) -> str:
689+
"""Retrieves the model artifact location given the Mlflow model input.
690+
691+
Args:
692+
mlflow_model_path (str): The MLflow model path input.
693+
694+
Returns:
695+
str: The path to the model artifact.
696+
"""
697+
if (is_run_id_type := re.match(MLFLOW_RUN_ID_REGEX, mlflow_model_path)) or re.match(
698+
MLFLOW_REGISTRY_PATH_REGEX, mlflow_model_path
699+
):
700+
mlflow_tracking_arn = self.model_metadata.get(MLFLOW_TRACKING_ARN)
701+
if not mlflow_tracking_arn:
702+
raise ValueError(
703+
"%s is not provided in ModelMetadata or through set_tracking_arn "
704+
"but MLflow model path was provided." % MLFLOW_TRACKING_ARN,
705+
)
706+
707+
if not importlib.util.find_spec("sagemaker_mlflow"):
708+
raise ImportError(
709+
"Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed"
710+
)
711+
712+
import mlflow
713+
714+
mlflow.set_tracking_uri(mlflow_tracking_arn)
715+
if is_run_id_type:
716+
_, run_id, model_path = mlflow_model_path.split("/", 2)
717+
artifact_uri = mlflow.get_run(run_id).info.artifact_uri
718+
if not artifact_uri.endswith("/"):
719+
artifact_uri += "/"
720+
return artifact_uri + model_path
721+
722+
mlflow_client = mlflow.MlflowClient()
723+
if not mlflow_model_path.endswith("/"):
724+
mlflow_model_path += "/"
725+
726+
if "@" in mlflow_model_path:
727+
_, model_name_and_alias, artifact_uri = mlflow_model_path.split("/", 2)
728+
model_name, model_alias = model_name_and_alias.split("@")
729+
model_metadata = mlflow_client.get_model_version_by_alias(model_name, model_alias)
730+
else:
731+
_, model_name, model_version, artifact_uri = mlflow_model_path.split("/", 3)
732+
model_metadata = mlflow_client.get_model_version(model_name, model_version)
733+
734+
source = model_metadata.source
735+
if not source.endswith("/"):
736+
source += "/"
737+
return source + artifact_uri
738+
739+
if re.match(MODEL_PACKAGE_ARN_REGEX, mlflow_model_path):
740+
model_package = self.sagemaker_session.sagemaker_client.describe_model_package(
741+
ModelPackageName=mlflow_model_path
742+
)
743+
return model_package["SourceUri"]
744+
745+
return mlflow_model_path
746+
747+
def _mlflow_metadata_exists(self, path: str) -> bool:
748+
"""Checks whether an MLmodel file exists in the given directory.
749+
750+
Returns:
751+
bool: True if the MLmodel file exists, False otherwise.
752+
"""
659753
if path.startswith("s3://"):
660754
s3_downloader = S3Downloader()
661755
if not path.endswith("/"):
@@ -667,17 +761,18 @@ def _check_if_input_is_mlflow_model(self) -> bool:
667761
file_path = os.path.join(path, MLFLOW_METADATA_FILE)
668762
return os.path.isfile(file_path)
669763

670-
def _initialize_for_mlflow(self) -> None:
671-
"""Initialize mlflow model artifacts, image uri and model server."""
672-
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
673-
if not _mlflow_input_is_local_path(mlflow_path):
674-
# TODO: extend to package arn, run id and etc.
675-
logger.info(
676-
"Start downloading model artifacts from %s to %s", mlflow_path, self.model_path
677-
)
678-
_download_s3_artifacts(mlflow_path, self.model_path, self.sagemaker_session)
764+
def _initialize_for_mlflow(self, artifact_path: str) -> None:
765+
"""Initialize mlflow model artifacts, image uri and model server.
766+
767+
Args:
768+
artifact_path (str): The path to the artifact store.
769+
"""
770+
if artifact_path.startswith("s3://"):
771+
_download_s3_artifacts(artifact_path, self.model_path, self.sagemaker_session)
772+
elif os.path.exists(artifact_path):
773+
_copy_directory_contents(artifact_path, self.model_path)
679774
else:
680-
_copy_directory_contents(mlflow_path, self.model_path)
775+
raise ValueError("Invalid path: %s" % artifact_path)
681776
mlflow_model_metadata_path = _generate_mlflow_artifact_path(
682777
self.model_path, MLFLOW_METADATA_FILE
683778
)
@@ -730,6 +825,8 @@ def build( # pylint: disable=R0911
730825
self.role_arn = role_arn
731826
self.sagemaker_session = sagemaker_session or Session()
732827

828+
self.sagemaker_session.settings._local_download_dir = self.model_path
829+
733830
# https://github.com/boto/botocore/blob/develop/botocore/useragent.py#L258
734831
# decorate to_string() due to
735832
# https://github.com/boto/botocore/blob/develop/botocore/client.py#L1014-L1015
@@ -741,14 +838,8 @@ def build( # pylint: disable=R0911
741838
self.serve_settings = self._get_serve_setting()
742839

743840
self._is_custom_image_uri = self.image_uri is not None
744-
self._is_mlflow_model = self._check_if_input_is_mlflow_model()
745-
if self._is_mlflow_model:
746-
logger.warning(
747-
"Support of MLflow format models is experimental and is not intended"
748-
" for production at this moment."
749-
)
750-
self._initialize_for_mlflow()
751-
_validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR"))
841+
842+
self._handle_mlflow_input()
752843

753844
if isinstance(self.model, str):
754845
model_task = None
@@ -844,6 +935,19 @@ def validate(self, model_dir: str) -> Type[bool]:
844935

845936
return get_metadata(model_dir)
846937

938+
def set_tracking_arn(self, arn: str):
939+
"""Set tracking server ARN"""
940+
# TODO: support native MLflow URIs
941+
if importlib.util.find_spec("sagemaker_mlflow"):
942+
import mlflow
943+
944+
mlflow.set_tracking_uri(arn)
945+
self.model_metadata[MLFLOW_TRACKING_ARN] = arn
946+
else:
947+
raise ImportError(
948+
"Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed"
949+
)
950+
847951
def _hf_schema_builder_init(self, model_task: str):
848952
"""Initialize the schema builder for the given HF_TASK
849953

src/sagemaker/serve/model_format/mlflow/constants.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
MODEL_PACKAGE_ARN_REGEX = (
2323
r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/(.*?)(?:/(\d+))?$"
2424
)
25-
MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9]+)*$"
26-
MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+(/[0-9]+)*$"
25+
MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9\-_\.]*)+$"
26+
MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+[@/]?[a-zA-Z0-9\-_\.][/a-zA-Z0-9\-_\.]*$"
2727
S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+(?:\/[a-zA-Z0-9\-_\/\.]*)?$"
28+
MLFLOW_TRACKING_ARN = "MLFLOW_TRACKING_ARN"
2829
MLFLOW_MODEL_PATH = "MLFLOW_MODEL_PATH"
2930
MLFLOW_METADATA_FILE = "MLmodel"
3031
MLFLOW_PIP_DEPENDENCY_FILE = "requirements.txt"

src/sagemaker/serve/model_format/mlflow/utils.py

-22
Original file line numberDiff line numberDiff line change
@@ -227,28 +227,6 @@ def _get_python_version_from_parsed_mlflow_model_file(
227227
raise ValueError(f"{MLFLOW_PYFUNC} cannot be found in MLmodel file.")
228228

229229

230-
def _mlflow_input_is_local_path(model_path: str) -> bool:
231-
"""Checks if the given model_path is a local filesystem path.
232-
233-
Args:
234-
- model_path (str): The model path to check.
235-
236-
Returns:
237-
- bool: True if model_path is a local path, False otherwise.
238-
"""
239-
if model_path.startswith("s3://"):
240-
return False
241-
242-
if "/runs/" in model_path or model_path.startswith("runs:"):
243-
return False
244-
245-
# Check if it's not a local file path
246-
if not os.path.exists(model_path):
247-
return False
248-
249-
return True
250-
251-
252230
def _download_s3_artifacts(s3_path: str, dst_path: str, session: Session) -> None:
253231
"""Downloads all artifacts from a specified S3 path to a local destination path.
254232

src/sagemaker/serve/utils/lineage_constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
LINEAGE_POLLER_INTERVAL_SECS = 15
1818
LINEAGE_POLLER_MAX_TIMEOUT_SECS = 120
19+
TRACKING_SERVER_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):mlflow-tracking-server/(.*?)$"
20+
TRACKING_SERVER_CREATION_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
1921
MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE = "ModelBuilderInputModelData"
2022
MLFLOW_S3_PATH = "S3"
2123
MLFLOW_MODEL_PACKAGE_PATH = "ModelPackage"

0 commit comments

Comments
 (0)