19
19
20
20
from sagemaker import Model , PipelineModel
21
21
from sagemaker .automl .candidate_estimator import CandidateEstimator
22
+ from sagemaker .config import (
23
+ AUTO_ML_ROLE_ARN_PATH ,
24
+ AUTO_ML_KMS_KEY_ID_PATH ,
25
+ AUTO_ML_VPC_CONFIG_PATH ,
26
+ AUTO_ML_VOLUME_KMS_KEY_ID_PATH ,
27
+ AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH ,
28
+ )
22
29
from sagemaker .job import _Job
23
30
from sagemaker .session import Session
24
- from sagemaker .utils import name_from_base
31
+ from sagemaker .utils import name_from_base , resolve_value_from_config
25
32
from sagemaker .workflow .entities import PipelineVariable
26
33
from sagemaker .workflow .pipeline_context import runnable_by_pipeline
27
34
@@ -98,15 +105,15 @@ class AutoML(object):
98
105
99
106
def __init__ (
100
107
self ,
101
- role : str ,
102
- target_attribute_name : str ,
108
+ role : Optional [ str ] = None ,
109
+ target_attribute_name : str = None ,
103
110
output_kms_key : Optional [str ] = None ,
104
111
output_path : Optional [str ] = None ,
105
112
base_job_name : Optional [str ] = None ,
106
113
compression_type : Optional [str ] = None ,
107
114
sagemaker_session : Optional [Session ] = None ,
108
115
volume_kms_key : Optional [str ] = None ,
109
- encrypt_inter_container_traffic : Optional [bool ] = False ,
116
+ encrypt_inter_container_traffic : Optional [bool ] = None ,
110
117
vpc_config : Optional [Dict [str , List ]] = None ,
111
118
problem_type : Optional [str ] = None ,
112
119
max_candidates : Optional [int ] = None ,
@@ -176,14 +183,10 @@ def __init__(
176
183
Returns:
177
184
AutoML object.
178
185
"""
179
- self .role = role
180
- self .output_kms_key = output_kms_key
181
186
self .output_path = output_path
182
187
self .base_job_name = base_job_name
183
188
self .compression_type = compression_type
184
- self .volume_kms_key = volume_kms_key
185
189
self .encrypt_inter_container_traffic = encrypt_inter_container_traffic
186
- self .vpc_config = vpc_config
187
190
self .problem_type = problem_type
188
191
self .max_candidate = max_candidates
189
192
self .max_runtime_per_training_job_in_seconds = max_runtime_per_training_job_in_seconds
@@ -204,6 +207,31 @@ def __init__(
204
207
self ._auto_ml_job_desc = None
205
208
self ._best_candidate = None
206
209
self .sagemaker_session = sagemaker_session or Session ()
210
+ self .vpc_config = resolve_value_from_config (
211
+ vpc_config , AUTO_ML_VPC_CONFIG_PATH , sagemaker_session = self .sagemaker_session
212
+ )
213
+ self .volume_kms_key = resolve_value_from_config (
214
+ volume_kms_key , AUTO_ML_VOLUME_KMS_KEY_ID_PATH , sagemaker_session = self .sagemaker_session
215
+ )
216
+ self .output_kms_key = resolve_value_from_config (
217
+ output_kms_key , AUTO_ML_KMS_KEY_ID_PATH , sagemaker_session = self .sagemaker_session
218
+ )
219
+ self .role = resolve_value_from_config (
220
+ role , AUTO_ML_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
221
+ )
222
+ if not self .role :
223
+ # Originally IAM role was a required parameter.
224
+ # Now we marked that as Optional because we can fetch it from SageMakerConfig
225
+ # Because of marking that parameter as optional, we should validate if it is None, even
226
+ # after fetching the config.
227
+ raise ValueError ("An AWS IAM role is required to create an AutoML job." )
228
+
229
+ self .encrypt_inter_container_traffic = resolve_value_from_config (
230
+ direct_input = encrypt_inter_container_traffic ,
231
+ config_path = AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH ,
232
+ default_value = False ,
233
+ sagemaker_session = self .sagemaker_session ,
234
+ )
207
235
208
236
self ._check_problem_type_and_job_objective (self .problem_type , self .job_objective )
209
237
@@ -276,6 +304,8 @@ def attach(cls, auto_ml_job_name, sagemaker_session=None):
276
304
volume_kms_key = auto_ml_job_desc .get ("AutoMLJobConfig" , {})
277
305
.get ("SecurityConfig" , {})
278
306
.get ("VolumeKmsKeyId" ),
307
+ # Do not override encrypt_inter_container_traffic from config because this info
308
+ # is pulled from an existing automl job
279
309
encrypt_inter_container_traffic = auto_ml_job_desc .get ("AutoMLJobConfig" , {})
280
310
.get ("SecurityConfig" , {})
281
311
.get ("EnableInterContainerTrafficEncryption" , False ),
0 commit comments