@@ -115,21 +115,21 @@ def __init__(
115
115
executing your model training code.
116
116
117
117
.. warning::
118
- This ``toolkit`` argument discontinued support for new RL users on April 2024. To use
119
- RLEstimator, please pass in ``image_uri``.
118
+ This ``toolkit`` argument discontinued support for new RL users on April 2024.
119
+ To use RLEstimator, pass in ``image_uri``.
120
120
toolkit_version (str): RL toolkit version you want to be use for executing your
121
121
model training code.
122
122
123
123
.. warning::
124
- This ``toolkit_version`` argument discontinued support for new RL users on April 2024.
125
- To use RLEstimator, please pass in ``image_uri``.
124
+ This ``toolkit_version`` argument discontinued support for new RL users on
125
+ April 2024. To use RLEstimator, pass in ``image_uri``.
126
126
framework (sagemaker.rl.RLFramework): Framework (MXNet or
127
127
TensorFlow) you want to be used as a toolkit backed for
128
128
reinforcement learning training.
129
129
130
130
.. warning::
131
- This ``framework`` argument discontinued support for new RL users on April 2024. To
132
- use RLEstimator, please pass in ``image_uri``.
131
+ This ``framework`` argument discontinued support for new RL users on April
132
+ 2024. To use RLEstimator, pass in ``image_uri``.
133
133
source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI)
134
134
to a directory with any other training source code dependencies aside from
135
135
the entry point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -141,11 +141,12 @@ def __init__(
141
141
SageMaker. For convenience, this accepts other types for keys
142
142
and values.
143
143
image_uri (str or PipelineVariable): An ECR url for an image the estimator would use
144
- for training and hosting. Example: 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
144
+ for training and hosting.
145
+ Example: 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
145
146
metric_definitions (list[dict[str, str] or list[dict[str, PipelineVariable]]):
146
147
A list of dictionaries that defines the metric(s) used to evaluate the
147
- training jobs. Each dictionary contains two keys: 'Name' for the name of the metric,
148
- and 'Regex' for the regular expression used to extract the
148
+ training jobs. Each dictionary contains two keys: 'Name' for the name of the
149
+ metric, and 'Regex' for the regular expression used to extract the
149
150
metric from the logs. This should be defined only for jobs that
150
151
don't use an Amazon algorithm.
151
152
**kwargs: Additional kwargs passed to the
@@ -167,11 +168,23 @@ def __init__(
167
168
self ._validate_images_args (toolkit , toolkit_version , framework , image_uri )
168
169
169
170
if toolkit :
170
- deprecation_warn ("The argument `toolkit`" , "April 2024" , " Please pass in `image_uri` to use RLEstimator" )
171
+ deprecation_warn (
172
+ "The argument `toolkit`" ,
173
+ "April 2024" ,
174
+ " Pass in `image_uri` to use RLEstimator" ,
175
+ )
171
176
if toolkit_version :
172
- deprecation_warn ("The argument `toolkit_version`" , "April 2024" , " Please pass in `image_uri` to use RLEstimator" )
177
+ deprecation_warn (
178
+ "The argument `toolkit_version`" ,
179
+ "April 2024" ,
180
+ " Pass in `image_uri` to use RLEstimator" ,
181
+ )
173
182
if framework :
174
- deprecation_warn ("The argument `framework`" , "April 2024" , " Please pass in `image_uri` to use RLEstimator" )
183
+ deprecation_warn (
184
+ "The argument `framework`" ,
185
+ "April 2024" ,
186
+ " Pass in `image_uri` to use RLEstimator" ,
187
+ )
175
188
176
189
if not image_uri :
177
190
self ._validate_toolkit_support (toolkit .value , toolkit_version , framework .value )
@@ -260,7 +273,7 @@ def create_model(
260
273
base_args ["name" ] = self ._get_or_create_name (kwargs .get ("name" ))
261
274
262
275
if not entry_point and (source_dir or dependencies ):
263
- raise AttributeError ("Please provide an `entry_point`." )
276
+ raise AttributeError ("Provide an `entry_point`." )
264
277
265
278
entry_point = entry_point or self ._model_entry_point ()
266
279
source_dir = source_dir or self ._model_source_dir ()
@@ -291,7 +304,7 @@ def create_model(
291
304
framework_version = self .framework_version , py_version = PYTHON_VERSION , ** extended_args
292
305
)
293
306
raise ValueError (
294
- "An unknown RLFramework enum was passed in. framework: {}" . format ( self .framework )
307
+ f "An unknown RLFramework enum was passed in. framework: { self .framework } "
295
308
)
296
309
297
310
def training_image_uri (self ):
@@ -349,10 +362,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
349
362
toolkit , toolkit_version = cls ._toolkit_and_version_from_tag (tag )
350
363
351
364
if not cls ._is_combination_supported (toolkit , toolkit_version , framework ):
365
+ training_job_name = job_details ["TrainingJobName" ]
352
366
raise ValueError (
353
- "Training job: {} didn't use image for requested framework" .format (
354
- job_details ["TrainingJobName" ]
355
- )
367
+ f"Training job: { training_job_name } didn't use image for requested framework"
356
368
)
357
369
358
370
init_params ["toolkit" ] = RLToolkit (toolkit )
@@ -392,17 +404,15 @@ def _validate_framework_format(cls, framework):
392
404
"""Placeholder docstring."""
393
405
if framework and framework not in list (RLFramework ):
394
406
raise ValueError (
395
- "Invalid type: {}, valid RL frameworks types are: {}" .format (
396
- framework , list (RLFramework )
397
- )
407
+ f"Invalid type: { framework } , valid RL frameworks types are: { list (RLFramework )} "
398
408
)
399
409
400
410
@classmethod
401
411
def _validate_toolkit_format (cls , toolkit ):
402
412
"""Placeholder docstring."""
403
413
if toolkit and toolkit not in list (RLToolkit ):
404
414
raise ValueError (
405
- "Invalid type: {}, valid RL toolkits types are: {}" . format ( toolkit , list (RLToolkit ))
415
+ f "Invalid type: { toolkit } , valid RL toolkits types are: { list (RLToolkit )} "
406
416
)
407
417
408
418
@classmethod
@@ -420,10 +430,9 @@ def _validate_images_args(cls, toolkit, toolkit_version, framework, image_uri):
420
430
if not framework :
421
431
not_found_args .append ("framework" )
422
432
if not_found_args :
433
+ not_found_args_joined = "`, `" .join (not_found_args )
423
434
raise AttributeError (
424
- "Please provide `{}` or `image_uri` parameter." .format (
425
- "`, `" .join (not_found_args )
426
- )
435
+ f"Provide `{ not_found_args_joined } ` or `image_uri` parameter."
427
436
)
428
437
else :
429
438
found_args = []
@@ -455,9 +464,8 @@ def _validate_toolkit_support(cls, toolkit, toolkit_version, framework):
455
464
"""Placeholder docstring."""
456
465
if not cls ._is_combination_supported (toolkit , toolkit_version , framework ):
457
466
raise AttributeError (
458
- "Provided `{}-{}` and `{}` combination is not supported." .format (
459
- toolkit , toolkit_version , framework
460
- )
467
+ f"Provided `{ toolkit } -{ toolkit_version } ` and `{ framework } ` combination is"
468
+ " not supported."
461
469
)
462
470
463
471
def _image_framework (self ):
@@ -487,7 +495,7 @@ def default_metric_definitions(cls, toolkit):
487
495
float_regex = "[-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?" # noqa: W605, E501
488
496
489
497
return [
490
- {"Name" : "episode_reward_mean" , "Regex" : "episode_reward_mean: (%s)" % float_regex },
491
- {"Name" : "episode_reward_max" , "Regex" : "episode_reward_max: (%s)" % float_regex },
498
+ {"Name" : "episode_reward_mean" , "Regex" : f "episode_reward_mean: ({ float_regex } )" },
499
+ {"Name" : "episode_reward_max" , "Regex" : f "episode_reward_max: ({ float_regex } )" },
492
500
]
493
- raise ValueError ("An unknown RLToolkit enum was passed in. toolkit: {}" . format ( toolkit ) )
501
+ raise ValueError (f "An unknown RLToolkit enum was passed in. toolkit: { toolkit } " )
0 commit comments