@@ -168,8 +168,10 @@ def __init__(
168
168
max_wait = max_wait ,
169
169
)
170
170
171
- self .algorithm_spec = self .sagemaker_session .sagemaker_client .describe_algorithm (
172
- AlgorithmName = algorithm_arn
171
+ self .algorithm_spec = (
172
+ self .sagemaker_session .sagemaker_client .describe_algorithm (
173
+ AlgorithmName = algorithm_arn
174
+ )
173
175
)
174
176
self .validate_train_spec ()
175
177
self .hyperparameter_definitions = self ._parse_hyperparameters ()
@@ -185,7 +187,9 @@ def validate_train_spec(self):
185
187
186
188
# Check that the input mode provided is compatible with the training input modes for the
187
189
# algorithm.
188
- input_modes = self ._algorithm_training_input_modes (train_spec ["TrainingChannels" ])
190
+ input_modes = self ._algorithm_training_input_modes (
191
+ train_spec ["TrainingChannels" ]
192
+ )
189
193
if self .input_mode not in input_modes :
190
194
raise ValueError (
191
195
"Invalid input mode: %s. %s only supports: %s"
@@ -233,7 +237,9 @@ def training_image_uri(self):
233
237
The fit() method, that does the model training, calls this method to
234
238
find the image to use for model training.
235
239
"""
236
- raise RuntimeError ("training_image_uri is never meant to be called on Algorithm Estimators" )
240
+ raise RuntimeError (
241
+ "training_image_uri is never meant to be called on Algorithm Estimators"
242
+ )
237
243
238
244
def enable_network_isolation (self ):
239
245
"""Return True if this Estimator will need network isolation to run.
@@ -377,7 +383,9 @@ def transformer(
377
383
378
384
tags = tags or self .tags
379
385
else :
380
- raise RuntimeError ("No finished training job found associated with this estimator" )
386
+ raise RuntimeError (
387
+ "No finished training job found associated with this estimator"
388
+ )
381
389
382
390
return Transformer (
383
391
model_name ,
@@ -431,21 +439,29 @@ def _validate_input_channels(self, channels):
431
439
for c in channels :
432
440
if c not in training_channels :
433
441
raise ValueError (
434
- "Unknown input channel: %s is not supported by: %s" % (c , algorithm_name )
442
+ "Unknown input channel: %s is not supported by: %s"
443
+ % (c , algorithm_name )
435
444
)
436
445
437
446
# check for required channels that were not provided
438
447
for name , channel in training_channels .items ():
439
- if name not in channels and "IsRequired" in channel and channel ["IsRequired" ]:
440
- raise ValueError ("Required input channel: %s Was not provided." % (name ))
448
+ if (
449
+ name not in channels
450
+ and "IsRequired" in channel
451
+ and channel ["IsRequired" ]
452
+ ):
453
+ raise ValueError (
454
+ "Required input channel: %s Was not provided." % (name )
455
+ )
441
456
442
457
def _validate_and_cast_hyperparameter (self , name , v ):
443
458
"""Placeholder docstring"""
444
459
algorithm_name = self .algorithm_spec ["AlgorithmName" ]
445
460
446
461
if name not in self .hyperparameter_definitions :
447
462
raise ValueError (
448
- "Invalid hyperparameter: %s is not supported by %s" % (name , algorithm_name )
463
+ "Invalid hyperparameter: %s is not supported by %s"
464
+ % (name , algorithm_name )
449
465
)
450
466
451
467
definition = self .hyperparameter_definitions [name ]
@@ -456,7 +472,9 @@ def _validate_and_cast_hyperparameter(self, name, v):
456
472
457
473
if "range" in definition and not definition ["range" ].is_valid (value ):
458
474
valid_range = definition ["range" ].as_tuning_range (name )
459
- raise ValueError ("Invalid value: %s Supported range: %s" % (value , valid_range ))
475
+ raise ValueError (
476
+ "Invalid value: %s Supported range: %s" % (value , valid_range )
477
+ )
460
478
return value
461
479
462
480
def _validate_and_set_default_hyperparameters (self ):
@@ -544,7 +562,9 @@ def _algorithm_training_input_modes(self, training_channels):
544
562
return current_input_modes
545
563
546
564
@classmethod
547
- def _prepare_init_params_from_job_description (cls , job_details , model_channel_name = None ):
565
+ def _prepare_init_params_from_job_description (
566
+ cls , job_details , model_channel_name = None
567
+ ):
548
568
"""Convert the job description to init params that can be handled by the class constructor.
549
569
550
570
Args:
@@ -556,9 +576,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
556
576
Returns:
557
577
dict: The transformed init_params
558
578
"""
559
- init_params = super (AlgorithmEstimator , cls ). _prepare_init_params_from_job_description (
560
- job_details , model_channel_name
561
- )
579
+ init_params = super (
580
+ AlgorithmEstimator , cls
581
+ ). _prepare_init_params_from_job_description ( job_details , model_channel_name )
562
582
563
583
# This hyperparameter is added by Amazon SageMaker Automatic Model Tuning.
564
584
# It cannot be set through instantiating an estimator.
0 commit comments