Skip to content

Commit e577dba

Browse files
Narrohagmollyheamazon
authored andcommitted
feat: add integ tests for training JumpStart models in private hub (aws#5076)
* feat: add integ tests for training JumpStart models in private hub * fixed formatting * remove unused imports * fix unused imports * fix unit test failure and fix bug around versioning * fix formatting * fix unit tests * fix model_uri usage issue * fix some formatting * separate private hub setup code * add try catch block * fix flake8 issue so except clause is not bare * black formatting
1 parent 6a64c1f commit e577dba

File tree

11 files changed

+285
-13
lines changed

11 files changed

+285
-13
lines changed

src/sagemaker/jumpstart/factory/estimator.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
JUMPSTART_LOGGER,
5757
TRAINING_ENTRY_POINT_SCRIPT_NAME,
5858
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
59+
JUMPSTART_MODEL_HUB_NAME,
5960
)
6061
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
6162
from sagemaker.jumpstart.factory import model
@@ -313,16 +314,31 @@ def _add_hub_access_config_to_kwargs_inputs(
313314
):
314315
"""Adds HubAccessConfig to kwargs inputs"""
315316

317+
dataset_uri = kwargs.specs.default_training_dataset_uri
316318
if isinstance(kwargs.inputs, str):
317-
kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config)
319+
if dataset_uri is not None and dataset_uri == kwargs.inputs:
320+
kwargs.inputs = TrainingInput(
321+
s3_data=kwargs.inputs, hub_access_config=hub_access_config
322+
)
318323
elif isinstance(kwargs.inputs, TrainingInput):
319-
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
324+
if (
325+
dataset_uri is not None
326+
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
327+
):
328+
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
320329
elif isinstance(kwargs.inputs, dict):
321330
for k, v in kwargs.inputs.items():
322331
if isinstance(v, str):
323-
kwargs.inputs[k] = TrainingInput(s3_data=v, hub_access_config=hub_access_config)
332+
training_input = TrainingInput(s3_data=v)
333+
if dataset_uri is not None and dataset_uri == v:
334+
training_input.add_hub_access_config(hub_access_config=hub_access_config)
335+
kwargs.inputs[k] = training_input
324336
elif isinstance(kwargs.inputs, TrainingInput):
325-
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)
337+
if (
338+
dataset_uri is not None
339+
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
340+
):
341+
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)
326342

327343
return kwargs
328344

@@ -616,8 +632,13 @@ def _add_model_reference_arn_to_kwargs(
616632

617633
def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
618634
"""Sets model uri in kwargs based on default or override, returns full kwargs."""
619-
620-
if _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)):
635+
# hub_arn is by default None unless the user specifies the hub_name
636+
# If no hub_name is specified, it is assumed the public hub
637+
is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs.hub_arn if kwargs.hub_arn else False
638+
if (
639+
_model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs))
640+
or is_private_hub
641+
):
621642
default_model_uri = model_uris.retrieve(
622643
model_scope=JumpStartScriptScope.TRAINING,
623644
instance_type=kwargs.instance_type,

src/sagemaker/jumpstart/hub/interfaces.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
630630
if json_obj.get("ValidationSupported")
631631
else None
632632
)
633-
self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri")
634633
self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase")
635634
self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False))
636635
self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (
@@ -671,6 +670,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
671670
)
672671

673672
if self.training_supported:
673+
self.default_training_dataset_uri: Optional[str] = json_obj.get(
674+
"DefaultTrainingDatasetUri"
675+
)
674676
self.training_model_package_artifact_uri: Optional[str] = json_obj.get(
675677
"TrainingModelPackageArtifactUri"
676678
)

src/sagemaker/jumpstart/hub/parsers.py

