Skip to content

Commit db599df

Browse files
committed
Merge remote-tracking branch 'origin' into feat/jumpstart-instance-types
2 parents 9bc9629 + 07dd8ba commit db599df

File tree

4 files changed

+144
-5
lines changed

4 files changed

+144
-5
lines changed

src/sagemaker/jumpstart/artifacts.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def _retrieve_image_uri(
179179
def _retrieve_model_uri(
180180
model_id: str,
181181
model_version: str,
182-
model_scope: Optional[str],
182+
model_scope: Optional[str] = None,
183183
region: Optional[str] = None,
184184
tolerate_vulnerable_model: bool = False,
185185
tolerate_deprecated_model: bool = False,
@@ -225,7 +225,11 @@ def _retrieve_model_uri(
225225
)
226226

227227
if model_scope == JumpStartScriptScope.INFERENCE:
228-
model_artifact_key = model_specs.hosting_artifact_key
228+
model_artifact_key = (
229+
getattr(model_specs, "hosting_prepacked_artifact_key", None)
230+
or model_specs.hosting_artifact_key
231+
)
232+
229233
elif model_scope == JumpStartScriptScope.TRAINING:
230234
model_artifact_key = model_specs.training_artifact_key
231235

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
297297
"default_training_instance_type",
298298
"supported_training_instance_types",
299299
"metrics",
300+
"hosting_prepacked_artifact_key",
300301
]
301302

302303
def __init__(self, spec: Dict[str, Any]):
@@ -346,6 +347,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
346347
"supported_training_instance_types"
347348
)
348349
self.metrics: Optional[List[Dict[str, str]]] = json_obj.get("metrics", None)
350+
self.hosting_prepacked_artifact_key: Optional[str] = json_obj.get(
351+
"hosting_prepacked_artifact_key", None
352+
)
349353

350354
if self.training_supported:
351355
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,18 +118,110 @@
118118
"scope": "container",
119119
},
120120
],
121+
"default_inference_instance_type": "",
122+
"supported_inference_instance_types": None,
123+
"default_training_instance_type": None,
124+
"supported_training_instance_types": [],
125+
"inference_vulnerable": False,
126+
"inference_dependencies": [],
127+
"inference_vulnerabilities": [],
128+
"training_vulnerable": False,
129+
"training_dependencies": [],
130+
"training_vulnerabilities": [],
131+
"deprecated": False,
132+
"metrics": [],
133+
},
134+
"huggingface-text2text-flan-t5-xxl-fp16": {
135+
"model_id": "huggingface-text2text-flan-t5-xxl-fp16",
136+
"url": "https://huggingface.co/google/flan-t5-xxl",
137+
"version": "1.0.0",
138+
"min_sdk_version": "2.130.0",
139+
"training_supported": False,
140+
"incremental_training_supported": False,
141+
"hosting_ecr_specs": {
142+
"framework": "pytorch",
143+
"framework_version": "1.12.0",
144+
"py_version": "py38",
145+
"huggingface_transformers_version": "4.17.0",
146+
},
147+
"hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz",
148+
"hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.2/sourcedir.tar.gz",
149+
"hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.0/infer-prepack-huggingface-"
150+
"text2text-flan-t5-xxl-fp16.tar.gz",
151+
"hosting_prepacked_artifact_version": "1.0.0",
152+
"inference_vulnerable": False,
153+
"inference_dependencies": [
154+
"accelerate==0.16.0",
155+
"bitsandbytes==0.37.0",
156+
"filelock==3.9.0",
157+
"huggingface-hub==0.12.0",
158+
"regex==2022.7.9",
159+
"tokenizers==0.13.2",
160+
"transformers==4.26.0",
161+
],
162+
"inference_vulnerabilities": [],
163+
"training_vulnerable": False,
164+
"training_dependencies": [],
165+
"training_vulnerabilities": [],
166+
"deprecated": False,
167+
"inference_environment_variables": [
168+
{
169+
"name": "SAGEMAKER_PROGRAM",
170+
"type": "text",
171+
"default": "inference.py",
172+
"scope": "container",
173+
},
174+
{
175+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
176+
"type": "text",
177+
"default": "/opt/ml/model/code",
178+
"scope": "container",
179+
},
180+
{
181+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
182+
"type": "text",
183+
"default": "20",
184+
"scope": "container",
185+
},
186+
{
187+
"name": "MODEL_CACHE_ROOT",
188+
"type": "text",
189+
"default": "/opt/ml/model",
190+
"scope": "container",
191+
},
192+
{"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"},
193+
{
194+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
195+
"type": "text",
196+
"default": "1",
197+
"scope": "container",
198+
},
199+
{
200+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
201+
"type": "text",
202+
"default": "3600",
203+
"scope": "container",
204+
},
205+
],
121206
"inference_vulnerable": False,
122207
"inference_dependencies": [],
123208
"inference_vulnerabilities": [],
124209
"training_vulnerable": False,
125210
"training_dependencies": [],
126211
"training_vulnerabilities": [],
127212
"deprecated": False,
128-
"default_inference_instance_type": "",
129-
"supported_inference_instance_types": None,
130213
"default_training_instance_type": None,
131214
"supported_training_instance_types": [],
132-
}
215+
"metrics": [],
216+
"default_inference_instance_type": "ml.g5.12xlarge",
217+
"supported_inference_instance_types": [
218+
"ml.g5.12xlarge",
219+
"ml.g5.24xlarge",
220+
"ml.p3.8xlarge",
221+
"ml.p3.16xlarge",
222+
"ml.g4dn.12xlarge",
223+
],
224+
},
133225
}
134226

135227

@@ -1214,6 +1306,7 @@
12141306
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
12151307
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
12161308
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
1309+
"hosting_prepacked_artifact_key": None,
12171310
"hyperparameters": [
12181311
{
12191312
"name": "epochs",
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock.mock import patch
16+
17+
from sagemaker import model_uris
18+
19+
from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec
20+
21+
22+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
23+
def test_jumpstart_combined_artifacts(patched_get_model_specs):
24+
25+
patched_get_model_specs.side_effect = get_special_model_spec
26+
27+
model_id_combined_model_artifact = "huggingface-text2text-flan-t5-xxl-fp16"
28+
29+
uri = model_uris.retrieve(
30+
region="us-west-2",
31+
model_scope="inference",
32+
model_id=model_id_combined_model_artifact,
33+
model_version="*",
34+
)
35+
assert (
36+
uri == "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/"
37+
"prepack/v1.0.0/infer-prepack-huggingface-text2text-flan-t5-xxl-fp16.tar.gz"
38+
)

0 commit comments

Comments
 (0)