12
12
# language governing permissions and limitations under the License.
13
13
"""Holds the ModelBuilder class and the ModelServer enum."""
14
14
from __future__ import absolute_import
15
+
16
+ import importlib .util
15
17
import uuid
16
18
from typing import Any , Type , List , Dict , Optional , Union
17
19
from dataclasses import dataclass , field
18
20
import logging
19
21
import os
22
+ import re
20
23
21
24
from pathlib import Path
22
25
44
47
from sagemaker .predictor import Predictor
45
48
from sagemaker .serve .model_format .mlflow .constants import (
46
49
MLFLOW_MODEL_PATH ,
50
+ MLFLOW_TRACKING_ARN ,
51
+ MLFLOW_RUN_ID_REGEX ,
52
+ MLFLOW_REGISTRY_PATH_REGEX ,
53
+ MODEL_PACKAGE_ARN_REGEX ,
47
54
MLFLOW_METADATA_FILE ,
48
55
MLFLOW_PIP_DEPENDENCY_FILE ,
49
56
)
50
57
from sagemaker .serve .model_format .mlflow .utils import (
51
58
_get_default_model_server_for_mlflow ,
52
- _mlflow_input_is_local_path ,
53
59
_download_s3_artifacts ,
54
60
_select_container_for_mlflow_model ,
55
61
_generate_mlflow_artifact_path ,
@@ -278,8 +284,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
278
284
default = None ,
279
285
metadata = {
280
286
"help" : "Define the model metadata to override, currently supports `HF_TASK`, "
281
- "`MLFLOW_MODEL_PATH`. HF_TASK should be set for new models without task metadata in "
282
- "the Hub, Adding unsupported task types will throw an exception"
287
+ "`MLFLOW_MODEL_PATH`, and `MLFLOW_TRACKING_ARN`. HF_TASK should be set for new "
288
+ "models without task metadata in the Hub, Adding unsupported task types will "
289
+ "throw an exception"
283
290
},
284
291
)
285
292
@@ -504,6 +511,7 @@ def _model_builder_register_wrapper(self, *args, **kwargs):
504
511
mlflow_model_path = self .model_metadata [MLFLOW_MODEL_PATH ],
505
512
s3_upload_path = self .s3_upload_path ,
506
513
sagemaker_session = self .sagemaker_session ,
514
+ tracking_server_arn = self .model_metadata .get (MLFLOW_TRACKING_ARN ),
507
515
)
508
516
return new_model_package
509
517
@@ -574,6 +582,7 @@ def _model_builder_deploy_wrapper(
574
582
mlflow_model_path = self .model_metadata [MLFLOW_MODEL_PATH ],
575
583
s3_upload_path = self .s3_upload_path ,
576
584
sagemaker_session = self .sagemaker_session ,
585
+ tracking_server_arn = self .model_metadata .get (MLFLOW_TRACKING_ARN ),
577
586
)
578
587
return predictor
579
588
@@ -627,11 +636,30 @@ def wrapper(*args, **kwargs):
627
636
628
637
return wrapper
629
638
630
- def _check_if_input_is_mlflow_model (self ) -> bool :
631
- """Checks whether an MLmodel file exists in the given directory.
639
+ def _handle_mlflow_input (self ):
640
+ """Check whether an MLflow model is present and handle accordingly"""
641
+ self ._is_mlflow_model = self ._has_mlflow_arguments ()
642
+ if not self ._is_mlflow_model :
643
+ return
644
+
645
+ mlflow_model_path = self .model_metadata .get (MLFLOW_MODEL_PATH )
646
+ artifact_path = self ._get_artifact_path (mlflow_model_path )
647
+ if not self ._mlflow_metadata_exists (artifact_path ):
648
+ logger .info (
649
+ "MLflow model metadata not detected in %s. ModelBuilder is not "
650
+ "handling MLflow model input" ,
651
+ mlflow_model_path ,
652
+ )
653
+ return
654
+
655
+ self ._initialize_for_mlflow (artifact_path )
656
+ _validate_input_for_mlflow (self .model_server , self .env_vars .get ("MLFLOW_MODEL_FLAVOR" ))
657
+
658
+ def _has_mlflow_arguments (self ) -> bool :
659
+ """Check whether MLflow model arguments are present
632
660
633
661
Returns:
634
- bool: True if the MLmodel file exists , False otherwise.
662
+ bool: True if MLflow arguments are present , False otherwise.
635
663
"""
636
664
if self .inference_spec or self .model :
637
665
logger .info (
@@ -646,16 +674,82 @@ def _check_if_input_is_mlflow_model(self) -> bool:
646
674
)
647
675
return False
648
676
649
- path = self .model_metadata .get (MLFLOW_MODEL_PATH )
650
- if not path :
677
+ mlflow_model_path = self .model_metadata .get (MLFLOW_MODEL_PATH )
678
+ if not mlflow_model_path :
651
679
logger .info (
652
680
"%s is not provided in ModelMetadata. ModelBuilder is not handling MLflow model "
653
681
"input" ,
654
682
MLFLOW_MODEL_PATH ,
655
683
)
656
684
return False
657
685
658
- # Check for S3 path
686
+ return True
687
+
688
+ def _get_artifact_path (self , mlflow_model_path : str ) -> str :
689
+ """Retrieves the model artifact location given the Mlflow model input.
690
+
691
+ Args:
692
+ mlflow_model_path (str): The MLflow model path input.
693
+
694
+ Returns:
695
+ str: The path to the model artifact.
696
+ """
697
+ if (is_run_id_type := re .match (MLFLOW_RUN_ID_REGEX , mlflow_model_path )) or re .match (
698
+ MLFLOW_REGISTRY_PATH_REGEX , mlflow_model_path
699
+ ):
700
+ mlflow_tracking_arn = self .model_metadata .get (MLFLOW_TRACKING_ARN )
701
+ if not mlflow_tracking_arn :
702
+ raise ValueError (
703
+ "%s is not provided in ModelMetadata or through set_tracking_arn "
704
+ "but MLflow model path was provided." % MLFLOW_TRACKING_ARN ,
705
+ )
706
+
707
+ if not importlib .util .find_spec ("sagemaker_mlflow" ):
708
+ raise ImportError (
709
+ "Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed"
710
+ )
711
+
712
+ import mlflow
713
+
714
+ mlflow .set_tracking_uri (mlflow_tracking_arn )
715
+ if is_run_id_type :
716
+ _ , run_id , model_path = mlflow_model_path .split ("/" , 2 )
717
+ artifact_uri = mlflow .get_run (run_id ).info .artifact_uri
718
+ if not artifact_uri .endswith ("/" ):
719
+ artifact_uri += "/"
720
+ return artifact_uri + model_path
721
+
722
+ mlflow_client = mlflow .MlflowClient ()
723
+ if not mlflow_model_path .endswith ("/" ):
724
+ mlflow_model_path += "/"
725
+
726
+ if "@" in mlflow_model_path :
727
+ _ , model_name_and_alias , artifact_uri = mlflow_model_path .split ("/" , 2 )
728
+ model_name , model_alias = model_name_and_alias .split ("@" )
729
+ model_metadata = mlflow_client .get_model_version_by_alias (model_name , model_alias )
730
+ else :
731
+ _ , model_name , model_version , artifact_uri = mlflow_model_path .split ("/" , 3 )
732
+ model_metadata = mlflow_client .get_model_version (model_name , model_version )
733
+
734
+ source = model_metadata .source
735
+ if not source .endswith ("/" ):
736
+ source += "/"
737
+ return source + artifact_uri
738
+
739
+ if re .match (MODEL_PACKAGE_ARN_REGEX , mlflow_model_path ):
740
+ model_package = self .sagemaker_session .sagemaker_client .describe_model_package (
741
+ ModelPackageName = mlflow_model_path
742
+ )
743
+ return model_package ["SourceUri" ]
744
+
745
+ return mlflow_model_path
746
+
747
+ def _mlflow_metadata_exists (self , path : str ) -> bool :
748
+ """Checks whether an MLmodel file exists in the given directory.
749
+
750
+ Returns:
751
+ bool: True if the MLmodel file exists, False otherwise.
752
+ """
659
753
if path .startswith ("s3://" ):
660
754
s3_downloader = S3Downloader ()
661
755
if not path .endswith ("/" ):
@@ -667,17 +761,18 @@ def _check_if_input_is_mlflow_model(self) -> bool:
667
761
file_path = os .path .join (path , MLFLOW_METADATA_FILE )
668
762
return os .path .isfile (file_path )
669
763
670
- def _initialize_for_mlflow (self ) -> None :
671
- """Initialize mlflow model artifacts, image uri and model server."""
672
- mlflow_path = self .model_metadata .get (MLFLOW_MODEL_PATH )
673
- if not _mlflow_input_is_local_path (mlflow_path ):
674
- # TODO: extend to package arn, run id and etc.
675
- logger .info (
676
- "Start downloading model artifacts from %s to %s" , mlflow_path , self .model_path
677
- )
678
- _download_s3_artifacts (mlflow_path , self .model_path , self .sagemaker_session )
764
+ def _initialize_for_mlflow (self , artifact_path : str ) -> None :
765
+ """Initialize mlflow model artifacts, image uri and model server.
766
+
767
+ Args:
768
+ artifact_path (str): The path to the artifact store.
769
+ """
770
+ if artifact_path .startswith ("s3://" ):
771
+ _download_s3_artifacts (artifact_path , self .model_path , self .sagemaker_session )
772
+ elif os .path .exists (artifact_path ):
773
+ _copy_directory_contents (artifact_path , self .model_path )
679
774
else :
680
- _copy_directory_contents ( mlflow_path , self . model_path )
775
+ raise ValueError ( "Invalid path: %s" % artifact_path )
681
776
mlflow_model_metadata_path = _generate_mlflow_artifact_path (
682
777
self .model_path , MLFLOW_METADATA_FILE
683
778
)
@@ -730,6 +825,8 @@ def build( # pylint: disable=R0911
730
825
self .role_arn = role_arn
731
826
self .sagemaker_session = sagemaker_session or Session ()
732
827
828
+ self .sagemaker_session .settings ._local_download_dir = self .model_path
829
+
733
830
# https://github.com/boto/botocore/blob/develop/botocore/useragent.py#L258
734
831
# decorate to_string() due to
735
832
# https://github.com/boto/botocore/blob/develop/botocore/client.py#L1014-L1015
@@ -741,14 +838,8 @@ def build( # pylint: disable=R0911
741
838
self .serve_settings = self ._get_serve_setting ()
742
839
743
840
self ._is_custom_image_uri = self .image_uri is not None
744
- self ._is_mlflow_model = self ._check_if_input_is_mlflow_model ()
745
- if self ._is_mlflow_model :
746
- logger .warning (
747
- "Support of MLflow format models is experimental and is not intended"
748
- " for production at this moment."
749
- )
750
- self ._initialize_for_mlflow ()
751
- _validate_input_for_mlflow (self .model_server , self .env_vars .get ("MLFLOW_MODEL_FLAVOR" ))
841
+
842
+ self ._handle_mlflow_input ()
752
843
753
844
if isinstance (self .model , str ):
754
845
model_task = None
@@ -844,6 +935,19 @@ def validate(self, model_dir: str) -> Type[bool]:
844
935
845
936
return get_metadata (model_dir )
846
937
938
+ def set_tracking_arn (self , arn : str ):
939
+ """Set tracking server ARN"""
940
+ # TODO: support native MLflow URIs
941
+ if importlib .util .find_spec ("sagemaker_mlflow" ):
942
+ import mlflow
943
+
944
+ mlflow .set_tracking_uri (arn )
945
+ self .model_metadata [MLFLOW_TRACKING_ARN ] = arn
946
+ else :
947
+ raise ImportError (
948
+ "Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed"
949
+ )
950
+
847
951
def _hf_schema_builder_init (self , model_task : str ):
848
952
"""Initialize the schema builder for the given HF_TASK
849
953
0 commit comments