Skip to content

Commit 7bdba6e

Browse files
jinyoung-limknikure
authored andcommitted
chore: deprecation warnings for RL
1 parent 55822f7 commit 7bdba6e

File tree

4 files changed

+112
-73
lines changed

4 files changed

+112
-73
lines changed

doc/frameworks/rl/using_rl.rst

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -19,49 +19,56 @@ Training RL models using ``RLEstimator`` is a two-step process:
1919
You should prepare your script in a separate source file than the notebook, terminal session, or source file you're
2020
using to submit the script to SageMaker via an ``RLEstimator``. This will be discussed in further detail below.
2121

22-
Suppose that you already have a training script called ``coach-train.py``.
22+
Suppose that you already have a training script called ``coach-train.py`` and have an RL image in your ECR registry
23+
called ``123123123123.dkr.ecr.us-west-2.amazonaws.com/your-rl-registry:your-cool-image-tag`` in ``us-west-2`` region.
2324
You can then create an ``RLEstimator`` with keyword arguments to point to this script and define how SageMaker runs it:
2425

2526
.. code:: python
2627
2728
from sagemaker.rl import RLEstimator, RLToolkit, RLFramework
2829
29-
rl_estimator = RLEstimator(entry_point='coach-train.py',
30-
toolkit=RLToolkit.COACH,
31-
toolkit_version='0.11.1',
32-
framework=RLFramework.TENSORFLOW,
33-
role='SageMakerRole',
34-
instance_type='ml.p3.2xlarge',
35-
instance_count=1)
30+
# Train my estimator
31+
rl_estimator = RLEstimator(
32+
entry_point='coach-train.py',
33+
image_uri='123123123123.dkr.ecr.us-west-2.amazonaws.com/your-rl-registry:your-cool-image-tag',
34+
role='SageMakerRole',
35+
instance_type='ml.c4.2xlarge',
36+
instance_count=1
37+
)
38+
39+
40+
.. tip::
41+
Refer to `SageMaker RL Docker Containers <#sagemaker-rl-docker-containers>`_ for the more information on how to
42+
build your custom RL image.
3643

3744
After that, you simply tell the estimator to start a training job:
3845

3946
.. code:: python
4047
4148
rl_estimator.fit()
4249
50+
4351
In the following sections, we'll discuss how to prepare a training script for execution on SageMaker
4452
and how to run that script on SageMaker using ``RLEstimator``.
4553

4654

4755
Preparing the RL Training Script
4856
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4957

50-
Your RL training script must be a Python 3.5 compatible source file from MXNet framework or Python 3.6 for TensorFlow.
51-
5258
The training script is very similar to a training script you might run outside of SageMaker, but you
5359
can access useful properties about the training environment through various environment variables, such as
5460

55-
* ``SM_MODEL_DIR``: A string representing the path to the directory to write model artifacts to.
56-
These artifacts are uploaded to S3 for model hosting.
57-
* ``SM_NUM_GPUS``: An integer representing the number of GPUs available to the host.
58-
* ``SM_OUTPUT_DATA_DIR``: A string representing the filesystem path to write output artifacts to. Output artifacts may
59-
include checkpoints, graphs, and other files to save, not including model artifacts. These artifacts are compressed
60-
and uploaded to S3 to the same S3 prefix as the model artifacts.
61+
* ``SM_MODEL_DIR``: A string representing the path to the directory to write model artifacts to.
62+
These artifacts are uploaded to S3 for model hosting.
63+
* ``SM_NUM_GPUS``: An integer representing the number of GPUs available to the host.
64+
* ``SM_OUTPUT_DATA_DIR``: A string representing the filesystem path to write output artifacts to. Output artifacts may
65+
include checkpoints, graphs, and other files to save, not including model artifacts. These artifacts are compressed
66+
and uploaded to S3 to the same S3 prefix as the model artifacts.
6167

6268
For the exhaustive list of available environment variables, see the
6369
`SageMaker Containers documentation <https://github.com/aws/sagemaker-containers#list-of-provided-environment-variables-by-sagemaker-containers>`__.
6470

71+
Note that your RL training script must have the Python version compatible with your custom RL Docker image.
6572

