Skip to content

Commit fd48fa8

Browse files
authored
Merge branch 'master' into cv_s3_upload
2 parents 8a5fa72 + 07dd8ba commit fd48fa8

File tree

7 files changed

+165
-6
lines changed

7 files changed

+165
-6
lines changed

CHANGELOG.md

+14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
# Changelog
22

3+
## v2.141.0 (2023-03-24)
4+
5+
### Features
6+
7+
* AutoGluon 0.7.0 image_uris update
8+
* Add DJL FasterTransformer image uris
9+
* EMR step runtime role support
10+
* locations for EMR configuration and Spark dependencies
11+
* Adding support for 1P Algorithms in ZAZ, ZRH, HYD, MEL
12+
13+
### Documentation Changes
14+
15+
* Update FeatureGroup kms key id documentation
16+
317
## v2.140.1 (2023-03-21)
418

519
### Bug Fixes and Other Changes

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.140.2.dev0
1+
2.141.1.dev0

src/sagemaker/jumpstart/artifacts.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,10 @@ def _retrieve_image_uri(
173173
def _retrieve_model_uri(
174174
model_id: str,
175175
model_version: str,
176-
model_scope: Optional[str],
177-
region: Optional[str],
178-
tolerate_vulnerable_model: bool,
179-
tolerate_deprecated_model: bool,
176+
model_scope: Optional[str] = None,
177+
region: Optional[str] = None,
178+
tolerate_vulnerable_model: bool = False,
179+
tolerate_deprecated_model: bool = False,
180180
):
181181
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
182182
@@ -219,7 +219,11 @@ def _retrieve_model_uri(
219219
)
220220

221221
if model_scope == JumpStartScriptScope.INFERENCE:
222-
model_artifact_key = model_specs.hosting_artifact_key
222+
model_artifact_key = (
223+
getattr(model_specs, "hosting_prepacked_artifact_key", None)
224+
or model_specs.hosting_artifact_key
225+
)
226+
223227
elif model_scope == JumpStartScriptScope.TRAINING:
224228
model_artifact_key = model_specs.training_artifact_key
225229

src/sagemaker/jumpstart/types.py

+4
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
293293
"training_vulnerabilities",
294294
"deprecated",
295295
"metrics",
296+
"hosting_prepacked_artifact_key",
296297
]
297298

298299
def __init__(self, spec: Dict[str, Any]):
@@ -330,6 +331,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
330331
self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"]
331332
self.deprecated: bool = bool(json_obj["deprecated"])
332333
self.metrics: Optional[List[Dict[str, str]]] = json_obj.get("metrics", None)
334+
self.hosting_prepacked_artifact_key: Optional[str] = json_obj.get(
335+
"hosting_prepacked_artifact_key", None
336+
)
333337

