Skip to content

Commit ae5ae1e

Browse files
author
Anton Repushko
committed
Add unit tests for all AutoMLV2 problem types
1 parent 240579b commit ae5ae1e

File tree

2 files changed

+256
-5
lines changed

2 files changed

+256
-5
lines changed

src/sagemaker/automl/automlv2.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ def from_response_dict(cls, api_problem_type_config: dict):
289289
text_generation_hyper_params=api_problem_type_config.get(
290290
"TextGenerationHyperParameters"
291291
),
292-
accept_eula=api_problem_type_config.get("AcceptEula", None),
292+
accept_eula=api_problem_type_config.get(
293+
"ModelAccessConfig", {}).get("AcceptEula", None),
293294
)
294295

295296
def to_request_dict(self):
@@ -309,7 +310,7 @@ def to_request_dict(self):
309310
] = self.max_runtime_per_training_job_in_seconds
310311
if self.max_total_job_runtime_in_seconds is not None:
311312
config["CompletionCriteria"][
312-
"MaxTotalRuntimeInSeconds"
313+
"MaxAutoMLJobRuntimeInSeconds"
313314
] = self.max_total_job_runtime_in_seconds
314315

315316
if self.base_model_name is not None:
@@ -427,7 +428,7 @@ def from_response_dict(cls, api_problem_type_config: dict):
427428
grouping_attribute_names=api_problem_type_config.get("TimeSeriesConfig", {}).get(
428429
"GroupingAttributeNames"
429430
),
430-
holiday_config=api_problem_type_config.get("HolidayConfig", {}).get("CountryCode"),
431+
holiday_config=api_problem_type_config.get("HolidayConfig", [{}])[0].get("CountryCode"),
431432
)
432433

433434
def to_request_dict(self):
@@ -462,8 +463,7 @@ def to_request_dict(self):
462463

463464
if self.holiday_config:
464465
config["HolidayConfig"] = []
465-
for country in self.holiday_config:
466-
config["HolidayConfig"].append({"CountryCode": country})
466+
config["HolidayConfig"].append({"CountryCode": self.holiday_config})
467467

