Skip to content

Commit a5c6229

Browse files
authored
Add tensorflow_serving support for mlflow models and enable lineage tracking for mlflow models (#4662)
* Initial commit for tensorflow_serving support of MLflow * Add integ tests for mlflow tf_serving * fix style issues * remove unused attributes from tf builder * Add deep ping for tf_serving local mode * Initial commit for lineage impl * Initial commit for tensorflow_serving support of MLflow * Add integ tests for mlflow tf_serving * fix style issues * remove unused attributes from tf builder * Add deep ping for tf_serving local mode * Add integ tests and uts * fix local mode for tf_serving * Allow lineage tracking only in sagemaker endpoint mode * fix regex pattern * fix style issues * fix regex pattern and hard coded py version in ut * fix missing session * Resolve pr comments and fix regex for mlflow registry and ids
1 parent d549e7d commit a5c6229

33 files changed

+2275
-33
lines changed

requirements/extras/test_requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@ onnx>=1.15.0
3636
nbformat>=5.9,<6
3737
accelerate>=0.24.1,<=0.27.0
3838
schema==0.7.5
39+
tensorflow>=2.1,<=2.16

src/sagemaker/serve/builder/model_builder.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sagemaker.serializers import NumpySerializer, TorchTensorSerializer
3030
from sagemaker.deserializers import JSONDeserializer, TorchTensorDeserializer
3131
from sagemaker.serve.builder.schema_builder import SchemaBuilder
32+
from sagemaker.serve.builder.tf_serving_builder import TensorflowServing
3233
from sagemaker.serve.mode.function_pointers import Mode
3334
from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode
3435
from sagemaker.serve.mode.local_container_mode import LocalContainerMode
@@ -59,6 +60,7 @@
5960
from sagemaker.serve.spec.inference_spec import InferenceSpec
6061
from sagemaker.serve.utils import task
6162
from sagemaker.serve.utils.exceptions import TaskNotFoundException
63+
from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model
6264
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
6365
from sagemaker.serve.utils.hardware_detector import (
6466
_get_gpu_info,
@@ -89,12 +91,13 @@
8991
ModelServer.TORCHSERVE,
9092
ModelServer.TRITON,
9193
ModelServer.DJL_SERVING,
94+
ModelServer.TENSORFLOW_SERVING,
9295
}
9396

9497

95-
# pylint: disable=attribute-defined-outside-init, disable=E1101
98+
# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901
9699
@dataclass
97-
class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
100+
class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing):
98101
"""Class that builds a deployable model.
99102
100103
Args:
@@ -493,6 +496,12 @@ def _model_builder_register_wrapper(self, *args, **kwargs):
493496
self.pysdk_model.model_package_arn = new_model_package.model_package_arn
494497
new_model_package.deploy = self._model_builder_deploy_model_package_wrapper
495498
self.model_package = new_model_package
499+
if getattr(self, "_is_mlflow_model", False) and self.mode == Mode.SAGEMAKER_ENDPOINT:
500+
_maintain_lineage_tracking_for_mlflow_model(
501+
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
502+
s3_upload_path=self.s3_upload_path,
503+
sagemaker_session=self.sagemaker_session,
504+
)
496505
return new_model_package
497506

498507
def _model_builder_deploy_model_package_wrapper(self, *args, **kwargs):
@@ -551,12 +560,19 @@ def _model_builder_deploy_wrapper(
551560

552561
if "endpoint_logging" not in kwargs:
553562
kwargs["endpoint_logging"] = True
554-
return self._original_deploy(
563+
predictor = self._original_deploy(
555564
*args,
556565
instance_type=instance_type,
557566
initial_instance_count=initial_instance_count,
558567
**kwargs,
559568
)
569+
if getattr(self, "_is_mlflow_model", False) and self.mode == Mode.SAGEMAKER_ENDPOINT:
570+
_maintain_lineage_tracking_for_mlflow_model(
571+
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
572+
s3_upload_path=self.s3_upload_path,
573+
sagemaker_session=self.sagemaker_session,
574+
)
575+
return predictor
560576

561577
def _overwrite_mode_in_deploy(self, overwrite_mode: str):
562578
"""Mode overwritten by customer during model.deploy()"""
@@ -728,7 +744,7 @@ def build( # pylint: disable=R0911
728744
" for production at this moment."
729745
)
730746
self._initialize_for_mlflow()
731-
_validate_input_for_mlflow(self.model_server)
747+
_validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR"))
732748

733749
if isinstance(self.model, str):
734750
model_task = None
@@ -767,6 +783,9 @@ def build( # pylint: disable=R0911
767783
if self.model_server == ModelServer.TRITON:
768784
return self._build_for_triton()
769785

786+
if self.model_server == ModelServer.TENSORFLOW_SERVING:
787+
return self._build_for_tensorflow_serving()
788+
770789
raise ValueError("%s model server is not supported" % self.model_server)
771790

772791
def save(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Holds mixin logic to support deployment of Model ID"""
14+
from __future__ import absolute_import
15+
import logging
16+
import os
17+
from pathlib import Path
18+
from abc import ABC, abstractmethod
19+
20+
from sagemaker import Session
21+
from sagemaker.serve.detector.pickler import save_pkl
22+
from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving
23+
from sagemaker.tensorflow import TensorFlowModel, TensorFlowPredictor
24+
25+
logger = logging.getLogger(__name__)
26+
27+
_TF_SERVING_MODEL_BUILDER_ENTRY_POINT = "inference.py"
28+
_CODE_FOLDER = "code"
29+
30+
31+
# pylint: disable=attribute-defined-outside-init, disable=E1101
32+
class TensorflowServing(ABC):
33+
"""TensorflowServing build logic for ModelBuilder()"""
34+
35+
def __init__(self):
36+
self.model = None
37+
self.serve_settings = None
38+
self.sagemaker_session = None
39+
self.model_path = None
40+
self.dependencies = None
41+
self.modes = None
42+
self.mode = None
43+
self.model_server = None
44+
self.image_uri = None
45+
self._is_custom_image_uri = False
46+
self.image_config = None
47+
self.vpc_config = None
48+
self._original_deploy = None
49+
self.secret_key = None
50+
self.engine = None
51+
self.pysdk_model = None
52+
self.schema_builder = None
53+
self.env_vars = None
54+
55+
@abstractmethod
56+
def _prepare_for_mode(self):
57+
"""Prepare model artifacts based on mode."""
58+
59+
@abstractmethod
60+
def _get_client_translators(self):
61+
"""Set up client marshaller based on schema builder."""
62+
63+
def _save_schema_builder(self):
64+
"""Save schema builder for tensorflow serving."""
65+
if not os.path.exists(self.model_path):
66+
os.makedirs(self.model_path)
67+
68+
code_path = Path(self.model_path).joinpath("code")
69+
save_pkl(code_path, self.schema_builder)
70+
71+
def _get_tensorflow_predictor(
72+
self, endpoint_name: str, sagemaker_session: Session
73+
) -> TensorFlowPredictor:
74+
"""Creates a TensorFlowPredictor object"""
75+
serializer, deserializer = self._get_client_translators()
76+
77+
return TensorFlowPredictor(
78+
endpoint_name=endpoint_name,
79+
sagemaker_session=sagemaker_session,
80+
serializer=serializer,
81+
deserializer=deserializer,
82+
)
83+
84+
def _validate_for_tensorflow_serving(self):
85+
"""Validate for tensorflow serving"""
86+
if not getattr(self, "_is_mlflow_model", False):
87+
raise ValueError("Tensorflow Serving is currently only supported for mlflow models.")
88+
89+
def _create_tensorflow_model(self):
90+
"""Creates a TensorFlow model object"""
91+
self.pysdk_model = TensorFlowModel(
92+
image_uri=self.image_uri,
93+
image_config=self.image_config,
94+
vpc_config=self.vpc_config,
95+
model_data=self.s3_upload_path,
96+
role=self.serve_settings.role_arn,
97+
env=self.env_vars,
98+
sagemaker_session=self.sagemaker_session,
99+
predictor_cls=self._get_tensorflow_predictor,
100+
)
101+
102+
self.pysdk_model.mode = self.mode
103+
self.pysdk_model.modes = self.modes
104+
self.pysdk_model.serve_settings = self.serve_settings
105+
106+
self._original_deploy = self.pysdk_model.deploy
107+
self.pysdk_model.deploy = self._model_builder_deploy_wrapper
108+
self._original_register = self.pysdk_model.register
109+
self.pysdk_model.register = self._model_builder_register_wrapper
110+
self.model_package = None
111+
return self.pysdk_model
112+
113+
def _build_for_tensorflow_serving(self):
114+
"""Build the model for Tensorflow Serving"""
115+
self._validate_for_tensorflow_serving()
116+
self._save_schema_builder()
117+
118+
if not self.image_uri:
119+
raise ValueError("image_uri is not set for tensorflow serving")
120+
121+
self.secret_key = prepare_for_tf_serving(
122+
model_path=self.model_path,
123+
shared_libs=self.shared_libs,
124+
dependencies=self.dependencies,
125+
)
126+
127+
self._prepare_for_mode()
128+
129+
return self._create_tensorflow_model()

