|
11 | 11 | # ANY KIND, either express or implied. See the License for the specific
|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | """A class for SageMaker AutoML V2 Jobs."""
|
14 |
| -from __future__ import absolute_import |
| 14 | + |
| 15 | +from __future__ import absolute_import, annotations |
15 | 16 |
|
16 | 17 | import logging
|
17 | 18 | from dataclasses import dataclass
|
18 | 19 | from typing import Dict, List, Optional, Union
|
19 | 20 |
|
20 | 21 | from sagemaker import Model, PipelineModel, s3
|
| 22 | +from sagemaker.automl.automl import AutoML |
21 | 23 | from sagemaker.automl.candidate_estimator import CandidateEstimator
|
22 | 24 | from sagemaker.config import (
|
23 | 25 | AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH,
|
@@ -727,6 +729,46 @@ def __init__(
|
727 | 729 | self._auto_ml_job_desc = None
|
728 | 730 | self._best_candidate = None
|
729 | 731 |
|
| 732 | + @classmethod |
| 733 | + def from_auto_ml(cls, auto_ml: AutoML) -> AutoMLV2: |
| 734 | + """Create an AutoMLV2 object from an AutoML object. |
| 735 | +
|
| 736 | + This method maps AutoML properties into an AutoMLV2 object, |
| 737 | + so you can create AutoMLV2 jobs from the existing AutoML objects. |
| 738 | +
|
| 739 | + Args: |
| 740 | + auto_ml (sagemaker.automl.automl.AutoML): An AutoML object from which |
| 741 | + an AutoMLV2 object will be created. |
| 742 | + """ |
| 743 | + auto_ml_v2 = AutoMLV2( |
| 744 | + problem_config=AutoMLTabularConfig( |
| 745 | + target_attribute_name=auto_ml.target_attribute_name, |
| 746 | + feature_specification_s3_uri=auto_ml.feature_specification_s3_uri, |
| 747 | + generate_candidate_definitions_only=auto_ml.generate_candidate_definitions_only, |
| 748 | + mode=auto_ml.mode, |
| 749 | + problem_type=auto_ml.problem_type, |
| 750 | + sample_weight_attribute_name=auto_ml.sample_weight_attribute_name, |
| 751 | + max_candidates=auto_ml.max_candidate, |
| 752 | + max_runtime_per_training_job_in_seconds=auto_ml.max_runtime_per_training_job_in_seconds, # noqa E501 # pylint: disable=c0301 |
| 753 | + max_total_job_runtime_in_seconds=auto_ml.total_job_runtime_in_seconds, |
| 754 | + ), |
| 755 | + base_job_name=auto_ml.base_job_name, |
| 756 | + output_path=auto_ml.output_path, |
| 757 | + output_kms_key=auto_ml.output_kms_key, |
| 758 | + job_objective=auto_ml.job_objective, |
| 759 | + validation_fraction=auto_ml.validation_fraction, |
| 760 | + auto_generate_endpoint_name=auto_ml.auto_generate_endpoint_name, |
| 761 | + endpoint_name=auto_ml.endpoint_name, |
| 762 | + role=auto_ml.role, |
| 763 | + volume_kms_key=auto_ml.volume_kms_key, |
| 764 | + encrypt_inter_container_traffic=auto_ml.encrypt_inter_container_traffic, |
| 765 | + vpc_config=auto_ml.vpc_config, |
| 766 | + tags=auto_ml.tags, |
| 767 | + sagemaker_session=auto_ml.sagemaker_session, |
| 768 | + ) |
| 769 | + auto_ml_v2._best_candidate = auto_ml._best_candidate |
| 770 | + return auto_ml_v2 |
| 771 | + |
730 | 772 | def fit(
|
731 | 773 | self,
|
732 | 774 | inputs: Optional[
|
|
0 commit comments