|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""Holds mixin logic to support deployment of Model ID""" |
| 14 | +from __future__ import absolute_import |
| 15 | +import logging |
| 16 | +import os |
| 17 | +from pathlib import Path |
| 18 | +from abc import ABC, abstractmethod |
| 19 | + |
| 20 | +from sagemaker import Session |
| 21 | +from sagemaker.serve.detector.pickler import save_pkl |
| 22 | +from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving |
| 23 | +from sagemaker.tensorflow import TensorFlowModel, TensorFlowPredictor |
| 24 | + |
| 25 | +logger = logging.getLogger(__name__) |
| 26 | + |
| 27 | +_TF_SERVING_MODEL_BUILDER_ENTRY_POINT = "inference.py" |
| 28 | +_CODE_FOLDER = "code" |
| 29 | + |
| 30 | + |
| 31 | +# pylint: disable=attribute-defined-outside-init, disable=E1101 |
| 32 | +class TensorflowServing(ABC): |
| 33 | + """TensorflowServing build logic for ModelBuilder()""" |
| 34 | + |
| 35 | + def __init__(self): |
| 36 | + self.model = None |
| 37 | + self.serve_settings = None |
| 38 | + self.sagemaker_session = None |
| 39 | + self.model_path = None |
| 40 | + self.dependencies = None |
| 41 | + self.modes = None |
| 42 | + self.mode = None |
| 43 | + self.model_server = None |
| 44 | + self.image_uri = None |
| 45 | + self._is_custom_image_uri = False |
| 46 | + self.image_config = None |
| 47 | + self.vpc_config = None |
| 48 | + self._original_deploy = None |
| 49 | + self.secret_key = None |
| 50 | + self.engine = None |
| 51 | + self.pysdk_model = None |
| 52 | + self.schema_builder = None |
| 53 | + self.env_vars = None |
| 54 | + |
| 55 | + @abstractmethod |
| 56 | + def _prepare_for_mode(self): |
| 57 | + """Prepare model artifacts based on mode.""" |
| 58 | + |
| 59 | + @abstractmethod |
| 60 | + def _get_client_translators(self): |
| 61 | + """Set up client marshaller based on schema builder.""" |
| 62 | + |
| 63 | + def _save_schema_builder(self): |
| 64 | + """Save schema builder for tensorflow serving.""" |
| 65 | + if not os.path.exists(self.model_path): |
| 66 | + os.makedirs(self.model_path) |
| 67 | + |
| 68 | + code_path = Path(self.model_path).joinpath("code") |
| 69 | + save_pkl(code_path, self.schema_builder) |
| 70 | + |
| 71 | + def _get_tensorflow_predictor( |
| 72 | + self, endpoint_name: str, sagemaker_session: Session |
| 73 | + ) -> TensorFlowPredictor: |
| 74 | + """Creates a TensorFlowPredictor object""" |
| 75 | + serializer, deserializer = self._get_client_translators() |
| 76 | + |
| 77 | + return TensorFlowPredictor( |
| 78 | + endpoint_name=endpoint_name, |
| 79 | + sagemaker_session=sagemaker_session, |
| 80 | + serializer=serializer, |
| 81 | + deserializer=deserializer, |
| 82 | + ) |
| 83 | + |
| 84 | + def _validate_for_tensorflow_serving(self): |
| 85 | + """Validate for tensorflow serving""" |
| 86 | + if not getattr(self, "_is_mlflow_model", False): |
| 87 | + raise ValueError("Tensorflow Serving is currently only supported for mlflow models.") |
| 88 | + |
| 89 | + def _create_tensorflow_model(self): |
| 90 | + """Creates a TensorFlow model object""" |
| 91 | + self.pysdk_model = TensorFlowModel( |
| 92 | + image_uri=self.image_uri, |
| 93 | + image_config=self.image_config, |
| 94 | + vpc_config=self.vpc_config, |
| 95 | + model_data=self.s3_upload_path, |
| 96 | + role=self.serve_settings.role_arn, |
| 97 | + env=self.env_vars, |
| 98 | + sagemaker_session=self.sagemaker_session, |
| 99 | + predictor_cls=self._get_tensorflow_predictor, |
| 100 | + ) |
| 101 | + |
| 102 | + self.pysdk_model.mode = self.mode |
| 103 | + self.pysdk_model.modes = self.modes |
| 104 | + self.pysdk_model.serve_settings = self.serve_settings |
| 105 | + |
| 106 | + self._original_deploy = self.pysdk_model.deploy |
| 107 | + self.pysdk_model.deploy = self._model_builder_deploy_wrapper |
| 108 | + self._original_register = self.pysdk_model.register |
| 109 | + self.pysdk_model.register = self._model_builder_register_wrapper |
| 110 | + self.model_package = None |
| 111 | + return self.pysdk_model |
| 112 | + |
| 113 | + def _build_for_tensorflow_serving(self): |
| 114 | + """Build the model for Tensorflow Serving""" |
| 115 | + self._validate_for_tensorflow_serving() |
| 116 | + self._save_schema_builder() |
| 117 | + |
| 118 | + if not self.image_uri: |
| 119 | + raise ValueError("image_uri is not set for tensorflow serving") |
| 120 | + |
| 121 | + self.secret_key = prepare_for_tf_serving( |
| 122 | + model_path=self.model_path, |
| 123 | + shared_libs=self.shared_libs, |
| 124 | + dependencies=self.dependencies, |
| 125 | + ) |
| 126 | + |
| 127 | + self._prepare_for_mode() |
| 128 | + |
| 129 | + return self._create_tensorflow_model() |
0 commit comments