src/sagemaker/serve/mode/local_container_mode.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import docker
1212

1313
from sagemaker.base_predictor import PredictorBase
14+
from sagemaker.serve.model_server.tensorflow_serving.server import LocalTensorflowServing
1415
from sagemaker.serve.spec.inference_spec import InferenceSpec
1516
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1617
from sagemaker.serve.utils.logging_agent import pull_logs
@@ -34,7 +35,12 @@
3435

3536

3637
class LocalContainerMode(
37-
LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing, LocalMultiModelServer
38+
LocalTorchServe,
39+
LocalDJLServing,
40+
LocalTritonServer,
41+
LocalTgiServing,
42+
LocalMultiModelServer,
43+
LocalTensorflowServing,
3844
):
3945
"""A class that holds methods to deploy model to a container in local environment"""
4046

@@ -141,6 +147,15 @@ def create_server(
141147
env_vars=env_vars if env_vars else self.env_vars,
142148
)
143149
self._ping_container = self._multi_model_server_deep_ping
150+
elif self.model_server == ModelServer.TENSORFLOW_SERVING:
151+
self._start_tensorflow_serving(
152+
client=self.client,
153+
image=image,
154+
model_path=model_path if model_path else self.model_path,
155+
secret_key=secret_key,
156+
env_vars=env_vars if env_vars else self.env_vars,
157+
)
158+
self._ping_container = self._tensorflow_serving_deep_ping
144159