6673
RL Estimators
6774
-------------
@@ -81,28 +88,16 @@ these in the constructor, either positionally or as keyword arguments.
8188
endpoints use this role to access training data and model artifacts.
8289
After the endpoint is created, the inference code might use the IAM
8390
role, if accessing AWS resource.
84-
- ``instance_count`` Number of Amazon EC2 instances to use for
85-
training.
91+
- ``instance_count`` Number of Amazon EC2 instances to use for training.
8692
- ``instance_type`` Type of EC2 instance to use for training, for
8793
example, 'ml.m4.xlarge'.
8894

89-
You must as well include either:
90-
91-
- ``toolkit`` RL toolkit (Ray RLlib or Coach) you want to use for executing your model training code.
92-
93-
- ``toolkit_version`` RL toolkit version you want to be use for executing your model training code.
94-
95-
- ``framework`` Framework (MXNet or TensorFlow) you want to be used as
96-
a toolkit backed for reinforcement learning training.
97-
98-
or provide:
95+
You must also provide:
9996

100-
- ``image_uri`` An alternative Docker image to use for training and
101-
serving. If specified, the estimator will use this image for training and
102-
hosting, instead of selecting the appropriate SageMaker official image based on
103-
framework_version and py_version. Refer to: `SageMaker RL Docker Containers
104-
<#sagemaker-rl-docker-containers>`_ for details on what the Official images support
105-
and where to find the source code to build your custom image.
97+
- ``image_uri`` An alternative Docker image to use for training and serving.
98+
If specified, the estimator will use this image for training and
99+
hosting. Refer to: `SageMaker RL Docker Containers <#sagemaker-rl-docker-containers>`_
100+
for the source code to build your custom RL image.
106101

107102

108103
Optional arguments
@@ -140,10 +135,8 @@ Deploying RL Models
140135
After an RL Estimator has been fit, you can host the newly created model in SageMaker.
141136

142137
After calling ``fit``, you can call ``deploy`` on an ``RLEstimator`` Estimator to create a SageMaker Endpoint.
143-
The Endpoint runs one of the SageMaker-provided model server based on the ``framework`` parameter
144-
specified in the ``RLEstimator`` constructor and hosts the model produced by your training script,
145-
which was run when you called ``fit``. This was the model you saved to ``model_dir``.
146-
In case if ``image_uri`` was specified it would use provided image for the deployment.
138+
The Endpoint runs provided image specified with ``image_uri`` and hosts the model produced by your
139+
training script, which was run when you called ``fit``. This is the model you saved to ``model_dir``.
147140

148141
``deploy`` returns a ``sagemaker.mxnet.MXNetPredictor`` for MXNet or
149142
``sagemaker.tensorflow.TensorFlowPredictor`` for TensorFlow.
@@ -153,19 +146,22 @@ In case if ``image_uri`` was specified it would use provided image for the deplo
153146
.. code:: python
154147
155148
# Train my estimator
156-
rl_estimator = RLEstimator(entry_point='coach-train.py',
157-
toolkit=RLToolkit.COACH,
158-
toolkit_version='0.11.0',
159-
framework=RLFramework.MXNET,
160-
role='SageMakerRole',
161-
instance_type='ml.c4.2xlarge',
162-
instance_count=1)
149+
region = 'us-west-2' # the AWS region of your training job
150+
rl_estimator = RLEstimator(
151+
entry_point='coach-train.py',
152+
image_uri=f'123123123123.dkr.ecr.{region}.amazonaws.com/your-rl-registry:your-cool-image-tag',
153+
role='SageMakerRole',
154+
instance_type='ml.c4.2xlarge',
155+
instance_count=1
156+
)
163157
164158
rl_estimator.fit()
165159
166160
# Deploy my estimator to a SageMaker Endpoint and get a MXNetPredictor
167-
predictor = rl_estimator.deploy(instance_type='ml.m4.xlarge',
168-
initial_instance_count=1)
161+
predictor = rl_estimator.deploy(
162+
instance_type='ml.m4.xlarge',
163+
initial_instance_count=1
164+
)
169165
170166
response = predictor.predict(data)
171167
@@ -193,10 +189,8 @@ attach will block and display log messages from the training job, until the trai
193189

194190
The ``attach`` method accepts the following arguments:
195191

