Skip to content

Commit 543cdbb

Browse files
repushkoliujiaorr
authored and
root
committed
Add AutoML -> AutoMLV2 mapper (aws#4500)
Co-authored-by: liujiaor <[email protected]>
1 parent a41e50c commit 543cdbb

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

src/sagemaker/automl/automlv2.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""A class for SageMaker AutoML V2 Jobs."""
14-
from __future__ import absolute_import
14+
15+
from __future__ import absolute_import, annotations
1516

1617
import logging
1718
from dataclasses import dataclass
1819
from typing import Dict, List, Optional, Union
1920

2021
from sagemaker import Model, PipelineModel, s3
22+
from sagemaker.automl.automl import AutoML
2123
from sagemaker.automl.candidate_estimator import CandidateEstimator
2224
from sagemaker.config import (
2325
AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH,
@@ -727,6 +729,46 @@ def __init__(
727729
self._auto_ml_job_desc = None
728730
self._best_candidate = None
729731

732+
@classmethod
733+
def from_auto_ml(cls, auto_ml: AutoML) -> AutoMLV2:
734+
"""Create an AutoMLV2 object from an AutoML object.
735+
736+
This method maps AutoML properties into an AutoMLV2 object,
737+
so you can create AutoMLV2 jobs from the existing AutoML objects.
738+
739+
Args:
740+
auto_ml (sagemaker.automl.automl.AutoML): An AutoML object from which
741+
an AutoMLV2 object will be created.
742+
"""
743+
auto_ml_v2 = AutoMLV2(
744+
problem_config=AutoMLTabularConfig(
745+
target_attribute_name=auto_ml.target_attribute_name,
746+
feature_specification_s3_uri=auto_ml.feature_specification_s3_uri,
747+
generate_candidate_definitions_only=auto_ml.generate_candidate_definitions_only,
748+
mode=auto_ml.mode,
749+
problem_type=auto_ml.problem_type,
750+
sample_weight_attribute_name=auto_ml.sample_weight_attribute_name,
751+
max_candidates=auto_ml.max_candidate,
752+
max_runtime_per_training_job_in_seconds=auto_ml.max_runtime_per_training_job_in_seconds, # noqa E501 # pylint: disable=c0301
753+
max_total_job_runtime_in_seconds=auto_ml.total_job_runtime_in_seconds,
754+
),
755+
base_job_name=auto_ml.base_job_name,
756+
output_path=auto_ml.output_path,
757+
output_kms_key=auto_ml.output_kms_key,
758+
job_objective=auto_ml.job_objective,
759+
validation_fraction=auto_ml.validation_fraction,
760+
auto_generate_endpoint_name=auto_ml.auto_generate_endpoint_name,
761+
endpoint_name=auto_ml.endpoint_name,
762+
role=auto_ml.role,
763+
volume_kms_key=auto_ml.volume_kms_key,
764+
encrypt_inter_container_traffic=auto_ml.encrypt_inter_container_traffic,
765+
vpc_config=auto_ml.vpc_config,
766+
tags=auto_ml.tags,
767+
sagemaker_session=auto_ml.sagemaker_session,
768+
)
769+
auto_ml_v2._best_candidate = auto_ml._best_candidate
770+
return auto_ml_v2
771+
730772
def fit(
731773
self,
732774
inputs: Optional[

tests/unit/sagemaker/automl/test_auto_ml_v2.py

+27
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
CandidateEstimator,
2525
LocalAutoMLDataChannel,
2626
PipelineModel,
27+
AutoML,
2728
)
2829
from sagemaker.predictor import Predictor
2930
from sagemaker.session_settings import SessionSettings
@@ -1100,3 +1101,29 @@ def without_user_input(sess):
11001101
expected__with_user_input__with_default_bucket_only="s3://test",
11011102
)
11021103
assert actual == expected
1104+
1105+
1106+
def test_automl_v1_to_automl_v2_mapping():
1107+
auto_ml = AutoML(
1108+
role=ROLE,
1109+
target_attribute_name=TARGET_ATTRIBUTE_NAME,
1110+
sample_weight_attribute_name=SAMPLE_WEIGHT_ATTRIBUTE_NAME,
1111+
output_kms_key=OUTPUT_KMS_KEY,
1112+
output_path=OUTPUT_PATH,
1113+
max_candidates=MAX_CANDIDATES,
1114+
base_job_name=BASE_JOB_NAME,
1115+
)
1116+
1117+
auto_ml_v2 = AutoMLV2.from_auto_ml(auto_ml=auto_ml)
1118+
1119+
assert isinstance(auto_ml_v2.problem_config, AutoMLTabularConfig)
1120+
assert auto_ml_v2.role == auto_ml.role
1121+
assert auto_ml_v2.problem_config.target_attribute_name == auto_ml.target_attribute_name
1122+
assert (
1123+
auto_ml_v2.problem_config.sample_weight_attribute_name
1124+
== auto_ml.sample_weight_attribute_name
1125+
)
1126+
assert auto_ml_v2.output_kms_key == auto_ml.output_kms_key
1127+
assert auto_ml_v2.output_path == auto_ml.output_path
1128+
assert auto_ml_v2.problem_config.max_candidates == auto_ml.max_candidate
1129+
assert auto_ml_v2.base_job_name == auto_ml.base_job_name

0 commit comments

Comments
 (0)