Skip to content

Commit f6ade25

Browse files
committed
feat: integration tests for jumpstart sdk retrieve functions
1 parent a427d4b commit f6ade25

File tree

11 files changed

+903
-1
lines changed

11 files changed

+903
-1
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@ venv/
2727
*.swp
2828
.docker/
2929
env/
30-
.vscode/
30+
.vscode/
31+
**/tmp

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def read_version():
7979
"fabric>=2.0",
8080
"requests>=2.20.0, <3",
8181
"sagemaker-experiments",
82+
"regex",
8283
],
8384
)
8485

tests/integ/sagemaker/jumpstart/__init__.py

Whitespace-only changes.

tests/integ/sagemaker/jumpstart/retrieve_uri/__init__.py

Whitespace-only changes.
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import boto3
16+
import pytest
17+
18+
from tests.integ.sagemaker.jumpstart.retrieve_uri.utils import (
19+
get_test_cache_bucket,
20+
get_test_suite_id,
21+
)
22+
from tests.integ.sagemaker.jumpstart.retrieve_uri.constants import (
23+
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
24+
JUMPSTART_TAG,
25+
)
26+
import os
27+
28+
29+
def _setup():
30+
print("Setting up...")
31+
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: get_test_suite_id()})
32+
33+
34+
def _teardown():
35+
print("Tearing down...")
36+
37+
test_cache_bucket = get_test_cache_bucket()
38+
39+
test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]
40+
41+
sagemaker_client = boto3.client("sagemaker")
42+
43+
search_endpoints_result = sagemaker_client.search(
44+
Resource="Endpoint",
45+
SearchExpression={
46+
"Filters": [
47+
{"Name": f"Tags.{JUMPSTART_TAG}", "Operator": "Equals", "Value": test_suite_id}
48+
]
49+
},
50+
)
51+
52+
endpoint_names = [
53+
endpoint_info["Endpoint"]["EndpointName"]
54+
for endpoint_info in search_endpoints_result["Results"]
55+
]
56+
endpoint_config_names = [
57+
endpoint_info["Endpoint"]["EndpointConfigName"]
58+
for endpoint_info in search_endpoints_result["Results"]
59+
]
60+
model_names = [
61+
sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)[
62+
"ProductionVariants"
63+
][0]["ModelName"]
64+
for endpoint_config_name in endpoint_config_names
65+
]
66+
67+
# delete test-suite-tagged endpoints
68+
for endpoint_name in endpoint_names:
69+
sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
70+
71+
# delete endpoint configs for test-suite-tagged endpoints
72+
for endpoint_config_name in endpoint_config_names:
73+
sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
74+
75+
# delete models for test-suite-tagged endpoints
76+
for model_name in model_names:
77+
sagemaker_client.delete_model(ModelName=model_name)
78+
79+
# delete test artifact/cache s3 folder
80+
s3_resource = boto3.resource("s3")
81+
bucket = s3_resource.Bucket(test_cache_bucket)
82+
bucket.objects.filter(Prefix=test_suite_id + "/").delete()
83+
84+
85+
@pytest.fixture(scope="session", autouse=True)
86+
def setup(request):
87+
_setup()
88+
89+
request.addfinalizer(_teardown)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from enum import Enum
16+
from typing import Dict
17+
from typing import Optional
18+
from typing import Union
19+
import os
20+
21+
22+
def _to_s3_path(filename: str, s3_folder: Optional[str]) -> str:
23+
return filename if not s3_folder else f"{s3_folder}/{filename}"
24+
25+
26+
_NB_ASSETS_S3_FOLDER = "inference-notebook-assets"
27+
_TF_FLOWERS_S3_FOLDER = "training-datasets/tf_flowers"
28+
29+
TMP_DIRECTORY_PATH = os.path.join(
30+
os.path.abspath(os.path.join(os.path.abspath(__file__), os.pardir)), "tmp"
31+
)
32+
33+
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID = "JUMPSTART_SDK_TEST_SUITE_ID"
34+
35+
JUMPSTART_TAG = "JumpStart-SDK-Integ-Test-Suite-Id"
36+
37+
HYPERPARAMETER_MODEL_DICT = {
38+
("huggingface-spc-bert-base-cased", "*"): {
39+
"epochs": "1",
40+
"adam-learning-rate": "2e-05",
41+
"batch-size": "8",
42+
"sagemaker_submit_directory": "/opt/ml/input/data/code/sourcedir.tar.gz",
43+
"sagemaker_program": "transfer_learning.py",
44+
"sagemaker_container_log_level": "20",
45+
},
46+
}
47+
48+
TRAINING_DATASET_MODEL_DICT = {
49+
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI/"),
50+
}
51+
52+
53+
class ContentType(str, Enum):
54+
"""Possible value for content type argument of SageMakerRuntime.invokeEndpoint."""
55+
56+
X_IMAGE = "application/x-image"
57+
LIST_TEXT = "application/list-text"
58+
X_TEXT = "application/x-text"
59+
TEXT_CSV = "text/csv"
60+
61+
62+
class InferenceImageFilename(str, Enum):
63+
"""Filename of the inference asset in JumpStart distribution buckets."""
64+
65+
DOG = "dog.jpg"
66+
CAT = "cat.jpg"
67+
DAISY = "100080576_f52e8ee070_n.jpg"
68+
DAISY_2 = "10140303196_b88d3d6cec.jpg"
69+
ROSE = "102501987_3cdb8e5394_n.jpg"
70+
NAXOS_TAVERNA = "Naxos_Taverna.jpg"
71+
PEDESTRIAN = "img_pedestrian.png"
72+
73+
74+
class InferenceTabularDataname(str, Enum):
75+
"""Filename of the tabular data example in JumpStart distribution buckets."""
76+
77+
REGRESSION_ONEHOT = "regressonehot_data.csv"
78+
REGRESSION = "regress_data.csv"
79+
MULTICLASS = "multiclass_data.csv"
80+
81+
82+
class ClassLabelFile(str, Enum):
83+
"""Filename in JumpStart distribution buckets for the map of the class index to human readable labels."""
84+
85+
IMAGE_NET = "ImageNetLabels.txt"
86+
87+
88+
TEST_ASSETS_SPECS: Dict[
89+
Union[InferenceImageFilename, InferenceTabularDataname, ClassLabelFile], str
90+
] = {
91+
InferenceImageFilename.DOG: _to_s3_path(InferenceImageFilename.DOG, _NB_ASSETS_S3_FOLDER),
92+
InferenceImageFilename.CAT: _to_s3_path(InferenceImageFilename.CAT, _NB_ASSETS_S3_FOLDER),
93+
InferenceImageFilename.DAISY: _to_s3_path(
94+
InferenceImageFilename.DAISY, f"{_TF_FLOWERS_S3_FOLDER}/daisy"
95+
),
96+
InferenceImageFilename.DAISY_2: _to_s3_path(
97+
InferenceImageFilename.DAISY_2, f"{_TF_FLOWERS_S3_FOLDER}/daisy"
98+
),
99+
InferenceImageFilename.ROSE: _to_s3_path(
100+
InferenceImageFilename.ROSE, f"{_TF_FLOWERS_S3_FOLDER}/roses"
101+
),
102+
InferenceImageFilename.NAXOS_TAVERNA: _to_s3_path(
103+
InferenceImageFilename.NAXOS_TAVERNA, _NB_ASSETS_S3_FOLDER
104+
),
105+
InferenceImageFilename.PEDESTRIAN: _to_s3_path(
106+
InferenceImageFilename.PEDESTRIAN, _NB_ASSETS_S3_FOLDER
107+
),
108+
ClassLabelFile.IMAGE_NET: _to_s3_path(ClassLabelFile.IMAGE_NET, _NB_ASSETS_S3_FOLDER),
109+
InferenceTabularDataname.REGRESSION_ONEHOT: _to_s3_path(
110+
InferenceTabularDataname.REGRESSION_ONEHOT, _NB_ASSETS_S3_FOLDER
111+
),
112+
InferenceTabularDataname.REGRESSION: _to_s3_path(
113+
InferenceTabularDataname.REGRESSION, _NB_ASSETS_S3_FOLDER
114+
),
115+
InferenceTabularDataname.MULTICLASS: _to_s3_path(
116+
InferenceTabularDataname.MULTICLASS, _NB_ASSETS_S3_FOLDER
117+
),
118+
}

0 commit comments

Comments
 (0)