Skip to content

Commit edcfe67

Browse files
committed
feat: combined inference + training script artifact
1 parent 43e3571 commit edcfe67

File tree

5 files changed

+207
-5
lines changed

5 files changed

+207
-5
lines changed

src/sagemaker/jumpstart/artifacts.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,11 @@ def _retrieve_model_uri(
235235
def _retrieve_script_uri(
236236
model_id: str,
237237
model_version: str,
238-
script_scope: Optional[str],
239-
region: Optional[str],
240-
tolerate_vulnerable_model: bool,
241-
tolerate_deprecated_model: bool,
238+
script_scope: Optional[str] = None,
239+
region: Optional[str] = False,
240+
tolerate_vulnerable_model: bool = False,
241+
tolerate_deprecated_model: bool = False,
242+
include_training_script: bool = False,
242243
):
243244
"""Retrieves the script S3 URI associated with the model matching the given arguments.
244245
@@ -259,6 +260,8 @@ def _retrieve_script_uri(
259260
tolerate_deprecated_model (bool): True if deprecated versions of model
260261
specifications should be tolerated (exception not raised). If False, raises
261262
an exception if the version of the model is deprecated.
263+
include_training_script (bool): True if training script should be packaged along with
264+
inference script. (Default: False.)
262265
Returns:
263266
str: the model script URI for the corresponding model.
264267
@@ -281,8 +284,17 @@ def _retrieve_script_uri(
281284
)
282285

283286
if script_scope == JumpStartScriptScope.INFERENCE:
284-
model_script_key = model_specs.hosting_script_key
287+
if not include_training_script:
288+
model_script_key = model_specs.hosting_script_key
289+
else:
290+
model_script_key = getattr(model_specs, "training_prepacked_script_key", None)
291+
if model_script_key is None:
292+
raise ValueError(
293+
f"Cannot include training script for {model_id} with version {model_version}."
294+
)
285295
elif script_scope == JumpStartScriptScope.TRAINING:
296+
if include_training_script:
297+
raise ValueError("Can only include training script for inference jobs.")
286298
model_script_key = model_specs.training_script_key
287299

288300
bucket = os.environ.get(

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
293293
"training_vulnerabilities",
294294
"deprecated",
295295
"metrics",
296+
"training_prepacked_script_key",
296297
]
297298

298299
def __init__(self, spec: Dict[str, Any]):
@@ -330,6 +331,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
330331
self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"]
331332
self.deprecated: bool = bool(json_obj["deprecated"])
332333
self.metrics: Optional[List[Dict[str, str]]] = json_obj.get("metrics", None)
334+
self.training_prepacked_script_key: Optional[str] = json_obj.get(
335+
"training_prepacked_script_key", None
336+
)
333337

334338
if self.training_supported:
335339
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(

src/sagemaker/script_uris.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def retrieve(
2929
script_scope=None,
3030
tolerate_vulnerable_model: bool = False,
3131
tolerate_deprecated_model: bool = False,
32+
include_training_script: bool = False,
3233
) -> str:
3334
"""Retrieves the script S3 URI associated with the model matching the given arguments.
3435
@@ -47,6 +48,8 @@ def retrieve(
4748
tolerate_deprecated_model (bool): ``True`` if deprecated models should be tolerated
4849
without raising an exception. ``False`` if these models should raise an exception.
4950
(Default: False).
51+
include_training_script (bool): True if training script should be packaged along with
52+
inference script. (Default: False.)
5053
Returns:
5154
str: The model script URI for the corresponding model.
5255
@@ -67,4 +70,5 @@ def retrieve(
6770
region,
6871
tolerate_vulnerable_model,
6972
tolerate_deprecated_model,
73+
include_training_script,
7074
)

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,127 @@
10701070
},
10711071
],
10721072
},
1073+
"mock-model-training-prepacked-script-key": {
1074+
"model_id": "sklearn-classification-linear",
1075+
"url": "https://scikit-learn.org/stable/",
1076+
"version": "1.0.0",
1077+
"min_sdk_version": "2.68.1",
1078+
"training_supported": True,
1079+
"incremental_training_supported": False,
1080+
"hosting_ecr_specs": {
1081+
"framework": "sklearn",
1082+
"framework_version": "0.23-1",
1083+
"py_version": "py3",
1084+
},
1085+
"hosting_artifact_key": "sklearn-infer/infer-sklearn-classification-linear.tar.gz",
1086+
"hosting_script_key": "source-directory-tarballs/sklearn/inference/classification/v1.0.0/sourcedir.tar.gz",
1087+
"inference_vulnerable": False,
1088+
"inference_dependencies": [],
1089+
"inference_vulnerabilities": [],
1090+
"training_vulnerable": False,
1091+
"training_dependencies": [],
1092+
"training_vulnerabilities": [],
1093+
"deprecated": False,
1094+
"hyperparameters": [
1095+
{
1096+
"name": "tol",
1097+
"type": "float",
1098+
"default": 0.0001,
1099+
"min": 1e-20,
1100+
"max": 50,
1101+
"scope": "algorithm",
1102+
},
1103+
{
1104+
"name": "penalty",
1105+
"type": "text",
1106+
"default": "l2",
1107+
"options": ["l1", "l2", "elasticnet", "none"],
1108+
"scope": "algorithm",
1109+
},
1110+
{
1111+
"name": "alpha",
1112+
"type": "float",
1113+
"default": 0.0001,
1114+
"min": 1e-20,
1115+
"max": 999,
1116+
"scope": "algorithm",
1117+
},
1118+
{
1119+
"name": "l1_ratio",
1120+
"type": "float",
1121+
"default": 0.15,
1122+
"min": 0,
1123+
"max": 1,
1124+
"scope": "algorithm",
1125+
},
1126+
{
1127+
"name": "sagemaker_submit_directory",
1128+
"type": "text",
1129+
"default": "/opt/ml/input/data/code/sourcedir.tar.gz",
1130+
"scope": "container",
1131+
},
1132+
{
1133+
"name": "sagemaker_program",
1134+
"type": "text",
1135+
"default": "transfer_learning.py",
1136+
"scope": "container",
1137+
},
1138+
{
1139+
"name": "sagemaker_container_log_level",
1140+
"type": "text",
1141+
"default": "20",
1142+
"scope": "container",
1143+
},
1144+
],
1145+
"training_script_key": "source-directory-tarballs/sklearn/transfer_learning/classification/"
1146+
"v1.0.0/sourcedir.tar.gz",
1147+
"training_prepacked_script_key": "some/key/to/training_prepacked_script_key.tar.gz",
1148+
"training_ecr_specs": {
1149+
"framework_version": "0.23-1",
1150+
"framework": "sklearn",
1151+
"py_version": "py3",
1152+
},
1153+
"training_artifact_key": "sklearn-training/train-sklearn-classification-linear.tar.gz",
1154+
"inference_environment_variables": [
1155+
{
1156+
"name": "SAGEMAKER_PROGRAM",
1157+
"type": "text",
1158+
"default": "inference.py",
1159+
"scope": "container",
1160+
},
1161+
{
1162+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
1163+
"type": "text",
1164+
"default": "/opt/ml/model/code",
1165+
"scope": "container",
1166+
},
1167+
{
1168+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
1169+
"type": "text",
1170+
"default": "20",
1171+
"scope": "container",
1172+
},
1173+
{
1174+
"name": "MODEL_CACHE_ROOT",
1175+
"type": "text",
1176+
"default": "/opt/ml/model",
1177+
"scope": "container",
1178+
},
1179+
{"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"},
1180+
{
1181+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
1182+
"type": "text",
1183+
"default": "1",
1184+
"scope": "container",
1185+
},
1186+
{
1187+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
1188+
"type": "text",
1189+
"default": "3600",
1190+
"scope": "container",
1191+
},
1192+
],
1193+
},
10731194
}
10741195