196-
- ``training_job_name:`` The name of the training job to attach
197-
to.
198-
- ``sagemaker_session:`` The Session used
199-
to interact with SageMaker
192+
- ``training_job_name:`` The name of the training job to attach to.
193+
- ``sagemaker_session:`` The Session used to interact with SageMaker
200194

201195
RL Training Examples
202196
--------------------
@@ -212,4 +206,6 @@ These are also available in SageMaker Notebook Instance hosted Jupyter notebooks
212206
SageMaker RL Docker Containers
213207
------------------------------
214208

215-
For more about the Docker images themselves, visit `the SageMaker RL containers repository <https://github.com/aws/sagemaker-rl-container>`_.
209+
For more information about how build your own RL image and use script mode with your image, see
210+
`Building your image section on SageMaker RL containers repository <https://github.com/aws/sagemaker-rl-container?tab=readme-ov-file#building-your-image/>`_
211+
and `Bring your own model with Amazon SageMaker script mode <https://aws.amazon.com/blogs/machine-learning/bring-your-own-model-with-amazon-sagemaker-script-mode/>`_.

src/sagemaker/image_uris.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY,
3232
GRAVITON_ALLOWED_FRAMEWORKS,
3333
)
34+
from sagemaker.deprecations import deprecation_warn
3435

3536
logger = logging.getLogger(__name__)
3637

@@ -45,7 +46,13 @@
4546
DATA_WRANGLER_FRAMEWORK = "data-wrangler"
4647
STABILITYAI_FRAMEWORK = "stabilityai"
4748
SAGEMAKER_TRITONSERVER_FRAMEWORK = "sagemaker-tritonserver"
48-
49+
RL_FRAMEWORKS = [
50+
"coach-tensorflow",
51+
"coach-mxnet",
52+
"ray-tensorflow",
53+
"ray-pytorch",
54+
"vw",
55+
]
4956

