Skip to content

Commit c09435f

Browse files
committed
fix: jumpstart unit tests
1 parent edcfe67 commit c09435f

File tree

3 files changed

+142
-123
lines changed

3 files changed

+142
-123
lines changed

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 125 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,131 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
16+
SPECIAL_MODEL_SPECS_DICT = {
17+
"mock-model-training-prepacked-script-key": {
18+
"model_id": "sklearn-classification-linear",
19+
"url": "https://scikit-learn.org/stable/",
20+
"version": "1.0.0",
21+
"min_sdk_version": "2.68.1",
22+
"training_supported": True,
23+
"incremental_training_supported": False,
24+
"hosting_ecr_specs": {
25+
"framework": "sklearn",
26+
"framework_version": "0.23-1",
27+
"py_version": "py3",
28+
},
29+
"hosting_artifact_key": "sklearn-infer/infer-sklearn-classification-linear.tar.gz",
30+
"hosting_script_key": "source-directory-tarballs/sklearn/inference/classification/v1.0.0/sourcedir.tar.gz",
31+
"inference_vulnerable": False,
32+
"inference_dependencies": [],
33+
"inference_vulnerabilities": [],
34+
"training_vulnerable": False,
35+
"training_dependencies": [],
36+
"training_vulnerabilities": [],
37+
"deprecated": False,
38+
"hyperparameters": [
39+
{
40+
"name": "tol",
41+
"type": "float",
42+
"default": 0.0001,
43+
"min": 1e-20,
44+
"max": 50,
45+
"scope": "algorithm",
46+
},
47+
{
48+
"name": "penalty",
49+
"type": "text",
50+
"default": "l2",
51+
"options": ["l1", "l2", "elasticnet", "none"],
52+
"scope": "algorithm",
53+
},
54+
{
55+
"name": "alpha",
56+
"type": "float",
57+
"default": 0.0001,
58+
"min": 1e-20,
59+
"max": 999,
60+
"scope": "algorithm",
61+
},
62+
{
63+
"name": "l1_ratio",
64+
"type": "float",
65+
"default": 0.15,
66+
"min": 0,
67+
"max": 1,
68+
"scope": "algorithm",
69+
},
70+
{
71+
"name": "sagemaker_submit_directory",
72+
"type": "text",
73+
"default": "/opt/ml/input/data/code/sourcedir.tar.gz",
74+
"scope": "container",
75+
},
76+
{
77+
"name": "sagemaker_program",
78+
"type": "text",
79+
"default": "transfer_learning.py",
80+
"scope": "container",
81+
},
82+
{
83+
"name": "sagemaker_container_log_level",
84+
"type": "text",
85+
"default": "20",
86+
"scope": "container",
87+
},
88+
],
89+
"training_script_key": "source-directory-tarballs/sklearn/transfer_learning/classification/"
90+
"v1.0.0/sourcedir.tar.gz",
91+
"training_prepacked_script_key": "some/key/to/training_prepacked_script_key.tar.gz",
92+
"training_ecr_specs": {
93+
"framework_version": "0.23-1",
94+
"framework": "sklearn",
95+
"py_version": "py3",
96+
},
97+
"training_artifact_key": "sklearn-training/train-sklearn-classification-linear.tar.gz",
98+
"inference_environment_variables": [
99+
{
100+
"name": "SAGEMAKER_PROGRAM",
101+
"type": "text",
102+
"default": "inference.py",
103+
"scope": "container",
104+
},
105+
{
106+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
107+
"type": "text",
108+
"default": "/opt/ml/model/code",
109+
"scope": "container",
110+
},
111+
{
112+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
113+
"type": "text",
114+
"default": "20",
115+
"scope": "container",
116+
},
117+
{
118+
"name": "MODEL_CACHE_ROOT",
119+
"type": "text",
120+
"default": "/opt/ml/model",
121+
"scope": "container",
122+
},
123+
{"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"},
124+
{
125+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
126+
"type": "text",
127+
"default": "1",
128+
"scope": "container",
129+
},
130+
{
131+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
132+
"type": "text",
133+
"default": "3600",
134+
"scope": "container",
135+
},
136+
],
137+
}
138+
}
139+
15140
PROTOTYPICAL_MODEL_SPECS_DICT = {
16141
"pytorch-eqa-bert-base-cased": {
17142
"model_id": "pytorch-eqa-bert-base-cased",
@@ -1070,127 +1195,6 @@
10701195
},
10711196
],
10721197
},
1073-
"mock-model-training-prepacked-script-key": {
1074-
"model_id": "sklearn-classification-linear",
1075-
"url": "https://scikit-learn.org/stable/",
1076-
"version": "1.0.0",
1077-
"min_sdk_version": "2.68.1",
1078-
"training_supported": True,
1079-
"incremental_training_supported": False,
1080-
"hosting_ecr_specs": {
1081-
"framework": "sklearn",
1082-
"framework_version": "0.23-1",
1083-
"py_version": "py3",
1084-
},
1085-
"hosting_artifact_key": "sklearn-infer/infer-sklearn-classification-linear.tar.gz",
1086-
"hosting_script_key": "source-directory-tarballs/sklearn/inference/classification/v1.0.0/sourcedir.tar.gz",
1087-
"inference_vulnerable": False,
1088-
"inference_dependencies": [],
1089-
"inference_vulnerabilities": [],
1090-
"training_vulnerable": False,
1091-
"training_dependencies": [],
1092-
"training_vulnerabilities": [],
1093-
"deprecated": False,
1094-
"hyperparameters": [
1095-
{
1096-
"name": "tol",
1097-
"type": "float",
1098-
"default": 0.0001,
1099-
"min": 1e-20,
1100-
"max": 50,
1101-
"scope": "algorithm",
1102-
},
1103-
{
1104-
"name": "penalty",
1105-
"type": "text",
1106-
"default": "l2",
1107-
"options": ["l1", "l2", "elasticnet", "none"],
1108-
"scope": "algorithm",
1109-
},
1110-
{
1111-
"name": "alpha",
1112-
"type": "float",
1113-
"default": 0.0001,
1114-
"min": 1e-20,
1115-
"max": 999,
1116-
"scope": "algorithm",
1117-
},
1118-
{
1119-
"name": "l1_ratio",
1120-
"type": "float",
1121-
"default": 0.15,
1122-
"min": 0,
1123-
"max": 1,
1124-
"scope": "algorithm",
1125-
},
1126-
{
1127-
"name": "sagemaker_submit_directory",
1128-
"type": "text",
1129-
"default": "/opt/ml/input/data/code/sourcedir.tar.gz",
1130-
"scope": "container",
1131-
},
1132-
{
1133-
"name": "sagemaker_program",
1134-
"type": "text",
1135-
"default": "transfer_learning.py",
1136-
"scope": "container",
1137-
},
1138-
{
1139-
"name": "sagemaker_container_log_level",
1140-
"type": "text",
1141-
"default": "20",
1142-
"scope": "container",
1143-
},
1144-
],
1145-
"training_script_key": "source-directory-tarballs/sklearn/transfer_learning/classification/"
1146-
"v1.0.0/sourcedir.tar.gz",
1147-
"training_prepacked_script_key": "some/key/to/training_prepacked_script_key.tar.gz",
1148-
"training_ecr_specs": {
1149-
"framework_version": "0.23-1",
1150-
"framework": "sklearn",
1151-
"py_version": "py3",
1152-
},
1153-
"training_artifact_key": "sklearn-training/train-sklearn-classification-linear.tar.gz",
1154-
"inference_environment_variables": [
1155-
{
1156-
"name": "SAGEMAKER_PROGRAM",
1157-
"type": "text",
1158-
"default": "inference.py",
1159-
"scope": "container",
1160-
},
1161-
{
1162-
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
1163-
"type": "text",
1164-
"default": "/opt/ml/model/code",
1165-
"scope": "container",
1166-
},
1167-
{
1168-
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
1169-
"type": "text",
1170-
"default": "20",
1171-
"scope": "container",
1172-
},
1173-
{
1174-
"name": "MODEL_CACHE_ROOT",
1175-
"type": "text",
1176-
"default": "/opt/ml/model",
1177-
"scope": "container",
1178-
},
1179-
{"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"},
1180-
{
1181-
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
1182-
"type": "text",
1183-
"default": "1",
1184-
"scope": "container",
1185-
},
1186-
{
1187-
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
1188-
"type": "text",
1189-
"default": "3600",
1190-
"scope": "container",
1191-
},
1192-
],
1193-
},
11941198
}
11951199