10751196
BASE_SPEC = {
@@ -1093,6 +1214,7 @@
10931214
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
10941215
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
10951216
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
1217+
"training_prepacked_script_key": None,
10961218
"hyperparameters": [
10971219
{
10981220
"name": "epochs",
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock.mock import patch
16+
17+
from sagemaker import script_uris
18+
import pytest
19+
20+
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
21+
22+
23+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
24+
def test_jumpstart_combined_artifacts(patched_get_model_specs):
25+
26+
patched_get_model_specs.side_effect = get_prototype_model_spec
27+
28+
model_id_combined_script_artifact = "mock-model-training-prepacked-script-key"
29+
30+
uri = script_uris.retrieve(
31+
region="us-west-2",
32+
script_scope="inference",
33+
model_id=model_id_combined_script_artifact,
34+
model_version="*",
35+
include_training_script=True,
36+
)
37+
assert (
38+
uri == "s3://jumpstart-cache-prod-us-west-2/some/key/to/"
39+
"training_prepacked_script_key.tar.gz"
40+
)
41+
42+
with pytest.raises(ValueError):
43+
script_uris.retrieve(
44+
region="us-west-2",
45+
script_scope="training",
46+
model_id=model_id_combined_script_artifact,
47+
model_version="*",
48+
include_training_script=True,
49+
)
50+
51+
model_id_combined_script_artifact_unsupported = "xgboost-classification-model"
52+
53+
with pytest.raises(ValueError):
54+
script_uris.retrieve(
55+
region="us-west-2",
56+
script_scope="inference",
57+
model_id=model_id_combined_script_artifact_unsupported,
58+
model_version="*",
59+
include_training_script=True,
60+
)

0 commit comments

Comments
 (0)