5057
@override_pipeline_parameter_var
5158
def retrieve(
@@ -188,6 +195,12 @@ def retrieve(
188195
)
189196
_validate_arg(full_base_framework_version, list(version_config.keys()), "base framework")
190197
version_config = version_config.get(full_base_framework_version)
198+
elif framework in RL_FRAMEWORKS:
199+
deprecation_warn(
200+
"SageMaker-hosted RL images no longer accept new pull requests and",
201+
"April 2024",
202+
" Please pass in `image_uri` to use RLEstimator"
203+
)
191204

192205
py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
193206
version_config = version_config.get(py_version) or version_config

src/sagemaker/rl/estimator.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sagemaker.tensorflow.model import TensorFlowModel
2626
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2727
from sagemaker.workflow.entities import PipelineVariable
28+
from sagemaker.deprecations import removed_function, deprecation_warn
2829

2930
logger = logging.getLogger("sagemaker")
3031

@@ -70,9 +71,9 @@ class RLFramework(enum.Enum):
7071
class RLEstimator(Framework):
7172
"""Handle end-to-end training and deployment of custom RLEstimator code."""
7273

73-
COACH_LATEST_VERSION_TF = "0.11.1"
74-
COACH_LATEST_VERSION_MXNET = "0.11.0"
75-
RAY_LATEST_VERSION = "1.6.0"
74+
COACH_LATEST_VERSION_TF = removed_function("COACH_LATEST_VERSION_TF")
75+
COACH_LATEST_VERSION_MXNET = removed_function("COACH_LATEST_VERSION_MXNET")
76+
RAY_LATEST_VERSION = removed_function("RAY_LATEST_VERSION")
7677

7778
def __init__(
7879
self,
@@ -112,11 +113,20 @@ def __init__(
112113
must point to a file located at the root of ``source_dir``.
113114
toolkit (sagemaker.rl.RLToolkit): RL toolkit you want to use for
114115
executing your model training code.
116+
.. warning::
117+
This ``toolkit`` argument discontinued support for new RL users on April 2024. To use
118+
RLEstimator, please pass in ``image_uri``.
115119
toolkit_version (str): RL toolkit version you want to be use for
116120
executing your model training code.
121+
.. warning::
122+
This ``toolkit_version`` argument discontinued support for new RL users on April 2024.
123+
To use RLEstimator, please pass in ``image_uri``.
117124
framework (sagemaker.rl.RLFramework): Framework (MXNet or
118125
TensorFlow) you want to be used as a toolkit backed for
119126
reinforcement learning training.
127+
.. warning::
128+
This ``framework`` argument discontinued support for new RL users on April 2024. To
129+
use RLEstimator, please pass in ``image_uri``.
120130
source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI)
121131
to a directory with any other training source code dependencies aside from
122132
the entry point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -127,11 +137,8 @@ def __init__(
127137
accessible as a dict[str, str] to the training code on
128138
SageMaker. For convenience, this accepts other types for keys
129139
and values.
130-
image_uri (str or PipelineVariable): An ECR url. If specified, the estimator will use
131-
this image for training and hosting, instead of selecting the
132-
appropriate SageMaker official image based on framework_version
133-
and py_version. Example:
134-
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
140+
image_uri (str or PipelineVariable): An ECR url for an image the estimator would use
141+
for training and hosting. Example: 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
135142
metric_definitions (list[dict[str, str] or list[dict[str, PipelineVariable]]):
136143
A list of dictionaries that defines the metric(s) used to evaluate the
137144
training jobs. Each dictionary contains two keys: 'Name' for the name of the metric,
@@ -141,6 +148,13 @@ def __init__(
141148
**kwargs: Additional kwargs passed to the
142149
:class:`~sagemaker.estimator.Framework` constructor.
143150
151+
.. seealso::
152+
For more information about how build your own RL image and use script mode with
153+
your image, see `Building your image on sagemaker-rl-container
154+
<https://github.com/aws/sagemaker-rl-container?tab=readme-ov-file#building-your-image/>`_
155+
and `Bring your own model with Amazon SageMaker script mode
156+
<https://aws.amazon.com/blogs/machine-learning/bring-your-own-model-with-amazon-sagemaker-script-mode/>`_
157+
144158
.. tip::
145159
146160
You can find additional parameters for initializing this class at
@@ -149,6 +163,13 @@ def __init__(
149163
"""
150164
self._validate_images_args(toolkit, toolkit_version, framework, image_uri)
151165

166+
if toolkit:
167+
deprecation_warn("The argument `toolkit`", "April 2024", " Please pass in `image_uri` to use RLEstimator")
168+
if toolkit_version:
169+
deprecation_warn("The argument `toolkit_version`", "April 2024", " Please pass in `image_uri` to use RLEstimator")
170+
if framework:
171+
deprecation_warn("The argument `framework`", "April 2024", " Please pass in `image_uri` to use RLEstimator")
172+
152173
if not image_uri:
153174
self._validate_toolkit_support(toolkit.value, toolkit_version, framework.value)
154175
self.toolkit = toolkit.value

tests/unit/sagemaker/image_uris/test_rl.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import pytest
16+
from unittest.mock import patch
1617

1718
from sagemaker import image_uris
1819
from tests.unit.sagemaker.image_uris import expected_uris
@@ -45,17 +46,25 @@ def test_rl_image_uris(load_config_and_file_name):
4546
instance_type = INSTANCE_TYPES[processor]
4647
for py_version in py_versions:
4748
for region in ACCOUNTS.keys():
48-
uri = image_uris.retrieve(
49-
framework, region, version=version, instance_type=instance_type
50-
)
49+
with patch("logging.Logger.warning") as mocked_warning_log:
50+
uri = image_uris.retrieve(
51+
framework, region, version=version, instance_type=instance_type
52+
)
5153

52-
expected = expected_uris.framework_uri(
53-
repo,
54-
tag_prefix,
55-
ACCOUNTS[region],
56-
py_version=py_version,
57-
processor=processor,
58-
region=region,
59-
)
54+
expected = expected_uris.framework_uri(
55+
repo,
56+
tag_prefix,
57+
ACCOUNTS[region],
58+
py_version=py_version,
59+
processor=processor,
60+
region=region,
61+
)
6062

61-
assert uri == expected
63+
mocked_warning_log.assert_called_once_with(
64+
"SageMaker-hosted RL images no longer accept new pull requests and will be "
65+
"deprecated on April 2024."
66+
" Please pass in `image_uri` to use RLEstimator in sagemaker>=2.\n"
67+
"See: https://sagemaker.readthedocs.io/en/stable/v2.html for details."
68+
)
69+
mocked_warning_log.reset_mock()
70+
assert uri == expected

0 commit comments

Comments
 (0)