@@ -429,7 +429,7 @@ class Estimator(EstimatorBase):
429
429
def __init__ (self , image_name , role , train_instance_count , train_instance_type ,
430
430
train_volume_size = 30 , train_max_run = 24 * 60 * 60 , input_mode = 'File' ,
431
431
output_path = None , output_kms_key = None , base_job_name = None , sagemaker_session = None ,
432
- hyperparameters = None ):
432
+ hyperparameters = None , tags = None , subnets = None , security_group_ids = None ):
433
433
"""Initialize an ``Estimator`` instance.
434
434
435
435
Args:
@@ -467,7 +467,8 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
467
467
self .hyperparam_dict = hyperparameters .copy () if hyperparameters else {}
468
468
super (Estimator , self ).__init__ (role , train_instance_count , train_instance_type ,
469
469
train_volume_size , train_max_run , input_mode ,
470
- output_path , output_kms_key , base_job_name , sagemaker_session )
470
+ output_path , output_kms_key , base_job_name , sagemaker_session ,
471
+ tags , subnets , security_group_ids )
471
472
472
473
def train_image (self ):
473
474
"""
0 commit comments