Skip to content

Commit 3e6ca28

Browse files
authored
Merge branch 'master' into fix/pipeline-variable-kms-key
2 parents 62ce65a + 90a5399 commit 3e6ca28

File tree

4 files changed

+82
-3
lines changed

4 files changed

+82
-3
lines changed

tests/integ/sagemaker/jumpstart/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
4545
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
4646
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"),
4747
("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"),
48+
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
4849
}
4950

5051

tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py

+56
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,36 @@
1313
from __future__ import absolute_import
1414
import os
1515
import time
16+
17+
import pytest
1618
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
1719

1820
from sagemaker.jumpstart.estimator import JumpStartEstimator
21+
import tests
1922
from tests.integ.sagemaker.jumpstart.constants import (
2023
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
2124
JUMPSTART_TAG,
2225
)
2326
from tests.integ.sagemaker.jumpstart.utils import (
2427
get_sm_session,
2528
get_training_dataset_for_model_and_version,
29+
x_fail_if_ice,
2630
)
2731

2832
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
2933

3034

3135
MAX_INIT_TIME_SECONDS = 5
3236

37+
GATED_TRAINING_MODEL_SUPPORTED_REGIONS = {
38+
"us-west-2",
39+
"us-east-1",
40+
"eu-west-1",
41+
"ap-southeast-1",
42+
"us-east-2",
43+
"ap-southeast-2",
44+
}
45+
3346

3447
def test_jumpstart_estimator(setup):
3548

@@ -63,6 +76,49 @@ def test_jumpstart_estimator(setup):
6376
assert response is not None
6477

6578

79+
@x_fail_if_ice
80+
@pytest.mark.skipif(
81+
tests.integ.test_region() not in GATED_TRAINING_MODEL_SUPPORTED_REGIONS,
82+
reason=f"JumpStart gated training models unavailable in {tests.integ.test_region()}.",
83+
)
84+
def test_gated_model_training(setup):
85+
86+
model_id, model_version = "meta-textgeneration-llama-2-7b", "*"
87+
88+
estimator = JumpStartEstimator(
89+
model_id=model_id,
90+
role=get_sm_session().get_caller_identity_arn(),
91+
sagemaker_session=get_sm_session(),
92+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
93+
environment={"accept_eula": "true"},
94+
max_run=259200, # avoid exceeding resource limits
95+
)
96+
97+
# uses ml.g5.12xlarge instance
98+
estimator.fit(
99+
{
100+
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
101+
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
102+
}
103+
)
104+
105+
# uses ml.g5.2xlarge instance
106+
predictor = estimator.deploy(
107+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
108+
role=get_sm_session().get_caller_identity_arn(),
109+
sagemaker_session=get_sm_session(),
110+
)
111+
112+
payload = {
113+
"inputs": "some-payload",
114+
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
115+
}
116+
117+
response = predictor.predict(payload, custom_attributes="accept_eula=true")
118+
119+
assert response is not None
120+
121+
66122
def test_instatiating_estimator_not_too_slow(setup):
67123

68124
model_id = "xgboost-classification-model"

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,14 @@
3333

3434
MAX_INIT_TIME_SECONDS = 5
3535

36-
MODEL_PACKAGE_ARN_SUPPORTED_REGIONS = {"us-west-2", "us-east-1"}
36+
GATED_INFERENCE_MODEL_SUPPORTED_REGIONS = {
37+
"us-west-2",
38+
"us-east-1",
39+
"eu-west-1",
40+
"ap-southeast-1",
41+
"us-east-2",
42+
"ap-southeast-2",
43+
}
3744

3845

3946
def test_non_prepacked_jumpstart_model(setup):
@@ -80,8 +87,8 @@ def test_prepacked_jumpstart_model(setup):
8087

8188

8289
@pytest.mark.skipif(
83-
tests.integ.test_region() not in MODEL_PACKAGE_ARN_SUPPORTED_REGIONS,
84-
reason=f"JumpStart Model Package models unavailable in {tests.integ.test_region()}.",
90+
tests.integ.test_region() not in GATED_INFERENCE_MODEL_SUPPORTED_REGIONS,
91+
reason=f"JumpStart gated inference models unavailable in {tests.integ.test_region()}.",
8592
)
8693
def test_model_package_arn_jumpstart_model(setup):
8794

tests/integ/sagemaker/jumpstart/utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
import functools
1415
import json
1516

1617
import uuid
@@ -19,6 +20,7 @@
1920
import pandas as pd
2021
import os
2122
from botocore.config import Config
23+
import pytest
2224

2325

2426
from tests.integ.sagemaker.jumpstart.constants import (
@@ -50,6 +52,19 @@ def get_training_dataset_for_model_and_version(model_id: str, version: str) -> d
5052
return TRAINING_DATASET_MODEL_DICT[(model_id, version)]
5153

5254

55+
def x_fail_if_ice(func):
56+
@functools.wraps(func)
57+
def wrapper(*args, **kwargs):
58+
try:
59+
return func(*args, **kwargs)
60+
except Exception as e:
61+
if "CapacityError" in str(e):
62+
pytest.xfail(str(e))
63+
raise
64+
65+
return wrapper
66+
67+
5368
def download_inference_assets():
5469

5570
if not os.path.exists(TMP_DIRECTORY_PATH):

0 commit comments

Comments
 (0)