11961200
BASE_SPEC = {

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
BASE_MANIFEST,
3030
BASE_SPEC,
3131
BASE_HEADER,
32+
SPECIAL_MODEL_SPECS_DICT,
3233
)
3334

3435

@@ -92,6 +93,18 @@ def get_prototype_model_spec(
9293
return specs
9394

9495

96+
def get_special_model_spec(
97+
region: str = None, model_id: str = None, version: str = None
98+
) -> JumpStartModelSpecs:
99+
"""This function mocks cache accessor functions. For this mock,
100+
we only retrieve model specs based on the model ID. This is reserved
101+
for special specs.
102+
"""
103+
104+
specs = JumpStartModelSpecs(SPECIAL_MODEL_SPECS_DICT[model_id])
105+
return specs
106+
107+
95108
def get_spec_from_base_spec(
96109
_obj: JumpStartModelsCache = None,
97110
region: str = None,

tests/unit/sagemaker/script_uris/jumpstart/test_combined_script_artifact.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from sagemaker import script_uris
1818
import pytest
1919

20-
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
20+
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec, get_special_model_spec
2121

2222

2323
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
2424
def test_jumpstart_combined_artifacts(patched_get_model_specs):
2525

26-
patched_get_model_specs.side_effect = get_prototype_model_spec
26+
patched_get_model_specs.side_effect = get_special_model_spec
2727

2828
model_id_combined_script_artifact = "mock-model-training-prepacked-script-key"
2929

@@ -48,6 +48,8 @@ def test_jumpstart_combined_artifacts(patched_get_model_specs):
4848
include_training_script=True,
4949
)
5050

51+
patched_get_model_specs.side_effect = get_prototype_model_spec
52+
5153
model_id_combined_script_artifact_unsupported = "xgboost-classification-model"
5254

5355
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)