+6
Original file line numberDiff line numberDiff line change
@@ -279,4 +279,10 @@ def make_model_specs_from_describe_hub_content_response(
279279
specs["training_instance_type_variants"] = (
280280
hub_model_document.training_instance_type_variants
281281
)
282+
if hub_model_document.default_training_dataset_uri:
283+
_, default_training_dataset_key = parse_s3_url( # pylint: disable=unused-variable
284+
hub_model_document.default_training_dataset_uri
285+
)
286+
specs["default_training_dataset_key"] = default_training_dataset_key
287+
specs["default_training_dataset_uri"] = hub_model_document.default_training_dataset_uri
282288
return JumpStartModelSpecs(_to_json(specs), is_hub_content=True)

src/sagemaker/jumpstart/hub/utils.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo
2323
from sagemaker.jumpstart import constants
2424
from packaging.specifiers import SpecifierSet, InvalidSpecifier
25+
from packaging import version
2526

2627
PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"
2728

@@ -219,9 +220,12 @@ def get_hub_model_version(
219220
sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION
220221

221222
try:
222-
hub_content_summaries = sagemaker_session.list_hub_content_versions(
223-
hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type
224-
).get("HubContentSummaries")
223+
hub_content_summaries = _list_hub_content_versions_helper(
224+
hub_name=hub_name,
225+
hub_content_name=hub_model_name,
226+
hub_content_type=hub_model_type,
227+
sagemaker_session=sagemaker_session,
228+
)
225229
except Exception as ex:
226230
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")
227231

@@ -238,13 +242,34 @@ def get_hub_model_version(
238242
raise
239243

240244

245+
def _list_hub_content_versions_helper(
246+
hub_name, hub_content_name, hub_content_type, sagemaker_session
247+
):
248+
all_hub_content_summaries = []
249+
list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
250+
hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type
251+
)
252+
all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries"))
253+
while "NextToken" in list_hub_content_versions_response:
254+
list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
255+
hub_name=hub_name,
256+
hub_content_name=hub_content_name,
257+
hub_content_type=hub_content_type,
258+
next_token=list_hub_content_versions_response["NextToken"],
259+
)
260+
all_hub_content_summaries.extend(
261+
list_hub_content_versions_response.get("HubContentSummaries")
262+
)
263+
return all_hub_content_summaries
264+
265+
241266
def _get_hub_model_version_for_open_weight_version(
242267
hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
243268
) -> str:
244269
available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries]
245270

246271
if hub_model_version == "*" or hub_model_version is None:
247-
return str(max(available_model_versions))
272+
return str(max(version.parse(v) for v in available_model_versions))
248273

249274
try:
250275
spec = SpecifierSet(f"=={hub_model_version}")

src/sagemaker/jumpstart/types.py

+8
Original file line numberDiff line numberDiff line change
@@ -1279,6 +1279,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
12791279
"hosting_neuron_model_version",
12801280
"hub_content_type",
12811281
"_is_hub_content",
1282+
"default_training_dataset_key",
1283+
"default_training_dataset_uri",
12821284
]
12831285

12841286
_non_serializable_slots = ["_is_hub_content"]
@@ -1462,6 +1464,12 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
14621464
else None
14631465
)
14641466
self.model_subscription_link = json_obj.get("model_subscription_link")
1467+
self.default_training_dataset_key: Optional[str] = json_obj.get(
1468+
"default_training_dataset_key"
1469+
)
1470+
self.default_training_dataset_uri: Optional[str] = json_obj.get(
1471+
"default_training_dataset_uri"
1472+
)
14651473

14661474
def to_json(self) -> Dict[str, Any]:
14671475
"""Returns json representation of JumpStartMetadataBaseFields object."""

tests/integ/sagemaker/jumpstart/constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
4747
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),
4848
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
4949
("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"),
50-
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"),
50+
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI/"),
5151
("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"),
5252
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
5353
("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"),