468468
if self.aggregation or self.filling:
469469
config["Transformations"] = {}
+251
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
14+
#
15+
# Licensed under the Apache License, Version 2.0 (the "License"). You
16+
# may not use this file except in compliance with the License. A copy of
17+
# the License is located at
18+
#
19+
# http://aws.amazon.com/apache2.0/
20+
#
21+
# or in the "license" file accompanying this file. This file is
22+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
23+
# ANY KIND, either express or implied. See the License for the specific
24+
# language governing permissions and limitations under the License.
25+
from __future__ import absolute_import
26+
27+
from sagemaker import (
28+
AutoMLTabularConfig,
29+
AutoMLImageClassificationConfig,
30+
AutoMLTextGenerationConfig,
31+
AutoMLTextClassificationConfig,
32+
AutoMLTimeSeriesForecastingConfig,
33+
)
34+
35+
# Common params
36+
MAX_CANDIDATES = 10
37+
MAX_RUNTIME_PER_TRAINING_JOB = 3600
38+
TOTAL_JOB_RUNTIME = 36000
39+
BUCKET_NAME = "mybucket"
40+
FEATURE_SPECIFICATION_S3_URI = "s3://{}/features.json".format(BUCKET_NAME)
41+
42+
# Tabular params
43+
AUTO_ML_TABULAR_ALGORITHMS = "xgboost"
44+
MODE = "ENSEMBLING"
45+
GENERATE_CANDIDATE_DEFINITIONS_ONLY = True
46+
PROBLEM_TYPE = "BinaryClassification"
47+
TARGET_ATTRIBUTE_NAME = "target"
48+
SAMPLE_WEIGHT_ATTRIBUTE_NAME = "sampleWeight"
49+
50+
TABULAR_PROBLEM_CONFIG = {
51+
"CompletionCriteria": {
52+
"MaxCandidates": MAX_CANDIDATES,
53+
"MaxRuntimePerTrainingJobInSeconds": MAX_RUNTIME_PER_TRAINING_JOB,
54+
"MaxAutoMLJobRuntimeInSeconds": TOTAL_JOB_RUNTIME,
55+
},
56+
"CandidateGenerationConfig": {
57+
"AlgorithmsConfig": [{"AutoMLAlgorithms": AUTO_ML_TABULAR_ALGORITHMS}],
58+
},
59+
"FeatureSpecificationS3Uri": FEATURE_SPECIFICATION_S3_URI,
60+
"Mode": MODE,
61+
"GenerateCandidateDefinitionsOnly": GENERATE_CANDIDATE_DEFINITIONS_ONLY,
62+
"ProblemType": PROBLEM_TYPE,
63+
"TargetAttributeName": TARGET_ATTRIBUTE_NAME,
64+
"SampleWeightAttributeName": SAMPLE_WEIGHT_ATTRIBUTE_NAME,
65+
}
66+
67+
# Image classification params
68+
69+
IMAGE_CLASSIFICATION_PROBLEM_CONFIG = {
70+
"CompletionCriteria": {
71+
"MaxCandidates": MAX_CANDIDATES,
72+
"MaxRuntimePerTrainingJobInSeconds": MAX_RUNTIME_PER_TRAINING_JOB,
73+
"MaxAutoMLJobRuntimeInSeconds": TOTAL_JOB_RUNTIME,
74+
},
75+
}
76+
77+
# Text classification
78+
CONTEXT_COLUMN = "text"
79+
TARGET_LABEL_COLUMN = "class"
80+
81+
TEXT_CLASSIFICATION_PROBLEM_CONFIG = {
82+
"CompletionCriteria": {
83+
"MaxCandidates": MAX_CANDIDATES,
84+
"MaxRuntimePerTrainingJobInSeconds": MAX_RUNTIME_PER_TRAINING_JOB,
85+
"MaxAutoMLJobRuntimeInSeconds": TOTAL_JOB_RUNTIME,
86+
},
87+
"ContentColumn": CONTEXT_COLUMN,
88+
"TargetLabelColumn": TARGET_LABEL_COLUMN,
89+
}
90+
91+
# Text generation params
92+
BASE_MODEL_NAME = "base_model"
93+
TEXT_GENERATION_HYPER_PARAMS = {"test": 1}
94+
ACCEPT_EULA = True
95+
96+
TEXT_GENERATION_PROBLEM_CONFIG = {
97+
"CompletionCriteria": {
98+
"MaxCandidates": MAX_CANDIDATES,
99+
"MaxRuntimePerTrainingJobInSeconds": MAX_RUNTIME_PER_TRAINING_JOB,
100+
"MaxAutoMLJobRuntimeInSeconds": TOTAL_JOB_RUNTIME,
101+
},
102+
"BaseModelName": BASE_MODEL_NAME,
103+
"TextGenerationHyperParameters": TEXT_GENERATION_HYPER_PARAMS,
104+
"ModelAccessConfig": {
105+
"AcceptEula": ACCEPT_EULA,
106+
}
107+
}
108+
109+
# Time series forecasting params
110+
FORECAST_FREQUENCY = "1D"
111+
FORECAST_HORIZON = 5
112+
ITEM_IDENTIFIER_ATTRIBUTE_NAME = "identifier_attribute"
113+
TIMESTAMP_ATTRIBUTE_NAME = "timestamp_attribute"
114+
FORECAST_QUANTILES = ["p1"]
115+
HOLIDAY_CONFIG = "DE"
116+
117+
118+
TIME_SERIES_FORECASTING_PROBLEM_CONFIG = {
119+
"CompletionCriteria": {
120+
"MaxCandidates": MAX_CANDIDATES,
121+
"MaxRuntimePerTrainingJobInSeconds": MAX_RUNTIME_PER_TRAINING_JOB,
122+
"MaxAutoMLJobRuntimeInSeconds": TOTAL_JOB_RUNTIME,
123+
},
124+
"FeatureSpecificationS3Uri": FEATURE_SPECIFICATION_S3_URI,
125+
"ForecastFrequency": FORECAST_FREQUENCY,
126+
"ForecastHorizon": FORECAST_HORIZON,
127+
"TimeSeriesConfig": {
128+
"ItemIdentifierAttributeName": ITEM_IDENTIFIER_ATTRIBUTE_NAME,
129+
"TargetAttributeName": TARGET_ATTRIBUTE_NAME,
130+
"TimestampAttributeName": TIMESTAMP_ATTRIBUTE_NAME,
131+
},
132+
"ForecastQuantiles": FORECAST_QUANTILES,
133+
"HolidayConfig": [{
134+
"CountryCode": HOLIDAY_CONFIG,
135+
}],
136+
}
137+
138+
def test_tabular_problem_config_from_response():
139+
problem_config = AutoMLTabularConfig.from_response_dict(TABULAR_PROBLEM_CONFIG)
140+
assert problem_config.algorithms_config == AUTO_ML_TABULAR_ALGORITHMS
141+
assert problem_config.feature_specification_s3_uri == FEATURE_SPECIFICATION_S3_URI
142+
assert problem_config.generate_candidate_definitions_only == GENERATE_CANDIDATE_DEFINITIONS_ONLY
143+
assert problem_config.max_candidates == MAX_CANDIDATES
144+
assert problem_config.max_runtime_per_training_job_in_seconds == MAX_RUNTIME_PER_TRAINING_JOB
145+
assert problem_config.max_total_job_runtime_in_seconds == TOTAL_JOB_RUNTIME
146+
assert problem_config.mode == MODE
147+
assert problem_config.problem_type == PROBLEM_TYPE
148+
assert problem_config.sample_weight_attribute_name == SAMPLE_WEIGHT_ATTRIBUTE_NAME
149+
assert problem_config.target_attribute_name == TARGET_ATTRIBUTE_NAME
150+
151+
def test_tabular_problem_config_to_request():
152+
problem_config = AutoMLTabularConfig(
153+
target_attribute_name=TARGET_ATTRIBUTE_NAME,
154+
algorithms_config=AUTO_ML_TABULAR_ALGORITHMS,
155+
feature_specification_s3_uri=FEATURE_SPECIFICATION_S3_URI,
156+
generate_candidate_definitions_only=GENERATE_CANDIDATE_DEFINITIONS_ONLY,
157+
mode=MODE,
158+
problem_type=PROBLEM_TYPE,
159+
sample_weight_attribute_name=SAMPLE_WEIGHT_ATTRIBUTE_NAME,
160+
max_candidates=MAX_CANDIDATES,
161+
max_total_job_runtime_in_seconds=TOTAL_JOB_RUNTIME,
162+
max_runtime_per_training_job_in_seconds=MAX_RUNTIME_PER_TRAINING_JOB,
163+
)
164+
165+
assert problem_config.to_request_dict()["TabularJobConfig"] == TABULAR_PROBLEM_CONFIG
166+
167+
def test_image_classification_problem_config_from_response():
168+
problem_config = AutoMLImageClassificationConfig.from_response_dict(IMAGE_CLASSIFICATION_PROBLEM_CONFIG)
169+
assert problem_config.max_candidates == MAX_CANDIDATES
170+
assert problem_config.max_runtime_per_training_job_in_seconds == MAX_RUNTIME_PER_TRAINING_JOB
171+
assert problem_config.max_total_job_runtime_in_seconds == TOTAL_JOB_RUNTIME
172+
173+
def test_image_classification_problem_config_to_request():
174+
problem_config = AutoMLImageClassificationConfig(
175+
max_candidates=MAX_CANDIDATES,
176+
max_total_job_runtime_in_seconds=TOTAL_JOB_RUNTIME,
177+
max_runtime_per_training_job_in_seconds=MAX_RUNTIME_PER_TRAINING_JOB,
178+
)
179+
180+
assert problem_config.to_request_dict()["ImageClassificationJobConfig"] == IMAGE_CLASSIFICATION_PROBLEM_CONFIG
181+
182+
def test_text_classification_problem_config_from_response():
183+
problem_config = AutoMLTextClassificationConfig.from_response_dict(TEXT_CLASSIFICATION_PROBLEM_CONFIG)
184+
assert problem_config.content_column == CONTEXT_COLUMN
185+
assert problem_config.target_label_column == TARGET_LABEL_COLUMN
186+
assert problem_config.max_candidates == MAX_CANDIDATES
187+
assert problem_config.max_runtime_per_training_job_in_seconds == MAX_RUNTIME_PER_TRAINING_JOB
188+
assert problem_config.max_total_job_runtime_in_seconds == TOTAL_JOB_RUNTIME
189+
190+
def test_text_classification_to_request():
191+
problem_config = AutoMLTextClassificationConfig(
192+
content_column=CONTEXT_COLUMN,
193+
target_label_column=TARGET_LABEL_COLUMN,
194+
max_candidates=MAX_CANDIDATES,
195+
max_total_job_runtime_in_seconds=TOTAL_JOB_RUNTIME,
196+
max_runtime_per_training_job_in_seconds=MAX_RUNTIME_PER_TRAINING_JOB,
197+
)
198+
199+
assert problem_config.to_request_dict()["TextClassificationJobConfig"] == TEXT_CLASSIFICATION_PROBLEM_CONFIG
200+
201+
def test_text_generation_problem_config_from_response():
202+
problem_config = AutoMLTextGenerationConfig.from_response_dict(TEXT_GENERATION_PROBLEM_CONFIG)
203+
assert problem_config.accept_eula == ACCEPT_EULA
204+
assert problem_config.base_model_name == BASE_MODEL_NAME
205+
assert problem_config.max_candidates == MAX_CANDIDATES
206+
assert problem_config.max_runtime_per_training_job_in_seconds == MAX_RUNTIME_PER_TRAINING_JOB
207+
assert problem_config.max_total_job_runtime_in_seconds == TOTAL_JOB_RUNTIME
208+
assert problem_config.text_generation_hyper_params == TEXT_GENERATION_HYPER_PARAMS
209+
210+
def test_text_generation_problem_config_to_request():
211+
problem_config = AutoMLTextGenerationConfig(
212+
accept_eula=ACCEPT_EULA,
213+
base_model_name=BASE_MODEL_NAME,
214+
text_generation_hyper_params=TEXT_GENERATION_HYPER_PARAMS,
215+
max_candidates=MAX_CANDIDATES,
216+
max_total_job_runtime_in_seconds=TOTAL_JOB_RUNTIME,
217+
max_runtime_per_training_job_in_seconds=MAX_RUNTIME_PER_TRAINING_JOB,
218+
)
219+
220+
assert problem_config.to_request_dict()["TextGenerationJobConfig"] == TEXT_GENERATION_PROBLEM_CONFIG
221+
222+
def test_time_series_forecasting_problem_config_from_response():
223+
problem_config = AutoMLTimeSeriesForecastingConfig.from_response_dict(TIME_SERIES_FORECASTING_PROBLEM_CONFIG)
224+
assert problem_config.forecast_frequency == FORECAST_FREQUENCY
225+
assert problem_config.forecast_horizon == FORECAST_HORIZON
226+
assert problem_config.item_identifier_attribute_name == ITEM_IDENTIFIER_ATTRIBUTE_NAME
227+
assert problem_config.target_attribute_name == TARGET_ATTRIBUTE_NAME
228+
assert problem_config.timestamp_attribute_name == TIMESTAMP_ATTRIBUTE_NAME
229+
assert problem_config.max_candidates == MAX_CANDIDATES
230+
assert problem_config.max_runtime_per_training_job_in_seconds == MAX_RUNTIME_PER_TRAINING_JOB
231+
assert problem_config.max_total_job_runtime_in_seconds == TOTAL_JOB_RUNTIME
232+
assert problem_config.forecast_quantiles == FORECAST_QUANTILES
233+
assert problem_config.holiday_config == HOLIDAY_CONFIG
234+
assert problem_config.feature_specification_s3_uri == FEATURE_SPECIFICATION_S3_URI
235+
236+
def test_time_series_forecasting_problem_config_to_request():
237+
problem_config = AutoMLTimeSeriesForecastingConfig(
238+
forecast_frequency=FORECAST_FREQUENCY,
239+
forecast_horizon=FORECAST_HORIZON,
240+
item_identifier_attribute_name=ITEM_IDENTIFIER_ATTRIBUTE_NAME,
241+
target_attribute_name=TARGET_ATTRIBUTE_NAME,
242+
timestamp_attribute_name=TIMESTAMP_ATTRIBUTE_NAME,
243+
forecast_quantiles=FORECAST_QUANTILES,
244+
holiday_config=HOLIDAY_CONFIG,
245+
feature_specification_s3_uri=FEATURE_SPECIFICATION_S3_URI,
246+
max_candidates=MAX_CANDIDATES,
247+
max_total_job_runtime_in_seconds=TOTAL_JOB_RUNTIME,
248+
max_runtime_per_training_job_in_seconds=MAX_RUNTIME_PER_TRAINING_JOB,
249+
)
250+
251+
assert problem_config.to_request_dict()["TimeSeriesForecastingJobConfig"] == TIME_SERIES_FORECASTING_PROBLEM_CONFIG

0 commit comments

Comments
 (0)