1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
14
+ #
15
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
16
+ # may not use this file except in compliance with the License. A copy of
17
+ # the License is located at
18
+ #
19
+ # http://aws.amazon.com/apache2.0/
20
+ #
21
+ # or in the "license" file accompanying this file. This file is
22
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
23
+ # ANY KIND, either express or implied. See the License for the specific
24
+ # language governing permissions and limitations under the License.
25
+ from __future__ import absolute_import
26
+
27
+ from sagemaker import (
28
+ AutoMLTabularConfig ,
29
+ AutoMLImageClassificationConfig ,
30
+ AutoMLTextGenerationConfig ,
31
+ AutoMLTextClassificationConfig ,
32
+ AutoMLTimeSeriesForecastingConfig ,
33
+ )
34
+
35
+ # Common params
36
+ MAX_CANDIDATES = 10
37
+ MAX_RUNTIME_PER_TRAINING_JOB = 3600
38
+ TOTAL_JOB_RUNTIME = 36000
39
+ BUCKET_NAME = "mybucket"
40
+ FEATURE_SPECIFICATION_S3_URI = "s3://{}/features.json" .format (BUCKET_NAME )
41
+
42
+ # Tabular params
43
+ AUTO_ML_TABULAR_ALGORITHMS = "xgboost"
44
+ MODE = "ENSEMBLING"
45
+ GENERATE_CANDIDATE_DEFINITIONS_ONLY = True
46
+ PROBLEM_TYPE = "BinaryClassification"
47
+ TARGET_ATTRIBUTE_NAME = "target"
48
+ SAMPLE_WEIGHT_ATTRIBUTE_NAME = "sampleWeight"
49
+
50
+ TABULAR_PROBLEM_CONFIG = {
51
+ "CompletionCriteria" : {
52
+ "MaxCandidates" : MAX_CANDIDATES ,
53
+ "MaxRuntimePerTrainingJobInSeconds" : MAX_RUNTIME_PER_TRAINING_JOB ,
54
+ "MaxAutoMLJobRuntimeInSeconds" : TOTAL_JOB_RUNTIME ,
55
+ },
56
+ "CandidateGenerationConfig" : {
57
+ "AlgorithmsConfig" : [{"AutoMLAlgorithms" : AUTO_ML_TABULAR_ALGORITHMS }],
58
+ },
59
+ "FeatureSpecificationS3Uri" : FEATURE_SPECIFICATION_S3_URI ,
60
+ "Mode" : MODE ,
61
+ "GenerateCandidateDefinitionsOnly" : GENERATE_CANDIDATE_DEFINITIONS_ONLY ,
62
+ "ProblemType" : PROBLEM_TYPE ,
63
+ "TargetAttributeName" : TARGET_ATTRIBUTE_NAME ,
64
+ "SampleWeightAttributeName" : SAMPLE_WEIGHT_ATTRIBUTE_NAME ,
65
+ }
66
+
67
+ # Image classification params
68
+
69
+ IMAGE_CLASSIFICATION_PROBLEM_CONFIG = {
70
+ "CompletionCriteria" : {
71
+ "MaxCandidates" : MAX_CANDIDATES ,
72
+ "MaxRuntimePerTrainingJobInSeconds" : MAX_RUNTIME_PER_TRAINING_JOB ,
73
+ "MaxAutoMLJobRuntimeInSeconds" : TOTAL_JOB_RUNTIME ,
74
+ },
75
+ }
76
+
77
+ # Text classification
78
+ CONTEXT_COLUMN = "text"
79
+ TARGET_LABEL_COLUMN = "class"
80
+
81
+ TEXT_CLASSIFICATION_PROBLEM_CONFIG = {
82
+ "CompletionCriteria" : {
83
+ "MaxCandidates" : MAX_CANDIDATES ,
84
+ "MaxRuntimePerTrainingJobInSeconds" : MAX_RUNTIME_PER_TRAINING_JOB ,
85
+ "MaxAutoMLJobRuntimeInSeconds" : TOTAL_JOB_RUNTIME ,
86
+ },
87
+ "ContentColumn" : CONTEXT_COLUMN ,
88
+ "TargetLabelColumn" : TARGET_LABEL_COLUMN ,
89
+ }
90
+
91
+ # Text generation params
92
+ BASE_MODEL_NAME = "base_model"
93
+ TEXT_GENERATION_HYPER_PARAMS = {"test" : 1 }
94
+ ACCEPT_EULA = True
95
+
96
+ TEXT_GENERATION_PROBLEM_CONFIG = {
97
+ "CompletionCriteria" : {
98
+ "MaxCandidates" : MAX_CANDIDATES ,
99
+ "MaxRuntimePerTrainingJobInSeconds" : MAX_RUNTIME_PER_TRAINING_JOB ,
100
+ "MaxAutoMLJobRuntimeInSeconds" : TOTAL_JOB_RUNTIME ,
101
+ },
102
+ "BaseModelName" : BASE_MODEL_NAME ,
103
+ "TextGenerationHyperParameters" : TEXT_GENERATION_HYPER_PARAMS ,
104
+ "ModelAccessConfig" : {
105
+ "AcceptEula" : ACCEPT_EULA ,
106
+ }
107
+ }
108
+
109
+ # Time series forecasting params
110
+ FORECAST_FREQUENCY = "1D"
111
+ FORECAST_HORIZON = 5
112
+ ITEM_IDENTIFIER_ATTRIBUTE_NAME = "identifier_attribute"
113
+ TIMESTAMP_ATTRIBUTE_NAME = "timestamp_attribute"
114
+ FORECAST_QUANTILES = ["p1" ]
115
+ HOLIDAY_CONFIG = "DE"
116
+
117
+
118
+ TIME_SERIES_FORECASTING_PROBLEM_CONFIG = {
119
+ "CompletionCriteria" : {
120
+ "MaxCandidates" : MAX_CANDIDATES ,
121
+ "MaxRuntimePerTrainingJobInSeconds" : MAX_RUNTIME_PER_TRAINING_JOB ,
122
+ "MaxAutoMLJobRuntimeInSeconds" : TOTAL_JOB_RUNTIME ,
123
+ },
124
+ "FeatureSpecificationS3Uri" : FEATURE_SPECIFICATION_S3_URI ,
125
+ "ForecastFrequency" : FORECAST_FREQUENCY ,
126
+ "ForecastHorizon" : FORECAST_HORIZON ,
127
+ "TimeSeriesConfig" : {
128
+ "ItemIdentifierAttributeName" : ITEM_IDENTIFIER_ATTRIBUTE_NAME ,
129
+ "TargetAttributeName" : TARGET_ATTRIBUTE_NAME ,
130
+ "TimestampAttributeName" : TIMESTAMP_ATTRIBUTE_NAME ,
131
+ },
132
+ "ForecastQuantiles" : FORECAST_QUANTILES ,
133
+ "HolidayConfig" : [{
134
+ "CountryCode" : HOLIDAY_CONFIG ,
135
+ }],
136
+ }
137
+
138
+ def test_tabular_problem_config_from_response ():
139
+ problem_config = AutoMLTabularConfig .from_response_dict (TABULAR_PROBLEM_CONFIG )
140
+ assert problem_config .algorithms_config == AUTO_ML_TABULAR_ALGORITHMS
141
+ assert problem_config .feature_specification_s3_uri == FEATURE_SPECIFICATION_S3_URI
142
+ assert problem_config .generate_candidate_definitions_only == GENERATE_CANDIDATE_DEFINITIONS_ONLY
143
+ assert problem_config .max_candidates == MAX_CANDIDATES
144
+ assert problem_config .max_runtime_per_training_job_in_seconds == MAX_RUNTIME_PER_TRAINING_JOB
145
+ assert problem_config .max_total_job_runtime_in_seconds == TOTAL_JOB_RUNTIME
146
+ assert problem_config .mode == MODE
147
+ assert problem_config .problem_type == PROBLEM_TYPE
148
+ assert problem_config .sample_weight_attribute_name == SAMPLE_WEIGHT_ATTRIBUTE_NAME
149
+ assert problem_config .target_attribute_name == TARGET_ATTRIBUTE_NAME
150
+
151
+ def test_tabular_problem_config_to_request ():
152
+ problem_config = AutoMLTabularConfig (
153
+ target_attribute_name = TARGET_ATTRIBUTE_NAME ,
154
+ algorithms_config = AUTO_ML_TABULAR_ALGORITHMS ,
155
+ feature_specification_s3_uri = FEATURE_SPECIFICATION_S3_URI ,
156
+ generate_candidate_definitions_only = GENERATE_CANDIDATE_DEFINITIONS_ONLY ,
157
+ mode = MODE ,
158
+ problem_type = PROBLEM_TYPE ,
159
+ sample_weight_attribute_name = SAMPLE_WEIGHT_ATTRIBUTE_NAME ,
160
+ max_candidates = MAX_CANDIDATES ,
161
+ max_total_job_runtime_in_seconds = TOTAL_JOB_RUNTIME ,
162
+ max_runtime_per_training_job_in_seconds = MAX_RUNTIME_PER_TRAINING_JOB ,
163
+ )
164
+
165
+ assert problem_config .to_request_dict ()["TabularJobConfig" ] == TABULAR_PROBLEM_CONFIG
166
+
167
+ def test_image_classification_problem_config_from_response ():
168
+ problem_config = AutoMLImageClassificationConfig .from_response_dict (IMAGE_CLASSIFICATION_PROBLEM_CONFIG )
169
+ assert problem_config .max_candidates == MAX_CANDIDATES
170
+ assert problem_config .max_runtime_per_training_job_in_seconds == MAX_RUNTIME_PER_TRAINING_JOB
171
+ assert problem_config .max_total_job_runtime_in_seconds == TOTAL_JOB_RUNTIME
172
+
173
+ def test_image_classification_problem_config_to_request ():
174
+ problem_config = AutoMLImageClassificationConfig (
175
+ max_candidates = MAX_CANDIDATES ,
176
+ max_total_job_runtime_in_seconds = TOTAL_JOB_RUNTIME ,
177
+ max_runtime_per_training_job_in_seconds = MAX_RUNTIME_PER_TRAINING_JOB ,
178
+ )
179
+
180
+ assert problem_config .to_request_dict ()["ImageClassificationJobConfig" ] == IMAGE_CLASSIFICATION_PROBLEM_CONFIG
181
+
182
+ def test_text_classification_problem_config_from_response ():
183
+ problem_config = AutoMLTextClassificationConfig .from_response_dict (TEXT_CLASSIFICATION_PROBLEM_CONFIG )
184
+ assert problem_config .content_column == CONTEXT_COLUMN
185
+ assert problem_config .target_label_column == TARGET_LABEL_COLUMN
186
+ assert problem_config .max_candidates == MAX_CANDIDATES
187
+ assert problem_config .max_runtime_per_training_job_in_seconds == MAX_RUNTIME_PER_TRAINING_JOB
188
+ assert problem_config .max_total_job_runtime_in_seconds == TOTAL_JOB_RUNTIME
189
+
190
+ def test_text_classification_to_request ():
191
+ problem_config = AutoMLTextClassificationConfig (
192
+ content_column = CONTEXT_COLUMN ,
193
+ target_label_column = TARGET_LABEL_COLUMN ,
194
+ max_candidates = MAX_CANDIDATES ,
195
+ max_total_job_runtime_in_seconds = TOTAL_JOB_RUNTIME ,
196
+ max_runtime_per_training_job_in_seconds = MAX_RUNTIME_PER_TRAINING_JOB ,
197
+ )
198
+
199
+ assert problem_config .to_request_dict ()["TextClassificationJobConfig" ] == TEXT_CLASSIFICATION_PROBLEM_CONFIG
200
+
201
+ def test_text_generation_problem_config_from_response ():
202
+ problem_config = AutoMLTextGenerationConfig .from_response_dict (TEXT_GENERATION_PROBLEM_CONFIG )
203
+ assert problem_config .accept_eula == ACCEPT_EULA
204
+ assert problem_config .base_model_name == BASE_MODEL_NAME
205
+ assert problem_config .max_candidates == MAX_CANDIDATES
206
+ assert problem_config .max_runtime_per_training_job_in_seconds == MAX_RUNTIME_PER_TRAINING_JOB
207
+ assert problem_config .max_total_job_runtime_in_seconds == TOTAL_JOB_RUNTIME
208
+ assert problem_config .text_generation_hyper_params == TEXT_GENERATION_HYPER_PARAMS
209
+
210
+ def test_text_generation_problem_config_to_request ():
211
+ problem_config = AutoMLTextGenerationConfig (
212
+ accept_eula = ACCEPT_EULA ,
213
+ base_model_name = BASE_MODEL_NAME ,
214
+ text_generation_hyper_params = TEXT_GENERATION_HYPER_PARAMS ,
215
+ max_candidates = MAX_CANDIDATES ,
216
+ max_total_job_runtime_in_seconds = TOTAL_JOB_RUNTIME ,
217
+ max_runtime_per_training_job_in_seconds = MAX_RUNTIME_PER_TRAINING_JOB ,
218
+ )
219
+
220
+ assert problem_config .to_request_dict ()["TextGenerationJobConfig" ] == TEXT_GENERATION_PROBLEM_CONFIG
221
+
222
+ def test_time_series_forecasting_problem_config_from_response ():
223
+ problem_config = AutoMLTimeSeriesForecastingConfig .from_response_dict (TIME_SERIES_FORECASTING_PROBLEM_CONFIG )
224
+ assert problem_config .forecast_frequency == FORECAST_FREQUENCY
225
+ assert problem_config .forecast_horizon == FORECAST_HORIZON
226
+ assert problem_config .item_identifier_attribute_name == ITEM_IDENTIFIER_ATTRIBUTE_NAME
227
+ assert problem_config .target_attribute_name == TARGET_ATTRIBUTE_NAME
228
+ assert problem_config .timestamp_attribute_name == TIMESTAMP_ATTRIBUTE_NAME
229
+ assert problem_config .max_candidates == MAX_CANDIDATES
230
+ assert problem_config .max_runtime_per_training_job_in_seconds == MAX_RUNTIME_PER_TRAINING_JOB
231
+ assert problem_config .max_total_job_runtime_in_seconds == TOTAL_JOB_RUNTIME
232
+ assert problem_config .forecast_quantiles == FORECAST_QUANTILES
233
+ assert problem_config .holiday_config == HOLIDAY_CONFIG
234
+ assert problem_config .feature_specification_s3_uri == FEATURE_SPECIFICATION_S3_URI
235
+
236
+ def test_time_series_forecasting_problem_config_to_request ():
237
+ problem_config = AutoMLTimeSeriesForecastingConfig (
238
+ forecast_frequency = FORECAST_FREQUENCY ,
239
+ forecast_horizon = FORECAST_HORIZON ,
240
+ item_identifier_attribute_name = ITEM_IDENTIFIER_ATTRIBUTE_NAME ,
241
+ target_attribute_name = TARGET_ATTRIBUTE_NAME ,
242
+ timestamp_attribute_name = TIMESTAMP_ATTRIBUTE_NAME ,
243
+ forecast_quantiles = FORECAST_QUANTILES ,
244
+ holiday_config = HOLIDAY_CONFIG ,
245
+ feature_specification_s3_uri = FEATURE_SPECIFICATION_S3_URI ,
246
+ max_candidates = MAX_CANDIDATES ,
247
+ max_total_job_runtime_in_seconds = TOTAL_JOB_RUNTIME ,
248
+ max_runtime_per_training_job_in_seconds = MAX_RUNTIME_PER_TRAINING_JOB ,
249
+ )
250
+
251
+ assert problem_config .to_request_dict ()["TimeSeriesForecastingJobConfig" ] == TIME_SERIES_FORECASTING_PROBLEM_CONFIG
0 commit comments