Skip to content

Commit e136eba

Browse files
committed
Codestyle
1 parent 57f3377 commit e136eba

File tree

3 files changed

+41
-32
lines changed

3 files changed

+41
-32
lines changed

src/sagemaker/image_uris.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"vw",
5555
]
5656

57+
5758
@override_pipeline_parameter_var
5859
def retrieve(
5960
framework,
@@ -199,7 +200,7 @@ def retrieve(
199200
deprecation_warn(
200201
"SageMaker-hosted RL images no longer accept new pull requests and",
201202
"April 2024",
202-
" Please pass in `image_uri` to use RLEstimator"
203+
" Please pass in `image_uri` to use RLEstimator",
203204
)
204205

205206
py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)

src/sagemaker/rl/estimator.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -115,21 +115,21 @@ def __init__(
115115
executing your model training code.
116116
117117
.. warning::
118-
This ``toolkit`` argument discontinued support for new RL users on April 2024. To use
119-
RLEstimator, please pass in ``image_uri``.
118+
This ``toolkit`` argument discontinued support for new RL users on April 2024.
119+
To use RLEstimator, pass in ``image_uri``.
120120
toolkit_version (str): RL toolkit version you want to be use for executing your
121121
model training code.
122122
123123
.. warning::
124-
This ``toolkit_version`` argument discontinued support for new RL users on April 2024.
125-
To use RLEstimator, please pass in ``image_uri``.
124+
This ``toolkit_version`` argument discontinued support for new RL users on
125+
April 2024. To use RLEstimator, pass in ``image_uri``.
126126
framework (sagemaker.rl.RLFramework): Framework (MXNet or
127127
TensorFlow) you want to be used as a toolkit backed for
128128
reinforcement learning training.
129129
130130
.. warning::
131-
This ``framework`` argument discontinued support for new RL users on April 2024. To
132-
use RLEstimator, please pass in ``image_uri``.
131+
This ``framework`` argument discontinued support for new RL users on April
132+
2024. To use RLEstimator, pass in ``image_uri``.
133133
source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI)
134134
to a directory with any other training source code dependencies aside from
135135
the entry point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -141,11 +141,12 @@ def __init__(
141141
SageMaker. For convenience, this accepts other types for keys
142142
and values.
143143
image_uri (str or PipelineVariable): An ECR url for an image the estimator would use
144-
for training and hosting. Example: 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
144+
for training and hosting.
145+
Example: 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
145146
metric_definitions (list[dict[str, str] or list[dict[str, PipelineVariable]]):
146147
A list of dictionaries that defines the metric(s) used to evaluate the
147-
training jobs. Each dictionary contains two keys: 'Name' for the name of the metric,
148-
and 'Regex' for the regular expression used to extract the
148+
training jobs. Each dictionary contains two keys: 'Name' for the name of the
149+
metric, and 'Regex' for the regular expression used to extract the
149150
metric from the logs. This should be defined only for jobs that
150151
don't use an Amazon algorithm.
151152
**kwargs: Additional kwargs passed to the
@@ -167,11 +168,23 @@ def __init__(
167168
self._validate_images_args(toolkit, toolkit_version, framework, image_uri)
168169

169170
if toolkit:
170-
deprecation_warn("The argument `toolkit`", "April 2024", " Please pass in `image_uri` to use RLEstimator")
171+
deprecation_warn(
172+
"The argument `toolkit`",
173+
"April 2024",
174+
" Pass in `image_uri` to use RLEstimator",
175+
)
171176
if toolkit_version:
172-
deprecation_warn("The argument `toolkit_version`", "April 2024", " Please pass in `image_uri` to use RLEstimator")
177+
deprecation_warn(
178+
"The argument `toolkit_version`",
179+
"April 2024",
180+
" Pass in `image_uri` to use RLEstimator",
181+
)
173182
if framework:
174-
deprecation_warn("The argument `framework`", "April 2024", " Please pass in `image_uri` to use RLEstimator")
183+
deprecation_warn(
184+
"The argument `framework`",
185+
"April 2024",
186+
" Pass in `image_uri` to use RLEstimator",
187+
)
175188

176189
if not image_uri:
177190
self._validate_toolkit_support(toolkit.value, toolkit_version, framework.value)
@@ -260,7 +273,7 @@ def create_model(
260273
base_args["name"] = self._get_or_create_name(kwargs.get("name"))
261274

262275
if not entry_point and (source_dir or dependencies):
263-
raise AttributeError("Please provide an `entry_point`.")
276+
raise AttributeError("Provide an `entry_point`.")
264277

265278
entry_point = entry_point or self._model_entry_point()
266279
source_dir = source_dir or self._model_source_dir()
@@ -291,7 +304,7 @@ def create_model(
291304
framework_version=self.framework_version, py_version=PYTHON_VERSION, **extended_args
292305
)
293306
raise ValueError(
294-
"An unknown RLFramework enum was passed in. framework: {}".format(self.framework)
307+
f"An unknown RLFramework enum was passed in. framework: {self.framework}"
295308
)
296309

297310
def training_image_uri(self):
@@ -349,10 +362,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
349362
toolkit, toolkit_version = cls._toolkit_and_version_from_tag(tag)
350363

351364
if not cls._is_combination_supported(toolkit, toolkit_version, framework):
365+
training_job_name = job_details["TrainingJobName"]
352366
raise ValueError(
353-
"Training job: {} didn't use image for requested framework".format(
354-
job_details["TrainingJobName"]
355-
)
367+
f"Training job: {training_job_name} didn't use image for requested framework"
356368
)
357369

358370
init_params["toolkit"] = RLToolkit(toolkit)
@@ -392,17 +404,15 @@ def _validate_framework_format(cls, framework):
392404
"""Placeholder docstring."""
393405
if framework and framework not in list(RLFramework):
394406
raise ValueError(
395-
"Invalid type: {}, valid RL frameworks types are: {}".format(
396-
framework, list(RLFramework)
397-
)
407+
f"Invalid type: {framework}, valid RL frameworks types are: {list(RLFramework)}"
398408
)
399409

400410
@classmethod
401411
def _validate_toolkit_format(cls, toolkit):
402412
"""Placeholder docstring."""
403413
if toolkit and toolkit not in list(RLToolkit):
404414
raise ValueError(
405-
"Invalid type: {}, valid RL toolkits types are: {}".format(toolkit, list(RLToolkit))
415+
f"Invalid type: {toolkit}, valid RL toolkits types are: {list(RLToolkit)}"
406416
)
407417

408418
@classmethod
@@ -420,10 +430,9 @@ def _validate_images_args(cls, toolkit, toolkit_version, framework, image_uri):
420430
if not framework:
421431
not_found_args.append("framework")
422432
if not_found_args:
433+
not_found_args_joined = "`, `".join(not_found_args)
423434
raise AttributeError(
424-
"Please provide `{}` or `image_uri` parameter.".format(
425-
"`, `".join(not_found_args)
426-
)
435+
f"Provide `{not_found_args_joined}` or `image_uri` parameter."
427436
)
428437
else:
429438
found_args = []
@@ -455,9 +464,8 @@ def _validate_toolkit_support(cls, toolkit, toolkit_version, framework):
455464
"""Placeholder docstring."""
456465
if not cls._is_combination_supported(toolkit, toolkit_version, framework):
457466
raise AttributeError(
458-
"Provided `{}-{}` and `{}` combination is not supported.".format(
459-
toolkit, toolkit_version, framework
460-
)
467+
f"Provided `{toolkit}-{toolkit_version}` and `{framework}` combination is"
468+
" not supported."
461469
)
462470

463471
def _image_framework(self):
@@ -487,7 +495,7 @@ def default_metric_definitions(cls, toolkit):
487495
float_regex = "[-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?" # noqa: W605, E501
488496

489497
return [
490-
{"Name": "episode_reward_mean", "Regex": "episode_reward_mean: (%s)" % float_regex},
491-
{"Name": "episode_reward_max", "Regex": "episode_reward_max: (%s)" % float_regex},
498+
{"Name": "episode_reward_mean", "Regex": f"episode_reward_mean: ({float_regex})"},
499+
{"Name": "episode_reward_max", "Regex": f"episode_reward_max: ({float_regex})"},
492500
]
493-
raise ValueError("An unknown RLToolkit enum was passed in. toolkit: {}".format(toolkit))
501+
raise ValueError(f"An unknown RLToolkit enum was passed in. toolkit: {toolkit}")

tests/unit/test_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ def test_missing_required_parameters(sagemaker_session):
601601
instance_type=INSTANCE_TYPE,
602602
)
603603
assert (
604-
"Please provide `toolkit`, `toolkit_version`, `framework`" + " or `image_uri` parameter."
604+
"Provide `toolkit`, `toolkit_version`, `framework`" + " or `image_uri` parameter."
605605
in str(e.value)
606606
)
607607

0 commit comments

Comments
 (0)