11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
13
"""Placeholder docstring"""
14
+ # commend to retrigger pipeeline
14
15
from __future__ import absolute_import
15
16
16
17
import logging
@@ -153,14 +154,10 @@ def __init__(
153
154
self ._validate_args (image_uri = image_uri )
154
155
155
156
if distribution is not None :
156
- instance_type = renamed_kwargs (
157
- "train_instance_type" , "instance_type" , kwargs .get ("instance_type" ), kwargs
158
- )
157
+ instance_type = renamed_kwargs ("train_instance_type" , "instance_type" , kwargs .get ("instance_type" ), kwargs )
159
158
160
159
base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
161
- base_framework_version = (
162
- tensorflow_version if tensorflow_version is not None else pytorch_version
163
- )
160
+ base_framework_version = tensorflow_version if tensorflow_version is not None else pytorch_version
164
161
165
162
validate_smdistributed (
166
163
instance_type = instance_type ,
@@ -171,18 +168,14 @@ def __init__(
171
168
image_uri = image_uri ,
172
169
)
173
170
174
- warn_if_parameter_server_with_multi_gpu (
175
- training_instance_type = instance_type , distribution = distribution
176
- )
171
+ warn_if_parameter_server_with_multi_gpu (training_instance_type = instance_type , distribution = distribution )
177
172
178
173
if "enable_sagemaker_metrics" not in kwargs :
179
174
kwargs ["enable_sagemaker_metrics" ] = True
180
175
181
176
kwargs ["py_version" ] = self .py_version
182
177
183
- super (HuggingFace , self ).__init__ (
184
- entry_point , source_dir , hyperparameters , image_uri = image_uri , ** kwargs
185
- )
178
+ super (HuggingFace , self ).__init__ (entry_point , source_dir , hyperparameters , image_uri = image_uri , ** kwargs )
186
179
self .distribution = distribution or {}
187
180
188
181
def _validate_args (self , image_uri ):
@@ -220,9 +213,7 @@ def _validate_args(self, image_uri):
220
213
def hyperparameters (self ):
221
214
"""Return hyperparameters used by your custom PyTorch code during model training."""
222
215
hyperparameters = super (HuggingFace , self ).hyperparameters ()
223
- additional_hyperparameters = self ._distribution_configuration (
224
- distribution = self .distribution
225
- )
216
+ additional_hyperparameters = self ._distribution_configuration (distribution = self .distribution )
226
217
hyperparameters .update (Framework ._json_encode_hyperparameters (additional_hyperparameters ))
227
218
return hyperparameters
228
219
@@ -332,9 +323,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
332
323
333
324
if framework != cls ._framework_name :
334
325
raise ValueError (
335
- "Training job: {} didn't use image for requested framework" .format (
336
- job_details ["TrainingJobName" ]
337
- )
326
+ "Training job: {} didn't use image for requested framework" .format (job_details ["TrainingJobName" ])
338
327
)
339
328
340
329
return init_params
0 commit comments