Skip to content

feat: Add AutoMLV2 support #4461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/api/training/automlv2.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
AutoMLV2
--------

.. automodule:: sagemaker.automl.automlv2
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions doc/api/training/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Training APIs
algorithm
analytics
automl
automlv2
debugger
estimators
tuner
Expand Down
11 changes: 11 additions & 0 deletions src/sagemaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@

from sagemaker.automl.automl import AutoML, AutoMLJob, AutoMLInput # noqa: F401
from sagemaker.automl.candidate_estimator import CandidateEstimator, CandidateStep # noqa: F401
from sagemaker.automl.automlv2 import ( # noqa: F401
AutoMLV2,
AutoMLJobV2,
LocalAutoMLDataChannel,
AutoMLDataChannel,
AutoMLTimeSeriesForecastingConfig,
AutoMLImageClassificationConfig,
AutoMLTabularConfig,
AutoMLTextClassificationConfig,
AutoMLTextGenerationConfig,
)

from sagemaker.debugger import ProfilerConfig, Profiler # noqa: F401

Expand Down
1,433 changes: 1,433 additions & 0 deletions src/sagemaker/automl/automlv2.py

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions src/sagemaker/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@
MONITORING_SCHEDULE,
MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH,
AUTO_ML_ROLE_ARN_PATH,
AUTO_ML_V2_ROLE_ARN_PATH,
AUTO_ML_OUTPUT_CONFIG_PATH,
AUTO_ML_V2_OUTPUT_CONFIG_PATH,
AUTO_ML_JOB_CONFIG_PATH,
AUTO_ML_JOB,
AUTO_ML_JOB_V2,
COMPILATION_JOB_ROLE_ARN_PATH,
COMPILATION_JOB_OUTPUT_CONFIG_PATH,
COMPILATION_JOB_VPC_CONFIG_PATH,
Expand Down Expand Up @@ -111,9 +114,13 @@
FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH,
FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH,
AUTO_ML_KMS_KEY_ID_PATH,
AUTO_ML_V2_KMS_KEY_ID_PATH,
AUTO_ML_VPC_CONFIG_PATH,
AUTO_ML_V2_VPC_CONFIG_PATH,
AUTO_ML_VOLUME_KMS_KEY_ID_PATH,
AUTO_ML_V2_VOLUME_KMS_KEY_ID_PATH,
AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH,
AUTO_ML_V2_INTER_CONTAINER_ENCRYPTION_PATH,
ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH,
SESSION_DEFAULT_S3_BUCKET_PATH,
SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
Expand Down
38 changes: 38 additions & 0 deletions src/sagemaker/config/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
ENDPOINT = "Endpoint"
INFERENCE_COMPONENT = "InferenceComponent"
AUTO_ML_JOB = "AutoMLJob"
AUTO_ML_JOB_V2 = "AutoMLJobV2"
COMPILATION_JOB = "CompilationJob"
CUSTOM_PARAMETERS = "CustomParameters"
PIPELINE = "Pipeline"
Expand Down Expand Up @@ -182,14 +183,21 @@ def _simple_path(*args: str):
FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, SECURITY_CONFIG, KMS_KEY_ID
)
AUTO_ML_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, OUTPUT_DATA_CONFIG)
AUTO_ML_V2_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB_V2, OUTPUT_DATA_CONFIG)
AUTO_ML_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, OUTPUT_DATA_CONFIG, KMS_KEY_ID)
AUTO_ML_V2_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB_V2, OUTPUT_DATA_CONFIG, KMS_KEY_ID)
AUTO_ML_VOLUME_KMS_KEY_ID_PATH = _simple_path(
SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VOLUME_KMS_KEY_ID
)
AUTO_ML_V2_VOLUME_KMS_KEY_ID_PATH = _simple_path(
SAGEMAKER, AUTO_ML_JOB_V2, SECURITY_CONFIG, VOLUME_KMS_KEY_ID
)
AUTO_ML_ROLE_ARN_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, ROLE_ARN)
AUTO_ML_V2_ROLE_ARN_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB_V2, ROLE_ARN)
AUTO_ML_VPC_CONFIG_PATH = _simple_path(
SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VPC_CONFIG
)
AUTO_ML_V2_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB_V2, SECURITY_CONFIG, VPC_CONFIG)
AUTO_ML_JOB_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG)
MONITORING_JOB_DEFINITION_PREFIX = _simple_path(
SAGEMAKER,
Expand Down Expand Up @@ -362,6 +370,12 @@ def _simple_path(*args: str):
SECURITY_CONFIG,
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION,
)
AUTO_ML_V2_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path(
SAGEMAKER,
AUTO_ML_JOB_V2,
SECURITY_CONFIG,
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION,
)
PROCESSING_JOB_ENVIRONMENT_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, ENVIRONMENT)
PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path(
SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
Expand Down Expand Up @@ -947,6 +961,30 @@ def _simple_path(*args: str):
TAGS: {"$ref": "#/definitions/tags"},
},
},
# Auto ML V2
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateAutoMLJobV2.html
AUTO_ML_JOB_V2: {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: False,
PROPERTIES: {
SECURITY_CONFIG: {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: False,
PROPERTIES: {
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: {TYPE: "boolean"},
VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"},
VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"},
},
},
OUTPUT_DATA_CONFIG: {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: False,
PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}},
},
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
TAGS: {"$ref": "#/definitions/tags"},
},
},
# Transform Job
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTransformJob.html
TRANSFORM_JOB: {
Expand Down
173 changes: 171 additions & 2 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,12 @@
MONITORING_SCHEDULE,
MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH,
AUTO_ML_ROLE_ARN_PATH,
AUTO_ML_V2_ROLE_ARN_PATH,
AUTO_ML_OUTPUT_CONFIG_PATH,
AUTO_ML_V2_OUTPUT_CONFIG_PATH,
AUTO_ML_JOB_CONFIG_PATH,
AUTO_ML_JOB,
AUTO_ML_JOB_V2,
COMPILATION_JOB_ROLE_ARN_PATH,
COMPILATION_JOB_OUTPUT_CONFIG_PATH,
COMPILATION_JOB_VPC_CONFIG_PATH,
Expand Down Expand Up @@ -2570,7 +2573,7 @@ def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this m
exceptions.UnexpectedStatusException: If waiting and auto ml job fails.
"""

description = _wait_until(lambda: self.describe_auto_ml_job(job_name), poll)
description = _wait_until(lambda: self.describe_auto_ml_job_v2(job_name), poll)

instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
self.boto_session, description, job="AutoML"
Expand Down Expand Up @@ -2618,7 +2621,7 @@ def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this m
if state == LogState.JOB_COMPLETE:
state = LogState.COMPLETE
elif time.time() - last_describe_job_call >= 30:
description = self.sagemaker_client.describe_auto_ml_job(AutoMLJobName=job_name)
description = self.sagemaker_client.describe_auto_ml_job_v2(AutoMLJobName=job_name)
last_describe_job_call = time.time()

status = description["AutoMLJobStatus"]
Expand All @@ -2632,6 +2635,172 @@ def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this m
if dot:
print()

def create_auto_ml_v2(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Session is getting too big. move to separate file

self,
input_config,
job_name,
problem_config,
output_config,
job_objective=None,
model_deploy_config=None,
data_split_config=None,
role=None,
security_config=None,
tags=None,
):
"""Create an Amazon SageMaker AutoMLV2 job.

Args:
input_config (list[dict]): A list of AutoMLDataChannel objects.
Each channel contains "DataSource" and other optional fields.
job_name (str): A string that can be used to identify an AutoMLJob. Each AutoMLJob
should have a unique job name.
problem_config (object): A collection of settings specific
to the problem type used to configure an AutoML job V2.
There must be one and only one config of the following type.
Supported problem types are:

- Image Classification (sagemaker.automl.automlv2.ImageClassificationJobConfig),
- Tabular (sagemaker.automl.automlv2.TabularJobConfig),
- Text Classification (sagemaker.automl.automlv2.TextClassificationJobConfig),
- Text Generation (TextGenerationJobConfig),
- Time Series Forecasting (
sagemaker.automl.automlv2.TimeSeriesForecastingJobConfig).

output_config (dict): The S3 URI where you want to store the training results and
optional KMS key ID.
job_objective (dict): AutoMLJob objective, contains "AutoMLJobObjectiveType" (optional),
"MetricName" and "Value".
model_deploy_config (dict): Specifies how to generate the endpoint name
for an automatic one-click Autopilot model deployment.
Contains "AutoGenerateEndpointName" and "EndpointName"
data_split_config (dict): This structure specifies how to split the data
into train and validation datasets.
role (str): The Amazon Resource Name (ARN) of an IAM role that
Amazon SageMaker can assume to perform tasks on your behalf.
security_config (dict): The security configuration for traffic encryption
or Amazon VPC settings.
tags (Optional[Tags]): A list of dictionaries containing key-value
pairs.
"""

role = resolve_value_from_config(role, AUTO_ML_V2_ROLE_ARN_PATH, sagemaker_session=self)
inferred_output_config = update_nested_dictionary_with_values_from_config(
output_config, AUTO_ML_V2_OUTPUT_CONFIG_PATH, sagemaker_session=self
)

auto_ml_job_v2_request = self._get_auto_ml_request_v2(
input_config=input_config,
job_name=job_name,
problem_config=problem_config,
output_config=inferred_output_config,
role=role,
job_objective=job_objective,
model_deploy_config=model_deploy_config,
data_split_config=data_split_config,
security_config=security_config,
tags=format_tags(tags),
)

def submit(request):
logger.info("Creating auto-ml-v2-job with name: %s", job_name)
logger.debug("auto ml v2 request: %s", json.dumps(request), indent=4)
print(json.dumps(request))
self.sagemaker_client.create_auto_ml_job_v2(**request)

self._intercept_create_request(
auto_ml_job_v2_request, submit, self.create_auto_ml_v2.__name__
)

def _get_auto_ml_request_v2(
self,
input_config,
output_config,
job_name,
problem_config,
role,
job_objective=None,
model_deploy_config=None,
data_split_config=None,
security_config=None,
tags=None,
):
"""Constructs a request compatible for creating an Amazon SageMaker AutoML job.

Args:
input_config (list[dict]): A list of Channel objects. Each channel contains "DataSource"
and "TargetAttributeName", "CompressionType" and "SampleWeightAttributeName" are
optional fields.
output_config (dict): The S3 URI where you want to store the training results and
optional KMS key ID.
job_name (str): A string that can be used to identify an AutoMLJob. Each AutoMLJob
should have a unique job name.
problem_config (object): A collection of settings specific
to the problem type used to configure an AutoML job V2.
There must be one and only one config of the following type.
Supported problem types are:

- Image Classification (sagemaker.automl.automlv2.ImageClassificationJobConfig),
- Tabular (sagemaker.automl.automlv2.TabularJobConfig),
- Text Classification (sagemaker.automl.automlv2.TextClassificationJobConfig),
- Text Generation (TextGenerationJobConfig),
- Time Series Forecasting (
sagemaker.automl.automlv2.TimeSeriesForecastingJobConfig).

role (str): The Amazon Resource Name (ARN) of an IAM role that
Amazon SageMaker can assume to perform tasks on your behalf.
job_objective (dict): AutoMLJob objective, contains "AutoMLJobObjectiveType" (optional),
"MetricName" and "Value".
model_deploy_config (dict): Specifies how to generate the endpoint name
for an automatic one-click Autopilot model deployment.
Contains "AutoGenerateEndpointName" and "EndpointName"
data_split_config (dict): This structure specifies how to split the data
into train and validation datasets.
security_config (dict): The security configuration for traffic encryption
or Amazon VPC settings.
tags (Optional[Tags]): A list of dictionaries containing key-value
pairs.

Returns:
Dict: a automl v2 request dict
"""
auto_ml_job_v2_request = {
"AutoMLJobName": job_name,
"AutoMLJobInputDataConfig": input_config,
"OutputDataConfig": output_config,
"AutoMLProblemTypeConfig": problem_config,
"RoleArn": role,
}
if job_objective is not None:
auto_ml_job_v2_request["AutoMLJobObjective"] = job_objective
if model_deploy_config is not None:
auto_ml_job_v2_request["ModelDeployConfig"] = model_deploy_config
if data_split_config is not None:
auto_ml_job_v2_request["DataSplitConfig"] = data_split_config
if security_config is not None:
auto_ml_job_v2_request["SecurityConfig"] = security_config

tags = _append_project_tags(format_tags(tags))
tags = self._append_sagemaker_config_tags(
tags, "{}.{}.{}".format(SAGEMAKER, AUTO_ML_JOB_V2, TAGS)
)
if tags is not None:
auto_ml_job_v2_request["Tags"] = tags

return auto_ml_job_v2_request

# Done
def describe_auto_ml_job_v2(self, job_name):
"""Calls the DescribeAutoMLJobV2 API for the given job name and returns the response.

Args:
job_name (str): The name of the AutoML job to describe.

Returns:
dict: A dictionary response with the AutoMLV2 Job description.
"""
return self.sagemaker_client.describe_auto_ml_job_v2(AutoMLJobName=job_name)

def compile_model(
self,
input_model_config,
Expand Down
Loading