Skip to content

Commit 53fea69

Browse files
evakravidoddaspk-amzn
authored andcommitted
feature: Combined inference and training script artifact (aws#3717)
1 parent f31e9fe commit 53fea69

File tree

4 files changed

+168
-2
lines changed

4 files changed

+168
-2
lines changed

src/sagemaker/jumpstart/artifacts.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def _retrieve_model_uri(
245245
def _retrieve_script_uri(
246246
model_id: str,
247247
model_version: str,
248-
script_scope: Optional[str],
248+
script_scope: Optional[str] = None,
249249
region: Optional[str] = None,
250250
tolerate_vulnerable_model: bool = False,
251251
tolerate_deprecated_model: bool = False,
@@ -294,7 +294,9 @@ def _retrieve_script_uri(
294294
if script_scope == JumpStartScriptScope.INFERENCE:
295295
model_script_key = model_specs.hosting_script_key
296296
elif script_scope == JumpStartScriptScope.TRAINING:
297-
model_script_key = model_specs.training_script_key
297+
model_script_key = (
298+
getattr(model_specs, "training_prepacked_script_key") or model_specs.training_script_key
299+
)
298300

299301
bucket = os.environ.get(
300302
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
297297
"default_training_instance_type",
298298
"supported_training_instance_types",
299299
"metrics",
300+
"training_prepacked_script_key",
300301
"hosting_prepacked_artifact_key",
301302
]
302303

@@ -347,6 +348,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
347348
"supported_training_instance_types"
348349
)
349350
self.metrics: Optional[List[Dict[str, str]]] = json_obj.get("metrics", None)
351+
self.training_prepacked_script_key: Optional[str] = json_obj.get(
352+
"training_prepacked_script_key", None
353+
)
350354
self.hosting_prepacked_artifact_key: Optional[str] = json_obj.get(
351355
"hosting_prepacked_artifact_key", None
352356
)

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,127 @@
218218
"ml.g4dn.12xlarge",
219219
],
220220
},
221+
"mock-model-training-prepacked-script-key": {
222+
"model_id": "sklearn-classification-linear",
223+
"url": "https://scikit-learn.org/stable/",
224+
"version": "1.0.0",
225+
"min_sdk_version": "2.68.1",
226+
"training_supported": True,
227+
"incremental_training_supported": False,
228+
"hosting_ecr_specs": {
229+
"framework": "sklearn",
230+
"framework_version": "0.23-1",
231+
"py_version": "py3",
232+
},
233+
"hosting_artifact_key": "sklearn-infer/infer-sklearn-classification-linear.tar.gz",
234+
"hosting_script_key": "source-directory-tarballs/sklearn/inference/classification/v1.0.0/sourcedir.tar.gz",
235+
"inference_vulnerable": False,
236+
"inference_dependencies": [],
237+
"inference_vulnerabilities": [],
238+
"training_vulnerable": False,
239+
"training_dependencies": [],
240+
"training_vulnerabilities": [],
241+
"deprecated": False,
242+
"hyperparameters": [
243+
{
244+
"name": "tol",
245+
"type": "float",
246+
"default": 0.0001,
247+
"min": 1e-20,
248+
"max": 50,
249+
"scope": "algorithm",
250+
},
251+
{
252+
"name": "penalty",
253+
"type": "text",
254+
"default": "l2",
255+
"options": ["l1", "l2", "elasticnet", "none"],
256+
"scope": "algorithm",
257+
},
258+
{
259+
"name": "alpha",
260+
"type": "float",
261+
"default": 0.0001,
262+
"min": 1e-20,
263+
"max": 999,
264+
"scope": "algorithm",
265+
},
266+
{
267+
"name": "l1_ratio",
268+
"type": "float",
269+
"default": 0.15,
270+
"min": 0,
271+
"max": 1,
272+
"scope": "algorithm",
273+
},
274+
{
275+
"name": "sagemaker_submit_directory",
276+
"type": "text",
277+
"default": "/opt/ml/input/data/code/sourcedir.tar.gz",
278+
"scope": "container",
279+
},
280+
{
281+
"name": "sagemaker_program",
282+
"type": "text",
283+
"default": "transfer_learning.py",
284+
"scope": "container",
285+
},
286+
{
287+
"name": "sagemaker_container_log_level",
288+
"type": "text",
289+
"default": "20",
290+
"scope": "container",
291+
},
292+
],
293+
"training_script_key": "source-directory-tarballs/sklearn/transfer_learning/classification/"
294+
"v1.0.0/sourcedir.tar.gz",
295+
"training_prepacked_script_key": "some/key/to/training_prepacked_script_key.tar.gz",
296+
"training_ecr_specs": {
297+
"framework_version": "0.23-1",
298+
"framework": "sklearn",
299+
"py_version": "py3",
300+
},
301+
"training_artifact_key": "sklearn-training/train-sklearn-classification-linear.tar.gz",
302+
"inference_environment_variables": [
303+
{
304+
"name": "SAGEMAKER_PROGRAM",
305+
"type": "text",
306+
"default": "inference.py",
307+
"scope": "container",
308+
},
309+
{
310+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
311+
"type": "text",
312+
"default": "/opt/ml/model/code",
313+
"scope": "container",
314+
},
315+
{
316+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
317+
"type": "text",
318+
"default": "20",
319+
"scope": "container",
320+
},
321+
{
322+
"name": "MODEL_CACHE_ROOT",
323+
"type": "text",
324+
"default": "/opt/ml/model",
325+
"scope": "container",
326+
},
327+
{"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"},
328+
{
329+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
330+
"type": "text",
331+
"default": "1",
332+
"scope": "container",
333+
},
334+
{
335+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
336+
"type": "text",
337+
"default": "3600",
338+
"scope": "container",
339+
},
340+
],
341+
},
221342
}
222343

223344

@@ -1302,6 +1423,7 @@
13021423
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
13031424
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
13041425
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
1426+
"training_prepacked_script_key": None,
13051427
"hosting_prepacked_artifact_key": None,
13061428
"hyperparameters": [
13071429
{
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock.mock import patch
16+
17+
from sagemaker import script_uris
18+
19+
from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec
20+
21+
22+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
23+
def test_jumpstart_combined_artifacts(patched_get_model_specs):
24+
25+
patched_get_model_specs.side_effect = get_special_model_spec
26+
27+
model_id_combined_script_artifact = "mock-model-training-prepacked-script-key"
28+
29+
uri = script_uris.retrieve(
30+
region="us-west-2",
31+
script_scope="training",
32+
model_id=model_id_combined_script_artifact,
33+
model_version="*",
34+
)
35+
assert (
36+
uri == "s3://jumpstart-cache-prod-us-west-2/some/key/to/"
37+
"training_prepacked_script_key.tar.gz"
38+
)

0 commit comments

Comments
 (0)