-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathsagemaker_endpoint_mode.py
164 lines (140 loc) · 6.14 KB
/
sagemaker_endpoint_mode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""Placeholder docstring"""
from __future__ import absolute_import
from pathlib import Path
import logging
from typing import Type
from sagemaker.serve.model_server.tei.server import SageMakerTeiServing
from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing
from sagemaker.session import Session
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.model_server.triton.server import SageMakerTritonServer
from sagemaker.serve.model_server.torchserve.server import SageMakerTorchServe
from sagemaker.serve.model_server.djl_serving.server import SageMakerDjlServing
from sagemaker.serve.model_server.tgi.server import SageMakerTgiServing
from sagemaker.serve.model_server.multi_model_server.server import SageMakerMultiModelServer
from sagemaker.serve.model_server.smd.server import SageMakerSmdServer
logger = logging.getLogger(__name__)
# pylint: disable=R0901
class SageMakerEndpointMode(
SageMakerTorchServe,
SageMakerTritonServer,
SageMakerDjlServing,
SageMakerTgiServing,
SageMakerMultiModelServer,
SageMakerTensorflowServing,
SageMakerSmdServer,
):
"""Holds the required method to deploy a model to a SageMaker Endpoint"""
def __init__(self, inference_spec: Type[InferenceSpec], model_server: ModelServer):
super().__init__()
# pylint: disable=bad-super-call
super(SageMakerTritonServer, self).__init__()
self.inference_spec = inference_spec
self.model_server = model_server
self._tei_serving = SageMakerTeiServing()
def load(self, model_path: str):
"""Placeholder docstring"""
path = Path(model_path)
if not path.exists():
raise Exception("model_path does not exist")
if not path.is_dir():
raise Exception("model_path is not a valid directory")
model_dir = path.joinpath("model")
return self.inference_spec.model_fn(str(model_dir))
def prepare(
self,
model_path: str,
secret_key: str,
s3_model_data_url: str = None,
sagemaker_session: Session = None,
image: str = None,
jumpstart: bool = False,
should_upload_artifacts: bool = False,
):
"""Placeholder docstring"""
try:
sagemaker_session = sagemaker_session or Session()
except Exception as e:
raise Exception(
"Failed to setup default SageMaker session. Please allow a default "
+ "session to be created or supply `sagemaker_session` into @serve.invoke."
) from e
upload_artifacts = None, None
if self.model_server == ModelServer.TORCHSERVE:
upload_artifacts = self._upload_torchserve_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
s3_model_data_url=s3_model_data_url,
image=image,
should_upload_artifacts=True,
)
if self.model_server == ModelServer.TRITON:
upload_artifacts = self._upload_triton_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
s3_model_data_url=s3_model_data_url,
image=image,
should_upload_artifacts=True,
)
if self.model_server == ModelServer.DJL_SERVING:
upload_artifacts = self._upload_djl_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
image=image,
should_upload_artifacts=True,
)
if self.model_server == ModelServer.TENSORFLOW_SERVING:
upload_artifacts = self._upload_tensorflow_serving_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
s3_model_data_url=s3_model_data_url,
image=image,
should_upload_artifacts=True,
)
# By default, we do not want to upload artifacts in S3 for the below server.
# In Case of Optimization, artifacts need to be uploaded into s3.
# In that case, `should_upload_artifacts` arg needs to come from
# the caller of prepare.
if self.model_server == ModelServer.TGI:
upload_artifacts = self._upload_tgi_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
image=image,
jumpstart=jumpstart,
should_upload_artifacts=should_upload_artifacts,
)
if self.model_server == ModelServer.MMS:
upload_artifacts = self._upload_server_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
secret_key=secret_key,
image=image,
should_upload_artifacts=should_upload_artifacts,
)
if self.model_server == ModelServer.TEI:
upload_artifacts = self._tei_serving._upload_tei_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
image=image,
should_upload_artifacts=should_upload_artifacts,
)
if self.model_server == ModelServer.SMD:
upload_artifacts = self._upload_smd_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
s3_model_data_url=s3_model_data_url,
image=image,
should_upload_artifacts=True,
)
if upload_artifacts or isinstance(self.model_server, ModelServer):
return upload_artifacts
raise ValueError("%s model server is not supported" % self.model_server)