Skip to content

Commit c6a574f

Browse files
cj-zhangJoseph Zhang
authored and
Pravali Uppugunduri
committed
feature:support custom workflow deployment in ModelBuilder using SMD image. (#1661)
* feature:support custom workflow deployment in ModelBuilder using SMD inference image. * Rename test case and pass session. * Address PR comments. * Tweak resource cleanup logic in integ test. * Fixing CodeBuild integ test failures. * Renamed integ test. * Remove unused integ test, restore once GA. --------- Co-authored-by: Joseph Zhang <[email protected]>
1 parent 201500c commit c6a574f

14 files changed

+1025
-37
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{
2+
"processors": ["cpu", "gpu"],
3+
"scope": ["inference"],
4+
"version_aliases": {
5+
"3.0": "3.0.0"
6+
},
7+
"versions": {
8+
"3.0.0": {
9+
"registries": {
10+
"us-east-1": "885854791233",
11+
"us-east-2": "137914896644",
12+
"us-west-1": "053634841547",
13+
"us-west-2": "542918446943",
14+
"af-south-1": "238384257742",
15+
"ap-east-1": "523751269255",
16+
"ap-south-1": "245090515133",
17+
"ap-northeast-2": "064688005998",
18+
"ap-southeast-1": "022667117163",
19+
"ap-southeast-2": "648430277019",
20+
"ap-northeast-1": "010972774902",
21+
"ca-central-1": "481561238223",
22+
"eu-central-1": "545423591354",
23+
"eu-west-1": "819792524951",
24+
"eu-west-2": "021081402939",
25+
"eu-west-3": "856416204555",
26+
"eu-north-1": "175620155138",
27+
"eu-south-1": "810671768855",
28+
"sa-east-1": "567556641782",
29+
"ap-northeast-3": "564864627153",
30+
"ap-southeast-3": "370607712162",
31+
"me-south-1": "523774347010",
32+
"me-central-1": "358593528301"
33+
},
34+
"repository": "sagemaker-distribution-prod"
35+
}
36+
}
37+
}

src/sagemaker/serve/builder/model_builder.py

+442-36
Large diffs are not rendered by default.

src/sagemaker/serve/mode/sagemaker_endpoint_mode.py

+14
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,21 @@
1616
from sagemaker.serve.model_server.djl_serving.server import SageMakerDjlServing
1717
from sagemaker.serve.model_server.tgi.server import SageMakerTgiServing
1818
from sagemaker.serve.model_server.multi_model_server.server import SageMakerMultiModelServer
19+
from sagemaker.serve.model_server.smd.server import SageMakerSmdServer
20+
1921

2022
logger = logging.getLogger(__name__)
2123

2224

25+
# pylint: disable=R0901
2326
class SageMakerEndpointMode(
2427
SageMakerTorchServe,
2528
SageMakerTritonServer,
2629
SageMakerDjlServing,
2730
SageMakerTgiServing,
2831
SageMakerMultiModelServer,
2932
SageMakerTensorflowServing,
33+
SageMakerSmdServer,
3034
):
3135
"""Holds the required method to deploy a model to a SageMaker Endpoint"""
3236

@@ -144,6 +148,16 @@ def prepare(
144148
should_upload_artifacts=should_upload_artifacts,
145149
)
146150

151+
if self.model_server == ModelServer.SMD:
152+
upload_artifacts = self._upload_smd_artifacts(
153+
model_path=model_path,
154+
sagemaker_session=sagemaker_session,
155+
secret_key=secret_key,
156+
s3_model_data_url=s3_model_data_url,
157+
image=image,
158+
should_upload_artifacts=True,
159+
)
160+
147161
if upload_artifacts or isinstance(self.model_server, ModelServer):
148162
return upload_artifacts
149163

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module is for SageMaker inference.py."""
14+
15+
from __future__ import absolute_import
16+
import asyncio
17+
import os
18+
import platform
19+
import cloudpickle
20+
import logging
21+
from pathlib import Path
22+
from sagemaker.serve.validations.check_integrity import perform_integrity_check
23+
24+
logger = LOGGER = logging.getLogger("sagemaker")
25+
26+
27+
def initialize_custom_orchestrator():
28+
"""Initializes the custom orchestrator."""
29+
code_dir = os.getenv("SAGEMAKER_INFERENCE_CODE_DIRECTORY", None)
30+
serve_path = Path(code_dir).joinpath("serve.pkl")
31+
with open(str(serve_path), mode="rb") as pkl_file:
32+
return cloudpickle.load(pkl_file)
33+
34+
35+
def _run_preflight_diagnostics():
36+
_py_vs_parity_check()
37+
_pickle_file_integrity_check()
38+
39+
40+
def _py_vs_parity_check():
41+
container_py_vs = platform.python_version()
42+
local_py_vs = os.getenv("LOCAL_PYTHON")
43+
44+
if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]:
45+
logger.warning(
46+
f"The local python version {local_py_vs} differs from the python version "
47+
f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior"
48+
)
49+
50+
51+
def _pickle_file_integrity_check():
52+
with open("/opt/ml/model/code/serve.pkl", "rb") as f:
53+
buffer = f.read()
54+
55+
metadata_path = Path("/opt/ml/model/code/metadata.json")
56+
perform_integrity_check(buffer=buffer, metadata_path=metadata_path)
57+
58+
59+
_run_preflight_diagnostics()
60+
custom_orchestrator, _ = initialize_custom_orchestrator()
61+
62+
63+
async def handler(request):
64+
"""Custom service entry point function.
65+
66+
:param request: raw input from request
67+
:return: outputs to be send back to client
68+
"""
69+
if asyncio.iscoroutinefunction(custom_orchestrator.handle):
70+
return await custom_orchestrator.handle(request)
71+
else:
72+
return custom_orchestrator.handle(request)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Summary of MyModule.
2+
3+
Extended discussion of my module.
4+
"""
5+
6+
from __future__ import absolute_import
7+
import os
8+
from pathlib import Path
9+
import shutil
10+
from typing import List
11+
12+
from sagemaker.serve.spec.inference_spec import InferenceSpec
13+
from sagemaker.serve.detector.dependency_manager import capture_dependencies
14+
from sagemaker.serve.validations.check_integrity import (
15+
generate_secret_key,
16+
compute_hash,
17+
)
18+
from sagemaker.remote_function.core.serialization import _MetaData
19+
from sagemaker.serve.spec.inference_base import CustomOrchestrator, AsyncCustomOrchestrator
20+
21+
22+
def prepare_for_smd(
23+
model_path: str,
24+
shared_libs: List[str],
25+
dependencies: dict,
26+
inference_spec: InferenceSpec = None,
27+
) -> str:
28+
"""Prepares artifacts for SageMaker model deployment.
29+
30+
Args:to
31+
model_path (str) : Argument
32+
shared_libs (List[]) : Argument
33+
dependencies (dict) : Argument
34+
inference_spec (InferenceSpec, optional) : Argument
35+
(default is None)
36+
37+
Returns:
38+
( str ) :
39+
40+
"""
41+
model_path = Path(model_path)
42+
if not model_path.exists():
43+
model_path.mkdir()
44+
elif not model_path.is_dir():
45+
raise Exception("model_dir is not a valid directory")
46+
47+
if inference_spec and isinstance(inference_spec, InferenceSpec):
48+
inference_spec.prepare(str(model_path))
49+
50+
code_dir = model_path.joinpath("code")
51+
code_dir.mkdir(exist_ok=True)
52+
53+
if inference_spec and isinstance(inference_spec, (CustomOrchestrator, AsyncCustomOrchestrator)):
54+
shutil.copy2(Path(__file__).parent.joinpath("custom_execution_inference.py"), code_dir)
55+
os.rename(
56+
str(code_dir.joinpath("custom_execution_inference.py")),
57+
str(code_dir.joinpath("inference.py")),
58+
)
59+
60+
shared_libs_dir = model_path.joinpath("shared_libs")
61+
shared_libs_dir.mkdir(exist_ok=True)
62+
for shared_lib in shared_libs:
63+
shutil.copy2(Path(shared_lib), shared_libs_dir)
64+
65+
capture_dependencies(dependencies=dependencies, work_dir=code_dir)
66+
67+
secret_key = generate_secret_key()
68+
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
69+
buffer = f.read()
70+
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
71+
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
72+
metadata.write(_MetaData(hash_value).to_json())
73+
74+
return secret_key
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Module for SMD Server"""
2+
3+
from __future__ import absolute_import
4+
5+
import logging
6+
import platform
7+
from sagemaker.serve.utils.optimize_utils import _is_s3_uri
8+
from sagemaker.session import Session
9+
from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url
10+
from sagemaker import fw_utils
11+
from sagemaker.serve.utils.uploader import upload
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
class SageMakerSmdServer:
17+
"""Placeholder docstring"""
18+
19+
def _upload_smd_artifacts(
20+
self,
21+
model_path: str,
22+
sagemaker_session: Session,
23+
secret_key: str,
24+
s3_model_data_url: str = None,
25+
image: str = None,
26+
should_upload_artifacts: bool = False,
27+
):
28+
"""Tar the model artifact and upload to S3 bucket, then prepare for the environment variables"""
29+
s3_upload_path = None
30+
if _is_s3_uri(model_path):
31+
s3_upload_path = model_path
32+
elif should_upload_artifacts:
33+
if s3_model_data_url:
34+
bucket, key_prefix = parse_s3_url(url=s3_model_data_url)
35+
else:
36+
bucket, key_prefix = None, None
37+
38+
code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image)
39+
40+
bucket, code_key_prefix = determine_bucket_and_prefix(
41+
bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session
42+
)
43+
44+
logger.debug(
45+
"Uploading the model resources to bucket=%s, key_prefix=%s.",
46+
bucket,
47+
code_key_prefix,
48+
)
49+
s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix)
50+
logger.debug("Model resources uploaded to: %s", s3_upload_path)
51+
52+
env_vars = {
53+
"SAGEMAKER_INFERENCE_CODE_DIRECTORY": "/opt/ml/model/code",
54+
"SAGEMAKER_INFERENCE_CODE": "inference.handler",
55+
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
56+
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
57+
"LOCAL_PYTHON": platform.python_version(),
58+
}
59+
return s3_upload_path, env_vars
+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Holds templated classes to enable users to provide custom inference scripting capabilities"""
14+
from __future__ import absolute_import
15+
from abc import ABC, abstractmethod
16+
17+
18+
class CustomOrchestrator(ABC):
19+
"""Templated class to standardize sync entrypoint-based inference scripts"""
20+
21+
@property
22+
def client(self):
23+
"""Boto3 SageMaker runtime client to use with custom orchestrator"""
24+
from boto3 import Session
25+
26+
return Session().client("sagemaker-runtime")
27+
28+
@abstractmethod
29+
def handle(self, data, context=None):
30+
"""Abstract class for defining an entrypoint for the model server"""
31+
return NotImplemented
32+
33+
34+
class AsyncCustomOrchestrator(ABC):
35+
"""Templated class to standardize async entrypoint-based inference scripts"""
36+
37+
@abstractmethod
38+
async def handle(self, data, context=None):
39+
"""Abstract class for defining an aynchronous entrypoint for the model server"""
40+
return NotImplemented

src/sagemaker/serve/utils/telemetry_logger.py

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
str(ModelServer.TRITON): 5,
6565
str(ModelServer.TGI): 6,
6666
str(ModelServer.TEI): 7,
67+
str(ModelServer.SMD): 8,
6768
}
6869

6970
MLFLOW_MODEL_PATH_CODE = {

src/sagemaker/serve/utils/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __str__(self):
1919
TRITON = 5
2020
TGI = 6
2121
TEI = 7
22+
SMD = 8
2223

2324

2425
class HardwareType(Enum):

tests/integ/sagemaker/serve/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
PYTHON_VERSION_IS_NOT_38 = platform.python_version_tuple()[1] != "8"
2727
PYTHON_VERSION_IS_NOT_310 = platform.python_version_tuple()[1] != "10"
28+
PYTHON_VERSION_IS_NOT_312 = platform.python_version_tuple()[1] != "12"
2829

2930
XGB_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "xgboost")
3031
PYTORCH_SQUEEZENET_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "pytorch")

0 commit comments

Comments
 (0)