Skip to content

Commit c3cf139

Browse files
repushkoAnton Repushko
and
Anton Repushko
authored
feat: Add AutoMLV2 support (#4461)
* Add AutoMLV2 support * Improvements of the integration tests --------- Co-authored-by: Anton Repushko <[email protected]>
1 parent 36138da commit c3cf139

32 files changed

+5896
-2
lines changed

doc/api/training/automlv2.rst

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
AutoMLV2
2+
--------
3+
4+
.. automodule:: sagemaker.automl.automlv2
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:

doc/api/training/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Training APIs
88
algorithm
99
analytics
1010
automl
11+
automlv2
1112
debugger
1213
estimators
1314
tuner

src/sagemaker/__init__.py

+11
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@
6161

6262
from sagemaker.automl.automl import AutoML, AutoMLJob, AutoMLInput # noqa: F401
6363
from sagemaker.automl.candidate_estimator import CandidateEstimator, CandidateStep # noqa: F401
64+
from sagemaker.automl.automlv2 import ( # noqa: F401
65+
AutoMLV2,
66+
AutoMLJobV2,
67+
LocalAutoMLDataChannel,
68+
AutoMLDataChannel,
69+
AutoMLTimeSeriesForecastingConfig,
70+
AutoMLImageClassificationConfig,
71+
AutoMLTabularConfig,
72+
AutoMLTextClassificationConfig,
73+
AutoMLTextGenerationConfig,
74+
)
6475

6576
from sagemaker.debugger import ProfilerConfig, Profiler # noqa: F401
6677

src/sagemaker/automl/automlv2.py

+1,433
Large diffs are not rendered by default.

src/sagemaker/config/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,12 @@
4747
MONITORING_SCHEDULE,
4848
MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH,
4949
AUTO_ML_ROLE_ARN_PATH,
50+
AUTO_ML_V2_ROLE_ARN_PATH,
5051
AUTO_ML_OUTPUT_CONFIG_PATH,
52+
AUTO_ML_V2_OUTPUT_CONFIG_PATH,
5153
AUTO_ML_JOB_CONFIG_PATH,
5254
AUTO_ML_JOB,
55+
AUTO_ML_JOB_V2,
5356
COMPILATION_JOB_ROLE_ARN_PATH,
5457
COMPILATION_JOB_OUTPUT_CONFIG_PATH,
5558
COMPILATION_JOB_VPC_CONFIG_PATH,
@@ -111,9 +114,13 @@
111114
FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH,
112115
FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH,
113116
AUTO_ML_KMS_KEY_ID_PATH,
117+
AUTO_ML_V2_KMS_KEY_ID_PATH,
114118
AUTO_ML_VPC_CONFIG_PATH,
119+
AUTO_ML_V2_VPC_CONFIG_PATH,
115120
AUTO_ML_VOLUME_KMS_KEY_ID_PATH,
121+
AUTO_ML_V2_VOLUME_KMS_KEY_ID_PATH,
116122
AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH,
123+
AUTO_ML_V2_INTER_CONTAINER_ENCRYPTION_PATH,
117124
ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH,
118125
SESSION_DEFAULT_S3_BUCKET_PATH,
119126
SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,

src/sagemaker/config/config_schema.py

+38
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
ENDPOINT = "Endpoint"
8484
INFERENCE_COMPONENT = "InferenceComponent"
8585
AUTO_ML_JOB = "AutoMLJob"
86+
AUTO_ML_JOB_V2 = "AutoMLJobV2"
8687
COMPILATION_JOB = "CompilationJob"
8788
CUSTOM_PARAMETERS = "CustomParameters"
8889
PIPELINE = "Pipeline"
@@ -182,14 +183,21 @@ def _simple_path(*args: str):
182183
FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, SECURITY_CONFIG, KMS_KEY_ID
183184
)
184185
AUTO_ML_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, OUTPUT_DATA_CONFIG)
186+
AUTO_ML_V2_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB_V2, OUTPUT_DATA_CONFIG)
185187
AUTO_ML_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, OUTPUT_DATA_CONFIG, KMS_KEY_ID)
188+
AUTO_ML_V2_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB_V2, OUTPUT_DATA_CONFIG, KMS_KEY_ID)
186189
AUTO_ML_VOLUME_KMS_KEY_ID_PATH = _simple_path(
187190
SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VOLUME_KMS_KEY_ID
188191
)
192+
AUTO_ML_V2_VOLUME_KMS_KEY_ID_PATH = _simple_path(
193+
SAGEMAKER, AUTO_ML_JOB_V2, SECURITY_CONFIG, VOLUME_KMS_KEY_ID
194+
)
189195
AUTO_ML_ROLE_ARN_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, ROLE_ARN)
196+
AUTO_ML_V2_ROLE_ARN_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB_V2, ROLE_ARN)
190197
AUTO_ML_VPC_CONFIG_PATH = _simple_path(
191198
SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VPC_CONFIG
192199
)
200+
AUTO_ML_V2_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB_V2, SECURITY_CONFIG, VPC_CONFIG)
193201
AUTO_ML_JOB_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG)
194202
MONITORING_JOB_DEFINITION_PREFIX = _simple_path(
195203
SAGEMAKER,
@@ -362,6 +370,12 @@ def _simple_path(*args: str):
362370
SECURITY_CONFIG,
363371
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION,
364372
)
373+
AUTO_ML_V2_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path(
374+
SAGEMAKER,
375+
AUTO_ML_JOB_V2,
376+
SECURITY_CONFIG,
377+
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION,
378+
)
365379
PROCESSING_JOB_ENVIRONMENT_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, ENVIRONMENT)
366380
PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path(
367381
SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
@@ -947,6 +961,30 @@ def _simple_path(*args: str):
947961
TAGS: {"$ref": "#/definitions/tags"},
948962
},
949963
},
964+
# Auto ML V2
965+
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateAutoMLJobV2.html
966+
AUTO_ML_JOB_V2: {
967+
TYPE: OBJECT,
968+
ADDITIONAL_PROPERTIES: False,
969+
PROPERTIES: {
970+
SECURITY_CONFIG: {
971+
TYPE: OBJECT,
972+
ADDITIONAL_PROPERTIES: False,
973+
PROPERTIES: {
974+
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: {TYPE: "boolean"},
975+
VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"},
976+
VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"},
977+
},
978+
},
979+
OUTPUT_DATA_CONFIG: {
980+
TYPE: OBJECT,
981+
ADDITIONAL_PROPERTIES: False,
982+
PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}},
983+
},
984+
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
985+
TAGS: {"$ref": "#/definitions/tags"},
986+
},
987+
},
950988
# Transform Job
951989
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTransformJob.html
952990
TRANSFORM_JOB: {

src/sagemaker/session.py

+171-2
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,12 @@
6565
MONITORING_SCHEDULE,
6666
MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH,
6767
AUTO_ML_ROLE_ARN_PATH,
68+
AUTO_ML_V2_ROLE_ARN_PATH,
6869
AUTO_ML_OUTPUT_CONFIG_PATH,
70+
AUTO_ML_V2_OUTPUT_CONFIG_PATH,
6971
AUTO_ML_JOB_CONFIG_PATH,
7072
AUTO_ML_JOB,
73+
AUTO_ML_JOB_V2,
7174
COMPILATION_JOB_ROLE_ARN_PATH,
7275
COMPILATION_JOB_OUTPUT_CONFIG_PATH,
7376
COMPILATION_JOB_VPC_CONFIG_PATH,
@@ -2570,7 +2573,7 @@ def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this m
25702573
exceptions.UnexpectedStatusException: If waiting and auto ml job fails.
25712574
"""
25722575

2573-
description = _wait_until(lambda: self.describe_auto_ml_job(job_name), poll)
2576+
description = _wait_until(lambda: self.describe_auto_ml_job_v2(job_name), poll)
25742577

25752578
instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
25762579
self.boto_session, description, job="AutoML"
@@ -2618,7 +2621,7 @@ def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this m
26182621
if state == LogState.JOB_COMPLETE:
26192622
state = LogState.COMPLETE
26202623
elif time.time() - last_describe_job_call >= 30:
2621-
description = self.sagemaker_client.describe_auto_ml_job(AutoMLJobName=job_name)
2624+
description = self.sagemaker_client.describe_auto_ml_job_v2(AutoMLJobName=job_name)
26222625
last_describe_job_call = time.time()
26232626

26242627
status = description["AutoMLJobStatus"]
@@ -2632,6 +2635,172 @@ def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this m
26322635
if dot:
26332636
print()
26342637

2638+
def create_auto_ml_v2(
2639+
self,
2640+
input_config,
2641+
job_name,
2642+
problem_config,
2643+
output_config,
2644+
job_objective=None,
2645+
model_deploy_config=None,
2646+
data_split_config=None,
2647+
role=None,
2648+
security_config=None,
2649+
tags=None,
2650+
):
2651+
"""Create an Amazon SageMaker AutoMLV2 job.
2652+
2653+
Args:
2654+
input_config (list[dict]): A list of AutoMLDataChannel objects.
2655+
Each channel contains "DataSource" and other optional fields.
2656+
job_name (str): A string that can be used to identify an AutoMLJob. Each AutoMLJob
2657+
should have a unique job name.
2658+
problem_config (object): A collection of settings specific
2659+
to the problem type used to configure an AutoML job V2.
2660+
There must be one and only one config of the following type.
2661+
Supported problem types are:
2662+
2663+
- Image Classification (sagemaker.automl.automlv2.ImageClassificationJobConfig),
2664+
- Tabular (sagemaker.automl.automlv2.TabularJobConfig),
2665+
- Text Classification (sagemaker.automl.automlv2.TextClassificationJobConfig),
2666+
- Text Generation (TextGenerationJobConfig),
2667+
- Time Series Forecasting (
2668+
sagemaker.automl.automlv2.TimeSeriesForecastingJobConfig).
2669+
2670+
output_config (dict): The S3 URI where you want to store the training results and
2671+
optional KMS key ID.
2672+
job_objective (dict): AutoMLJob objective, contains "AutoMLJobObjectiveType" (optional),
2673+
"MetricName" and "Value".
2674+
model_deploy_config (dict): Specifies how to generate the endpoint name
2675+
for an automatic one-click Autopilot model deployment.
2676+
Contains "AutoGenerateEndpointName" and "EndpointName"
2677+
data_split_config (dict): This structure specifies how to split the data
2678+
into train and validation datasets.
2679+
role (str): The Amazon Resource Name (ARN) of an IAM role that
2680+
Amazon SageMaker can assume to perform tasks on your behalf.
2681+
security_config (dict): The security configuration for traffic encryption
2682+
or Amazon VPC settings.
2683+
tags (Optional[Tags]): A list of dictionaries containing key-value
2684+
pairs.
2685+
"""
2686+
2687+
role = resolve_value_from_config(role, AUTO_ML_V2_ROLE_ARN_PATH, sagemaker_session=self)
2688+
inferred_output_config = update_nested_dictionary_with_values_from_config(
2689+
output_config, AUTO_ML_V2_OUTPUT_CONFIG_PATH, sagemaker_session=self
2690+
)
2691+
2692+
auto_ml_job_v2_request = self._get_auto_ml_request_v2(
2693+
input_config=input_config,
2694+
job_name=job_name,
2695+
problem_config=problem_config,
2696+
output_config=inferred_output_config,
2697+
role=role,
2698+
job_objective=job_objective,
2699+
model_deploy_config=model_deploy_config,
2700+
data_split_config=data_split_config,
2701+
security_config=security_config,
2702+
tags=format_tags(tags),
2703+
)
2704+
2705+
def submit(request):
2706+
logger.info("Creating auto-ml-v2-job with name: %s", job_name)
2707+
logger.debug("auto ml v2 request: %s", json.dumps(request), indent=4)
2708+
print(json.dumps(request))
2709+
self.sagemaker_client.create_auto_ml_job_v2(**request)
2710+
2711+
self._intercept_create_request(
2712+
auto_ml_job_v2_request, submit, self.create_auto_ml_v2.__name__
2713+
)
2714+
2715+
def _get_auto_ml_request_v2(
2716+
self,
2717+
input_config,
2718+
output_config,
2719+
job_name,
2720+
problem_config,
2721+
role,
2722+
job_objective=None,
2723+
model_deploy_config=None,
2724+
data_split_config=None,
2725+
security_config=None,
2726+
tags=None,
2727+
):
2728+
"""Constructs a request compatible for creating an Amazon SageMaker AutoML job.
2729+
2730+
Args:
2731+
input_config (list[dict]): A list of Channel objects. Each channel contains "DataSource"
2732+
and "TargetAttributeName", "CompressionType" and "SampleWeightAttributeName" are
2733+
optional fields.
2734+
output_config (dict): The S3 URI where you want to store the training results and
2735+
optional KMS key ID.
2736+
job_name (str): A string that can be used to identify an AutoMLJob. Each AutoMLJob
2737+
should have a unique job name.
2738+
problem_config (object): A collection of settings specific
2739+
to the problem type used to configure an AutoML job V2.
2740+
There must be one and only one config of the following type.
2741+
Supported problem types are:
2742+
2743+
- Image Classification (sagemaker.automl.automlv2.ImageClassificationJobConfig),
2744+
- Tabular (sagemaker.automl.automlv2.TabularJobConfig),
2745+
- Text Classification (sagemaker.automl.automlv2.TextClassificationJobConfig),
2746+
- Text Generation (TextGenerationJobConfig),
2747+
- Time Series Forecasting (
2748+
sagemaker.automl.automlv2.TimeSeriesForecastingJobConfig).
2749+
2750+
role (str): The Amazon Resource Name (ARN) of an IAM role that
2751+
Amazon SageMaker can assume to perform tasks on your behalf.
2752+
job_objective (dict): AutoMLJob objective, contains "AutoMLJobObjectiveType" (optional),
2753+
"MetricName" and "Value".
2754+
model_deploy_config (dict): Specifies how to generate the endpoint name
2755+
for an automatic one-click Autopilot model deployment.
2756+
Contains "AutoGenerateEndpointName" and "EndpointName"
2757+
data_split_config (dict): This structure specifies how to split the data
2758+
into train and validation datasets.
2759+
security_config (dict): The security configuration for traffic encryption
2760+
or Amazon VPC settings.
2761+
tags (Optional[Tags]): A list of dictionaries containing key-value
2762+
pairs.
2763+
2764+
Returns:
2765+
Dict: a automl v2 request dict
2766+
"""
2767+
auto_ml_job_v2_request = {
2768+
"AutoMLJobName": job_name,
2769+
"AutoMLJobInputDataConfig": input_config,
2770+
"OutputDataConfig": output_config,
2771+
"AutoMLProblemTypeConfig": problem_config,
2772+
"RoleArn": role,
2773+
}
2774+
if job_objective is not None:
2775+
auto_ml_job_v2_request["AutoMLJobObjective"] = job_objective
2776+
if model_deploy_config is not None:
2777+
auto_ml_job_v2_request["ModelDeployConfig"] = model_deploy_config
2778+
if data_split_config is not None:
2779+
auto_ml_job_v2_request["DataSplitConfig"] = data_split_config
2780+
if security_config is not None:
2781+
auto_ml_job_v2_request["SecurityConfig"] = security_config
2782+
2783+
tags = _append_project_tags(format_tags(tags))
2784+
tags = self._append_sagemaker_config_tags(
2785+
tags, "{}.{}.{}".format(SAGEMAKER, AUTO_ML_JOB_V2, TAGS)
2786+
)
2787+
if tags is not None:
2788+
auto_ml_job_v2_request["Tags"] = tags
2789+
2790+
return auto_ml_job_v2_request
2791+
2792+
# Done
2793+
def describe_auto_ml_job_v2(self, job_name):
2794+
"""Calls the DescribeAutoMLJobV2 API for the given job name and returns the response.
2795+
2796+
Args:
2797+
job_name (str): The name of the AutoML job to describe.
2798+
2799+
Returns:
2800+
dict: A dictionary response with the AutoMLV2 Job description.
2801+
"""
2802+
return self.sagemaker_client.describe_auto_ml_job_v2(AutoMLJobName=job_name)
2803+
26352804
def compile_model(
26362805
self,
26372806
input_model_config,

0 commit comments

Comments
 (0)