@@ -35,11 +35,11 @@ def __init__(
35
35
self ,
36
36
algorithm_arn ,
37
37
role ,
38
- train_instance_count ,
39
- train_instance_type ,
40
- train_volume_size = 30 ,
41
- train_volume_kms_key = None ,
42
- train_max_run = 24 * 60 * 60 ,
38
+ instance_count ,
39
+ instance_type ,
40
+ volume_size = 30 ,
41
+ volume_kms_key = None ,
42
+ max_run = 24 * 60 * 60 ,
43
43
input_mode = "File" ,
44
44
output_path = None ,
45
45
output_kms_key = None ,
@@ -65,15 +65,15 @@ def __init__(
65
65
access training data and model artifacts. After the endpoint
66
66
is created, the inference code might use the IAM role, if it
67
67
needs to access an AWS resource.
68
- train_instance_count (int): Number of Amazon EC2 instances to
69
- use for training. train_instance_type (str): Type of EC2
68
+ instance_count (int): Number of Amazon EC2 instances to
69
+ use for training. instance_type (str): Type of EC2
70
70
instance to use for training, for example, 'ml.c4.xlarge'.
71
- train_volume_size (int): Size in GB of the EBS volume to use for
71
+ volume_size (int): Size in GB of the EBS volume to use for
72
72
storing input data during training (default: 30). Must be large enough to store
73
73
training data if File Mode is used (which is the default).
74
- train_volume_kms_key (str): Optional. KMS key ID for encrypting EBS volume attached
74
+ volume_kms_key (str): Optional. KMS key ID for encrypting EBS volume attached
75
75
to the training instance (default: None).
76
- train_max_run (int): Timeout in seconds for training (default: 24 * 60 * 60).
76
+ max_run (int): Timeout in seconds for training (default: 24 * 60 * 60).
77
77
After this amount of time Amazon SageMaker terminates the
78
78
job regardless of its current status.
79
79
input_mode (str): The input mode that the algorithm supports
@@ -131,11 +131,11 @@ def __init__(
131
131
self .algorithm_arn = algorithm_arn
132
132
super (AlgorithmEstimator , self ).__init__ (
133
133
role ,
134
- train_instance_count ,
135
- train_instance_type ,
136
- train_volume_size ,
137
- train_volume_kms_key ,
138
- train_max_run ,
134
+ instance_count ,
135
+ instance_type ,
136
+ volume_size ,
137
+ volume_kms_key ,
138
+ max_run ,
139
139
input_mode ,
140
140
output_path ,
141
141
output_kms_key ,
@@ -167,30 +167,30 @@ def validate_train_spec(self):
167
167
168
168
# Check that the input mode provided is compatible with the training input modes for the
169
169
# algorithm.
170
- train_input_modes = self ._algorithm_training_input_modes (train_spec ["TrainingChannels" ])
171
- if self .input_mode not in train_input_modes :
170
+ input_modes = self ._algorithm_training_input_modes (train_spec ["TrainingChannels" ])
171
+ if self .input_mode not in input_modes :
172
172
raise ValueError (
173
173
"Invalid input mode: %s. %s only supports: %s"
174
- % (self .input_mode , algorithm_name , train_input_modes )
174
+ % (self .input_mode , algorithm_name , input_modes )
175
175
)
176
176
177
177
# Check that the training instance type is compatible with the algorithm.
178
178
supported_instances = train_spec ["SupportedTrainingInstanceTypes" ]
179
- if self .train_instance_type not in supported_instances :
179
+ if self .instance_type not in supported_instances :
180
180
raise ValueError (
181
- "Invalid train_instance_type : %s. %s supports the following instance types: %s"
182
- % (self .train_instance_type , algorithm_name , supported_instances )
181
+ "Invalid instance_type : %s. %s supports the following instance types: %s"
182
+ % (self .instance_type , algorithm_name , supported_instances )
183
183
)
184
184
185
185
# Verify if distributed training is supported by the algorithm
186
186
if (
187
- self .train_instance_count > 1
187
+ self .instance_count > 1
188
188
and "SupportsDistributedTraining" in train_spec
189
189
and not train_spec ["SupportsDistributedTraining" ]
190
190
):
191
191
raise ValueError (
192
192
"Distributed training is not supported by %s. "
193
- "Please set train_instance_count =1" % algorithm_name
193
+ "Please set instance_count =1" % algorithm_name
194
194
)
195
195
196
196
def set_hyperparameters (self , ** kwargs ):
0 commit comments