145160
# allow some time for container to be ready
146161
time.sleep(10)

src/sagemaker/serve/mode/sagemaker_endpoint_mode.py

+11
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
from typing import Type
88

9+
from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing
910
from sagemaker.session import Session
1011
from sagemaker.serve.utils.types import ModelServer
1112
from sagemaker.serve.spec.inference_spec import InferenceSpec
@@ -24,6 +25,7 @@ class SageMakerEndpointMode(
2425
SageMakerDjlServing,
2526
SageMakerTgiServing,
2627
SageMakerMultiModelServer,
28+
SageMakerTensorflowServing,
2729
):
2830
"""Holds the required method to deploy a model to a SageMaker Endpoint"""
2931

@@ -107,4 +109,13 @@ def prepare(
107109
image=image,
108110
)
109111

112+
if self.model_server == ModelServer.TENSORFLOW_SERVING:
113+
return self._upload_tensorflow_serving_artifacts(
114+
model_path=model_path,
115+
sagemaker_session=sagemaker_session,
116+
secret_key=secret_key,
117+
s3_model_data_url=s3_model_data_url,
118+
image=image,
119+
)
120+
110121
raise ValueError("%s model server is not supported" % self.model_server)

src/sagemaker/serve/model_format/mlflow/constants.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919
"py39": "1.13.1",
2020
"py310": "2.2.0",
2121
}
22+
MODEL_PACAKGE_ARN_REGEX = (
23+
r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/[" r"a-zA-Z0-9\-_\/\.]+$"
24+
)
25+
MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9]+)*$"
26+
MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+(/[0-9]+)*$"
27+
S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+\/[a-zA-Z0-9\-_\/\.]*$"
2228
MLFLOW_MODEL_PATH = "MLFLOW_MODEL_PATH"
2329
MLFLOW_METADATA_FILE = "MLmodel"
2430
MLFLOW_PIP_DEPENDENCY_FILE = "requirements.txt"
@@ -34,8 +40,12 @@
3440
"spark": "pyspark",
3541
"onnx": "onnxruntime",
3642
}
37-
FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT = [ # will extend to keras and tf
38-
"sklearn",
39-
"pytorch",
40-
"xgboost",
41-
]
43+
TENSORFLOW_SAVED_MODEL_NAME = "saved_model.pb"
44+
FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT = {
45+
"sklearn": "sklearn",
46+
"pytorch": "pytorch",
47+
"xgboost": "xgboost",
48+
"tensorflow": "tensorflow",
49+
"keras": "tensorflow",
50+
}
51+
FLAVORS_DEFAULT_WITH_TF_SERVING = ["keras", "tensorflow"]

0 commit comments

Comments
 (0)