Skip to content

Commit e0d90fb

Browse files
SSRraymondknikure
authored andcommitted
fix: HMAC signing for ModelBuilder Triton python backend (#1282)
1 parent 20cd3b6 commit e0d90fb

18 files changed

+503
-96
lines changed

doc/overview.rst

+193-24
Large diffs are not rendered by default.

src/sagemaker/model.py

-1
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,6 @@ def _create_sagemaker_model(
801801
Specifies configuration related to serverless endpoint. Instance type is
802802
not provided in serverless inference. So this is used to find image URIs.
803803
"""
804-
805804
if self.model_package_arn is not None or self.algorithm_arn is not None:
806805
model_package = ModelPackage(
807806
role=self.role,

src/sagemaker/serve/builder/model_builder.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ class ModelBuilder(Triton, DJL, JumpStart):
5656
5757
* 1/Launch on SageMaker endpoint
5858
* 2/Launch locally with a container
59-
* 3/Launch in process
6059
6160
shared_libs (List[str]): Any shared libraries you want to bring into
6261
the model.
@@ -336,6 +335,7 @@ def _create_model(self):
336335
self.pysdk_model.deploy = self._model_builder_deploy_wrapper
337336
self._original_register = self.pysdk_model.register
338337
self.pysdk_model.register = self._model_builder_register_wrapper
338+
self.model_package = None
339339
return self.pysdk_model
340340

341341
@_capture_telemetry("register")
@@ -347,9 +347,23 @@ def _model_builder_register_wrapper(self, *args, **kwargs):
347347
if "response_types" not in kwargs:
348348
self.pysdk_model.response_types = deserializer.ACCEPT.split()
349349
new_model_package = self._original_register(*args, **kwargs)
350-
new_model_package.deploy = self._model_builder_deploy_wrapper
350+
self.pysdk_model.model_package_arn = new_model_package.model_package_arn
351+
new_model_package.deploy = self._model_builder_deploy_model_package_wrapper
352+
self.model_package = new_model_package
351353
return new_model_package
352354

355+
def _model_builder_deploy_model_package_wrapper(self, *args, **kwargs):
356+
"""Placeholder docstring"""
357+
if self.pysdk_model.model_package_arn is not None:
358+
return self._model_builder_register_wrapper(*args, **kwargs)
359+
360+
# need to set the model_package_arn
361+
# so that the model is created using the model_package's configs
362+
self.pysdk_model.model_package_arn = self.model_package.model_package_arn
363+
predictor = self._model_builder_register_wrapper(*args, **kwargs)
364+
self.pysdk_model.model_package_arn = None
365+
return predictor
366+
353367
@_capture_telemetry("torchserve.deploy")
354368
def _model_builder_deploy_wrapper(
355369
self, *args, container_timeout_in_second: int = 300, **kwargs

src/sagemaker/serve/detector/dependency_manager.py

+58-13
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919

2020
from pathlib import Path
2121
import logging
22-
import shutil
2322
import subprocess
2423
import sys
25-
24+
import re
2625

2726
_SUPPORTED_SUFFIXES = [".txt"]
2827
# TODO : Move PKL_FILE_NAME to common location
@@ -31,7 +30,7 @@
3130
logger = logging.getLogger(__name__)
3231

3332

34-
def capture_dependencies(dependencies: str, work_dir: Path, capture_all: bool = False):
33+
def capture_dependencies(dependencies: dict, work_dir: Path, capture_all: bool = False):
3534
"""Placeholder docstring"""
3635
path = work_dir.joinpath("requirements.txt")
3736
if "auto" in dependencies and dependencies["auto"]:
@@ -53,24 +52,44 @@ def capture_dependencies(dependencies: str, work_dir: Path, capture_all: bool =
5352
check=True,
5453
)
5554

56-
if "requirements" in dependencies:
57-
_capture_from_customer_provided_requirements(dependencies["requirements"], path)
55+
with open(path, "r") as f:
56+
autodetect_depedencies = f.read().splitlines()
57+
else:
58+
autodetect_depedencies = []
5859

60+
module_version_dict = _parse_dependency_list(autodetect_depedencies)
61+
62+
if "requirements" in dependencies:
63+
module_version_dict = _process_customer_provided_requirements(
64+
requirements_file=dependencies["requirements"], module_version_dict=module_version_dict
65+
)
5966
if "custom" in dependencies:
67+
module_version_dict = _process_custom_dependencies(
68+
custom_dependencies=dependencies.get("custom"), module_version_dict=module_version_dict
69+
)
70+
with open(path, "w") as f:
71+
for module, version in module_version_dict.items():
72+
f.write(f"{module}{version}\n")
6073

61-
with open(path, "a+") as f:
62-
for package in dependencies["custom"]:
63-
f.write(f"{package}\n")
6474

75+
def _process_custom_dependencies(custom_dependencies: list, module_version_dict: dict):
76+
"""Placeholder docstring"""
77+
custom_module_version_dict = _parse_dependency_list(custom_dependencies)
78+
module_version_dict.update(custom_module_version_dict)
79+
return module_version_dict
6580

66-
def _capture_from_customer_provided_requirements(requirements_file: str, output_path: Path):
81+
82+
def _process_customer_provided_requirements(requirements_file: str, module_version_dict: dict):
6783
"""Placeholder docstring"""
68-
input_path = Path(requirements_file)
69-
if not input_path.is_file() or not _is_valid_requirement_file(input_path):
84+
requirements_file = Path(requirements_file)
85+
if not requirements_file.is_file() or not _is_valid_requirement_file(requirements_file):
7086
raise Exception(f"Path: {requirements_file} to requirements.txt doesn't exist")
7187
logger.debug("Packaging provided requirements.txt from %s", requirements_file)
72-
with open(output_path, "a+") as f:
73-
shutil.copyfileobj(open(input_path, "r"), f)
88+
with open(requirements_file, "r") as f:
89+
custom_dependencies = f.read().splitlines()
90+
91+
module_version_dict.update(_parse_dependency_list(custom_dependencies))
92+
return module_version_dict
7493

7594

7695
def _is_valid_requirement_file(path):
@@ -82,6 +101,32 @@ def _is_valid_requirement_file(path):
82101
return False
83102

84103

104+
def _parse_dependency_list(depedency_list: list) -> dict:
105+
"""Placeholder docstring"""
106+
107+
# Divide a string into 2 part, first part is the module name
108+
# and second part is its version constraint or the url
109+
# checkout tests/unit/sagemaker/serve/detector/test_dependency_manager.py
110+
# for examples
111+
pattern = r"^([\w.-]+)(@[^,\n]+|((?:[<>=!~]=?[\w.*-]+,?)+)?)$"
112+
113+
module_version_dict = {}
114+
115+
for dependency in depedency_list:
116+
if dependency.startswith("#"):
117+
continue
118+
match = re.match(pattern, dependency)
119+
if match:
120+
package = match.group(1)
121+
# Group 2 is either a URL or version constraint, if present
122+
url_or_version = match.group(2) if match.group(2) else ""
123+
module_version_dict.update({package: url_or_version})
124+
else:
125+
module_version_dict.update({dependency: ""})
126+
127+
return module_version_dict
128+
129+
85130
# only required for dev testing
86131
def prepare_wheel(code_artifact_client, whl_dir: str):
87132
"""Placeholder docstring"""

src/sagemaker/serve/detector/pickle_dependencies.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,14 @@ def get_requirements_for_pkl_file(pkl_path: Path, dest: Path):
108108
with open(dest, mode="w+") as out:
109109
for x in get_all_installed_packages():
110110
name = x["name"]
111+
version = x["version"]
111112
# skip only for dev
112113
if name == "sagemaker":
113-
out.write("/opt/ml/model/whl/sagemaker-2.185.1.dev0-py2.py3-none-any.whl\n")
114+
out.write("/opt/ml/model/whl/sagemaker-2.195.2.dev0-py2.py3-none-any.whl\n")
114115
elif name == "boto3":
115116
out.write("boto3==1.26.131\n")
116117
elif name in currently_used_packages:
117-
out.write(f"{name}\n")
118+
out.write(f"{name}=={version}\n")
118119

119120

120121
def get_all_requirements(dest: Path):
@@ -128,10 +129,8 @@ def get_all_requirements(dest: Path):
128129
# skip only for dev
129130
if name == "sagemaker":
130131
continue
131-
if name == "boto3":
132-
out.write("boto3==1.26.131\n")
133-
else:
134-
out.write(f"{name}=={version}\n")
132+
133+
out.write(f"{name}=={version}\n")
135134

136135

137136
def parse_args():

src/sagemaker/serve/mode/local_container_mode.py

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def create_server(
9393
docker_client=self.client,
9494
model_path=model_path if model_path else self.model_path,
9595
image_uri=image,
96+
secret_key=secret_key,
9697
env_vars=env_vars if env_vars else self.env_vars,
9798
)
9899
self._ping_container = self._triton_deep_ping

src/sagemaker/serve/mode/sagemaker_endpoint_mode.py

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def prepare(
7070
return self._upload_triton_artifacts(
7171
model_path=model_path,
7272
sagemaker_session=sagemaker_session,
73+
secret_key=secret_key,
7374
s3_model_data_url=s3_model_data_url,
7475
image=image,
7576
)

src/sagemaker/serve/model_server/torchserve/inference.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def _py_vs_parity_check():
9797
def _pickle_file_integrity_check():
9898
with open("/opt/ml/model/code/serve.pkl", "rb") as f:
9999
buffer = f.read()
100-
perform_integrity_check(buffer=buffer)
100+
101+
metadeata_path = Path("/opt/ml/model/code/metadata.json")
102+
perform_integrity_check(buffer=buffer, metadata_path=metadeata_path)
101103

102104

103105
# on import, execute

src/sagemaker/serve/model_server/triton/model.py

+30
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import logging
55
import ssl
66
from pathlib import Path
7+
import platform
78

89
import triton_python_backend_utils as pb_utils
910
import cloudpickle
11+
from sagemaker.serve.validations.check_integrity import perform_integrity_check
1012

1113
logger = logging.getLogger(__name__)
1214

@@ -52,3 +54,31 @@ def execute(self, requests):
5254
responses.append(response)
5355

5456
return responses
57+
58+
59+
def _run_preflight_diagnostics():
60+
_py_vs_parity_check()
61+
_pickle_file_integrity_check()
62+
63+
64+
def _py_vs_parity_check():
65+
container_py_vs = platform.python_version()
66+
local_py_vs = os.getenv("LOCAL_PYTHON")
67+
68+
if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]:
69+
logger.warning(
70+
f"The local python version {local_py_vs} differs from the python version "
71+
f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior"
72+
)
73+
74+
75+
def _pickle_file_integrity_check():
76+
serve_path = Path(TRITON_MODEL_DIR).joinpath("serve.pkl")
77+
metadata_path = Path(TRITON_MODEL_DIR).joinpath("metadata.json")
78+
with open(str(serve_path), "rb") as f:
79+
buffer = f.read()
80+
perform_integrity_check(buffer=buffer, metadata_path=metadata_path)
81+
82+
83+
# on import, execute
84+
_run_preflight_diagnostics()

src/sagemaker/serve/model_server/triton/server.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import uuid
44
import logging
55
import importlib
6+
import platform
67

78
from sagemaker import fw_utils
89
from sagemaker import Session
@@ -29,13 +30,20 @@ def _start_triton_server(
2930
self,
3031
docker_client: docker.DockerClient,
3132
model_path: str,
33+
secret_key: str,
3234
image_uri: str,
3335
env_vars: dict,
3436
):
3537
"""Placeholder docstring"""
3638
self.container_name = "triton" + uuid.uuid1().hex
3739
model_repository = model_path + "/model_repository"
38-
env_vars.update({"TRITON_MODEL_DIR": "/models/model"})
40+
env_vars.update(
41+
{
42+
"TRITON_MODEL_DIR": "/models/model",
43+
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
44+
"LOCAL_PYTHON": platform.python_version(),
45+
}
46+
)
3947

4048
if "cpu" not in image_uri:
4149
self.container = docker_client.containers.run(
@@ -102,6 +110,7 @@ def _upload_triton_artifacts(
102110
self,
103111
model_path: str,
104112
sagemaker_session: Session,
113+
secret_key: str,
105114
s3_model_data_url: str = None,
106115
image: str = None,
107116
):
@@ -127,5 +136,7 @@ def _upload_triton_artifacts(
127136
env_vars = {
128137
"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model",
129138
"TRITON_MODEL_DIR": "/opt/ml/model/model",
139+
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
140+
"LOCAL_PYTHON": platform.python_version(),
130141
}
131142
return s3_upload_path, env_vars

src/sagemaker/serve/model_server/triton/triton_builder.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
from sagemaker.base_deserializers import JSONDeserializer
2222
from sagemaker.serve.detector.pickler import save_pkl
2323
from sagemaker.serve.model_server.triton.config_template import CONFIG_TEMPLATE
24+
from sagemaker.serve.validations.check_integrity import (
25+
generate_secret_key,
26+
compute_hash,
27+
)
28+
29+
from sagemaker.remote_function.core.serialization import _MetaData
2430

2531

2632
logger = logging.getLogger(__name__)
@@ -206,6 +212,8 @@ def _prepare_for_triton(self):
206212
export_path.mkdir(parents=True)
207213

208214
if self.model:
215+
self.secret_key = "dummy secret key for onnx backend"
216+
209217
if self.framework == "pytorch":
210218
self._export_pytorch_to_onnx(
211219
export_path=export_path, model=self.model, schema_builder=self.schema_builder
@@ -228,10 +236,26 @@ def _prepare_for_triton(self):
228236

229237
self._pack_conda_env(pkl_path=pkl_path)
230238

239+
self._hmac_signing()
240+
231241
return
232242

233243
raise ValueError("Either model or inference_spec should be provided to ModelBuilder.")
234244

245+
def _hmac_signing(self):
246+
"""Perform HMAC signing on picke file for integrity check"""
247+
secret_key = generate_secret_key()
248+
pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model")
249+
250+
with open(str(pkl_path.joinpath("serve.pkl")), "rb") as f:
251+
buffer = f.read()
252+
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
253+
254+
with open(str(pkl_path.joinpath("metadata.json")), "wb") as metadata:
255+
metadata.write(_MetaData(hash_value).to_json())
256+
257+
self.secret_key = secret_key
258+
235259
def _generate_config_pbtxt(self, pkl_path: Path):
236260
config_path = pkl_path.joinpath("config.pbtxt")
237261

@@ -436,8 +460,6 @@ def _build_for_triton(self):
436460

437461
self._auto_detect_image_for_triton()
438462

439-
self.secret_key = "dummy secret key"
440-
441463
self._save_inference_spec()
442464

443465
self._prepare_for_triton()

src/sagemaker/serve/validations/check_integrity.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,16 @@ def compute_hash(buffer: bytes, secret_key: str) -> str:
1919
return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()
2020

2121

22-
def perform_integrity_check(buffer: bytes):
22+
def perform_integrity_check(buffer: bytes, metadata_path: Path):
2323
"""Validates the integrity of bytes by comparing the hash value"""
2424
secret_key = os.environ.get("SAGEMAKER_SERVE_SECRET_KEY")
2525
actual_hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
2626

27-
metadata_path = Path("/opt/ml/model/code/metadata.json").resolve()
28-
2927
if not Path.exists(metadata_path):
30-
raise Exception("Path to metadata.json does not exist")
28+
raise ValueError("Path to metadata.json does not exist")
3129

32-
with open(metadata_path, "rb") as md:
30+
with open(str(metadata_path), "rb") as md:
3331
expected_hash_value = _MetaData.from_json(md.read()).sha256_hash
3432

3533
if not hmac.compare_digest(expected_hash_value, actual_hash_value):
36-
raise Exception("Integrity check for the serialized function or data failed.")
34+
raise ValueError("Integrity check for the serialized function or data failed.")

0 commit comments

Comments
 (0)