334338
if self.training_supported:
335339
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(

tests/unit/sagemaker/jumpstart/constants.py

+86
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,91 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
SPECIAL_MODEL_SPECS_DICT = {
16+
"huggingface-text2text-flan-t5-xxl-fp16": {
17+
"model_id": "huggingface-text2text-flan-t5-xxl-fp16",
18+
"url": "https://huggingface.co/google/flan-t5-xxl",
19+
"version": "1.0.0",
20+
"min_sdk_version": "2.130.0",
21+
"training_supported": False,
22+
"incremental_training_supported": False,
23+
"hosting_ecr_specs": {
24+
"framework": "pytorch",
25+
"framework_version": "1.12.0",
26+
"py_version": "py38",
27+
"huggingface_transformers_version": "4.17.0",
28+
},
29+
"hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz",
30+
"hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.2/sourcedir.tar.gz",
31+
"hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.0/infer-prepack-huggingface-"
32+
"text2text-flan-t5-xxl-fp16.tar.gz",
33+
"hosting_prepacked_artifact_version": "1.0.0",
34+
"inference_vulnerable": False,
35+
"inference_dependencies": [
36+
"accelerate==0.16.0",
37+
"bitsandbytes==0.37.0",
38+
"filelock==3.9.0",
39+
"huggingface-hub==0.12.0",
40+
"regex==2022.7.9",
41+
"tokenizers==0.13.2",
42+
"transformers==4.26.0",
43+
],
44+
"inference_vulnerabilities": [],
45+
"training_vulnerable": False,
46+
"training_dependencies": [],
47+
"training_vulnerabilities": [],
48+
"deprecated": False,
49+
"inference_environment_variables": [
50+
{
51+
"name": "SAGEMAKER_PROGRAM",
52+
"type": "text",
53+
"default": "inference.py",
54+
"scope": "container",
55+
},
56+
{
57+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
58+
"type": "text",
59+
"default": "/opt/ml/model/code",
60+
"scope": "container",
61+
},
62+
{
63+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
64+
"type": "text",
65+
"default": "20",
66+
"scope": "container",
67+
},
68+
{
69+
"name": "MODEL_CACHE_ROOT",
70+
"type": "text",
71+
"default": "/opt/ml/model",
72+
"scope": "container",
73+
},
74+
{"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"},
75+
{
76+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
77+
"type": "text",
78+
"default": "1",
79+
"scope": "container",
80+
},
81+
{
82+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
83+
"type": "text",
84+
"default": "3600",
85+
"scope": "container",
86+
},
87+
],
88+
"metrics": [],
89+
"default_inference_instance_type": "ml.g5.12xlarge",
90+
"supported_inference_instance_types": [
91+
"ml.g5.12xlarge",
92+
"ml.g5.24xlarge",
93+
"ml.p3.8xlarge",
94+
"ml.p3.16xlarge",
95+
"ml.g4dn.12xlarge",
96+
],
97+
}
98+
}
99+
15100
PROTOTYPICAL_MODEL_SPECS_DICT = {
16101
"pytorch-eqa-bert-base-cased": {
17102
"model_id": "pytorch-eqa-bert-base-cased",
@@ -1093,6 +1178,7 @@
10931178
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
10941179
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
10951180
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
1181+
"hosting_prepacked_artifact_key": None,
10961182
"hyperparameters": [
10971183
{
10981184
"name": "epochs",

tests/unit/sagemaker/jumpstart/utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
BASE_MANIFEST,
3030
BASE_SPEC,
3131
BASE_HEADER,
32+
SPECIAL_MODEL_SPECS_DICT,
3233
)
3334

3435

@@ -92,6 +93,18 @@ def get_prototype_model_spec(
9293
return specs
9394

9495

96+
def get_special_model_spec(
97+
region: str = None, model_id: str = None, version: str = None
98+
) -> JumpStartModelSpecs:
99+
"""This function mocks cache accessor functions. For this mock,
100+
we only retrieve model specs based on the model ID. This is reserved
101+
for special specs.
102+
"""
103+
104+
specs = JumpStartModelSpecs(SPECIAL_MODEL_SPECS_DICT[model_id])
105+
return specs
106+
107+
95108
def get_spec_from_base_spec(
96109
_obj: JumpStartModelsCache = None,
97110
region: str = None,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock.mock import patch
16+
17+
from sagemaker import model_uris
18+
19+
from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec
20+
21+
22+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
23+
def test_jumpstart_combined_artifacts(patched_get_model_specs):
24+
25+
patched_get_model_specs.side_effect = get_special_model_spec
26+
27+
model_id_combined_model_artifact = "huggingface-text2text-flan-t5-xxl-fp16"
28+
29+
uri = model_uris.retrieve(
30+
region="us-west-2",
31+
model_scope="inference",
32+
model_id=model_id_combined_model_artifact,
33+
model_version="*",
34+
)
35+
assert (
36+
uri == "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/"
37+
"prepack/v1.0.0/infer-prepack-huggingface-text2text-flan-t5-xxl-fp16.tar.gz"
38+
)

0 commit comments

Comments
 (0)