|
13 | 13 | from __future__ import absolute_import
|
14 | 14 | import os
|
15 | 15 | import time
|
| 16 | + |
| 17 | +import pytest |
16 | 18 | from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
|
17 | 19 |
|
18 | 20 | from sagemaker.jumpstart.estimator import JumpStartEstimator
|
| 21 | +import tests |
19 | 22 | from tests.integ.sagemaker.jumpstart.constants import (
|
20 | 23 | ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
|
21 | 24 | JUMPSTART_TAG,
|
22 | 25 | )
|
23 | 26 | from tests.integ.sagemaker.jumpstart.utils import (
|
24 | 27 | get_sm_session,
|
25 | 28 | get_training_dataset_for_model_and_version,
|
| 29 | + x_fail_if_ice, |
26 | 30 | )
|
27 | 31 |
|
28 | 32 | from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
|
29 | 33 |
|
30 | 34 |
|
31 | 35 | MAX_INIT_TIME_SECONDS = 5
|
32 | 36 |
|
| 37 | +GATED_TRAINING_MODEL_SUPPORTED_REGIONS = { |
| 38 | + "us-west-2", |
| 39 | + "us-east-1", |
| 40 | + "eu-west-1", |
| 41 | + "ap-southeast-1", |
| 42 | + "us-east-2", |
| 43 | + "ap-southeast-2", |
| 44 | +} |
| 45 | + |
33 | 46 |
|
34 | 47 | def test_jumpstart_estimator(setup):
|
35 | 48 |
|
@@ -63,6 +76,49 @@ def test_jumpstart_estimator(setup):
|
63 | 76 | assert response is not None
|
64 | 77 |
|
65 | 78 |
|
| 79 | +@x_fail_if_ice |
| 80 | +@pytest.mark.skipif( |
| 81 | + tests.integ.test_region() not in GATED_TRAINING_MODEL_SUPPORTED_REGIONS, |
| 82 | + reason=f"JumpStart gated training models unavailable in {tests.integ.test_region()}.", |
| 83 | +) |
| 84 | +def test_gated_model_training(setup): |
| 85 | + |
| 86 | + model_id, model_version = "meta-textgeneration-llama-2-7b", "*" |
| 87 | + |
| 88 | + estimator = JumpStartEstimator( |
| 89 | + model_id=model_id, |
| 90 | + role=get_sm_session().get_caller_identity_arn(), |
| 91 | + sagemaker_session=get_sm_session(), |
| 92 | + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], |
| 93 | + environment={"accept_eula": "true"}, |
| 94 | + max_run=259200, # avoid exceeding resource limits |
| 95 | + ) |
| 96 | + |
| 97 | + # uses ml.g5.12xlarge instance |
| 98 | + estimator.fit( |
| 99 | + { |
| 100 | + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" |
| 101 | + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", |
| 102 | + } |
| 103 | + ) |
| 104 | + |
| 105 | + # uses ml.g5.2xlarge instance |
| 106 | + predictor = estimator.deploy( |
| 107 | + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], |
| 108 | + role=get_sm_session().get_caller_identity_arn(), |
| 109 | + sagemaker_session=get_sm_session(), |
| 110 | + ) |
| 111 | + |
| 112 | + payload = { |
| 113 | + "inputs": "some-payload", |
| 114 | + "parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6}, |
| 115 | + } |
| 116 | + |
| 117 | + response = predictor.predict(payload, custom_attributes="accept_eula=true") |
| 118 | + |
| 119 | + assert response is not None |
| 120 | + |
| 121 | + |
66 | 122 | def test_instatiating_estimator_not_too_slow(setup):
|
67 | 123 |
|
68 | 124 | model_id = "xgboost-classification-model"
|
|
0 commit comments