tests/integ/sagemaker/jumpstart/private_hub/estimator/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
import time
17+
18+
import pytest
19+
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
20+
from sagemaker.jumpstart.hub.hub import Hub
21+
22+
from sagemaker.jumpstart.estimator import JumpStartEstimator
23+
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
24+
25+
from tests.integ.sagemaker.jumpstart.constants import (
26+
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME,
27+
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
28+
JUMPSTART_TAG,
29+
)
30+
from tests.integ.sagemaker.jumpstart.utils import (
31+
get_public_hub_model_arn,
32+
get_sm_session,
33+
with_exponential_backoff,
34+
get_training_dataset_for_model_and_version,
35+
)
36+
37+
MAX_INIT_TIME_SECONDS = 5
38+
39+
TEST_MODEL_IDS = {
40+
"huggingface-spc-bert-base-cased",
41+
"meta-textgeneration-llama-2-7b",
42+
"catboost-regression-model",
43+
}
44+
45+
46+
@with_exponential_backoff()
47+
def create_model_reference(hub_instance, model_arn):
48+
try:
49+
hub_instance.create_model_reference(model_arn=model_arn)
50+
except Exception:
51+
pass
52+
53+
54+
@pytest.fixture(scope="session")
55+
def add_model_references():
56+
# Create Model References to test in Hub
57+
hub_instance = Hub(
58+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
59+
)
60+
for model in TEST_MODEL_IDS:
61+
model_arn = get_public_hub_model_arn(hub_instance, model)
62+
create_model_reference(hub_instance, model_arn)
63+
64+
65+
def test_jumpstart_hub_estimator(setup, add_model_references):
66+
model_id, model_version = "huggingface-spc-bert-base-cased", "*"
67+
68+
estimator = JumpStartEstimator(
69+
model_id=model_id,
70+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
71+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
72+
)
73+
74+
estimator.fit(
75+
inputs={
76+
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
77+
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
78+
}
79+
)
80+
81+
# test that we can create a JumpStartEstimator from existing job with `attach`
82+
estimator = JumpStartEstimator.attach(
83+
training_job_name=estimator.latest_training_job.name,
84+
model_id=model_id,
85+
model_version=model_version,
86+
)
87+
88+
# uses ml.p3.2xlarge instance
89+
predictor = estimator.deploy(
90+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
91+
)
92+
93+
response = predictor.predict(["hello", "world"])
94+
95+
assert response is not None
96+
97+
98+
def test_jumpstart_hub_estimator_with_session(setup, add_model_references):
99+
100+
model_id, model_version = "huggingface-spc-bert-base-cased", "*"
101+
102+
sagemaker_session = get_sm_session()
103+
104+
estimator = JumpStartEstimator(
105+
model_id=model_id,
106+
role=sagemaker_session.get_caller_identity_arn(),
107+
sagemaker_session=sagemaker_session,
108+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
109+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
110+
)
111+
112+
estimator.fit(
113+
inputs={
114+
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
115+
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
116+
}
117+
)
118+
119+
# test that we can create a JumpStartEstimator from existing job with `attach`
120+
estimator = JumpStartEstimator.attach(
121+
training_job_name=estimator.latest_training_job.name,
122+
model_id=model_id,
123+
model_version=model_version,
124+
sagemaker_session=get_sm_session(),
125+
)
126+
127+
# uses ml.p3.2xlarge instance
128+
predictor = estimator.deploy(
129+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
130+
role=get_sm_session().get_caller_identity_arn(),
131+
sagemaker_session=get_sm_session(),
132+
)
133+
134+
response = predictor.predict(["hello", "world"])
135+
136+
assert response is not None
137+
138+
139+
def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):
140+
141+
model_id, model_version = "meta-textgeneration-llama-2-7b", "*"
142+
143+
estimator = JumpStartEstimator(
144+
model_id=model_id,
145+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
146+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
147+
)
148+
149+
estimator.fit(
150+
accept_eula=True,
151+
inputs={
152+
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
153+
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
154+
},
155+
)
156+
157+
predictor = estimator.deploy(
158+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
159+
role=get_sm_session().get_caller_identity_arn(),
160+
sagemaker_session=get_sm_session(),
161+
)
162+
163+
payload = {
164+
"inputs": "some-payload",
165+
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
166+
}
167+
168+
response = predictor.predict(payload, custom_attributes="accept_eula=true")
169+
170+
assert response is not None
171+
172+
173+
def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references):
174+
175+
model_id, model_version = "meta-textgeneration-llama-2-7b", "*"
176+
177+
estimator = JumpStartEstimator(
178+
model_id=model_id,
179+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
180+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
181+
)
182+
with pytest.raises(Exception):
183+
estimator.fit(
184+
inputs={
185+
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
186+
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
187+
}
188+
)
189+
190+
191+
def test_instantiating_estimator(setup, add_model_references):
192+
193+
model_id = "catboost-regression-model"
194+
195+
start_time = time.perf_counter()
196+
197+
JumpStartEstimator(
198+
model_id=model_id,
199+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
200+
)
201+
202+
elapsed_time = time.perf_counter() - start_time
203+
204+
assert elapsed_time <= MAX_INIT_TIME_SECONDS

tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@
4848

4949
@with_exponential_backoff()
5050
def create_model_reference(hub_instance, model_arn):
51-
hub_instance.create_model_reference(model_arn=model_arn)
51+
try:
52+
hub_instance.create_model_reference(model_arn=model_arn)
53+
except Exception:
54+
pass
5255

5356

5457
@pytest.fixture(scope="session")

0 commit comments

Comments
 (0)